Skip to content
Snippets Groups Projects
Commit 02e1103c authored by mnsc's avatar mnsc
Browse files

fixed and stabilized updates)

parent de62975e
Branches
No related tags found
No related merge requests found
...@@ -287,7 +287,7 @@ class NormalLeaf(SuprLayer): ...@@ -287,7 +287,7 @@ class NormalLeaf(SuprLayer):
self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C)) self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C))
def em_batch(self): def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0) self.z_acc.data += torch.clamp(torch.sum(self.z.grad, dim=0), self.epsilon)
self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0) self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
...@@ -295,12 +295,15 @@ class NormalLeaf(SuprLayer): ...@@ -295,12 +295,15 @@ class NormalLeaf(SuprLayer):
# Sum of weights # Sum of weights
sum_z = torch.clamp(self.z_acc, self.epsilon) sum_z = torch.clamp(self.z_acc, self.epsilon)
# Mean # Mean
mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n) # mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)) / (self.nu0 + self.z_acc)
self.mu.data *= 1. - learning_rate self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * mu_update self.mu.data += learning_rate * mu_update
# Standard deviation # Standard deviation
sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * ( # sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3) # self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
sig_update = (self.z_x_sq_acc - self.z_acc * self.mu ** 2 + 2 * self.beta0 + self.nu0 * (
self.mu0 - self.mu) ** 2) / (self.z_acc + 2 * self.alpha0 + 3)
self.sig.data *= 1 - learning_rate self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * sig_update self.sig.data += learning_rate * sig_update
# Reset accumulators # Reset accumulators
...@@ -367,7 +370,8 @@ class BernoulliLeaf(SuprLayer): ...@@ -367,7 +370,8 @@ class BernoulliLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
# Probability # Probability
sum_z = torch.clamp(self.z_acc, self.epsilon) sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2) # p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
p_update = (self.z_x_acc + self.alpha0 - 1) / (self.z_acc + self.alpha0 + self.beta0 - 2)
self.p.data *= 1. - learning_rate self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update self.p.data += learning_rate * p_update
# Reset accumulators # Reset accumulators
...@@ -430,15 +434,16 @@ class CategoricalLeaf(SuprLayer): ...@@ -430,15 +434,16 @@ class CategoricalLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.): def em_update(self, learning_rate: float = 1.):
# Probability # Probability
sum_z = torch.clamp(self.z_acc, self.epsilon) sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / ( # p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
self.n + self.D * (self.alpha0 - 1)) # self.n + self.D * (self.alpha0 - 1))
p_update = (self.z_x_acc + self.alpha0 - 1) / (
self.z_acc[:,:,:,None] + self.D * (self.alpha0 - 1))
self.p.data *= 1. - learning_rate self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update self.p.data += learning_rate * p_update
# Reset accumulators # Reset accumulators
self.z_acc.zero_() self.z_acc.zero_()
self.z_x_acc.zero_() self.z_x_acc.zero_()
# XXX Implement this
def sample(self, track: int, channel_per_variable: torch.Tensor): def sample(self, track: int, channel_per_variable: torch.Tensor):
p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :] p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :]
r = torch.empty_like(self.x[0]) r = torch.empty_like(self.x[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment