diff --git a/supr/layers.py b/supr/layers.py
index b1d6cf99dabceed66a1cbcf1ff83833885d16170..2a1274811523f3f9a9b127f086dec59988198030 100644
--- a/supr/layers.py
+++ b/supr/layers.py
@@ -287,7 +287,7 @@ class NormalLeaf(SuprLayer):
         self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C))
 
     def em_batch(self):
-        self.z_acc.data += torch.sum(self.z.grad, dim=0)
+        self.z_acc.data += torch.clamp(torch.sum(self.z.grad, dim=0), self.epsilon)
         self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
         self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
 
@@ -295,12 +295,15 @@ class NormalLeaf(SuprLayer):
         # Sum of weights
         sum_z = torch.clamp(self.z_acc, self.epsilon)
         # Mean
-        mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
+        # mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
+        mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)) / (self.nu0 + self.z_acc)
         self.mu.data *= 1. - learning_rate
         self.mu.data += learning_rate * mu_update
         # Standard deviation
-        sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
-                    self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
+        # sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
+        #           self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
+        sig_update = (self.z_x_sq_acc - self.z_acc * self.mu ** 2 + 2 * self.beta0 + self.nu0 * (
+                  self.mu0 - self.mu) ** 2) / (self.z_acc + 2 * self.alpha0 + 3)
         self.sig.data *= 1 - learning_rate
         self.sig.data += learning_rate * sig_update
         # Reset accumulators
@@ -367,7 +370,8 @@ class BernoulliLeaf(SuprLayer):
     def em_update(self, learning_rate: float = 1.):
         # Probability
         sum_z = torch.clamp(self.z_acc, self.epsilon)
-        p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
+        # p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
+        p_update = (self.z_x_acc + self.alpha0 - 1) / (self.z_acc + self.alpha0 + self.beta0 - 2)
         self.p.data *= 1. - learning_rate
         self.p.data += learning_rate * p_update
         # Reset accumulators
@@ -430,15 +434,16 @@ class CategoricalLeaf(SuprLayer):
     def em_update(self, learning_rate: float = 1.):
         # Probability
         sum_z = torch.clamp(self.z_acc, self.epsilon)
-        p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
-                    self.n + self.D * (self.alpha0 - 1))
+        # p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
+                    # self.n + self.D * (self.alpha0 - 1))
+        p_update = (self.z_x_acc + self.alpha0 - 1) / (
+                    self.z_acc[:,:,:,None] + self.D * (self.alpha0 - 1))
         self.p.data *= 1. - learning_rate
         self.p.data += learning_rate * p_update
         # Reset accumulators
         self.z_acc.zero_()
         self.z_x_acc.zero_()
 
-    # XXX Implement this
     def sample(self, track: int, channel_per_variable: torch.Tensor):
         p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :]
         r = torch.empty_like(self.x[0])
@@ -460,4 +465,4 @@ class CategoricalLeaf(SuprLayer):
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
             torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float()
-        return self.z
+        return self.z
\ No newline at end of file