Skip to content
Snippets Groups Projects
Commit 02e1103c authored by mnsc's avatar mnsc
Browse files

fixed and stabilized updates)

parent de62975e
Branches
No related tags found
No related merge requests found
......@@ -287,7 +287,7 @@ class NormalLeaf(SuprLayer):
self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C))
def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0)
self.z_acc.data += torch.clamp(torch.sum(self.z.grad, dim=0), self.epsilon)
self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
......@@ -295,12 +295,15 @@ class NormalLeaf(SuprLayer):
# Sum of weights
sum_z = torch.clamp(self.z_acc, self.epsilon)
# Mean
mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
# mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n)
mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)) / (self.nu0 + self.z_acc)
self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * mu_update
# Standard deviation
sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
# sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
# self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3)
sig_update = (self.z_x_sq_acc - self.z_acc * self.mu ** 2 + 2 * self.beta0 + self.nu0 * (
self.mu0 - self.mu) ** 2) / (self.z_acc + 2 * self.alpha0 + 3)
self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * sig_update
# Reset accumulators
......@@ -367,7 +370,8 @@ class BernoulliLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
# p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
p_update = (self.z_x_acc + self.alpha0 - 1) / (self.z_acc + self.alpha0 + self.beta0 - 2)
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update
# Reset accumulators
......@@ -430,15 +434,16 @@ class CategoricalLeaf(SuprLayer):
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
self.n + self.D * (self.alpha0 - 1))
# p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / (
# self.n + self.D * (self.alpha0 - 1))
p_update = (self.z_x_acc + self.alpha0 - 1) / (
self.z_acc[:,:,:,None] + self.D * (self.alpha0 - 1))
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
# XXX Implement this
def sample(self, track: int, channel_per_variable: torch.Tensor):
p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :]
r = torch.empty_like(self.x[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment