diff --git a/demos/regression.py b/demos/regression.py
index b6e9e097bbef7e0001618395489961ad9727f421..5253596a1e256e720f5ebd5cbffbc3348a39870c 100644
--- a/demos/regression.py
+++ b/demos/regression.py
@@ -1,63 +1,62 @@
-#%% Import libraries
-import numpy as np
+# %% Import libraries
 import matplotlib.pyplot as plt
 import torch
 import supr
 from supr.utils import drawnow
 from scipy.stats import norm
 
-#%% Dataset
+# %% Dataset
 N = 100
 x = torch.linspace(0, 1, N)
-y = -1*x + (torch.rand(N)>0.5)*(x>0.5) + torch.randn(N)*0.1
-x[x>0.5] += 0.25
-x[x<0.5] -= 0.25
+y = -1*x + (torch.rand(N) > 0.5)*(x > 0.5) + torch.randn(N)*0.1
+x[x > 0.5] += 0.25
+x[x < 0.5] -= 0.25
 
-X = torch.stack((x,y), dim=1)
+X = torch.stack((x, y), dim=1)
 
-#%% Grid to evaluate predictive distribution
+# %% Grid to evaluate predictive distribution
 x_grid = torch.linspace(-1, 2, 100)
 y_grid = torch.linspace(-2, 2, 100)
 X_grid = torch.stack([x.flatten() for x in torch.meshgrid(x_grid, y_grid, indexing='ij')], dim=1)
 
-#%% Sum-product network
+# %% Sum-product network
 tracks = 1
 variables = 2
 channels = 20
 
 # Priors for variance of x and y
-alpha0 = torch.tensor([[[0.2], [0.1]]])
-beta0 =  torch.tensor([[[0.2], [0.1]]])
+alpha0 = torch.tensor([[[0.5], [0.1]]])
+beta0 = torch.tensor([[[0.5], [0.1]]])
 
 model = supr.Sequential(
     supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0),
     supr.Weightsum(tracks, variables, channels)
-    )
+)
 
-#%% Fit model and display results
+# %% Fit model and display results
 epochs = 20
 
 for epoch in range(epochs):
     model.train()
-    model[0].marginalize = torch.zeros(variables, dtype=torch.bool)    
-    loss = model(X).sum()
-    print(f"Loss = {loss}")
-    loss.backward()
-    
+    model[0].marginalize = torch.zeros(variables, dtype=torch.bool)
+    logp = model(X).sum()
+    print(f"Log-posterior ∝ {logp:.2f} ")
+    logp.backward()
+
     with torch.no_grad():
         model.eval()
         model.em_batch_update()
         model.zero_grad(True)
-        
+
         p_xy = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
 
         model[0].marginalize = torch.tensor([False, True])
         p_x = torch.exp(model(X_grid).reshape(len(x_grid), len(y_grid)).T)
-        
-        p_prior = norm(0, 0.5).pdf(y_grid)[:,None]
-        
+
+        p_prior = norm(0, 0.5).pdf(y_grid)[:, None]
+
         p_predictive = (N*p_xy + p_prior)/(N*p_x+1)
-    
+
         plt.figure(1).clf()
         dx = (x_grid[1]-x_grid[0])/2.
         dy = (y_grid[1]-y_grid[0])/2.
@@ -65,4 +64,3 @@ for epoch in range(epochs):
         plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1)
         plt.plot(x, y, '.')
         drawnow()
-
diff --git a/supr/layers.py b/supr/layers.py
index 6fcace92c5404b77fdbc33b2abe0ef5006046570..d1a59958892c1e5d353cbf44c733f638c8e51f75 100644
--- a/supr/layers.py
+++ b/supr/layers.py
@@ -5,7 +5,6 @@ import math
 from supr.utils import discrete_rand, local_scramble_2d
 from typing import List
 
-
 # Data:
 # N x V x D
 # └───│──│─ N: Data points
@@ -40,7 +39,7 @@ class SuprLayer(nn.Module):
         pass
     
 class Sequential(nn.Sequential):
-    def __init__(self, *args):
+    def __init__(self, *args: object) -> object:
         super().__init__(*args)
 
     def em_batch_update(self):
@@ -256,7 +255,7 @@ class NormalLeaf(SuprLayer):
     """ NormalLeaf layer """
 
     def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
-                 mu0: float = 0., nu0: float = 0., alpha0: float = 0., beta0: float = 0.):
+                 mu0: torch.tensor = 0., nu0: torch.tensor = 0., torch.tensor: float = 0., beta0: torch.tensor = 0.):
         super().__init__()
         # Dimensions
         self.T, self.V, self.C = tracks, variables, channels