diff --git a/demos/regression.py b/demos/regression.py
index 5253596a1e256e720f5ebd5cbffbc3348a39870c..7f61499b251d8efad0b9db45aab69f684e98677e 100644
--- a/demos/regression.py
+++ b/demos/regression.py
@@ -6,27 +6,30 @@ from supr.utils import drawnow
 from scipy.stats import norm
 
 # %% Dataset
-N = 100
+N = 200
 x = torch.linspace(0, 1, N)
-y = -1*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
+y = 1 -2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
 x[x > 0.5] += 0.25
 x[x < 0.5] -= 0.25
 
+x[0] = -1.
+y[0] = 0
+
 X = torch.stack((x, y), dim=1)
 
 # %% Grid to evaluate predictive distribution
-x_grid = torch.linspace(-1, 2, 100)
-y_grid = torch.linspace(-2, 2, 100)
+x_grid = torch.linspace(-2, 2, 200)
+y_grid = torch.linspace(-2, 2, 200)
 X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1)
 
 # %% Sum-product network
 tracks = 1
 variables = 2
-channels = 20
+channels = 50
 
 # Priors for variance of x and y
-alpha0 = torch.tensor([[[0.5], [0.1]]])
-beta0 = torch.tensor([[[0.5], [0.1]]])
+alpha0 = torch.tensor([[[1], [1]]])
+beta0 = torch.tensor([[[.05], [0.01]]])
 
 model = supr.Sequential(
     supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0),
@@ -53,14 +56,16 @@ for epoch in range(epochs):
         model[0].marginalize = torch.tensor([False, True])
         p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
 
+        Ndx = 1
         p_prior = norm(0, 0.5).pdf(y_grid)[:, None]
 
-        p_predictive = (N*p_xy + p_prior)/(N*p_x+1)
+        p_predictive = (N*p_xy + Ndx*p_prior)/(N*p_x+Ndx)
 
         plt.figure(1).clf()
         dx = (x_grid[1]-x_grid[0])/2.
         dy = (y_grid[1]-y_grid[0])/2.
         extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy]
-        plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1)
-        plt.plot(x, y, '.')
+        plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-4, vmax=1)
+        plt.plot(x, y, '.', color='tab:orange', alpha=.5, markersize=4, markeredgewidth=0)
+        plt.axis('square')
         drawnow()