diff --git a/demos/grid_xy.py b/demos/grid_xy.py new file mode 100644 index 0000000000000000000000000000000000000000..85b5c980aee4663735bc5e5ba5186f3fc381aff4 --- /dev/null +++ b/demos/grid_xy.py @@ -0,0 +1,91 @@ +# %% Import libraries +import matplotlib.pyplot as plt +import torch +import supr +from supr.utils import drawnow +from scipy.stats import norm +from math import sqrt +import numpy as np + +# %% Default settings +plt.ion() + +# %% Dataset +Nc = 10 +K = 3 +N = K**2*Nc +sigma_x = torch.tensor([[[.02], [.04], [.08]], [[.02], [.04], [.08]],[[.02], [.04], [.08]]]) +sigma_y = torch.tensor([[[.02], [.04], [.08]], [[.02], [.04], [.08]],[[.02], [.04], [.08]]]) +mu_x = torch.linspace(0, 1, K) +mu_y = torch.linspace(0, 1, K) +x, y = torch.meshgrid((mu_x, mu_y)) +x = torch.flatten(x[:,:,None] + torch.randn((K, K, Nc))*sigma_x) +y = torch.flatten(y[:,:,None] + torch.randn((K, K, Nc))*sigma_y) + +X = torch.stack((x, y), dim=1) + +# %% Grid to evaluate predictive distribution +x_res, y_res = 400, 400 +x_min, x_max = -.25, 1.25 +y_min, y_max = -.25, 1.25 +x_grid = torch.linspace(x_min, x_max, x_res) +y_grid = torch.linspace(y_min, y_max, y_res) +XY_grid = torch.stack([x.flatten() for x in torch.meshgrid( + x_grid, y_grid, indexing='ij')], dim=1) +X_grid = torch.stack([x_grid, torch.zeros(x_res)]).T + +# %% Sum-product network +# Parameters +tracks = 1 +variables = 2 +channels = 3 + +# Priors for variance of x and y +alpha0 = torch.tensor([[[1], [1]]]) +beta0 = torch.tensor([[[.01], [.01]]]) + +# Construct SPN model +model = supr.Sequential( + supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., + nu0=0, alpha0=alpha0, beta0=beta0), + supr.Einsum(tracks, variables, channels, 1), + supr.Weightsum(tracks, variables, 1) +) + +# %% Fit model and display results +epochs = 20 + +for epoch in range(epochs): + model.train() + model[0].marginalize = torch.zeros(variables, dtype=torch.bool) + logp = model(X).sum() + print(f"Log-posterior ∝ {logp:.2f} ") + model.zero_grad(True) + logp.backward() + + model.eval() # swap? + model.em_batch_update() + + # Plot data and model + # ------------------------------------------------------------------------- + # Evaluate joint distribution on grid + with torch.no_grad(): + log_p_xy = model(XY_grid) + p_xy = torch.exp(log_p_xy).reshape(x_res, y_res) + + # Plot posterior + plt.figure(1).clf() + plt.title('Posterior distribution') + dx = (x_max-x_min)/x_res/2 + dy = (y_max-y_min)/y_res/2 + extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy] + plt.imshow(torch.log(p_xy).T, extent=extent, + aspect='auto', origin='lower', + vmin=-4, vmax=1, cmap='Blues') + plt.plot(x, y, '.', color='tab:orange', alpha=.5, + markersize=15, markeredgewidth=0) + plt.axis('square') + plt.xlim([x_min, x_max]) + plt.ylim([y_min, y_max]) + drawnow() +