diff --git a/supr/layers.py b/supr/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b34df5591518eb2cedd5b6f26b09a954c8c52b3
--- /dev/null
+++ b/supr/layers.py
@@ -0,0 +1,352 @@
+import torch
+import torch
+import torch.nn as nn
+from torch.nn.functional import pad
+import math
+from supr.utils import discrete_rand, local_scramble_2d
+from typing import List
+
+
+# Data:
+# N x V x C
+# └───│───│─ N: Data points
+#     └───│─ V: Variables
+#         └─ C: Channels
+# Probability:
+# N x T x V x C
+# └───│───│───│─ N: Data points
+#     └───│───│─ T: Tracks
+#         └───│─ V: Variables
+#             └─ C: Channels
+
+class SuprLayer(nn.Module):
+    epsilon = 1e-12
+
+    def __init__(self):
+        super().__init__()
+
+    def em_batch(self):
+        pass
+
+    def em_update(self, *args, **kwargs):
+        pass
+
+
+class Parallel(SuprLayer):
+    def __init__(self, nets: List[SuprLayer]):
+        super().__init__()
+        self.nets = nets
+
+    def forward(self, x: torch.Tensor):
+        return [n(x) for n, x in zip(self.nets, x)]
+
+
+class ScrambleTracks(SuprLayer):
+    """ Scrambles the variables in each track """
+
+    def __init__(self, tracks: int, variables: int):
+        super().__init__()
+        # Permutation for each track
+        perm = torch.stack([torch.randperm(variables) for _ in range(tracks)])
+        self.register_buffer('perm', perm)
+
+    def sample(self, track, channel_per_variable):
+        return track, torch.scatter(channel_per_variable, 0, self.perm[track], channel_per_variable)
+
+    def forward(self, x):
+        return x[:, torch.arange(x.shape[1])[:, None], self.perm, :]
+
+class ScrambleTracks2d(SuprLayer):
+    """ Scrambles the variables in each track """
+
+    def __init__(self, tracks: int, variables: int, distance: float, dims: tuple):
+        super().__init__()
+        # Permutation for each track
+        perm = torch.stack([local_scramble_2d(distance, dims) for _ in range(tracks)])
+        self.register_buffer('perm', perm)
+
+    def sample(self, track, channel_per_variable):
+        return track, torch.scatter(channel_per_variable, 0, self.perm[track], channel_per_variable)
+
+    def forward(self, x):
+        return x[:, torch.arange(x.shape[1])[:, None], self.perm, :]
+
+
+class VariablesProduct(SuprLayer):
+    """ Product over all variables """
+
+    def __init(self):
+        super().__init__()
+
+    def sample(self, track, channel_per_variable):
+        return track, torch.full((self.variables, ), channel_per_variable[0]).to(channel_per_variable.device)
+
+    def forward(self, x):
+        if not self.training:
+            self.variables = x.shape[2]
+        return torch.sum(x, dim=2, keepdim=True)
+
+
+class ProductSumLayer(SuprLayer):
+    """ Base class for product-sum layers """
+    def __init__(self, weight_shape, normalize_dims):
+        super().__init__()
+        # Parameters
+        self.weights = nn.Parameter(torch.rand(*weight_shape))
+        self.weights.data /= torch.clamp(self.weights.sum(dim=normalize_dims, keepdim=True), self.epsilon)
+        # Normalize dimensions
+        self.normalize_dims = normalize_dims
+        # EM accumulator
+        self.register_buffer('weights_acc', torch.zeros(*weight_shape))
+
+    def em_batch(self):
+        self.weights_acc.data += self.weights * self.weights.grad
+
+    def em_update(self, learning_rate: float = 1.):
+        weights_grad = torch.clamp(self.weights_acc, self.epsilon)
+        weights_grad /= torch.clamp(weights_grad.sum(dim=self.normalize_dims, keepdim=True), self.epsilon)
+        if learning_rate < 1.:
+            self.weights.data *= 1. - learning_rate
+            self.weights.data += learning_rate * weights_grad
+        else:
+            self.weights.data = weights_grad
+        # Reset accumulator
+        self.weights_acc.zero_()
+
+
+class Einsum(ProductSumLayer):
+    """ Einsum layer """
+
+    def __init__(self, tracks: int, variables: int, channels: int, channels_out: int = None):
+        # Dimensions
+        variables_out = math.ceil(variables / 2)
+        if channels_out is None:
+            channels_out = channels
+        # Initialize super
+        super().__init__((tracks, variables_out, channels_out, channels, channels), (3, 4))
+        # Padding
+        self.x1_pad = torch.zeros(variables_out, dtype=torch.bool)
+        self.x2_pad = torch.zeros(variables_out, dtype=torch.bool)
+        # Zero-pad if necessary
+        if variables % 2 == 1:
+            # Pad on the right
+            self.pad = True
+            self.x2_padding = [0, 0, 0, 1]
+            self.x2_pad[-1] = True
+        else:
+            self.pad = False
+        # TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks
+
+    def sample(self, track: int, channel_per_variable: torch.Tensor):
+        r = []
+        for v, c in enumerate(channel_per_variable):
+            # Probability matrix
+            px1 = torch.exp(self.x1[0, track, v, :][:, None])
+            px2 = torch.exp(self.x2[0, track, v, :][None, :])
+            prob = self.weights[track, v, c] * px1 * px2
+            # Sample
+            idx = discrete_rand(prob)[0]
+            # Remove indices of padding
+            idx_valid = idx[[not self.x1_pad[v], not self.x2_pad[v]]]
+            # Store on list
+            r.append(idx_valid)
+        # Concatenate and return indices
+        return track, torch.cat(r)
+
+    def forward(self, x: torch.Tensor):
+        # Split the input variables in two and apply padding if necessary
+        x1 = x[:, :, 0::2, :]
+        x2 = x[:, :, 1::2, :]
+        if self.pad:
+            x2 = pad(x2, self.x2_padding)
+        # Store the inputs for use in sampling routine
+        if not self.training:
+            self.x1, self.x2 = x1, x2
+        # Compute maximum
+        a1, a2 = [torch.max(x, dim=3, keepdim=True)[0] for x in [x1, x2]]
+        # Subtract maximum and compute exponential
+        exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon) for x, a in [(x1, a1), (x2, a2)]]
+        # Compute the contraction
+        y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights))
+        return y
+
+class Weightsum(ProductSumLayer):
+    """ Weightsum layer """
+
+    # Product over all variables and weighted sum over tracks and channels
+    def __init__(self, tracks: int, variables: int, channels: int):
+        # Initialize super
+        super().__init__((tracks, channels), (0, 1))
+
+    def sample(self):
+        prob = self.weights * torch.exp(self.x_sum[0] - torch.max(self.x_sum[0]))
+        s = discrete_rand(prob)[0]
+        return s[0], torch.full((self.variables,), s[1]).to(self.weights.device)
+
+    def forward(self, x: torch.Tensor):
+        # Product over variables
+        x_sum = torch.sum(x, 2)
+        # Store the inputs for use in sampling routine
+        if not self.training:
+            self.x_sum = x_sum
+            self.variables = x.shape[2]
+        # Compute maximum
+        a = torch.max(torch.max(x_sum, dim=1)[0], dim=1)[0]
+        # Subtract maximum and compute exponential
+        exa = torch.clamp(torch.exp(x_sum - a[:, None, None]), self.epsilon)
+        # Compute the contraction
+        y = a + torch.log(torch.einsum('ntc,tc->n', exa, self.weights))
+        return y
+
+
+class TrackSum(ProductSumLayer):
+    """ TrackSum layer """
+
+    # Weighted sum over tracks
+    def __init__(self, tracks: int, channels: int):
+        # Initialize super
+        super().__init__((tracks, channels), (0, ))
+
+    def sample(self, track: int, channel_per_variable: torch.Tensor):
+        prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0])
+        s = discrete_rand(prob)[0]
+        return s[0], channel_per_variable
+
+    def forward(self, x: torch.Tensor):
+        # Module is only valid when number of variables is 1
+        assert x.shape[2] == 1
+        # Store the inputs for use in sampling routine
+        if not self.training:
+            self.x = x
+        # Compute maximum
+        a = torch.max(x, dim=1)[0]
+        # Subtract maximum and compute exponential
+        exa = torch.clamp(torch.exp(x - a[:, None]), self.epsilon)
+        # Compute the contraction
+        y = a + torch.log(torch.einsum('ntvc,tc->nvc', exa, self.weights))
+        # Insert track dimension
+        y = y[:, None]
+        return y
+
+
+class NormalLeaf(SuprLayer):
+    """ NormalLeaf layer """
+
+    def __init__(self, tracks: int, variables: int, channels: int):
+        super().__init__()
+        # Dimensions
+        self.T, self.V, self.C = tracks, variables, channels
+        # Parametes
+        # self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
+        # self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
+        self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))
+        self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5)
+        # Which variables to marginalized
+        self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
+        # Input
+        self.register_buffer('x', torch.Tensor())
+        # Output
+        self.register_buffer('z', torch.Tensor())
+        # EM accumulator
+        self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
+        self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C))
+        self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C))
+
+    def em_batch(self):
+        self.z_acc.data += torch.sum(self.z.grad, dim=0)
+        self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
+        self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
+
+    def em_update(self, learning_rate: float = 1.):
+        # Mean
+        sum_z = torch.clamp(self.z_acc, self.epsilon)
+        self.mu.data *= 1. - learning_rate
+        self.mu.data += learning_rate * self.z_x_acc / sum_z
+        # Standard deviation
+        self.sig.data *= 1 - learning_rate
+        self.sig.data += learning_rate * torch.sqrt(torch.clamp(self.z_x_sq_acc / sum_z - self.mu ** 2, self.epsilon + 0.01))
+        # Reset accumulators
+        self.z_acc.zero_()
+        self.z_x_acc.zero_()
+        self.z_x_sq_acc.zero_()
+
+    def sample(self, track: int, channel_per_variable: torch.Tensor):
+        variables_marginalize = torch.sum(self.marginalize).int()
+        mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
+        sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
+        r = torch.empty_like(self.x[0])
+        r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * sig_marginalize
+        r[~self.marginalize] = self.x[0][~self.marginalize]
+        return r
+
+    def forward(self, x: torch.Tensor):
+        # Get shape
+        batch_size = x.shape[0]
+        # Store the data
+        self.x = x
+        # Compute the probability
+        self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
+        # Get non-marginalized parameters and data
+        mu_valid = self.mu[None, :, ~self.marginalize, :]
+        sig_valid = self.sig[None, :, ~self.marginalize, :]
+        x_valid = self.x[:, None, ~self.marginalize, None]
+        # Evaluate log probability
+        self.z.data[:, :, ~self.marginalize, :] = \
+            torch.distributions.Normal(mu_valid, sig_valid).log_prob(x_valid).float()
+        return self.z
+
+class BernoulliLeaf(SuprLayer):
+    """ BernoulliLeaf layer """
+
+    def __init__(self, tracks: int, variables: int, channels: int):
+        super().__init__()
+        # Dimensions
+        self.T, self.V, self.C = tracks, variables, channels
+        # Parametes
+        self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
+        # Which variables to marginalized
+        self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
+        # Input
+        self.register_buffer('x', torch.Tensor())
+        # Output
+        self.register_buffer('z', torch.Tensor())
+        # EM accumulator
+        self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
+        self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C))
+
+    def em_batch(self):
+        self.z_acc.data += torch.sum(self.z.grad, dim=0)
+        self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
+
+    def em_update(self, learning_rate: float = 1.):
+        # Probability
+        sum_z = torch.clamp(self.z_acc, self.epsilon)
+        self.p.data *= 1. - learning_rate
+        self.p.data += learning_rate * self.z_x_acc / sum_z
+        # Reset accumulators
+        self.z_acc.zero_()
+        self.z_x_acc.zero_()
+
+    def sample(self, track: int, channel_per_variable: torch.Tensor):
+        variables_marginalize = torch.sum(self.marginalize).int()
+        p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize]]
+        r = torch.empty_like(self.x[0])
+        r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float()
+        r[~self.marginalize] = self.x[0][~self.marginalize]
+        return r
+
+    def forward(self, x: torch.Tensor):
+        # Get shape
+        batch_size = x.shape[0]
+        # Store the data
+        self.x = x
+        # Compute the probability
+        self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
+        # Get non-marginalized parameters and data
+        p_valid = self.p[None, :, ~self.marginalize, :]
+        x_valid = self.x[:, None, ~self.marginalize, None]
+        # Evaluate log probability
+        self.z.data[:, :, ~self.marginalize, :] = \
+            p_valid*(x_valid==1) + (1-p_valid)*(x_valid==0)
+        return self.z
+
diff --git a/supr/utils.py b/supr/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c24317bcab9802dbc09b715cff2917c2d5071c
--- /dev/null
+++ b/supr/utils.py
@@ -0,0 +1,65 @@
+# %% Imports
+import matplotlib.pyplot as plt
+from typing import Tuple
+import torch
+import numpy as np
+import math
+from PyQt5 import QtWidgets
+
+
+# %%
+def drawnow():
+    plt.gcf().canvas.draw()
+    plt.gcf().canvas.flush_events()
+
+
+def arrange_figs(cols=1, min_rows=3, toolbar=False, x0=1400, y0=28, x1=1920, y1=1200):
+    try:
+        current_fig_num = plt.gcf().number
+        extra = 37
+        w = x1 - x0
+        h = y1 - y0
+        fignums = plt.get_fignums()
+        n = len(fignums)
+        rows = np.maximum(math.ceil(n / cols), min_rows)
+        height = int(h / rows - extra)
+        width = int(w / cols)
+        for i, fn in enumerate(fignums):
+            r = i % rows
+            c = int(i / rows)
+            plt.figure(fn)
+            win = plt.get_current_fig_manager().window
+            win.findChild(QtWidgets.QToolBar).setVisible(toolbar)
+            win.setGeometry(x0 + width * c, y0 + int(h / rows * r) + extra, width, height)
+        plt.figure(current_fig_num)
+    except:
+        pass
+
+
+def unravel_indices(indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.LongTensor:
+    r"""Converts flat indices into unraveled coordinates in a target shape.
+    Args:
+        indices: A tensor of (flat) indices, (*, N).
+        shape: The targeted shape, (D,).
+
+    Returns:
+        The unraveled coordinates, (*, N, D).
+    """
+    coord = []
+    for dim in reversed(shape):
+        coord.append(indices % dim)
+        indices = indices // dim
+    coord = torch.stack(coord[::-1], dim=-1)
+    return coord
+
+
+def discrete_rand(v: torch.Tensor, n: int = 1):
+    idx = torch.sum(torch.rand(n)[:, None].to(v.device) > torch.cumsum(v.flatten(), 0)[None, :] / torch.sum(v), dim=1)
+    return unravel_indices(idx, v.shape)
+
+
+def local_scramble_2d(dist: float, dim: tuple):
+    grid = torch.meshgrid(*[torch.arange(d) for d in dim])
+    n = [torch.argsort(m + torch.randn(dim) * dist, dim=i) for i, m in enumerate(grid)]
+    idx = torch.reshape(torch.arange(torch.tensor(dim).prod()), dim)
+    return idx[n[0], grid[1]][grid[0], n[1]].flatten()