diff --git a/supr/layers.py b/supr/layers.py
index 35938a3d94a8cff5fb5a7dcb10ad925da911946b..dde2fb7f18743d5cec576269c06f8ae8b910ea3f 100644
--- a/supr/layers.py
+++ b/supr/layers.py
@@ -7,9 +7,10 @@ from typing import List
 
 
 # Data:
-# N x V 
-# └───│── N: Data points
-#     └── V: Variables
+# N x V x D
+# └───│──│─ N: Data points
+#     └──│─ V: Variables
+#        └─ D: Dimensions
 #
 # Probability:
 # N x T x V x C
@@ -23,6 +24,7 @@ class Supr(nn.Module):
         super().__init__()
     
     def sample(self):
+        pass
         
 
 class SuprLayer(nn.Module):
@@ -253,10 +255,15 @@ class TrackSum(ProductSumLayer):
 class NormalLeaf(SuprLayer):
     """ NormalLeaf layer """
 
-    def __init__(self, tracks: int, variables: int, channels: int):
+    def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
+                 mu0: float = 0., nu0: float = 0., alpha0: float = 0., beta0: float = 0.):
         super().__init__()
         # Dimensions
         self.T, self.V, self.C = tracks, variables, channels
+        # Number of data points
+        self.n = n
+        # Prior
+        self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0
         # Parametes
         # self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
         # self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
@@ -279,13 +286,16 @@ class NormalLeaf(SuprLayer):
         self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
 
     def em_update(self, learning_rate: float = 1.):
-        # Mean
+        # Sum of weights
         sum_z = torch.clamp(self.z_acc, self.epsilon)
+        # Mean
+        mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
         self.mu.data *= 1. - learning_rate
-        self.mu.data += learning_rate * self.z_x_acc / sum_z
+        self.mu.data += learning_rate * mu_update
         # Standard deviation
+        sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3)
         self.sig.data *= 1 - learning_rate
-        self.sig.data += learning_rate * torch.sqrt(torch.clamp(self.z_x_sq_acc / sum_z - self.mu ** 2, self.epsilon + 0.01))
+        self.sig.data += learning_rate * sig_update
         # Reset accumulators
         self.z_acc.zero_()
         self.z_x_acc.zero_()
@@ -296,7 +306,7 @@ class NormalLeaf(SuprLayer):
         mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
         sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
         r = torch.empty_like(self.x[0])
-        r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * sig_marginalize
+        r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon))
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
 
@@ -313,16 +323,21 @@ class NormalLeaf(SuprLayer):
         x_valid = self.x[:, None, ~self.marginalize, None]
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
-            torch.distributions.Normal(mu_valid, sig_valid).log_prob(x_valid).float()
+            torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float()
         return self.z
 
 class BernoulliLeaf(SuprLayer):
     """ BernoulliLeaf layer """
 
-    def __init__(self, tracks: int, variables: int, channels: int):
+    def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
+                 alpha0: float = 1., beta0: float = 1.):
         super().__init__()
         # Dimensions
         self.T, self.V, self.C = tracks, variables, channels
+        # Number of data points
+        self.n = n
+        # Prior
+        self.alpha0, self.beta0 = alpha0, beta0        
         # Parametes
         self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
         # Which variables to marginalized
@@ -342,8 +357,9 @@ class BernoulliLeaf(SuprLayer):
     def em_update(self, learning_rate: float = 1.):
         # Probability
         sum_z = torch.clamp(self.z_acc, self.epsilon)
+        p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
         self.p.data *= 1. - learning_rate
-        self.p.data += learning_rate * self.z_x_acc / sum_z
+        self.p.data += learning_rate * p_update
         # Reset accumulators
         self.z_acc.zero_()
         self.z_x_acc.zero_()
@@ -368,6 +384,70 @@ class BernoulliLeaf(SuprLayer):
         x_valid = self.x[:, None, ~self.marginalize, None]
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
-            p_valid*(x_valid==1) + (1-p_valid)*(x_valid==0)
+            torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float() 
+        return self.z
+
+# TODO: This is not tested properly. 
+class CategoricalLeaf(SuprLayer):
+    """ CategoricalLeaf layer """
+
+    def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1,
+                 alpha0: float = 1.):
+        super().__init__()
+        # Dimensions
+        self.T, self.V, self.C, self.D = tracks, variables, channels, dimensions
+        # Number of data points
+        self.n = n
+        # Prior
+        self.alpha0 = alpha0      
+        # Parametes
+        self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D))
+        # Which variables to marginalized
+        self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
+        # Input
+        self.register_buffer('x', torch.Tensor())
+        # Output
+        self.register_buffer('z', torch.Tensor())
+        # EM accumulator
+        self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
+        self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C, self.D))
+
+    def em_batch(self):
+        self.z_acc.data += torch.sum(self.z.grad, dim=0)
+        x_onehot = torch.eye(self.D, dtype=bool)[self.x]
+        self.z_x_acc.data += torch.sum(self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0)
+
+    def em_update(self, learning_rate: float = 1.):
+        # Probability
+        sum_z = torch.clamp(self.z_acc, self.epsilon)
+        p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1))
+        self.p.data *= 1. - learning_rate
+        self.p.data += learning_rate * p_update
+        # Reset accumulators
+        self.z_acc.zero_()
+        self.z_x_acc.zero_()
+
+    # XXX Implement this
+    def sample(self, track: int, channel_per_variable: torch.Tensor):
+        p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :]
+        r = torch.empty_like(self.x[0])
+        r_sample = torch.distributions.Categorical(probs=p_marginalize).sample()
+        r[self.marginalize] = r_sample
+        r[~self.marginalize] = self.x[0][~self.marginalize]
+        return r
+
+    def forward(self, x: torch.Tensor):
+        # Get shape
+        batch_size = x.shape[0]
+        # Store the data
+        self.x = x
+        # Compute the probability
+        self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
+        # Get non-marginalized parameters and data
+        p_valid = self.p[None, :, ~self.marginalize, :, :]
+        x_valid = self.x[:, None, ~self.marginalize, None]
+        # Evaluate log probability
+        self.z.data[:, :, ~self.marginalize, :] = \
+            torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() 
         return self.z