diff --git a/supr/layers.py b/supr/layers.py index b1d6cf99dabceed66a1cbcf1ff83833885d16170..2a1274811523f3f9a9b127f086dec59988198030 100644 --- a/supr/layers.py +++ b/supr/layers.py @@ -287,7 +287,7 @@ class NormalLeaf(SuprLayer): self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C)) def em_batch(self): - self.z_acc.data += torch.sum(self.z.grad, dim=0) + self.z_acc.data += torch.clamp(torch.sum(self.z.grad, dim=0), self.epsilon) self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0) @@ -295,12 +295,15 @@ class NormalLeaf(SuprLayer): # Sum of weights sum_z = torch.clamp(self.z_acc, self.epsilon) # Mean - mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n) + # mu_update = (self.nu0 * self.mu0 + self.n * (self.z_x_acc / sum_z)) / (self.nu0 + self.n) + mu_update = (self.nu0 * self.mu0 + self.z_acc * (self.z_x_acc / sum_z)) / (self.nu0 + self.z_acc) self.mu.data *= 1. - learning_rate self.mu.data += learning_rate * mu_update # Standard deviation - sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * ( - self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3) + # sig_update = (self.n * (self.z_x_sq_acc / sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * ( + # self.mu0 - self.mu) ** 2) / (self.n + 2 * self.alpha0 + 3) + sig_update = (self.z_x_sq_acc - self.z_acc * self.mu ** 2 + 2 * self.beta0 + self.nu0 * ( + self.mu0 - self.mu) ** 2) / (self.z_acc + 2 * self.alpha0 + 3) self.sig.data *= 1 - learning_rate self.sig.data += learning_rate * sig_update # Reset accumulators @@ -367,7 +370,8 @@ class BernoulliLeaf(SuprLayer): def em_update(self, learning_rate: float = 1.): # Probability sum_z = torch.clamp(self.z_acc, self.epsilon) - p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2) + # p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2) + p_update = (self.z_x_acc + self.alpha0 - 1) / (self.z_acc + self.alpha0 + self.beta0 - 2) self.p.data *= 1. - learning_rate self.p.data += learning_rate * p_update # Reset accumulators @@ -430,15 +434,16 @@ class CategoricalLeaf(SuprLayer): def em_update(self, learning_rate: float = 1.): # Probability sum_z = torch.clamp(self.z_acc, self.epsilon) - p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / ( - self.n + self.D * (self.alpha0 - 1)) + # p_update = (self.n * self.z_x_acc / sum_z[:, :, :, None] + self.alpha0 - 1) / ( + # self.n + self.D * (self.alpha0 - 1)) + p_update = (self.z_x_acc + self.alpha0 - 1) / ( + self.z_acc[:,:,:,None] + self.D * (self.alpha0 - 1)) self.p.data *= 1. - learning_rate self.p.data += learning_rate * p_update # Reset accumulators self.z_acc.zero_() self.z_x_acc.zero_() - # XXX Implement this def sample(self, track: int, channel_per_variable: torch.Tensor): p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize], :] r = torch.empty_like(self.x[0]) @@ -460,4 +465,4 @@ class CategoricalLeaf(SuprLayer): # Evaluate log probability self.z.data[:, :, ~self.marginalize, :] = \ torch.distributions.Categorical(probs=p_valid).log_prob(x_valid).float() - return self.z + return self.z \ No newline at end of file