Skip to content
Snippets Groups Projects
Commit 0c2aae1b authored by mnsc's avatar mnsc
Browse files

added mean/var, updated regression, major other changes

parent 82a57cd2
No related branches found
No related tags found
No related merge requests found
...@@ -4,38 +4,51 @@ import torch ...@@ -4,38 +4,51 @@ import torch
import supr import supr
from supr.utils import drawnow from supr.utils import drawnow
from scipy.stats import norm from scipy.stats import norm
from math import sqrt
import numpy as np
# %% Dataset # %% Dataset
N = 200 N = 100
x = torch.linspace(0, 1, N) x = torch.linspace(0, 1, N)
y = 1 - 2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1 y = 1 - 2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
x[x > 0.5] += 0.25 x[x > 0.5] += 0.25
x[x < 0.5] -= 0.25 x[x < 0.5] -= 0.25
x[0] = -1. x[0] = -1.
y[0] = 0 y[0] = -0.5
X = torch.stack((x, y), dim=1) X = torch.stack((x, y), dim=1)
# %% Grid to evaluate predictive distribution # %% Grid to evaluate predictive distribution
x_grid = torch.linspace(-2, 2, 200) x_res, y_res = 400, 500
y_grid = torch.linspace(-2, 2, 200) x_min, x_max = -2, 2
X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1) y_min, y_max = -2, 2
x_grid = torch.linspace(x_min, x_max, x_res)
y_grid = torch.linspace(y_min, y_max, y_res)
XY_grid = torch.stack([x.flatten() for x in torch.meshgrid(
x_grid, y_grid, indexing='ij')], dim=1)
X_grid = torch.stack([x_grid, torch.zeros(x_res)]).T
# %% Sum-product network # %% Sum-product network
# Parameters
tracks = 1 tracks = 1
variables = 2 variables = 2
channels = 50 channels = 50
# Priors for variance of x and y # Priors for variance of x and y
alpha0 = torch.tensor([[[1], [1]]]) alpha0 = torch.tensor([[[1], [1]]])
beta0 = torch.tensor([[[.05], [0.01]]]) beta0 = torch.tensor([[[.01], [.01]]])
# Construct SPN model
model = supr.Sequential( model = supr.Sequential(
supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0), supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0.,
nu0=0, alpha0=alpha0, beta0=beta0),
supr.Weightsum(tracks, variables, channels) supr.Weightsum(tracks, variables, channels)
) )
# Marginalization query
marginalize_y = torch.tensor([False, True])
# %% Fit model and display results # %% Fit model and display results
epochs = 20 epochs = 20
...@@ -44,28 +57,74 @@ for epoch in range(epochs): ...@@ -44,28 +57,74 @@ for epoch in range(epochs):
model[0].marginalize = torch.zeros(variables, dtype=torch.bool) model[0].marginalize = torch.zeros(variables, dtype=torch.bool)
logp = model(X).sum() logp = model(X).sum()
print(f"Log-posterior ∝ {logp:.2f} ") print(f"Log-posterior ∝ {logp:.2f} ")
model.zero_grad(True)
logp.backward() logp.backward()
with torch.no_grad(): model.eval() # swap?
model.eval()
model.em_batch_update() model.em_batch_update()
model.zero_grad(True)
p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) # Plot data and model
# -------------------------------------------------------------------------
# Evaluate joint distribution on grid
with torch.no_grad():
log_p_xy = model(XY_grid)
p_xy = torch.exp(log_p_xy).reshape(x_res, y_res)
model[0].marginalize = torch.tensor([False, True]) # Evaluate marginal distribution on x-grid
p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) log_p_x = model(X_grid, marginalize=marginalize_y)
p_x = torch.exp(log_p_x)
model.zero_grad(True)
log_p_x.sum().backward()
with torch.no_grad():
# Define prior conditional p(y|x)
Ndx = 1 Ndx = 1
p_prior = norm(0, 0.5).pdf(y_grid)[:, None] sig_prior = 1
p_y = norm(0, sqrt(sig_prior)).pdf(y_grid)
# Compute normal approximation
m_pred = (N*(model.mean())[:, 1]*p_x + Ndx*0)/(N*p_x+Ndx)
v_pred = (N*p_x*(model.var()[:, 1]+model.mean()[:, 1]
** 2) + Ndx*sig_prior)/(N*p_x+Ndx) - m_pred**2
std_pred = torch.sqrt(v_pred)
# Compute predictive distribution
p_predictive = (N*p_xy + Ndx*p_y[None, :]) / (N*p_x[:, None] + Ndx)
p_predictive = (N*p_xy + Ndx*p_prior)/(N*p_x+Ndx) # Compute 95% highest-posterior region
hpr = torch.ones((x_res, y_res), dtype=torch.bool)
for k in range(x_res):
p_sorted = -np.sort(-(p_predictive[k] * np.gradient(y_grid)))
i = np.searchsorted(np.cumsum(p_sorted), 0.95)
idx = (p_predictive[k]*np.gradient(y_grid)) < p_sorted[i]
hpr[k, idx] = False
# Plot posterior
plt.figure(1).clf() plt.figure(1).clf()
dx = (x_grid[1]-x_grid[0])/2. plt.title('Posterior distribution')
dy = (y_grid[1]-y_grid[0])/2. dx = (x_max-x_min)/x_res/2
dy = (y_max-y_min)/y_res/2
extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy] extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy]
plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-4, vmax=1) plt.imshow(torch.log(p_predictive).T, extent=extent,
plt.plot(x, y, '.', color='tab:orange', alpha=.5, markersize=4, markeredgewidth=0) aspect='auto', origin='lower',
vmin=-4, vmax=1, cmap='Blues')
plt.contour(hpr.T, levels=1, extent=extent)
plt.plot(x, y, '.', color='tab:orange', alpha=.5,
markersize=15, markeredgewidth=0)
plt.axis('square')
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
drawnow()
# Plot normal approximation to posterior
plt.figure(2).clf()
plt.title('Posterior Normal approximation')
plt.plot(x, y, '.', color='tab:orange', alpha=.5,
markersize=15, markeredgewidth=0)
plt.plot(x_grid, m_pred, color='tab:orange')
plt.fill_between(x_grid, m_pred+1.96*std_pred, m_pred -
1.96*std_pred, color='tab:orange', alpha=0.1)
plt.axis('square') plt.axis('square')
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
drawnow() drawnow()
...@@ -50,12 +50,37 @@ class Sequential(nn.Sequential): ...@@ -50,12 +50,37 @@ class Sequential(nn.Sequential):
module.em_batch() module.em_batch()
module.em_update() module.em_update()
def em_batch(self):
with torch.no_grad():
for module in self:
module.em_batch()
def em_update(self):
with torch.no_grad():
for module in self:
module.em_update()
def sample(self): def sample(self):
value = [] value = []
for module in reversed(self): for module in reversed(self):
value = module.sample(*value) value = module.sample(*value)
return value return value
def mean(self):
return self[0].mean()
def var(self):
return self[0].var()
def forward(self, value, marginalize=None):
for module in self:
if isinstance(module, SuprLeaf):
value = module(value, marginalize=marginalize)
else:
value = module(value)
return value
class Parallel(SuprLayer): class Parallel(SuprLayer):
def __init__(self, nets: List[SuprLayer]): def __init__(self, nets: List[SuprLayer]):
...@@ -257,8 +282,11 @@ class TrackSum(ProductSumLayer): ...@@ -257,8 +282,11 @@ class TrackSum(ProductSumLayer):
y = y[:, None] y = y[:, None]
return y return y
class SuprLeaf(SuprLayer):
def __init__(self):
super().__init__()
class NormalLeaf(SuprLayer): class NormalLeaf(SuprLeaf):
""" NormalLeaf layer """ """ NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0., def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0.,
...@@ -321,9 +349,18 @@ class NormalLeaf(SuprLayer): ...@@ -321,9 +349,18 @@ class NormalLeaf(SuprLayer):
r[~self.marginalize] = self.x[0][~self.marginalize] r[~self.marginalize] = self.x[0][~self.marginalize]
return r return r
def forward(self, x: torch.Tensor): def mean(self):
return (torch.clamp(self.z.grad, self.epsilon) * self.mu).sum([1, 3])
def var(self):
return (torch.clamp(self.z.grad, self.epsilon) * (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2
def forward(self, x: torch.Tensor, marginalize=None):
# Get shape # Get shape
batch_size = x.shape[0] batch_size = x.shape[0]
# Marginalize variables
if marginalize is not None:
self.marginalize = marginalize
# Store the data # Store the data
self.x = x self.x = x
# Compute the probability # Compute the probability
...@@ -339,7 +376,7 @@ class NormalLeaf(SuprLayer): ...@@ -339,7 +376,7 @@ class NormalLeaf(SuprLayer):
return self.z return self.z
class BernoulliLeaf(SuprLayer): class BernoulliLeaf(SuprLeaf):
""" BernoulliLeaf layer """ """ BernoulliLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
...@@ -402,7 +439,7 @@ class BernoulliLeaf(SuprLayer): ...@@ -402,7 +439,7 @@ class BernoulliLeaf(SuprLayer):
return self.z return self.z
class CategoricalLeaf(SuprLayer): class CategoricalLeaf(SuprLeaf):
""" CategoricalLeaf layer """ """ CategoricalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1, def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment