diff --git a/supr/layers.py b/supr/layers.py index 35938a3d94a8cff5fb5a7dcb10ad925da911946b..dde2fb7f18743d5cec576269c06f8ae8b910ea3f 100644 --- a/supr/layers.py +++ b/supr/layers.py @@ -7,9 +7,10 @@ from typing import List # Data: -# N x V -# └───│── N: Data points -# └── V: Variables +# N x V x D +# └───│──│─ N: Data points +# └──│─ V: Variables +# └─ D: Dimensions # # Probability: # N x T x V x C @@ -23,6 +24,7 @@ class Supr(nn.Module): super().__init__() def sample(self): + pass class SuprLayer(nn.Module): @@ -253,10 +255,15 @@ class TrackSum(ProductSumLayer): class NormalLeaf(SuprLayer): """ NormalLeaf layer """ - def __init__(self, tracks: int, variables: int, channels: int): + def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, + mu0: float = 0., nu0: float = 0., alpha0: float = 0., beta0: float = 0.): super().__init__() # Dimensions self.T, self.V, self.C = tracks, variables, channels + # Number of data points + self.n = n + # Prior + self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0 # Parametes # self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C)) # self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1))) @@ -279,13 +286,16 @@ class NormalLeaf(SuprLayer): self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0) def em_update(self, learning_rate: float = 1.): - # Mean + # Sum of weights sum_z = torch.clamp(self.z_acc, self.epsilon) + # Mean + mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n) self.mu.data *= 1. - learning_rate - self.mu.data += learning_rate * self.z_x_acc / sum_z + self.mu.data += learning_rate * mu_update # Standard deviation + sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3) self.sig.data *= 1 - learning_rate - self.sig.data += learning_rate * torch.sqrt(torch.clamp(self.z_x_sq_acc / sum_z - self.mu ** 2, self.epsilon + 0.01)) + self.sig.data += learning_rate * sig_update # Reset accumulators self.z_acc.zero_() self.z_x_acc.zero_() @@ -296,7 +306,7 @@ class NormalLeaf(SuprLayer): mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]] sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]] r = torch.empty_like(self.x[0]) - r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * sig_marginalize + r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon)) r[~self.marginalize] = self.x[0][~self.marginalize] return r @@ -313,16 +323,21 @@ class NormalLeaf(SuprLayer): x_valid = self.x[:, None, ~self.marginalize, None] # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ - torch.distributions.Normal(mu_valid, sig_valid).log_prob(x_valid).float() + torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float() return self.z class BernoulliLeaf(SuprLayer): """ BernoulliLeaf layer """ - def __init__(self, tracks: int, variables: int, channels: int): + def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, + alpha0: float = 1., beta0: float = 1.): super().__init__() # Dimensions self.T, self.V, self.C = tracks, variables, channels + # Number of data points + self.n = n + # Prior + self.alpha0, self.beta0 = alpha0, beta0 # Parametes self.p = nn.Parameter(torch.rand(self.T, self.V, self.C)) # Which variables to marginalized @@ -342,8 +357,9 @@ class BernoulliLeaf(SuprLayer): def em_update(self, learning_rate: float = 1.): # Probability sum_z = torch.clamp(self.z_acc, self.epsilon) + p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2) self.p.data *= 1. - learning_rate - self.p.data += learning_rate * self.z_x_acc / sum_z + self.p.data += learning_rate * p_update # Reset accumulators self.z_acc.zero_() self.z_x_acc.zero_() @@ -368,6 +384,70 @@ class BernoulliLeaf(SuprLayer): x_valid = self.x[:, None, ~self.marginalize, None] # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ - p_valid*(x_valid==1) + (1-p_valid)*(x_valid==0) + torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float() + return self.z + +# TODO: This is not tested properly. +class CategoricalLeaf(SuprLayer): + """ CategoricalLeaf layer """ + + def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1, + alpha0: float = 1.): + super().__init__() + # Dimensions + self.T, self.V, self.C, self.D = tracks, variables, channels, dimensions + # Number of data points + self.n = n + # Prior + self.alpha0 = alpha0 + # Parametes + self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D)) + # Which variables to marginalized + self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) + # Input + self.register_buffer('x', torch.Tensor()) + # Output + self.register_buffer('z', torch.Tensor()) + # EM accumulator + self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C)) + self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C, self.D)) + + def em_batch(self): + self.z_acc.data += torch.sum(self.z.grad, dim=0) + x_onehot = torch.eye(self.D, dtype=bool)[self.x] + self.z_x_acc.data += torch.sum(self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0) + + def em_update(self, learning_rate: float = 1.): + # Probability + sum_z = torch.clamp(self.z_acc, self.epsilon) + p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1)) + self.p.data *= 1. - learning_rate + self.p.data += learning_rate * p_update + # Reset accumulators + self.z_acc.zero_() + self.z_x_acc.zero_() + + # XXX Implement this + def sample(self, track: int, channel_per_variable: torch.Tensor): + p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :] + r = torch.empty_like(self.x[0]) + r_sample = torch.distributions.Categorical(probs=p_marginalize).sample() + r[self.marginalize] = r_sample + r[~self.marginalize] = self.x[0][~self.marginalize] + return r + + def forward(self, x: torch.Tensor): + # Get shape + batch_size = x.shape[0] + # Store the data + self.x = x + # Compute the probability + self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) + # Get non-marginalized parameters and data + p_valid = self.p[None, :, ~self.marginalize, :, :] + x_valid = self.x[:, None, ~self.marginalize, None] + # Evaluate log probability + self.z.data[:, :, ~self.marginalize, :] = \ + torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() return self.z