diff --git a/demos/regression.py b/demos/regression.py index 7f61499b251d8efad0b9db45aab69f684e98677e..3d1562b1c44c588c6853b369e186807fdaf2ca45 100644 --- a/demos/regression.py +++ b/demos/regression.py @@ -4,38 +4,51 @@ import torch import supr from supr.utils import drawnow from scipy.stats import norm +from math import sqrt +import numpy as np # %% Dataset -N = 200 +N = 100 x = torch.linspace(0, 1, N) -y = 1 -2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1 +y = 1 - 2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1 x[x > 0.5] += 0.25 x[x < 0.5] -= 0.25 x[0] = -1. -y[0] = 0 +y[0] = -0.5 X = torch.stack((x, y), dim=1) # %% Grid to evaluate predictive distribution -x_grid = torch.linspace(-2, 2, 200) -y_grid = torch.linspace(-2, 2, 200) -X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1) +x_res, y_res = 400, 500 +x_min, x_max = -2, 2 +y_min, y_max = -2, 2 +x_grid = torch.linspace(x_min, x_max, x_res) +y_grid = torch.linspace(y_min, y_max, y_res) +XY_grid = torch.stack([x.flatten() for x in torch.meshgrid( + x_grid, y_grid, indexing='ij')], dim=1) +X_grid = torch.stack([x_grid, torch.zeros(x_res)]).T # %% Sum-product network +# Parameters tracks = 1 variables = 2 channels = 50 # Priors for variance of x and y alpha0 = torch.tensor([[[1], [1]]]) -beta0 = torch.tensor([[[.05], [0.01]]]) +beta0 = torch.tensor([[[.01], [.01]]]) +# Construct SPN model model = supr.Sequential( - supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0), + supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., + nu0=0, alpha0=alpha0, beta0=beta0), supr.Weightsum(tracks, variables, channels) ) +# Marginalization query +marginalize_y = torch.tensor([False, True]) + # %% Fit model and display results epochs = 20 @@ -44,28 +57,74 @@ for epoch in range(epochs): model[0].marginalize = torch.zeros(variables, dtype=torch.bool) logp = model(X).sum() print(f"Log-posterior ∝ {logp:.2f} ") + model.zero_grad(True) logp.backward() - with torch.no_grad(): - model.eval() - model.em_batch_update() - model.zero_grad(True) + model.eval() # swap? + model.em_batch_update() - p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) + # Plot data and model + # ------------------------------------------------------------------------- + # Evaluate joint distribution on grid + with torch.no_grad(): + log_p_xy = model(XY_grid) + p_xy = torch.exp(log_p_xy).reshape(x_res, y_res) - model[0].marginalize = torch.tensor([False, True]) - p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) + # Evaluate marginal distribution on x-grid + log_p_x = model(X_grid, marginalize=marginalize_y) + p_x = torch.exp(log_p_x) + model.zero_grad(True) + log_p_x.sum().backward() + with torch.no_grad(): + # Define prior conditional p(y|x) Ndx = 1 - p_prior = norm(0, 0.5).pdf(y_grid)[:, None] + sig_prior = 1 + p_y = norm(0, sqrt(sig_prior)).pdf(y_grid) + + # Compute normal approximation + m_pred = (N*(model.mean())[:, 1]*p_x + Ndx*0)/(N*p_x+Ndx) + v_pred = (N*p_x*(model.var()[:, 1]+model.mean()[:, 1] + ** 2) + Ndx*sig_prior)/(N*p_x+Ndx) - m_pred**2 + std_pred = torch.sqrt(v_pred) - p_predictive = (N*p_xy + Ndx*p_prior)/(N*p_x+Ndx) + # Compute predictive distribution + p_predictive = (N*p_xy + Ndx*p_y[None, :]) / (N*p_x[:, None] + Ndx) + # Compute 95% highest-posterior region + hpr = torch.ones((x_res, y_res), dtype=torch.bool) + for k in range(x_res): + p_sorted = -np.sort(-(p_predictive[k] * np.gradient(y_grid))) + i = np.searchsorted(np.cumsum(p_sorted), 0.95) + idx = (p_predictive[k]*np.gradient(y_grid)) < p_sorted[i] + hpr[k, idx] = False + + # Plot posterior plt.figure(1).clf() - dx = (x_grid[1]-x_grid[0])/2. - dy = (y_grid[1]-y_grid[0])/2. + plt.title('Posterior distribution') + dx = (x_max-x_min)/x_res/2 + dy = (y_max-y_min)/y_res/2 extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy] - plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-4, vmax=1) - plt.plot(x, y, '.', color='tab:orange', alpha=.5, markersize=4, markeredgewidth=0) + plt.imshow(torch.log(p_predictive).T, extent=extent, + aspect='auto', origin='lower', + vmin=-4, vmax=1, cmap='Blues') + plt.contour(hpr.T, levels=1, extent=extent) + plt.plot(x, y, '.', color='tab:orange', alpha=.5, + markersize=15, markeredgewidth=0) + plt.axis('square') + plt.xlim([x_min, x_max]) + plt.ylim([y_min, y_max]) + drawnow() + + # Plot normal approximation to posterior + plt.figure(2).clf() + plt.title('Posterior Normal approximation') + plt.plot(x, y, '.', color='tab:orange', alpha=.5, + markersize=15, markeredgewidth=0) + plt.plot(x_grid, m_pred, color='tab:orange') + plt.fill_between(x_grid, m_pred+1.96*std_pred, m_pred - + 1.96*std_pred, color='tab:orange', alpha=0.1) plt.axis('square') + plt.xlim([x_min, x_max]) + plt.ylim([y_min, y_max]) drawnow() diff --git a/supr/layers.py b/supr/layers.py index 2a1274811523f3f9a9b127f086dec59988198030..181f8abd3f228918ba9e9e5b165f100ac6ceac3e 100644 --- a/supr/layers.py +++ b/supr/layers.py @@ -50,12 +50,37 @@ class Sequential(nn.Sequential): module.em_batch() module.em_update() + def em_batch(self): + with torch.no_grad(): + for module in self: + module.em_batch() + + def em_update(self): + with torch.no_grad(): + for module in self: + module.em_update() + def sample(self): value = [] for module in reversed(self): value = module.sample(*value) return value + def mean(self): + return self[0].mean() + + def var(self): + return self[0].var() + + def forward(self, value, marginalize=None): + for module in self: + if isinstance(module, SuprLeaf): + value = module(value, marginalize=marginalize) + else: + value = module(value) + return value + + class Parallel(SuprLayer): def __init__(self, nets: List[SuprLayer]): @@ -257,8 +282,11 @@ class TrackSum(ProductSumLayer): y = y[:, None] return y +class SuprLeaf(SuprLayer): + def __init__(self): + super().__init__() -class NormalLeaf(SuprLayer): +class NormalLeaf(SuprLeaf): """ NormalLeaf layer """ def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0., @@ -320,10 +348,19 @@ class NormalLeaf(SuprLayer): torch.clamp(sig_marginalize, self.epsilon)) r[~self.marginalize] = self.x[0][~self.marginalize] return r - - def forward(self, x: torch.Tensor): + + def mean(self): + return (torch.clamp(self.z.grad, self.epsilon) * self.mu).sum([1, 3]) + + def var(self): + return (torch.clamp(self.z.grad, self.epsilon) * (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2 + + def forward(self, x: torch.Tensor, marginalize=None): # Get shape batch_size = x.shape[0] + # Marginalize variables + if marginalize is not None: + self.marginalize = marginalize # Store the data self.x = x # Compute the probability @@ -339,7 +376,7 @@ class NormalLeaf(SuprLayer): return self.z -class BernoulliLeaf(SuprLayer): +class BernoulliLeaf(SuprLeaf): """ BernoulliLeaf layer """ def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, @@ -385,7 +422,7 @@ class BernoulliLeaf(SuprLayer): r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float() r[~self.marginalize] = self.x[0][~self.marginalize] return r - + def forward(self, x: torch.Tensor): # Get shape batch_size = x.shape[0] @@ -402,7 +439,7 @@ class BernoulliLeaf(SuprLayer): return self.z -class CategoricalLeaf(SuprLayer): +class CategoricalLeaf(SuprLeaf): """ CategoricalLeaf layer """ def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1,