diff --git a/demos/regression.py b/demos/regression.py
index 7f61499b251d8efad0b9db45aab69f684e98677e..3d1562b1c44c588c6853b369e186807fdaf2ca45 100644
--- a/demos/regression.py
+++ b/demos/regression.py
@@ -4,38 +4,51 @@ import torch
 import supr
 from supr.utils import drawnow
 from scipy.stats import norm
+from math import sqrt
+import numpy as np
 
 # %% Dataset
-N = 200
+N = 100
 x = torch.linspace(0, 1, N)
-y = 1 -2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
+y = 1 - 2*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
 x[x > 0.5] += 0.25
 x[x < 0.5] -= 0.25
 
 x[0] = -1.
-y[0] = 0
+y[0] = -0.5
 
 X = torch.stack((x, y), dim=1)
 
 # %% Grid to evaluate predictive distribution
-x_grid = torch.linspace(-2, 2, 200)
-y_grid = torch.linspace(-2, 2, 200)
-X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1)
+x_res, y_res = 400, 500
+x_min, x_max = -2, 2
+y_min, y_max = -2, 2
+x_grid = torch.linspace(x_min, x_max, x_res)
+y_grid = torch.linspace(y_min, y_max, y_res)
+XY_grid = torch.stack([x.flatten() for x in torch.meshgrid(
+    x_grid, y_grid, indexing='ij')], dim=1)
+X_grid = torch.stack([x_grid, torch.zeros(x_res)]).T
 
 # %% Sum-product network
+# Parameters
 tracks = 1
 variables = 2
 channels = 50
 
 # Priors for variance of x and y
 alpha0 = torch.tensor([[[1], [1]]])
-beta0 = torch.tensor([[[.05], [0.01]]])
+beta0 = torch.tensor([[[.01], [.01]]])
 
+# Construct SPN model
 model = supr.Sequential(
-    supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0),
+    supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0.,
+                    nu0=0, alpha0=alpha0, beta0=beta0),
     supr.Weightsum(tracks, variables, channels)
 )
 
+# Marginalization query
+marginalize_y = torch.tensor([False, True])
+
 # %% Fit model and display results
 epochs = 20
 
@@ -44,28 +57,74 @@ for epoch in range(epochs):
     model[0].marginalize = torch.zeros(variables, dtype=torch.bool)
     logp = model(X).sum()
     print(f"Log-posterior ∝ {logp:.2f} ")
+    model.zero_grad(True)
     logp.backward()
 
-    with torch.no_grad():
-        model.eval()
-        model.em_batch_update()
-        model.zero_grad(True)
+    model.eval()  # swap?
+    model.em_batch_update()
 
-        p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
+    # Plot data and model
+    # -------------------------------------------------------------------------
+    # Evaluate joint distribution on grid
+    with torch.no_grad():
+        log_p_xy = model(XY_grid)
+        p_xy = torch.exp(log_p_xy).reshape(x_res, y_res)
 
-        model[0].marginalize = torch.tensor([False, True])
-        p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
+    # Evaluate marginal distribution on x-grid
+    log_p_x = model(X_grid, marginalize=marginalize_y)
+    p_x = torch.exp(log_p_x)
+    model.zero_grad(True)
+    log_p_x.sum().backward()
 
+    with torch.no_grad():
+        # Define prior conditional p(y|x)
         Ndx = 1
-        p_prior = norm(0, 0.5).pdf(y_grid)[:, None]
+        sig_prior = 1
+        p_y = norm(0, sqrt(sig_prior)).pdf(y_grid)
+
+        # Compute normal approximation
+        m_pred = (N*(model.mean())[:, 1]*p_x + Ndx*0)/(N*p_x+Ndx)
+        v_pred = (N*p_x*(model.var()[:, 1]+model.mean()[:, 1]
+                  ** 2) + Ndx*sig_prior)/(N*p_x+Ndx) - m_pred**2
+        std_pred = torch.sqrt(v_pred)
 
-        p_predictive = (N*p_xy + Ndx*p_prior)/(N*p_x+Ndx)
+        # Compute predictive distribution
+        p_predictive = (N*p_xy + Ndx*p_y[None, :]) / (N*p_x[:, None] + Ndx)
 
+        # Compute 95% highest-posterior region
+        hpr = torch.ones((x_res, y_res), dtype=torch.bool)
+        for k in range(x_res):
+            p_sorted = -np.sort(-(p_predictive[k] * np.gradient(y_grid)))
+            i = np.searchsorted(np.cumsum(p_sorted), 0.95)
+            idx = (p_predictive[k]*np.gradient(y_grid)) < p_sorted[i]
+            hpr[k, idx] = False
+
+        # Plot posterior
         plt.figure(1).clf()
-        dx = (x_grid[1]-x_grid[0])/2.
-        dy = (y_grid[1]-y_grid[0])/2.
+        plt.title('Posterior distribution')
+        dx = (x_max-x_min)/x_res/2
+        dy = (y_max-y_min)/y_res/2
         extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy]
-        plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-4, vmax=1)
-        plt.plot(x, y, '.', color='tab:orange', alpha=.5, markersize=4, markeredgewidth=0)
+        plt.imshow(torch.log(p_predictive).T, extent=extent,
+                   aspect='auto', origin='lower',
+                   vmin=-4, vmax=1, cmap='Blues')
+        plt.contour(hpr.T, levels=1, extent=extent)
+        plt.plot(x, y, '.', color='tab:orange', alpha=.5,
+                 markersize=15, markeredgewidth=0)
+        plt.axis('square')
+        plt.xlim([x_min, x_max])
+        plt.ylim([y_min, y_max])
+        drawnow()
+
+        # Plot normal approximation to posterior
+        plt.figure(2).clf()
+        plt.title('Posterior Normal approximation')
+        plt.plot(x, y, '.', color='tab:orange', alpha=.5,
+                 markersize=15, markeredgewidth=0)
+        plt.plot(x_grid, m_pred, color='tab:orange')
+        plt.fill_between(x_grid, m_pred+1.96*std_pred, m_pred -
+                         1.96*std_pred, color='tab:orange', alpha=0.1)
         plt.axis('square')
+        plt.xlim([x_min, x_max])
+        plt.ylim([y_min, y_max])
         drawnow()
diff --git a/supr/layers.py b/supr/layers.py
index 2a1274811523f3f9a9b127f086dec59988198030..181f8abd3f228918ba9e9e5b165f100ac6ceac3e 100644
--- a/supr/layers.py
+++ b/supr/layers.py
@@ -50,12 +50,37 @@ class Sequential(nn.Sequential):
                 module.em_batch()
                 module.em_update()
 
+    def em_batch(self):
+        with torch.no_grad():
+            for module in self:
+                module.em_batch()
+
+    def em_update(self):
+        with torch.no_grad():
+            for module in self:
+                module.em_update()
+                
     def sample(self):
         value = []
         for module in reversed(self):
             value = module.sample(*value)
         return value
 
+    def mean(self):
+        return self[0].mean()
+
+    def var(self):
+        return self[0].var()
+    
+    def forward(self, value, marginalize=None):
+        for module in self:
+            if isinstance(module, SuprLeaf):
+                value = module(value, marginalize=marginalize)
+            else:
+                value = module(value)
+        return value
+
+
 
 class Parallel(SuprLayer):
     def __init__(self, nets: List[SuprLayer]):
@@ -257,8 +282,11 @@ class TrackSum(ProductSumLayer):
         y = y[:, None]
         return y
 
+class SuprLeaf(SuprLayer):
+    def __init__(self):
+        super().__init__()
 
-class NormalLeaf(SuprLayer):
+class NormalLeaf(SuprLeaf):
     """ NormalLeaf layer """
 
     def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0.,
@@ -320,10 +348,19 @@ class NormalLeaf(SuprLayer):
             torch.clamp(sig_marginalize, self.epsilon))
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
-
-    def forward(self, x: torch.Tensor):
+    
+    def mean(self):
+        return (torch.clamp(self.z.grad, self.epsilon) * self.mu).sum([1, 3])
+    
+    def var(self):
+        return (torch.clamp(self.z.grad, self.epsilon) * (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2
+
+    def forward(self, x: torch.Tensor, marginalize=None):
         # Get shape
         batch_size = x.shape[0]
+        # Marginalize variables
+        if marginalize is not None:
+            self.marginalize = marginalize
         # Store the data
         self.x = x
         # Compute the probability
@@ -339,7 +376,7 @@ class NormalLeaf(SuprLayer):
         return self.z
 
 
-class BernoulliLeaf(SuprLayer):
+class BernoulliLeaf(SuprLeaf):
     """ BernoulliLeaf layer """
 
     def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
@@ -385,7 +422,7 @@ class BernoulliLeaf(SuprLayer):
         r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float()
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
-
+    
     def forward(self, x: torch.Tensor):
         # Get shape
         batch_size = x.shape[0]
@@ -402,7 +439,7 @@ class BernoulliLeaf(SuprLayer):
         return self.z
 
 
-class CategoricalLeaf(SuprLayer):
+class CategoricalLeaf(SuprLeaf):
     """ CategoricalLeaf layer """
 
     def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1,