diff --git a/supr/layers.py b/supr/layers.py
index d1a59958892c1e5d353cbf44c733f638c8e51f75..b1d6cf99dabceed66a1cbcf1ff83833885d16170 100644
--- a/supr/layers.py
+++ b/supr/layers.py
@@ -5,6 +5,7 @@ import math
 from supr.utils import discrete_rand, local_scramble_2d
 from typing import List
 
+
 # Data:
 # N x V x D
 # └───│──│─ N: Data points
@@ -21,10 +22,10 @@ from typing import List
 class Supr(nn.Module):
     def __init__(self):
         super().__init__()
-    
+
     def sample(self):
         pass
-        
+
 
 class SuprLayer(nn.Module):
     epsilon = 1e-12
@@ -37,9 +38,10 @@ class SuprLayer(nn.Module):
 
     def em_update(self, *args, **kwargs):
         pass
-    
+
+
 class Sequential(nn.Sequential):
-    def __init__(self, *args: object) -> object:
+    def __init__(self, *args: object):
         super().__init__(*args)
 
     def em_batch_update(self):
@@ -54,6 +56,7 @@ class Sequential(nn.Sequential):
             value = module.sample(*value)
         return value
 
+
 class Parallel(SuprLayer):
     def __init__(self, nets: List[SuprLayer]):
         super().__init__()
@@ -78,6 +81,7 @@ class ScrambleTracks(SuprLayer):
     def forward(self, x):
         return x[:, torch.arange(x.shape[1])[:, None], self.perm, :]
 
+
 class ScrambleTracks2d(SuprLayer):
     """ Scrambles the variables in each track """
 
@@ -99,9 +103,10 @@ class VariablesProduct(SuprLayer):
 
     def __init(self):
         super().__init__()
+        self.variables = None
 
     def sample(self, track, channel_per_variable):
-        return track, torch.full((self.variables, ), channel_per_variable[0]).to(channel_per_variable.device)
+        return track, torch.full((self.variables,), channel_per_variable[0]).to(channel_per_variable.device)
 
     def forward(self, x):
         if not self.training:
@@ -111,6 +116,7 @@ class VariablesProduct(SuprLayer):
 
 class ProductSumLayer(SuprLayer):
     """ Base class for product-sum layers """
+
     def __init__(self, weight_shape, normalize_dims):
         super().__init__()
         # Parameters
@@ -192,6 +198,7 @@ class Einsum(ProductSumLayer):
         y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights))
         return y
 
+
 class Weightsum(ProductSumLayer):
     """ Weightsum layer """
 
@@ -227,7 +234,7 @@ class TrackSum(ProductSumLayer):
     # Weighted sum over tracks
     def __init__(self, tracks: int, channels: int):
         # Initialize super
-        super().__init__((tracks, channels), (0, ))
+        super().__init__((tracks, channels), (0,))
 
     def sample(self, track: int, channel_per_variable: torch.Tensor):
         prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0])
@@ -254,8 +261,8 @@ class TrackSum(ProductSumLayer):
 class NormalLeaf(SuprLayer):
     """ NormalLeaf layer """
 
-    def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
-                 mu0: torch.tensor = 0., nu0: torch.tensor = 0., torch.tensor: float = 0., beta0: torch.tensor = 0.):
+    def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0.,
+                 nu0: torch.tensor = 0., alpha0: torch.tensor = 0., beta0: torch.tensor = 0.):
         super().__init__()
         # Dimensions
         self.T, self.V, self.C = tracks, variables, channels
@@ -292,7 +299,8 @@ class NormalLeaf(SuprLayer):
         self.mu.data *= 1. - learning_rate
         self.mu.data += learning_rate * mu_update
         # Standard deviation
-        sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3)
+        sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
+                    self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
         self.sig.data *= 1 - learning_rate
         self.sig.data += learning_rate * sig_update
         # Reset accumulators
@@ -305,7 +313,8 @@ class NormalLeaf(SuprLayer):
         mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
         sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
         r = torch.empty_like(self.x[0])
-        r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon))
+        r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(
+            torch.clamp(sig_marginalize, self.epsilon))
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
 
@@ -322,9 +331,11 @@ class NormalLeaf(SuprLayer):
         x_valid = self.x[:, None, ~self.marginalize, None]
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
-            torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float()
+            torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(
+                x_valid).float()
         return self.z
 
+
 class BernoulliLeaf(SuprLayer):
     """ BernoulliLeaf layer """
 
@@ -336,7 +347,7 @@ class BernoulliLeaf(SuprLayer):
         # Number of data points
         self.n = n
         # Prior
-        self.alpha0, self.beta0 = alpha0, beta0        
+        self.alpha0, self.beta0 = alpha0, beta0
         # Parametes
         self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
         # Which variables to marginalized
@@ -383,10 +394,10 @@ class BernoulliLeaf(SuprLayer):
         x_valid = self.x[:, None, ~self.marginalize, None]
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
-            torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float() 
+            torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float()
         return self.z
 
-# TODO: This is not tested properly. 
+
 class CategoricalLeaf(SuprLayer):
     """ CategoricalLeaf layer """
 
@@ -398,7 +409,7 @@ class CategoricalLeaf(SuprLayer):
         # Number of data points
         self.n = n
         # Prior
-        self.alpha0 = alpha0      
+        self.alpha0 = alpha0
         # Parametes
         self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D))
         # Which variables to marginalized
@@ -419,7 +430,8 @@ class CategoricalLeaf(SuprLayer):
     def em_update(self, learning_rate: float = 1.):
         # Probability
         sum_z = torch.clamp(self.z_acc, self.epsilon)
-        p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1))
+        p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
+                    self.n + self.D * (self.alpha0 - 1))
         self.p.data *= 1. - learning_rate
         self.p.data += learning_rate * p_update
         # Reset accumulators
@@ -447,6 +459,5 @@ class CategoricalLeaf(SuprLayer):
         x_valid = self.x[:, None, ~self.marginalize, None]
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
-            torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() 
+            torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float()
         return self.z
-