From de62975e098ac971b14871a2ffb22c8eaf4ba69a Mon Sep 17 00:00:00 2001 From: "Mikkel N. Schmidt" <mnsc@dtu.dk> Date: Thu, 28 Apr 2022 15:40:58 +0200 Subject: [PATCH] pep --- supr/layers.py | 47 +++++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/supr/layers.py b/supr/layers.py index d1a5995..b1d6cf9 100644 --- a/supr/layers.py +++ b/supr/layers.py @@ -5,6 +5,7 @@ import math from supr.utils import discrete_rand, local_scramble_2d from typing import List + # Data: # N x V x D # └───│──│─ N: Data points @@ -21,10 +22,10 @@ from typing import List class Supr(nn.Module): def __init__(self): super().__init__() - + def sample(self): pass - + class SuprLayer(nn.Module): epsilon = 1e-12 @@ -37,9 +38,10 @@ class SuprLayer(nn.Module): def em_update(self, *args, **kwargs): pass - + + class Sequential(nn.Sequential): - def __init__(self, *args: object) -> object: + def __init__(self, *args: object): super().__init__(*args) def em_batch_update(self): @@ -54,6 +56,7 @@ class Sequential(nn.Sequential): value = module.sample(*value) return value + class Parallel(SuprLayer): def __init__(self, nets: List[SuprLayer]): super().__init__() @@ -78,6 +81,7 @@ class ScrambleTracks(SuprLayer): def forward(self, x): return x[:, torch.arange(x.shape[1])[:, None], self.perm, :] + class ScrambleTracks2d(SuprLayer): """ Scrambles the variables in each track """ @@ -99,9 +103,10 @@ class VariablesProduct(SuprLayer): def __init(self): super().__init__() + self.variables = None def sample(self, track, channel_per_variable): - return track, torch.full((self.variables, ), channel_per_variable[0]).to(channel_per_variable.device) + return track, torch.full((self.variables,), channel_per_variable[0]).to(channel_per_variable.device) def forward(self, x): if not self.training: @@ -111,6 +116,7 @@ class VariablesProduct(SuprLayer): class ProductSumLayer(SuprLayer): """ Base class for product-sum layers """ + def __init__(self, weight_shape, normalize_dims): super().__init__() # Parameters @@ -192,6 +198,7 @@ class Einsum(ProductSumLayer): y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights)) return y + class Weightsum(ProductSumLayer): """ Weightsum layer """ @@ -227,7 +234,7 @@ class TrackSum(ProductSumLayer): # Weighted sum over tracks def __init__(self, tracks: int, channels: int): # Initialize super - super().__init__((tracks, channels), (0, )) + super().__init__((tracks, channels), (0,)) def sample(self, track: int, channel_per_variable: torch.Tensor): prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0]) @@ -254,8 +261,8 @@ class TrackSum(ProductSumLayer): class NormalLeaf(SuprLayer): """ NormalLeaf layer """ - def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, - mu0: torch.tensor = 0., nu0: torch.tensor = 0., torch.tensor: float = 0., beta0: torch.tensor = 0.): + def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0., + nu0: torch.tensor = 0., alpha0: torch.tensor = 0., beta0: torch.tensor = 0.): super().__init__() # Dimensions self.T, self.V, self.C = tracks, variables, channels @@ -292,7 +299,8 @@ class NormalLeaf(SuprLayer): self.mu.data *= 1. - learning_rate self.mu.data += learning_rate * mu_update # Standard deviation - sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3) + sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * ( + self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3) self.sig.data *= 1 - learning_rate self.sig.data += learning_rate * sig_update # Reset accumulators @@ -305,7 +313,8 @@ class NormalLeaf(SuprLayer): mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]] sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]] r = torch.empty_like(self.x[0]) - r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon)) + r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt( + torch.clamp(sig_marginalize, self.epsilon)) r[~self.marginalize] = self.x[0][~self.marginalize] return r @@ -322,9 +331,11 @@ class NormalLeaf(SuprLayer): x_valid = self.x[:, None, ~self.marginalize, None] # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ - torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float() + torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob( + x_valid).float() return self.z + class BernoulliLeaf(SuprLayer): """ BernoulliLeaf layer """ @@ -336,7 +347,7 @@ class BernoulliLeaf(SuprLayer): # Number of data points self.n = n # Prior - self.alpha0, self.beta0 = alpha0, beta0 + self.alpha0, self.beta0 = alpha0, beta0 # Parametes self.p = nn.Parameter(torch.rand(self.T, self.V, self.C)) # Which variables to marginalized @@ -383,10 +394,10 @@ class BernoulliLeaf(SuprLayer): x_valid = self.x[:, None, ~self.marginalize, None] # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ - torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float() + torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float() return self.z -# TODO: This is not tested properly. + class CategoricalLeaf(SuprLayer): """ CategoricalLeaf layer """ @@ -398,7 +409,7 @@ class CategoricalLeaf(SuprLayer): # Number of data points self.n = n # Prior - self.alpha0 = alpha0 + self.alpha0 = alpha0 # Parametes self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D)) # Which variables to marginalized @@ -419,7 +430,8 @@ class CategoricalLeaf(SuprLayer): def em_update(self, learning_rate: float = 1.): # Probability sum_z = torch.clamp(self.z_acc, self.epsilon) - p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1)) + p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / ( + self.n + self.D * (self.alpha0 - 1)) self.p.data *= 1. - learning_rate self.p.data += learning_rate * p_update # Reset accumulators @@ -447,6 +459,5 @@ class CategoricalLeaf(SuprLayer): x_valid = self.x[:, None, ~self.marginalize, None] # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ - torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() + torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() return self.z - -- GitLab