Skip to content
Snippets Groups Projects
Commit b2616e7e authored by mnsc's avatar mnsc
Browse files

add regression demo

parent 59ae64f5
No related branches found
No related tags found
No related merge requests found
#%% Import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import supr
from tspn.utils import drawnow, arrange_figs
from scipy.stats import norm
#%% Dataset
N = 100
x = torch.linspace(0, 1, N)
y = -1*x + (torch.rand(N)>0.5)*(x>0.5) + torch.randn(N)*0.1
x[x>0.5] += 0.25
x[x<0.5] -= 0.25
X = torch.stack((x,y), dim=1)
#%% Grid to evaluate predictive distribution
x_grid = torch.linspace(-1, 2, 100)
y_grid = torch.linspace(-2, 2, 100)
X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1)
#%% Sum-product network
tracks = 1
variables = 2
channels = 20
# Priors for variance of x and y
alpha0 = torch.tensor([[[0.2], [0.1]]])
beta0 = torch.tensor([[[0.2], [0.1]]])
model = supr.Sequential(
supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0),
supr.Weightsum(tracks, variables, channels)
)
#%% Fit model and display results
epochs = 20
for epoch in range(epochs):
model.train()
model[0].marginalize = torch.zeros(variables, dtype=torch.bool)
loss = model(X).sum()
print(f"Loss = {loss}")
loss.backward()
with torch.no_grad():
model.eval()
model.em_batch_update()
model.zero_grad(True)
p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
model[0].marginalize = torch.tensor([False, True])
p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
p_prior = norm(0, 0.5).pdf(y_grid)[:,None]
p_predictive = (N*p_xy + p_prior)/(N*p_x+1)
plt.figure(1).clf()
dx = (x_grid[1]-x_grid[0])/2.
dy = (y_grid[1]-y_grid[0])/2.
extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy]
plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1)
plt.plot(x, y, '.')
drawnow()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment