Skip to content
Snippets Groups Projects
Commit 82a57cd2 authored by mnsc's avatar mnsc
Browse files

minor adjustments

parent 02e1103c
Branches
No related tags found
No related merge requests found
...@@ -6,27 +6,30 @@ from supr.utils import drawnow ...@@ -6,27 +6,30 @@ from supr.utils import drawnow
from scipy.stats import norm from scipy.stats import norm
# %% Dataset # %% Dataset
N = 100 N = 200
x = torch.linspace(0, 1, N) x = torch.linspace(0, 1, N)
y = -1*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1 y = 1 -2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
x[x > 0.5] += 0.25 x[x > 0.5] += 0.25
x[x < 0.5] -= 0.25 x[x < 0.5] -= 0.25
x[0] = -1.
y[0] = 0
X = torch.stack((x, y), dim=1) X = torch.stack((x, y), dim=1)
# %% Grid to evaluate predictive distribution # %% Grid to evaluate predictive distribution
x_grid = torch.linspace(-1, 2, 100) x_grid = torch.linspace(-2, 2, 200)
y_grid = torch.linspace(-2, 2, 100) y_grid = torch.linspace(-2, 2, 200)
X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1) X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1)
# %% Sum-product network # %% Sum-product network
tracks = 1 tracks = 1
variables = 2 variables = 2
channels = 20 channels = 50
# Priors for variance of x and y # Priors for variance of x and y
alpha0 = torch.tensor([[[0.5], [0.1]]]) alpha0 = torch.tensor([[[1], [1]]])
beta0 = torch.tensor([[[0.5], [0.1]]]) beta0 = torch.tensor([[[.05], [0.01]]])
model = supr.Sequential( model = supr.Sequential(
supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0), supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0),
...@@ -53,14 +56,16 @@ for epoch in range(epochs): ...@@ -53,14 +56,16 @@ for epoch in range(epochs):
model[0].marginalize = torch.tensor([False, True]) model[0].marginalize = torch.tensor([False, True])
p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T) p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
Ndx = 1
p_prior = norm(0, 0.5).pdf(y_grid)[:, None] p_prior = norm(0, 0.5).pdf(y_grid)[:, None]
p_predictive = (N*p_xy + p_prior)/(N*p_x+1) p_predictive = (N*p_xy + Ndx*p_prior)/(N*p_x+Ndx)
plt.figure(1).clf() plt.figure(1).clf()
dx = (x_grid[1]-x_grid[0])/2. dx = (x_grid[1]-x_grid[0])/2.
dy = (y_grid[1]-y_grid[0])/2. dy = (y_grid[1]-y_grid[0])/2.
extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy] extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy]
plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1) plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-4, vmax=1)
plt.plot(x, y, '.') plt.plot(x, y, '.', color='tab:orange', alpha=.5, markersize=4, markeredgewidth=0)
plt.axis('square')
drawnow() drawnow()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment