Skip to content
Snippets Groups Projects
Select Git revision
  • 82a57cd2cfe15075d20747e5e5d8030259e86afb
  • main default protected
2 results

regression.py

Blame
  • Mikkel N Schmidt's avatar
    mnsc authored
    82a57cd2
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    regression.py 1.99 KiB
    # %% Import libraries
    import matplotlib.pyplot as plt
    import torch
    import supr
    from supr.utils import drawnow
    from scipy.stats import norm
    
    # %% Dataset
    N = 200
    x = torch.linspace(0, 1, N)
    y = 1 -2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
    x[x > 0.5] += 0.25
    x[x < 0.5] -= 0.25
    
    x[0] = -1.
    y[0] = 0
    
    X = torch.stack((x, y), dim=1)
    
    # %% Grid to evaluate predictive distribution
    x_grid = torch.linspace(-2, 2, 200)
    y_grid = torch.linspace(-2, 2, 200)
    X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1)
    
    # %% Sum-product network
    tracks = 1
    variables = 2
    channels = 50
    
    # Priors for variance of x and y
    alpha0 = torch.tensor([[[1], [1]]])
    beta0 = torch.tensor([[[.05], [0.01]]])
    
    model = supr.Sequential(
        supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0),
        supr.Weightsum(tracks, variables, channels)
    )
    
    # %% Fit model and display results
    epochs = 20
    
    for epoch in range(epochs):
        model.train()
        model[0].marginalize = torch.zeros(variables, dtype=torch.bool)
        logp = model(X).sum()
        print(f"Log-posterior ∝ {logp:.2f} ")
        logp.backward()
    
        with torch.no_grad():
            model.eval()
            model.em_batch_update()
            model.zero_grad(True)
    
            p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
    
            model[0].marginalize = torch.tensor([False, True])
            p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
    
            Ndx = 1
            p_prior = norm(0, 0.5).pdf(y_grid)[:, None]
    
            p_predictive = (N*p_xy + Ndx*p_prior)/(N*p_x+Ndx)
    
            plt.figure(1).clf()
            dx = (x_grid[1]-x_grid[0])/2.
            dy = (y_grid[1]-y_grid[0])/2.
            extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy]
            plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-4, vmax=1)
            plt.plot(x, y, '.', color='tab:orange', alpha=.5, markersize=4, markeredgewidth=0)
            plt.axis('square')
            drawnow()