diff --git a/demos/regression.py b/demos/regression.py index b6e9e097bbef7e0001618395489961ad9727f421..5253596a1e256e720f5ebd5cbffbc3348a39870c 100644 --- a/demos/regression.py +++ b/demos/regression.py @@ -1,63 +1,62 @@ -#%% Import libraries -import numpy as np +# %% Import libraries import matplotlib.pyplot as plt import torch import supr from supr.utils import drawnow from scipy.stats import norm -#%% Dataset +# %% Dataset N = 100 x = torch.linspace(0, 1, N) -y = -1*x + (torch.rand(N)>0.5)*(x>0.5) + torch.randn(N)*0.1 -x[x>0.5] += 0.25 -x[x<0.5] -= 0.25 +y = -1*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1 +x[x > 0.5] += 0.25 +x[x < 0.5] -= 0.25 -X = torch.stack((x,y), dim=1) +X = torch.stack((x, y), dim=1) -#%% Grid to evaluate predictive distribution +# %% Grid to evaluate predictive distribution x_grid = torch.linspace(-1, 2, 100) y_grid = torch.linspace(-2, 2, 100) X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1) -#%% Sum-product network +# %% Sum-product network tracks = 1 variables = 2 channels = 20 # Priors for variance of x and y -alpha0 = torch.tensor([[[0.2], [0.1]]]) -beta0 = torch.tensor([[[0.2], [0.1]]]) +alpha0 = torch.tensor([[[0.5], [0.1]]]) +beta0 = torch.tensor([[[0.5], [0.1]]]) model = supr.Sequential( supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0), supr.Weightsum(tracks, variables, channels) - ) +) -#%% Fit model and display results +# %% Fit model and display results epochs = 20 for epoch in range(epochs): model.train() - model[0].marginalize = torch.zeros(variables, dtype=torch.bool) - loss = model(X).sum() - print(f"Loss = {loss}") - loss.backward() - + model[0].marginalize = torch.zeros(variables, dtype=torch.bool) + logp = model(X).sum() + print(f"Log-posterior ∝ {logp:.2f} ") + logp.backward() + with torch.no_grad(): model.eval() model.em_batch_update() model.zero_grad(True) - + p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) model[0].marginalize = torch.tensor([False, True]) p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) - - p_prior = norm(0, 0.5).pdf(y_grid)[:,None] - + + p_prior = norm(0, 0.5).pdf(y_grid)[:, None] + p_predictive = (N*p_xy + p_prior)/(N*p_x+1) - + plt.figure(1).clf() dx = (x_grid[1]-x_grid[0])/2. dy = (y_grid[1]-y_grid[0])/2. @@ -65,4 +64,3 @@ for epoch in range(epochs): plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1) plt.plot(x, y, '.') drawnow() - diff --git a/supr/layers.py b/supr/layers.py index 6fcace92c5404b77fdbc33b2abe0ef5006046570..d1a59958892c1e5d353cbf44c733f638c8e51f75 100644 --- a/supr/layers.py +++ b/supr/layers.py @@ -5,7 +5,6 @@ import math from supr.utils import discrete_rand, local_scramble_2d from typing import List - # Data: # N x V x D # └───│──│─ N: Data points @@ -40,7 +39,7 @@ class SuprLayer(nn.Module): pass class Sequential(nn.Sequential): - def __init__(self, *args): + def __init__(self, *args: object) -> object: super().__init__(*args) def em_batch_update(self): @@ -256,7 +255,7 @@ class NormalLeaf(SuprLayer): """ NormalLeaf layer """ def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, - mu0: float = 0., nu0: float = 0., alpha0: float = 0., beta0: float = 0.): + mu0: torch.tensor = 0., nu0: torch.tensor = 0., torch.tensor: float = 0., beta0: torch.tensor = 0.): super().__init__() # Dimensions self.T, self.V, self.C = tracks, variables, channels