Skip to content
Snippets Groups Projects
Commit de62975e authored by mnsc's avatar mnsc
Browse files

pep

parent 642eff87
Branches
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ import math
from supr.utils import discrete_rand, local_scramble_2d
from typing import List
# Data:
# N x V x D
# └───│──│─ N: Data points
......@@ -38,8 +39,9 @@ class SuprLayer(nn.Module):
def em_update(self, *args, **kwargs):
pass
class Sequential(nn.Sequential):
def __init__(self, *args: object) -> object:
def __init__(self, *args: object):
super().__init__(*args)
def em_batch_update(self):
......@@ -54,6 +56,7 @@ class Sequential(nn.Sequential):
value = module.sample(*value)
return value
class Parallel(SuprLayer):
def __init__(self, nets: List[SuprLayer]):
super().__init__()
......@@ -78,6 +81,7 @@ class ScrambleTracks(SuprLayer):
def forward(self, x):
return x[:, torch.arange(x.shape[1])[:, None], self.perm, :]
class ScrambleTracks2d(SuprLayer):
""" Scrambles the variables in each track """
......@@ -99,6 +103,7 @@ class VariablesProduct(SuprLayer):
def __init(self):
super().__init__()
self.variables = None
def sample(self, track, channel_per_variable):
return track, torch.full((self.variables,), channel_per_variable[0]).to(channel_per_variable.device)
......@@ -111,6 +116,7 @@ class VariablesProduct(SuprLayer):
class ProductSumLayer(SuprLayer):
""" Base class for product-sum layers """
def __init__(self, weight_shape, normalize_dims):
super().__init__()
# Parameters
......@@ -192,6 +198,7 @@ class Einsum(ProductSumLayer):
y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights))
return y
class Weightsum(ProductSumLayer):
""" Weightsum layer """
......@@ -254,8 +261,8 @@ class TrackSum(ProductSumLayer):
class NormalLeaf(SuprLayer):
""" NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
mu0: torch.tensor = 0., nu0: torch.tensor = 0., torch.tensor: float = 0., beta0: torch.tensor = 0.):
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0.,
nu0: torch.tensor = 0., alpha0: torch.tensor = 0., beta0: torch.tensor = 0.):
super().__init__()
# Dimensions
self.T, self.V, self.C = tracks, variables, channels
......@@ -292,7 +299,8 @@ class NormalLeaf(SuprLayer):
self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * mu_update
# Standard deviation
sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3)
sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * sig_update
# Reset accumulators
......@@ -305,7 +313,8 @@ class NormalLeaf(SuprLayer):
mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0])
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon))
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(
torch.clamp(sig_marginalize, self.epsilon))
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
......@@ -322,9 +331,11 @@ class NormalLeaf(SuprLayer):
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float()
torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(
x_valid).float()
return self.z
class BernoulliLeaf(SuprLayer):
""" BernoulliLeaf layer """
......@@ -386,7 +397,7 @@ class BernoulliLeaf(SuprLayer):
torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float()
return self.z
# TODO: This is not tested properly.
class CategoricalLeaf(SuprLayer):
""" CategoricalLeaf layer """
......@@ -419,7 +430,8 @@ class CategoricalLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1))
p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
self.n + self.D * (self.alpha0 - 1))
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update
# Reset accumulators
......@@ -449,4 +461,3 @@ class CategoricalLeaf(SuprLayer):
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float()
return self.z
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment