Skip to content
Snippets Groups Projects
Commit 642eff87 authored by mnsc's avatar mnsc
Browse files

pep

parent 21371596
No related branches found
No related tags found
No related merge requests found
# %% Import libraries # %% Import libraries
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
import supr import supr
...@@ -26,8 +25,8 @@ variables = 2 ...@@ -26,8 +25,8 @@ variables = 2
channels = 20 channels = 20
# Priors for variance of x and y # Priors for variance of x and y
alpha0 = torch.tensor([[[0.2], [0.1]]]) alpha0 = torch.tensor([[[0.5], [0.1]]])
beta0 = torch.tensor([[[0.2], [0.1]]]) beta0 = torch.tensor([[[0.5], [0.1]]])
model = supr.Sequential( model = supr.Sequential(
supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0), supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0., nu0=0, alpha0=alpha0, beta0=beta0),
...@@ -40,9 +39,9 @@ epochs = 20 ...@@ -40,9 +39,9 @@ epochs = 20
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
model[0].marginalize = torch.zeros(variables, dtype=torch.bool) model[0].marginalize = torch.zeros(variables, dtype=torch.bool)
loss = model(X).sum() logp = model(X).sum()
print(f"Loss = {loss}") print(f"Log-posterior ∝ {logp:.2f} ")
loss.backward() logp.backward()
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
...@@ -65,4 +64,3 @@ for epoch in range(epochs): ...@@ -65,4 +64,3 @@ for epoch in range(epochs):
plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1) plt.imshow(torch.log(p_predictive), extent=extent, aspect='auto', origin='lower', vmin=-3, vmax=1)
plt.plot(x, y, '.') plt.plot(x, y, '.')
drawnow() drawnow()
...@@ -5,7 +5,6 @@ import math ...@@ -5,7 +5,6 @@ import math
from supr.utils import discrete_rand, local_scramble_2d from supr.utils import discrete_rand, local_scramble_2d
from typing import List from typing import List
# Data: # Data:
# N x V x D # N x V x D
# └───│──│─ N: Data points # └───│──│─ N: Data points
...@@ -40,7 +39,7 @@ class SuprLayer(nn.Module): ...@@ -40,7 +39,7 @@ class SuprLayer(nn.Module):
pass pass
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
def __init__(self, *args): def __init__(self, *args: object) -> object:
super().__init__(*args) super().__init__(*args)
def em_batch_update(self): def em_batch_update(self):
...@@ -256,7 +255,7 @@ class NormalLeaf(SuprLayer): ...@@ -256,7 +255,7 @@ class NormalLeaf(SuprLayer):
""" NormalLeaf layer """ """ NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int, n: int = 1, def __init__(self, tracks: int, variables: int, channels: int, n: int = 1,
mu0: float = 0., nu0: float = 0., alpha0: float = 0., beta0: float = 0.): mu0: torch.tensor = 0., nu0: torch.tensor = 0., torch.tensor: float = 0., beta0: torch.tensor = 0.):
super().__init__() super().__init__()
# Dimensions # Dimensions
self.T, self.V, self.C = tracks, variables, channels self.T, self.V, self.C = tracks, variables, channels
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment