Skip to content
Snippets Groups Projects
Commit de62975e authored by mnsc's avatar mnsc
Browse files

pep

parent 642eff87
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ import math
from supr.utils import discrete_rand, local_scramble_2d
from typing import List
# Data:
# N x V x D
# └───│──│─ N: Data points
......@@ -38,8 +39,9 @@ class SuprLayer(nn.Module):
def em_update(self, *args, **kwargs):
pass
class Sequential(nn.Sequential):
def __init__(self, *args: object) -> object:
def __init__(self, *args: object):
super().__init__(*args)
def em_batch_update(self):
......@@ -54,6 +56,7 @@ class Sequential(nn.Sequential):
value = module.sample(*value)
return value
class Parallel(SuprLayer):
def __init__(self, nets: List[SuprLayer]):
super().__init__()
......@@ -78,6 +81,7 @@ class ScrambleTracks(SuprLayer):
def forward(self, x):
return x[:, torch.arange(x.shape[1])[:, None], self.perm, :]
class ScrambleTracks2d(SuprLayer):
""" Scrambles the variables in each track """
......@@ -99,6 +103,7 @@ class VariablesProduct(SuprLayer):
def __init(self):
super().__init__()
self.variables = None
def sample(self, track, channel_per_variable):
return track, torch.full((self.variables,), channel_per_variable[0]).to(channel_per_variable.device)
......@@ -111,6 +116,7 @@ class VariablesProduct(SuprLayer):
class ProductSumLayer(SuprLayer):
""" Base class for product-sum layers """
def __init__(self, weight_shape, normalize_dims):
super().__init__()
# Parameters
......@@ -192,6 +198,7 @@ class Einsum(ProductSumLayer):
y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights))
return y
class Weightsum(ProductSumLayer):
""" Weightsum layer """
......@@ -254,8 +261,8 @@ class TrackSum(ProductSumLayer):
class NormalLeaf(SuprLayer):
""" NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
mu0: torch.tensor = 0., nu0: torch.tensor = 0., torch.tensor: float = 0., beta0: torch.tensor = 0.):
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0.,
nu0: torch.tensor = 0., alpha0: torch.tensor = 0., beta0: torch.tensor = 0.):
super().__init__()
# Dimensions
self.T, self.V, self.C = tracks, variables, channels
......@@ -292,7 +299,8 @@ class NormalLeaf(SuprLayer):
self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * mu_update
# Standard deviation
sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3)
sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * sig_update
# Reset accumulators
......@@ -305,7 +313,8 @@ class NormalLeaf(SuprLayer):
mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0])
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon))
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(
torch.clamp(sig_marginalize, self.epsilon))
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
......@@ -322,9 +331,11 @@ class NormalLeaf(SuprLayer):
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float()
torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(
x_valid).float()
return self.z
class BernoulliLeaf(SuprLayer):
""" BernoulliLeaf layer """
......@@ -386,7 +397,7 @@ class BernoulliLeaf(SuprLayer):
torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float()
return self.z
# TODO: This is not tested properly.
class CategoricalLeaf(SuprLayer):
""" CategoricalLeaf layer """
......@@ -419,7 +430,8 @@ class CategoricalLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1))
p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
self.n + self.D * (self.alpha0 - 1))
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update
# Reset accumulators
......@@ -449,4 +461,3 @@ class CategoricalLeaf(SuprLayer):
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float()
return self.z
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment