Skip to content
Snippets Groups Projects
Commit 1b0d6980 authored by mnsc's avatar mnsc
Browse files

added sampling to regression demo

parent e3ac4ca9
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,7 @@ for epoch in range(epochs):
with torch.no_grad():
# Define prior conditional p(y|x)
Ndx = 1
sig_prior = 1
sig_prior = .5
p_y = norm(0, sqrt(sig_prior)).pdf(y_grid)
# Compute normal approximation
......@@ -102,6 +102,17 @@ for epoch in range(epochs):
idx = (p_predictive[k]*np.gradient(y_grid)) < p_sorted[i]
hpr[k, idx] = False
# Generate posterior conditional samples
y_sample = []
for x_g, p_x_g in zip(X_grid, p_x):
model(x_g[None], marginalize=marginalize_y)
# Sample from either SPN conditional or background/prior
if torch.rand(1) < N*p_x_g / (N*p_x_g + Ndx):
y_sample.append(model.sample()[1])
else:
y_sample.append(torch.randn(1)*sig_prior)
y_sample = torch.tensor(y_sample)
# Plot posterior
plt.figure(1).clf()
plt.title('Posterior distribution')
......@@ -131,3 +142,16 @@ for epoch in range(epochs):
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
drawnow()
# Plot samples from posterior conditional
plt.figure(3).clf()
plt.title('Posterior conditional samples')
plt.fill_between(x_grid, m_pred+1.96*std_pred, m_pred -
1.96*std_pred, color='tab:orange', alpha=0.1)
plt.plot(x_grid, y_sample, '.', color='tab:blue', alpha=.5,
markersize=15, markeredgewidth=0)
plt.axis('square')
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
drawnow()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment