From 8b606b3ebaa30149c04ae6a83d02c25bf9e61d2c Mon Sep 17 00:00:00 2001 From: "Mikkel N. Schmidt" <mnsc@dtu.dk> Date: Wed, 11 May 2022 16:03:41 +0200 Subject: [PATCH] pep8 --- supr/layers.py | 129 ++++++++++++++++++++++++++++++------------------- 1 file changed, 78 insertions(+), 51 deletions(-) diff --git a/supr/layers.py b/supr/layers.py index 181f8ab..441dc47 100644 --- a/supr/layers.py +++ b/supr/layers.py @@ -59,7 +59,7 @@ class Sequential(nn.Sequential): with torch.no_grad(): for module in self: module.em_update() - + def sample(self): value = [] for module in reversed(self): @@ -71,7 +71,7 @@ class Sequential(nn.Sequential): def var(self): return self[0].var() - + def forward(self, value, marginalize=None): for module in self: if isinstance(module, SuprLeaf): @@ -81,7 +81,6 @@ class Sequential(nn.Sequential): return value - class Parallel(SuprLayer): def __init__(self, nets: List[SuprLayer]): super().__init__() @@ -113,7 +112,8 @@ class ScrambleTracks2d(SuprLayer): def __init__(self, tracks: int, variables: int, distance: float, dims: tuple): super().__init__() # Permutation for each track - perm = torch.stack([local_scramble_2d(distance, dims) for _ in range(tracks)]) + perm = torch.stack([local_scramble_2d(distance, dims) + for _ in range(tracks)]) self.register_buffer('perm', perm) def sample(self, track, channel_per_variable): @@ -146,7 +146,8 @@ class ProductSumLayer(SuprLayer): super().__init__() # Parameters self.weights = nn.Parameter(torch.rand(*weight_shape)) - self.weights.data /= torch.clamp(self.weights.sum(dim=normalize_dims, keepdim=True), self.epsilon) + self.weights.data /= torch.clamp(self.weights.sum( + dim=normalize_dims, keepdim=True), self.epsilon) # Normalize dimensions self.normalize_dims = normalize_dims # EM accumulator @@ -157,7 +158,8 @@ class ProductSumLayer(SuprLayer): def em_update(self, learning_rate: float = 1.): weights_grad = torch.clamp(self.weights_acc, self.epsilon) - weights_grad /= torch.clamp(weights_grad.sum(dim=self.normalize_dims, keepdim=True), self.epsilon) + weights_grad /= torch.clamp(weights_grad.sum( + dim=self.normalize_dims, keepdim=True), self.epsilon) if learning_rate < 1.: self.weights.data *= 1. - learning_rate self.weights.data += learning_rate * weights_grad @@ -188,7 +190,6 @@ class Einsum(ProductSumLayer): self.x2_pad[-1] = True else: self.pad = False - # TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks def sample(self, track: int, channel_per_variable: torch.Tensor): r = [] @@ -218,9 +219,12 @@ class Einsum(ProductSumLayer): # Compute maximum a1, a2 = [torch.max(x, dim=3, keepdim=True)[0] for x in [x1, x2]] # Subtract maximum and compute exponential - exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon) for x, a in [(x1, a1), (x2, a2)]] + exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon) + for x, a in [(x1, a1), (x2, a2)]] # Compute the contraction - y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights)) + y = a1 + a2 + \ + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', + exa1, exa2, self.weights)) return y @@ -233,7 +237,8 @@ class Weightsum(ProductSumLayer): super().__init__((tracks, channels), (0, 1)) def sample(self): - prob = self.weights * torch.exp(self.x_sum[0] - torch.max(self.x_sum[0])) + prob = self.weights * \ + torch.exp(self.x_sum[0] - torch.max(self.x_sum[0])) s = discrete_rand(prob)[0] return s[0], torch.full((self.variables,), s[1]).to(self.weights.device) @@ -262,7 +267,8 @@ class TrackSum(ProductSumLayer): super().__init__((tracks, channels), (0,)) def sample(self, track: int, channel_per_variable: torch.Tensor): - prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0]) + prob = self.weights[:, None] * \ + torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0]) s = discrete_rand(prob)[0] return s[0], channel_per_variable @@ -282,15 +288,18 @@ class TrackSum(ProductSumLayer): y = y[:, None] return y + class SuprLeaf(SuprLayer): def __init__(self): super().__init__() + class NormalLeaf(SuprLeaf): """ NormalLeaf layer """ - def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0., - nu0: torch.tensor = 0., alpha0: torch.tensor = 0., beta0: torch.tensor = 0.): + def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, + mu0: torch.tensor = 0., nu0: torch.tensor = 0., + alpha0: torch.tensor = 0., beta0: torch.tensor = 0.): super().__init__() # Dimensions self.T, self.V, self.C = tracks, variables, channels @@ -299,12 +308,11 @@ class NormalLeaf(SuprLeaf): # Prior self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0 # Parametes - # self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C)) - # self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1))) self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C)) self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5) # Which variables to marginalized - self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) + self.register_buffer('marginalize', torch.zeros( + variables, dtype=torch.bool)) # Input self.register_buffer('x', torch.Tensor()) # Output @@ -315,23 +323,26 @@ class NormalLeaf(SuprLeaf): self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C)) def em_batch(self): - self.z_acc.data += torch.clamp(torch.sum(self.z.grad, dim=0), self.epsilon) - self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) - self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0) + self.z_acc.data += torch.clamp(torch.sum(self.z.grad, + dim=0), self.epsilon) + self.z_x_acc.data += torch.sum(self.z.grad * + self.x[:, None, :, None], dim=0) + self.z_x_sq_acc.data += torch.sum(self.z.grad * + self.x[:, None, :, None] ** 2, dim=0) def em_update(self, learning_rate: float = 1.): # Sum of weights sum_z = torch.clamp(self.z_acc, self.epsilon) # Mean - # mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n) - mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)) / (self.nu0 + self.z_acc) + mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z) + ) / (self.nu0 + self.z_acc) self.mu.data *= 1. - learning_rate self.mu.data += learning_rate * mu_update # Standard deviation - # sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * ( - # self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3) - sig_update = (self.z_x_sq_acc - self.z_acc * self.mu ** 2 + 2 * self.beta0 + self.nu0 * ( - self.mu0 - self.mu) ** 2) / (self.z_acc + 2 * self.alpha0 + 3) + sig_update = (self.z_x_sq_acc - + self.z_acc * self.mu ** 2 + 2 * self.beta0 + + self.nu0 * (self.mu0 - self.mu) ** 2 + ) / (self.z_acc + 2 * self.alpha0 + 3) self.sig.data *= 1 - learning_rate self.sig.data += learning_rate * sig_update # Reset accumulators @@ -341,19 +352,22 @@ class NormalLeaf(SuprLeaf): def sample(self, track: int, channel_per_variable: torch.Tensor): variables_marginalize = torch.sum(self.marginalize).int() - mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]] - sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]] + mu_marginalize = self.mu[track, self.marginalize, + channel_per_variable[self.marginalize]] + sig_marginalize = self.sig[track, self.marginalize, + channel_per_variable[self.marginalize]] r = torch.empty_like(self.x[0]) r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt( torch.clamp(sig_marginalize, self.epsilon)) r[~self.marginalize] = self.x[0][~self.marginalize] return r - + def mean(self): return (torch.clamp(self.z.grad, self.epsilon) * self.mu).sum([1, 3]) - + def var(self): - return (torch.clamp(self.z.grad, self.epsilon) * (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2 + return (torch.clamp(self.z.grad, self.epsilon) * + (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2 def forward(self, x: torch.Tensor, marginalize=None): # Get shape @@ -364,7 +378,8 @@ class NormalLeaf(SuprLeaf): # Store the data self.x = x # Compute the probability - self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) + self.z = torch.zeros(batch_size, self.T, self.V, + self.C, requires_grad=True, device=x.device) # Get non-marginalized parameters and data mu_valid = self.mu[None, :, ~self.marginalize, :] sig_valid = self.sig[None, :, ~self.marginalize, :] @@ -391,7 +406,8 @@ class BernoulliLeaf(SuprLeaf): # Parametes self.p = nn.Parameter(torch.rand(self.T, self.V, self.C)) # Which variables to marginalized - self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) + self.register_buffer('marginalize', torch.zeros( + variables, dtype=torch.bool)) # Input self.register_buffer('x', torch.Tensor()) # Output @@ -402,13 +418,13 @@ class BernoulliLeaf(SuprLeaf): def em_batch(self): self.z_acc.data += torch.sum(self.z.grad, dim=0) - self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) + self.z_x_acc.data += torch.sum(self.z.grad * + self.x[:, None, :, None], dim=0) def em_update(self, learning_rate: float = 1.): # Probability - sum_z = torch.clamp(self.z_acc, self.epsilon) - # p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2) - p_update = (self.z_x_acc + self.alpha0 - 1) / (self.z_acc + self.alpha0 + self.beta0 - 2) + p_update = (self.z_x_acc + self.alpha0 - 1) / \ + (self.z_acc + self.alpha0 + self.beta0 - 2) self.p.data *= 1. - learning_rate self.p.data += learning_rate * p_update # Reset accumulators @@ -417,25 +433,29 @@ class BernoulliLeaf(SuprLeaf): def sample(self, track: int, channel_per_variable: torch.Tensor): variables_marginalize = torch.sum(self.marginalize).int() - p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize]] + p_marginalize = self.p[track, self.marginalize, + channel_per_variable[self.marginalize]] r = torch.empty_like(self.x[0]) - r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float() + r[self.marginalize] = (torch.rand(variables_marginalize).to( + self.x.device) < p_marginalize).float() r[~self.marginalize] = self.x[0][~self.marginalize] return r - + def forward(self, x: torch.Tensor): # Get shape batch_size = x.shape[0] # Store the data self.x = x # Compute the probability - self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) + self.z = torch.zeros(batch_size, self.T, self.V, + self.C, requires_grad=True, device=x.device) # Get non-marginalized parameters and data p_valid = self.p[None, :, ~self.marginalize, :] x_valid = self.x[:, None, ~self.marginalize, None] # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ - torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float() + torch.distributions.Bernoulli( + probs=p_valid).log_prob(x_valid).float() return self.z @@ -454,27 +474,30 @@ class CategoricalLeaf(SuprLeaf): # Parametes self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D)) # Which variables to marginalized - self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) + self.register_buffer('marginalize', torch.zeros( + variables, dtype=torch.bool)) # Input self.register_buffer('x', torch.Tensor()) # Output self.register_buffer('z', torch.Tensor()) # EM accumulator self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C)) - self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C, self.D)) + self.register_buffer('z_x_acc', torch.zeros( + self.T, self.V, self.C, self.D)) def em_batch(self): self.z_acc.data += torch.sum(self.z.grad, dim=0) x_onehot = torch.eye(self.D, dtype=bool)[self.x] - self.z_x_acc.data += torch.sum(self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0) + self.z_x_acc.data += torch.sum( + self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0) def em_update(self, learning_rate: float = 1.): # Probability sum_z = torch.clamp(self.z_acc, self.epsilon) # p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / ( - # self.n + self.D * (self.alpha0 - 1)) + # self.n + self.D * (self.alpha0 - 1)) p_update = (self.z_x_acc + self.alpha0 - 1) / ( - self.z_acc[:,:,:,None] + self.D * (self.alpha0 - 1)) + self.z_acc[:, :, :, None] + self.D * (self.alpha0 - 1)) self.p.data *= 1. - learning_rate self.p.data += learning_rate * p_update # Reset accumulators @@ -482,9 +505,11 @@ class CategoricalLeaf(SuprLeaf): self.z_x_acc.zero_() def sample(self, track: int, channel_per_variable: torch.Tensor): - p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :] + p_marginalize = self.p[track, self.marginalize, + channel_per_variable[self.marginalize], :] r = torch.empty_like(self.x[0]) - r_sample = torch.distributions.Categorical(probs=p_marginalize).sample() + r_sample = torch.distributions.Categorical( + probs=p_marginalize).sample() r[self.marginalize] = r_sample r[~self.marginalize] = self.x[0][~self.marginalize] return r @@ -495,11 +520,13 @@ class CategoricalLeaf(SuprLeaf): # Store the data self.x = x # Compute the probability - self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) + self.z = torch.zeros(batch_size, self.T, self.V, + self.C, requires_grad=True, device=x.device) # Get non-marginalized parameters and data p_valid = self.p[None, :, ~self.marginalize, :, :] x_valid = self.x[:, None, ~self.marginalize, None] # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ - torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() - return self.z \ No newline at end of file + torch.distributions.Categorical( + probs=p_valid).log_prob(x_valid).float() + return self.z -- GitLab