diff --git a/demos/regression.py b/demos/regression.py index 52ada7178759f8d9a79bc9d3633d6caf4ac53f99..789e3a6dbe36b2c0a91cb82d7c21c2a88740b693 100644 --- a/demos/regression.py +++ b/demos/regression.py @@ -82,7 +82,7 @@ for epoch in range(epochs): with torch.no_grad(): # Define prior conditional p(y|x) Ndx = 1 - sig_prior = 1 + sig_prior = .5 p_y = norm(0, sqrt(sig_prior)).pdf(y_grid) # Compute normal approximation @@ -101,6 +101,17 @@ for epoch in range(epochs): i = np.searchsorted(np.cumsum(p_sorted), 0.95) idx = (p_predictive[k]*np.gradient(y_grid)) < p_sorted[i] hpr[k, idx] = False + + # Generate posterior conditional samples + y_sample = [] + for x_g, p_x_g in zip(X_grid, p_x): + model(x_g[None], marginalize=marginalize_y) + # Sample from either SPN conditional or background/prior + if torch.rand(1) < N*p_x_g / (N*p_x_g + Ndx): + y_sample.append(model.sample()[1]) + else: + y_sample.append(torch.randn(1)*sig_prior) + y_sample = torch.tensor(y_sample) # Plot posterior plt.figure(1).clf() @@ -131,3 +142,16 @@ for epoch in range(epochs): plt.xlim([x_min, x_max]) plt.ylim([y_min, y_max]) drawnow() + + # Plot samples from posterior conditional + plt.figure(3).clf() + plt.title('Posterior conditional samples') + plt.fill_between(x_grid, m_pred+1.96*std_pred, m_pred - + 1.96*std_pred, color='tab:orange', alpha=0.1) + plt.plot(x_grid, y_sample, '.', color='tab:blue', alpha=.5, + markersize=15, markeredgewidth=0) + plt.axis('square') + plt.xlim([x_min, x_max]) + plt.ylim([y_min, y_max]) + drawnow() +