Skip to content
Snippets Groups Projects
Commit 8b606b3e authored by mnsc's avatar mnsc
Browse files

pep8

parent 0c2aae1b
No related branches found
No related tags found
No related merge requests found
...@@ -81,7 +81,6 @@ class Sequential(nn.Sequential): ...@@ -81,7 +81,6 @@ class Sequential(nn.Sequential):
return value return value
class Parallel(SuprLayer): class Parallel(SuprLayer):
def __init__(self, nets: List[SuprLayer]): def __init__(self, nets: List[SuprLayer]):
super().__init__() super().__init__()
...@@ -113,7 +112,8 @@ class ScrambleTracks2d(SuprLayer): ...@@ -113,7 +112,8 @@ class ScrambleTracks2d(SuprLayer):
def __init__(self, tracks: int, variables: int, distance: float, dims: tuple): def __init__(self, tracks: int, variables: int, distance: float, dims: tuple):
super().__init__() super().__init__()
# Permutation for each track # Permutation for each track
perm = torch.stack([local_scramble_2d(distance, dims) for _ in range(tracks)]) perm = torch.stack([local_scramble_2d(distance, dims)
for _ in range(tracks)])
self.register_buffer('perm', perm) self.register_buffer('perm', perm)
def sample(self, track, channel_per_variable): def sample(self, track, channel_per_variable):
...@@ -146,7 +146,8 @@ class ProductSumLayer(SuprLayer): ...@@ -146,7 +146,8 @@ class ProductSumLayer(SuprLayer):
super().__init__() super().__init__()
# Parameters # Parameters
self.weights = nn.Parameter(torch.rand(*weight_shape)) self.weights = nn.Parameter(torch.rand(*weight_shape))
self.weights.data /= torch.clamp(self.weights.sum(dim=normalize_dims, keepdim=True), self.epsilon) self.weights.data /= torch.clamp(self.weights.sum(
dim=normalize_dims, keepdim=True), self.epsilon)
# Normalize dimensions # Normalize dimensions
self.normalize_dims = normalize_dims self.normalize_dims = normalize_dims
# EM accumulator # EM accumulator
...@@ -157,7 +158,8 @@ class ProductSumLayer(SuprLayer): ...@@ -157,7 +158,8 @@ class ProductSumLayer(SuprLayer):
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
weights_grad = torch.clamp(self.weights_acc, self.epsilon) weights_grad = torch.clamp(self.weights_acc, self.epsilon)
weights_grad /= torch.clamp(weights_grad.sum(dim=self.normalize_dims, keepdim=True), self.epsilon) weights_grad /= torch.clamp(weights_grad.sum(
dim=self.normalize_dims, keepdim=True), self.epsilon)
if learning_rate < 1.: if learning_rate < 1.:
self.weights.data *= 1. - learning_rate self.weights.data *= 1. - learning_rate
self.weights.data += learning_rate * weights_grad self.weights.data += learning_rate * weights_grad
...@@ -188,7 +190,6 @@ class Einsum(ProductSumLayer): ...@@ -188,7 +190,6 @@ class Einsum(ProductSumLayer):
self.x2_pad[-1] = True self.x2_pad[-1] = True
else: else:
self.pad = False self.pad = False
# TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks
def sample(self, track: int, channel_per_variable: torch.Tensor): def sample(self, track: int, channel_per_variable: torch.Tensor):
r = [] r = []
...@@ -218,9 +219,12 @@ class Einsum(ProductSumLayer): ...@@ -218,9 +219,12 @@ class Einsum(ProductSumLayer):
# Compute maximum # Compute maximum
a1, a2 = [torch.max(x, dim=3, keepdim=True)[0] for x in [x1, x2]] a1, a2 = [torch.max(x, dim=3, keepdim=True)[0] for x in [x1, x2]]
# Subtract maximum and compute exponential # Subtract maximum and compute exponential
exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon) for x, a in [(x1, a1), (x2, a2)]] exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon)
for x, a in [(x1, a1), (x2, a2)]]
# Compute the contraction # Compute the contraction
y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights)) y = a1 + a2 + \
torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc',
exa1, exa2, self.weights))
return y return y
...@@ -233,7 +237,8 @@ class Weightsum(ProductSumLayer): ...@@ -233,7 +237,8 @@ class Weightsum(ProductSumLayer):
super().__init__((tracks, channels), (0, 1)) super().__init__((tracks, channels), (0, 1))
def sample(self): def sample(self):
prob = self.weights * torch.exp(self.x_sum[0] - torch.max(self.x_sum[0])) prob = self.weights * \
torch.exp(self.x_sum[0] - torch.max(self.x_sum[0]))
s = discrete_rand(prob)[0] s = discrete_rand(prob)[0]
return s[0], torch.full((self.variables,), s[1]).to(self.weights.device) return s[0], torch.full((self.variables,), s[1]).to(self.weights.device)
...@@ -262,7 +267,8 @@ class TrackSum(ProductSumLayer): ...@@ -262,7 +267,8 @@ class TrackSum(ProductSumLayer):
super().__init__((tracks, channels), (0,)) super().__init__((tracks, channels), (0,))
def sample(self, track: int, channel_per_variable: torch.Tensor): def sample(self, track: int, channel_per_variable: torch.Tensor):
prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0]) prob = self.weights[:, None] * \
torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0])
s = discrete_rand(prob)[0] s = discrete_rand(prob)[0]
return s[0], channel_per_variable return s[0], channel_per_variable
...@@ -282,15 +288,18 @@ class TrackSum(ProductSumLayer): ...@@ -282,15 +288,18 @@ class TrackSum(ProductSumLayer):
y = y[:, None] y = y[:, None]
return y return y
class SuprLeaf(SuprLayer): class SuprLeaf(SuprLayer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
class NormalLeaf(SuprLeaf): class NormalLeaf(SuprLeaf):
""" NormalLeaf layer """ """ NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, mu0: torch.tensor = 0., def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
nu0: torch.tensor = 0., alpha0: torch.tensor = 0., beta0: torch.tensor = 0.): mu0: torch.tensor = 0., nu0: torch.tensor = 0.,
alpha0: torch.tensor = 0., beta0: torch.tensor = 0.):
super().__init__() super().__init__()
# Dimensions # Dimensions
self.T, self.V, self.C = tracks, variables, channels self.T, self.V, self.C = tracks, variables, channels
...@@ -299,12 +308,11 @@ class NormalLeaf(SuprLeaf): ...@@ -299,12 +308,11 @@ class NormalLeaf(SuprLeaf):
# Prior # Prior
self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0 self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0
# Parametes # Parametes
# self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
# self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C)) self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))
self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5) self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5)
# Which variables to marginalized # Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) self.register_buffer('marginalize', torch.zeros(
variables, dtype=torch.bool))
# Input # Input
self.register_buffer('x', torch.Tensor()) self.register_buffer('x', torch.Tensor())
# Output # Output
...@@ -315,23 +323,26 @@ class NormalLeaf(SuprLeaf): ...@@ -315,23 +323,26 @@ class NormalLeaf(SuprLeaf):
self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C)) self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C))
def em_batch(self): def em_batch(self):
self.z_acc.data += torch.clamp(torch.sum(self.z.grad, dim=0), self.epsilon) self.z_acc.data += torch.clamp(torch.sum(self.z.grad,
self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) dim=0), self.epsilon)
self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0) self.z_x_acc.data += torch.sum(self.z.grad *
self.x[:, None, :, None], dim=0)
self.z_x_sq_acc.data += torch.sum(self.z.grad *
self.x[:, None, :, None] ** 2, dim=0)
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
# Sum of weights # Sum of weights
sum_z = torch.clamp(self.z_acc, self.epsilon) sum_z = torch.clamp(self.z_acc, self.epsilon)
# Mean # Mean
# mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n) mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)
mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)) / (self.nu0 + self.z_acc) ) / (self.nu0 + self.z_acc)
self.mu.data *= 1. - learning_rate self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * mu_update self.mu.data += learning_rate * mu_update
# Standard deviation # Standard deviation
# sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * ( sig_update = (self.z_x_sq_acc -
# self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3) self.z_acc * self.mu ** 2 + 2 * self.beta0 +
sig_update = (self.z_x_sq_acc - self.z_acc * self.mu ** 2 + 2 * self.beta0 + self.nu0 * ( self.nu0 * (self.mu0 - self.mu) ** 2
self.mu0 - self.mu) ** 2) / (self.z_acc + 2 * self.alpha0 + 3) ) / (self.z_acc + 2 * self.alpha0 + 3)
self.sig.data *= 1 - learning_rate self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * sig_update self.sig.data += learning_rate * sig_update
# Reset accumulators # Reset accumulators
...@@ -341,8 +352,10 @@ class NormalLeaf(SuprLeaf): ...@@ -341,8 +352,10 @@ class NormalLeaf(SuprLeaf):
def sample(self, track: int, channel_per_variable: torch.Tensor): def sample(self, track: int, channel_per_variable: torch.Tensor):
variables_marginalize = torch.sum(self.marginalize).int() variables_marginalize = torch.sum(self.marginalize).int()
mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]] mu_marginalize = self.mu[track, self.marginalize,
sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]] channel_per_variable[self.marginalize]]
sig_marginalize = self.sig[track, self.marginalize,
channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0]) r = torch.empty_like(self.x[0])
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt( r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(
torch.clamp(sig_marginalize, self.epsilon)) torch.clamp(sig_marginalize, self.epsilon))
...@@ -353,7 +366,8 @@ class NormalLeaf(SuprLeaf): ...@@ -353,7 +366,8 @@ class NormalLeaf(SuprLeaf):
return (torch.clamp(self.z.grad, self.epsilon) * self.mu).sum([1, 3]) return (torch.clamp(self.z.grad, self.epsilon) * self.mu).sum([1, 3])
def var(self): def var(self):
return (torch.clamp(self.z.grad, self.epsilon) * (self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2 return (torch.clamp(self.z.grad, self.epsilon) *
(self.mu**2 + self.sig)).sum([1, 3]) - self.mean()**2
def forward(self, x: torch.Tensor, marginalize=None): def forward(self, x: torch.Tensor, marginalize=None):
# Get shape # Get shape
...@@ -364,7 +378,8 @@ class NormalLeaf(SuprLeaf): ...@@ -364,7 +378,8 @@ class NormalLeaf(SuprLeaf):
# Store the data # Store the data
self.x = x self.x = x
# Compute the probability # Compute the probability
self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) self.z = torch.zeros(batch_size, self.T, self.V,
self.C, requires_grad=True, device=x.device)
# Get non-marginalized parameters and data # Get non-marginalized parameters and data
mu_valid = self.mu[None, :, ~self.marginalize, :] mu_valid = self.mu[None, :, ~self.marginalize, :]
sig_valid = self.sig[None, :, ~self.marginalize, :] sig_valid = self.sig[None, :, ~self.marginalize, :]
...@@ -391,7 +406,8 @@ class BernoulliLeaf(SuprLeaf): ...@@ -391,7 +406,8 @@ class BernoulliLeaf(SuprLeaf):
# Parametes # Parametes
self.p = nn.Parameter(torch.rand(self.T, self.V, self.C)) self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
# Which variables to marginalized # Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) self.register_buffer('marginalize', torch.zeros(
variables, dtype=torch.bool))
# Input # Input
self.register_buffer('x', torch.Tensor()) self.register_buffer('x', torch.Tensor())
# Output # Output
...@@ -402,13 +418,13 @@ class BernoulliLeaf(SuprLeaf): ...@@ -402,13 +418,13 @@ class BernoulliLeaf(SuprLeaf):
def em_batch(self): def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0) self.z_acc.data += torch.sum(self.z.grad, dim=0)
self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) self.z_x_acc.data += torch.sum(self.z.grad *
self.x[:, None, :, None], dim=0)
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
# Probability # Probability
sum_z = torch.clamp(self.z_acc, self.epsilon) p_update = (self.z_x_acc + self.alpha0 - 1) / \
# p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2) (self.z_acc + self.alpha0 + self.beta0 - 2)
p_update = (self.z_x_acc + self.alpha0 - 1) / (self.z_acc + self.alpha0 + self.beta0 - 2)
self.p.data *= 1. - learning_rate self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update self.p.data += learning_rate * p_update
# Reset accumulators # Reset accumulators
...@@ -417,9 +433,11 @@ class BernoulliLeaf(SuprLeaf): ...@@ -417,9 +433,11 @@ class BernoulliLeaf(SuprLeaf):
def sample(self, track: int, channel_per_variable: torch.Tensor): def sample(self, track: int, channel_per_variable: torch.Tensor):
variables_marginalize = torch.sum(self.marginalize).int() variables_marginalize = torch.sum(self.marginalize).int()
p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize]] p_marginalize = self.p[track, self.marginalize,
channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0]) r = torch.empty_like(self.x[0])
r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float() r[self.marginalize] = (torch.rand(variables_marginalize).to(
self.x.device) < p_marginalize).float()
r[~self.marginalize] = self.x[0][~self.marginalize] r[~self.marginalize] = self.x[0][~self.marginalize]
return r return r
...@@ -429,13 +447,15 @@ class BernoulliLeaf(SuprLeaf): ...@@ -429,13 +447,15 @@ class BernoulliLeaf(SuprLeaf):
# Store the data # Store the data
self.x = x self.x = x
# Compute the probability # Compute the probability
self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) self.z = torch.zeros(batch_size, self.T, self.V,
self.C, requires_grad=True, device=x.device)
# Get non-marginalized parameters and data # Get non-marginalized parameters and data
p_valid = self.p[None, :, ~self.marginalize, :] p_valid = self.p[None, :, ~self.marginalize, :]
x_valid = self.x[:, None, ~self.marginalize, None] x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability # Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \ self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float() torch.distributions.Bernoulli(
probs=p_valid).log_prob(x_valid).float()
return self.z return self.z
...@@ -454,19 +474,22 @@ class CategoricalLeaf(SuprLeaf): ...@@ -454,19 +474,22 @@ class CategoricalLeaf(SuprLeaf):
# Parametes # Parametes
self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D)) self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D))
# Which variables to marginalized # Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) self.register_buffer('marginalize', torch.zeros(
variables, dtype=torch.bool))
# Input # Input
self.register_buffer('x', torch.Tensor()) self.register_buffer('x', torch.Tensor())
# Output # Output
self.register_buffer('z', torch.Tensor()) self.register_buffer('z', torch.Tensor())
# EM accumulator # EM accumulator
self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C)) self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C, self.D)) self.register_buffer('z_x_acc', torch.zeros(
self.T, self.V, self.C, self.D))
def em_batch(self): def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0) self.z_acc.data += torch.sum(self.z.grad, dim=0)
x_onehot = torch.eye(self.D, dtype=bool)[self.x] x_onehot = torch.eye(self.D, dtype=bool)[self.x]
self.z_x_acc.data += torch.sum(self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0) self.z_x_acc.data += torch.sum(
self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0)
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
# Probability # Probability
...@@ -482,9 +505,11 @@ class CategoricalLeaf(SuprLeaf): ...@@ -482,9 +505,11 @@ class CategoricalLeaf(SuprLeaf):
self.z_x_acc.zero_() self.z_x_acc.zero_()
def sample(self, track: int, channel_per_variable: torch.Tensor): def sample(self, track: int, channel_per_variable: torch.Tensor):
p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :] p_marginalize = self.p[track, self.marginalize,
channel_per_variable[self.marginalize], :]
r = torch.empty_like(self.x[0]) r = torch.empty_like(self.x[0])
r_sample = torch.distributions.Categorical(probs=p_marginalize).sample() r_sample = torch.distributions.Categorical(
probs=p_marginalize).sample()
r[self.marginalize] = r_sample r[self.marginalize] = r_sample
r[~self.marginalize] = self.x[0][~self.marginalize] r[~self.marginalize] = self.x[0][~self.marginalize]
return r return r
...@@ -495,11 +520,13 @@ class CategoricalLeaf(SuprLeaf): ...@@ -495,11 +520,13 @@ class CategoricalLeaf(SuprLeaf):
# Store the data # Store the data
self.x = x self.x = x
# Compute the probability # Compute the probability
self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) self.z = torch.zeros(batch_size, self.T, self.V,
self.C, requires_grad=True, device=x.device)
# Get non-marginalized parameters and data # Get non-marginalized parameters and data
p_valid = self.p[None, :, ~self.marginalize, :, :] p_valid = self.p[None, :, ~self.marginalize, :, :]
x_valid = self.x[:, None, ~self.marginalize, None] x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability # Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \ self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() torch.distributions.Categorical(
probs=p_valid).log_prob(x_valid).float()
return self.z return self.z
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment