From 8b606b3ebaa30149c04ae6a83d02c25bf9e61d2c Mon Sep 17 00:00:00 2001
From: "Mikkel N. Schmidt" <mnsc@dtu.dk>
Date: Wed, 11 May 2022 16:03:41 +0200
Subject: [PATCH] pep8

---
 supr/layers.py | 129 ++++++++++++++++++++++++++++++-------------------
 1 file changed, 78 insertions(+), 51 deletions(-)

diff --git a/supr/layers.py b/supr/layers.py
index 181f8ab..441dc47 100644
--- a/supr/layers.py
+++ b/supr/layers.py
@@ -59,7 +59,7 @@ class Sequential(nn.Sequential):
         with torch.no_grad():
             for module in self:
                 module.em_update()
-                
+
     def sample(self):
         value = []
         for module in reversed(self):
@@ -71,7 +71,7 @@ class Sequential(nn.Sequential):
 
     def var(self):
         return self[0].var()
-    
+
     def forward(self, value, marginalize=None):
         for module in self:
             if isinstance(module, SuprLeaf):
@@ -81,7 +81,6 @@ class Sequential(nn.Sequential):
         return value
 
 
-
 class Parallel(SuprLayer):
     def __init__(self, nets: List[SuprLayer]):
         super().__init__()
@@ -113,7 +112,8 @@ class ScrambleTracks2d(SuprLayer):
     def __init__(self, tracks: int, variables: int, distance: float, dims: tuple):
         super().__init__()
         # Permutation for each track
-        perm = torch.stack([local_scramble_2d(distance, dims) for _ in range(tracks)])
+        perm = torch.stack([local_scramble_2d(distance, dims)
+                           for _ in range(tracks)])
         self.register_buffer('perm', perm)
 
     def sample(self, track, channel_per_variable):
@@ -146,7 +146,8 @@ class ProductSumLayer(SuprLayer):
         super().__init__()
         # Parameters
         self.weights = nn.Parameter(torch.rand(*weight_shape))
-        self.weights.data /= torch.clamp(self.weights.sum(dim=normalize_dims, keepdim=True), self.epsilon)
+        self.weights.data /= torch.clamp(self.weights.sum(
+            dim=normalize_dims, keepdim=True), self.epsilon)
         # Normalize dimensions
         self.normalize_dims = normalize_dims
         # EM accumulator
@@ -157,7 +158,8 @@ class ProductSumLayer(SuprLayer):
 
     def em_update(self, learning_rate: float = 1.):
         weights_grad = torch.clamp(self.weights_acc, self.epsilon)
-        weights_grad /= torch.clamp(weights_grad.sum(dim=self.normalize_dims, keepdim=True), self.epsilon)
+        weights_grad /= torch.clamp(weights_grad.sum(
+            dim=self.normalize_dims, keepdim=True), self.epsilon)
         if learning_rate < 1.:
             self.weights.data *= 1. - learning_rate
             self.weights.data += learning_rate * weights_grad
@@ -188,7 +190,6 @@ class Einsum(ProductSumLayer):
             self.x2_pad[-1] = True
         else:
             self.pad = False
-        # TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks
 
     def sample(self, track: int, channel_per_variable: torch.Tensor):
         r = []
@@ -218,9 +219,12 @@ class Einsum(ProductSumLayer):
         # Compute maximum
         a1, a2 = [torch.max(x, dim=3, keepdim=True)[0] for x in [x1, x2]]
         # Subtract maximum and compute exponential
-        exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon) for x, a in [(x1, a1), (x2, a2)]]
+        exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon)
+                      for x, a in [(x1, a1), (x2, a2)]]
         # Compute the contraction
-        y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights))
+        y = a1 + a2 + \
+            torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc',
+                      exa1, exa2, self.weights))
         return y
 
 
@@ -233,7 +237,8 @@ class Weightsum(ProductSumLayer):
         super().__init__((tracks, channels), (0, 1))
 
     def sample(self):
-        prob = self.weights * torch.exp(self.x_sum[0] - torch.max(self.x_sum[0]))
+        prob = self.weights * \
+            torch.exp(self.x_sum[0] - torch.max(self.x_sum[0]))
         s = discrete_rand(prob)[0]
         return s[0], torch.full((self.variables,), s[1]).to(self.weights.device)
 
@@ -262,7 +267,8 @@ class TrackSum(ProductSumLayer):
         super().__init__((tracks, channels), (0,))
 
     def sample(self, track: int, channel_per_variable: torch.Tensor):
-        prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0])
+        prob = self.weights[:, None] * \
+            torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0])
         s = discrete_rand(prob)[0]
         return s[0], channel_per_variable
 
@@ -282,15 +288,18 @@ class TrackSum(ProductSumLayer):
         y = y[:, None]
         return y
 
+
 class SuprLeaf(SuprLayer):
     def __init__(self):
         super().__init__()
 
+
 class NormalLeaf(SuprLeaf):
     """ NormalLeaf layer """
 
-    def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0.,
-                 nu0: torch.tensor = 0., alpha0: torch.tensor = 0., beta0: torch.tensor = 0.):
+    def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
+                 mu0: torch.tensor = 0., nu0: torch.tensor = 0.,
+                 alpha0: torch.tensor = 0., beta0: torch.tensor = 0.):
         super().__init__()
         # Dimensions
         self.T, self.V, self.C = tracks, variables, channels
@@ -299,12 +308,11 @@ class NormalLeaf(SuprLeaf):
         # Prior
         self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0
         # Parametes
-        # self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
-        # self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
         self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))
         self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5)
         # Which variables to marginalized
-        self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
+        self.register_buffer('marginalize', torch.zeros(
+            variables, dtype=torch.bool))
         # Input
         self.register_buffer('x', torch.Tensor())
         # Output
@@ -315,23 +323,26 @@ class NormalLeaf(SuprLeaf):
         self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C))
 
     def em_batch(self):
-        self.z_acc.data += torch.clamp(torch.sum(self.z.grad, dim=0), self.epsilon)
-        self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
-        self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
+        self.z_acc.data += torch.clamp(torch.sum(self.z.grad,
+                                       dim=0), self.epsilon)
+        self.z_x_acc.data += torch.sum(self.z.grad *
+                                       self.x[:, None, :, None], dim=0)
+        self.z_x_sq_acc.data += torch.sum(self.z.grad *
+                                          self.x[:, None, :, None] ** 2, dim=0)
 
     def em_update(self, learning_rate: float = 1.):
         # Sum of weights
         sum_z = torch.clamp(self.z_acc, self.epsilon)
         # Mean
-        # mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
-        mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)) / (self.nu0 + self.z_acc)
+        mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)
+                     ) / (self.nu0 + self.z_acc)
         self.mu.data *= 1. - learning_rate
         self.mu.data += learning_rate * mu_update
         # Standard deviation
-        # sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
-        #           self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
-        sig_update = (self.z_x_sq_acc - self.z_acc * self.mu ** 2 + 2 * self.beta0 + self.nu0 * (
-                  self.mu0 - self.mu) ** 2) / (self.z_acc + 2 * self.alpha0 + 3)
+        sig_update = (self.z_x_sq_acc -
+                      self.z_acc * self.mu ** 2 + 2 * self.beta0 +
+                      self.nu0 * (self.mu0 - self.mu) ** 2
+                      ) / (self.z_acc + 2 * self.alpha0 + 3)
         self.sig.data *= 1 - learning_rate
         self.sig.data += learning_rate * sig_update
         # Reset accumulators
@@ -341,19 +352,22 @@ class NormalLeaf(SuprLeaf):
 
     def sample(self, track: int, channel_per_variable: torch.Tensor):
         variables_marginalize = torch.sum(self.marginalize).int()
-        mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
-        sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
+        mu_marginalize = self.mu[track, self.marginalize,
+                                 channel_per_variable[self.marginalize]]
+        sig_marginalize = self.sig[track, self.marginalize,
+                                   channel_per_variable[self.marginalize]]
         r = torch.empty_like(self.x[0])
         r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(
             torch.clamp(sig_marginalize, self.epsilon))
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
-    
+
     def mean(self):
         return (torch.clamp(self.z.grad, self.epsilon) * self.mu).sum([1, 3])
-    
+
     def var(self):
-        return (torch.clamp(self.z.grad, self.epsilon) * (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2
+        return (torch.clamp(self.z.grad, self.epsilon) *
+                (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2
 
     def forward(self, x: torch.Tensor, marginalize=None):
         # Get shape
@@ -364,7 +378,8 @@ class NormalLeaf(SuprLeaf):
         # Store the data
         self.x = x
         # Compute the probability
-        self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
+        self.z = torch.zeros(batch_size, self.T, self.V,
+                             self.C, requires_grad=True, device=x.device)
         # Get non-marginalized parameters and data
         mu_valid = self.mu[None, :, ~self.marginalize, :]
         sig_valid = self.sig[None, :, ~self.marginalize, :]
@@ -391,7 +406,8 @@ class BernoulliLeaf(SuprLeaf):
         # Parametes
         self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
         # Which variables to marginalized
-        self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
+        self.register_buffer('marginalize', torch.zeros(
+            variables, dtype=torch.bool))
         # Input
         self.register_buffer('x', torch.Tensor())
         # Output
@@ -402,13 +418,13 @@ class BernoulliLeaf(SuprLeaf):
 
     def em_batch(self):
         self.z_acc.data += torch.sum(self.z.grad, dim=0)
-        self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
+        self.z_x_acc.data += torch.sum(self.z.grad *
+                                       self.x[:, None, :, None], dim=0)
 
     def em_update(self, learning_rate: float = 1.):
         # Probability
-        sum_z = torch.clamp(self.z_acc, self.epsilon)
-        # p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
-        p_update = (self.z_x_acc + self.alpha0 - 1) / (self.z_acc + self.alpha0 + self.beta0 - 2)
+        p_update = (self.z_x_acc + self.alpha0 - 1) / \
+            (self.z_acc + self.alpha0 + self.beta0 - 2)
         self.p.data *= 1. - learning_rate
         self.p.data += learning_rate * p_update
         # Reset accumulators
@@ -417,25 +433,29 @@ class BernoulliLeaf(SuprLeaf):
 
     def sample(self, track: int, channel_per_variable: torch.Tensor):
         variables_marginalize = torch.sum(self.marginalize).int()
-        p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize]]
+        p_marginalize = self.p[track, self.marginalize,
+                               channel_per_variable[self.marginalize]]
         r = torch.empty_like(self.x[0])
-        r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float()
+        r[self.marginalize] = (torch.rand(variables_marginalize).to(
+            self.x.device) < p_marginalize).float()
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
-    
+
     def forward(self, x: torch.Tensor):
         # Get shape
         batch_size = x.shape[0]
         # Store the data
         self.x = x
         # Compute the probability
-        self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
+        self.z = torch.zeros(batch_size, self.T, self.V,
+                             self.C, requires_grad=True, device=x.device)
         # Get non-marginalized parameters and data
         p_valid = self.p[None, :, ~self.marginalize, :]
         x_valid = self.x[:, None, ~self.marginalize, None]
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
-            torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float()
+            torch.distributions.Bernoulli(
+                probs=p_valid).log_prob(x_valid).float()
         return self.z
 
 
@@ -454,27 +474,30 @@ class CategoricalLeaf(SuprLeaf):
         # Parametes
         self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D))
         # Which variables to marginalized
-        self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
+        self.register_buffer('marginalize', torch.zeros(
+            variables, dtype=torch.bool))
         # Input
         self.register_buffer('x', torch.Tensor())
         # Output
         self.register_buffer('z', torch.Tensor())
         # EM accumulator
         self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
-        self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C, self.D))
+        self.register_buffer('z_x_acc', torch.zeros(
+            self.T, self.V, self.C, self.D))
 
     def em_batch(self):
         self.z_acc.data += torch.sum(self.z.grad, dim=0)
         x_onehot = torch.eye(self.D, dtype=bool)[self.x]
-        self.z_x_acc.data += torch.sum(self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0)
+        self.z_x_acc.data += torch.sum(
+            self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0)
 
     def em_update(self, learning_rate: float = 1.):
         # Probability
         sum_z = torch.clamp(self.z_acc, self.epsilon)
         # p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
-                    # self.n + self.D * (self.alpha0 - 1))
+        # self.n + self.D * (self.alpha0 - 1))
         p_update = (self.z_x_acc + self.alpha0 - 1) / (
-                    self.z_acc[:,:,:,None] + self.D * (self.alpha0 - 1))
+            self.z_acc[:, :, :, None] + self.D * (self.alpha0 - 1))
         self.p.data *= 1. - learning_rate
         self.p.data += learning_rate * p_update
         # Reset accumulators
@@ -482,9 +505,11 @@ class CategoricalLeaf(SuprLeaf):
         self.z_x_acc.zero_()
 
     def sample(self, track: int, channel_per_variable: torch.Tensor):
-        p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :]
+        p_marginalize = self.p[track, self.marginalize,
+                               channel_per_variable[self.marginalize], :]
         r = torch.empty_like(self.x[0])
-        r_sample = torch.distributions.Categorical(probs=p_marginalize).sample()
+        r_sample = torch.distributions.Categorical(
+            probs=p_marginalize).sample()
         r[self.marginalize] = r_sample
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
@@ -495,11 +520,13 @@ class CategoricalLeaf(SuprLeaf):
         # Store the data
         self.x = x
         # Compute the probability
-        self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
+        self.z = torch.zeros(batch_size, self.T, self.V,
+                             self.C, requires_grad=True, device=x.device)
         # Get non-marginalized parameters and data
         p_valid = self.p[None, :, ~self.marginalize, :, :]
         x_valid = self.x[:, None, ~self.marginalize, None]
         # Evaluate log probability
         self.z.data[:, :, ~self.marginalize, :] = \
-            torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float()
-        return self.z
\ No newline at end of file
+            torch.distributions.Categorical(
+                probs=p_valid).log_prob(x_valid).float()
+        return self.z
-- 
GitLab