diff --git a/demos/regression.py b/demos/regression.py new file mode 100644 index 0000000000000000000000000000000000000000..48d477c53ac3b488546235682b6f0b91e626316a --- /dev/null +++ b/demos/regression.py @@ -0,0 +1,68 @@ +#%% Import libraries +import numpy as np +import matplotlib.pyplot as plt +import torch +import supr +from tspn.utils import drawnow, arrange_figs +from scipy.stats import norm + +#%% Dataset +N = 100 +x = torch.linspace(0, 1, N) +y = -1*x + (torch.rand(N)>0.5)*(x>0.5) + torch.randn(N)*0.1 +x[x>0.5] += 0.25 +x[x<0.5] -= 0.25 + +X = torch.stack((x,y), dim=1) + +#%% Grid to evaluate predictive distribution +x_grid = torch.linspace(-1, 2, 100) +y_grid = torch.linspace(-2, 2, 100) +X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1) + +#%% Sum-product network +tracks = 1 +variables = 2 +channels = 20 + +# Priors for variance of x and y +alpha0 = torch.tensor([[[0.2], [0.1]]]) +beta0 = torch.tensor([[[0.2], [0.1]]]) + +model = supr.Sequential( + supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0), + supr.Weightsum(tracks, variables, channels) + ) + +#%% Fit model and display results +epochs = 20 + +for epoch in range(epochs): + model.train() + model[0].marginalize = torch.zeros(variables, dtype=torch.bool) + loss = model(X).sum() + print(f"Loss = {loss}") + loss.backward() + + with torch.no_grad(): + model.eval() + model.em_batch_update() + model.zero_grad(True) + + p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) + + model[0].marginalize = torch.tensor([False, True]) + p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) + + p_prior = norm(0, 0.5).pdf(y_grid)[:,None] + + p_predictive = (N*p_xy + p_prior)/(N*p_x+1) + + plt.figure(1).clf() + dx = (x_grid[1]-x_grid[0])/2. + dy = (y_grid[1]-y_grid[0])/2. + extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy] + plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1) + plt.plot(x, y, '.') + drawnow() +