Skip to content
Snippets Groups Projects
Commit 48793abd authored by mnsc's avatar mnsc
Browse files

map estimate

parent 4cc5c3c5
Branches
No related tags found
No related merge requests found
...@@ -7,9 +7,10 @@ from typing import List ...@@ -7,9 +7,10 @@ from typing import List
# Data: # Data:
# N x V # N x V x D
# └───│── N: Data points # └───│──│─ N: Data points
# └── V: Variables # └──│─ V: Variables
# └─ D: Dimensions
# #
# Probability: # Probability:
# N x T x V x C # N x T x V x C
...@@ -23,6 +24,7 @@ class Supr(nn.Module): ...@@ -23,6 +24,7 @@ class Supr(nn.Module):
super().__init__() super().__init__()
def sample(self): def sample(self):
pass
class SuprLayer(nn.Module): class SuprLayer(nn.Module):
...@@ -253,10 +255,15 @@ class TrackSum(ProductSumLayer): ...@@ -253,10 +255,15 @@ class TrackSum(ProductSumLayer):
class NormalLeaf(SuprLayer): class NormalLeaf(SuprLayer):
""" NormalLeaf layer """ """ NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int): def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
mu0: float = 0., nu0: float = 0., alpha0: float = 0., beta0: float = 0.):
super().__init__() super().__init__()
# Dimensions # Dimensions
self.T, self.V, self.C = tracks, variables, channels self.T, self.V, self.C = tracks, variables, channels
# Number of data points
self.n = n
# Prior
self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0
# Parametes # Parametes
# self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C)) # self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
# self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1))) # self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
...@@ -279,13 +286,16 @@ class NormalLeaf(SuprLayer): ...@@ -279,13 +286,16 @@ class NormalLeaf(SuprLayer):
self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0) self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
# Mean # Sum of weights
sum_z = torch.clamp(self.z_acc, self.epsilon) sum_z = torch.clamp(self.z_acc, self.epsilon)
# Mean
mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
self.mu.data *= 1. - learning_rate self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * self.z_x_acc / sum_z self.mu.data += learning_rate * mu_update
# Standard deviation # Standard deviation
sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3)
self.sig.data *= 1 - learning_rate self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * torch.sqrt(torch.clamp(self.z_x_sq_acc / sum_z - self.mu ** 2, self.epsilon + 0.01)) self.sig.data += learning_rate * sig_update
# Reset accumulators # Reset accumulators
self.z_acc.zero_() self.z_acc.zero_()
self.z_x_acc.zero_() self.z_x_acc.zero_()
...@@ -296,7 +306,7 @@ class NormalLeaf(SuprLayer): ...@@ -296,7 +306,7 @@ class NormalLeaf(SuprLayer):
mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]] mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]] sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0]) r = torch.empty_like(self.x[0])
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * sig_marginalize r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon))
r[~self.marginalize] = self.x[0][~self.marginalize] r[~self.marginalize] = self.x[0][~self.marginalize]
return r return r
...@@ -313,16 +323,21 @@ class NormalLeaf(SuprLayer): ...@@ -313,16 +323,21 @@ class NormalLeaf(SuprLayer):
x_valid = self.x[:, None, ~self.marginalize, None] x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability # Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \ self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Normal(mu_valid, sig_valid).log_prob(x_valid).float() torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float()
return self.z return self.z
class BernoulliLeaf(SuprLayer): class BernoulliLeaf(SuprLayer):
""" BernoulliLeaf layer """ """ BernoulliLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int): def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
alpha0: float = 1., beta0: float = 1.):
super().__init__() super().__init__()
# Dimensions # Dimensions
self.T, self.V, self.C = tracks, variables, channels self.T, self.V, self.C = tracks, variables, channels
# Number of data points
self.n = n
# Prior
self.alpha0, self.beta0 = alpha0, beta0
# Parametes # Parametes
self.p = nn.Parameter(torch.rand(self.T, self.V, self.C)) self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
# Which variables to marginalized # Which variables to marginalized
...@@ -342,8 +357,9 @@ class BernoulliLeaf(SuprLayer): ...@@ -342,8 +357,9 @@ class BernoulliLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
# Probability # Probability
sum_z = torch.clamp(self.z_acc, self.epsilon) sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
self.p.data *= 1. - learning_rate self.p.data *= 1. - learning_rate
self.p.data += learning_rate * self.z_x_acc / sum_z self.p.data += learning_rate * p_update
# Reset accumulators # Reset accumulators
self.z_acc.zero_() self.z_acc.zero_()
self.z_x_acc.zero_() self.z_x_acc.zero_()
...@@ -368,6 +384,70 @@ class BernoulliLeaf(SuprLayer): ...@@ -368,6 +384,70 @@ class BernoulliLeaf(SuprLayer):
x_valid = self.x[:, None, ~self.marginalize, None] x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability # Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \ self.z.data[:, :, ~self.marginalize, :] = \
p_valid*(x_valid==1) + (1-p_valid)*(x_valid==0) torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float()
return self.z
# TODO: This is not tested properly.
class CategoricalLeaf(SuprLayer):
""" CategoricalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1,
alpha0: float = 1.):
super().__init__()
# Dimensions
self.T, self.V, self.C, self.D = tracks, variables, channels, dimensions
# Number of data points
self.n = n
# Prior
self.alpha0 = alpha0
# Parametes
self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D))
# Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
# Input
self.register_buffer('x', torch.Tensor())
# Output
self.register_buffer('z', torch.Tensor())
# EM accumulator
self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C, self.D))
def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0)
x_onehot = torch.eye(self.D, dtype=bool)[self.x]
self.z_x_acc.data += torch.sum(self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0)
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1))
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
# XXX Implement this
def sample(self, track: int, channel_per_variable: torch.Tensor):
p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :]
r = torch.empty_like(self.x[0])
r_sample = torch.distributions.Categorical(probs=p_marginalize).sample()
r[self.marginalize] = r_sample
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
def forward(self, x: torch.Tensor):
# Get shape
batch_size = x.shape[0]
# Store the data
self.x = x
# Compute the probability
self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
# Get non-marginalized parameters and data
p_valid = self.p[None, :, ~self.marginalize, :, :]
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float()
return self.z return self.z
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment