Skip to content
Snippets Groups Projects
Commit 48793abd authored by mnsc's avatar mnsc
Browse files

map estimate

parent 4cc5c3c5
No related branches found
No related tags found
No related merge requests found
......@@ -7,9 +7,10 @@ from typing import List
# Data:
# N x V
# └───│── N: Data points
# └── V: Variables
# N x V x D
# └───│──│─ N: Data points
# └──│─ V: Variables
# └─ D: Dimensions
#
# Probability:
# N x T x V x C
......@@ -23,6 +24,7 @@ class Supr(nn.Module):
super().__init__()
def sample(self):
pass
class SuprLayer(nn.Module):
......@@ -253,10 +255,15 @@ class TrackSum(ProductSumLayer):
class NormalLeaf(SuprLayer):
""" NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int):
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
mu0: float = 0., nu0: float = 0., alpha0: float = 0., beta0: float = 0.):
super().__init__()
# Dimensions
self.T, self.V, self.C = tracks, variables, channels
# Number of data points
self.n = n
# Prior
self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0
# Parametes
# self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
# self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
......@@ -279,13 +286,16 @@ class NormalLeaf(SuprLayer):
self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
def em_update(self, learning_rate: float = 1.):
# Mean
# Sum of weights
sum_z = torch.clamp(self.z_acc, self.epsilon)
# Mean
mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * self.z_x_acc / sum_z
self.mu.data += learning_rate * mu_update
# Standard deviation
sig_update = (self.n*(self.z_x_sq_acc / sum_z - self.mu ** 2) + 2*self.beta0 + self.nu0*(self.mu0-self.mu)**2) / (self.n + 2*self.alpha0 + 3)
self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * torch.sqrt(torch.clamp(self.z_x_sq_acc / sum_z - self.mu ** 2, self.epsilon + 0.01))
self.sig.data += learning_rate * sig_update
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
......@@ -296,7 +306,7 @@ class NormalLeaf(SuprLayer):
mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0])
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * sig_marginalize
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * torch.sqrt(torch.clamp(sig_marginalize, self.epsilon))
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
......@@ -313,16 +323,21 @@ class NormalLeaf(SuprLayer):
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Normal(mu_valid, sig_valid).log_prob(x_valid).float()
torch.distributions.Normal(mu_valid, torch.sqrt(torch.clamp(sig_valid, self.epsilon))).log_prob(x_valid).float()
return self.z
class BernoulliLeaf(SuprLayer):
""" BernoulliLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int):
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
alpha0: float = 1., beta0: float = 1.):
super().__init__()
# Dimensions
self.T, self.V, self.C = tracks, variables, channels
# Number of data points
self.n = n
# Prior
self.alpha0, self.beta0 = alpha0, beta0
# Parametes
self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
# Which variables to marginalized
......@@ -342,8 +357,9 @@ class BernoulliLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * self.z_x_acc / sum_z
self.p.data += learning_rate * p_update
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
......@@ -368,6 +384,70 @@ class BernoulliLeaf(SuprLayer):
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
p_valid*(x_valid==1) + (1-p_valid)*(x_valid==0)
torch.distributions.Bernoulli(probs=p_valid).log_prob(x_valid).float()
return self.z
# TODO: This is not tested properly.
class CategoricalLeaf(SuprLayer):
""" CategoricalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, dimensions: int, n: int = 1,
alpha0: float = 1.):
super().__init__()
# Dimensions
self.T, self.V, self.C, self.D = tracks, variables, channels, dimensions
# Number of data points
self.n = n
# Prior
self.alpha0 = alpha0
# Parametes
self.p = nn.Parameter(torch.rand(self.T, self.V, self.C, self.D))
# Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
# Input
self.register_buffer('x', torch.Tensor())
# Output
self.register_buffer('z', torch.Tensor())
# EM accumulator
self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C, self.D))
def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0)
x_onehot = torch.eye(self.D, dtype=bool)[self.x]
self.z_x_acc.data += torch.sum(self.z.grad[:, :, :, :, None] * x_onehot[:, None, :, None, :], dim=0)
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z[:,:,:,None] + self.alpha0 - 1) / (self.n + self.D*(self.alpha0 - 1))
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
# XXX Implement this
def sample(self, track: int, channel_per_variable: torch.Tensor):
p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :]
r = torch.empty_like(self.x[0])
r_sample = torch.distributions.Categorical(probs=p_marginalize).sample()
r[self.marginalize] = r_sample
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
def forward(self, x: torch.Tensor):
# Get shape
batch_size = x.shape[0]
# Store the data
self.x = x
# Compute the probability
self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
# Get non-marginalized parameters and data
p_valid = self.p[None, :, ~self.marginalize, :, :]
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float()
return self.z
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment