Skip to content
Snippets Groups Projects
Commit 42c3ba55 authored by mnsc's avatar mnsc
Browse files

move to src and add demos

parent 1b0d6980
Branches main
No related tags found
No related merge requests found
# %% Imports
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from supr.utils import drawnow, arrange_figs
from supr.layers import *
#%% Create data
N = 20
V = 10
p = np.linspace(0,1,V)
X = np.random.rand(N, V) < p
X = torch.tensor(X).float()
# Plot data
plt.figure(1).clf()
arrange_figs()
plt.plot(range(V), X.mean(0))
drawnow()
#%% Make sum-product network
variables = V
channels = 4
tracks = 4
network = nn.Sequential(
BernoulliLeaf(tracks, variables, channels, n=N, alpha0=2, beta0=2),
Weightsum(tracks, 1, channels)
)
print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
#%% Fit model
epochs = 20
for r in range(epochs):
network.train()
P = torch.sum(network(X))
print(float(P))
P.backward()
with torch.no_grad():
for m in network:
m.em_batch()
m.em_update()
network.zero_grad(True)
# Plot fit
plt.figure(2).clf()
arrange_figs()
network.eval()
Y = torch.ones(1, variables)
p_est = torch.zeros(variables)
for v in range(variables):
marginalize = torch.ones(variables, dtype=torch.bool)
marginalize[v] = False
network[0].marginalize = marginalize
p_est[v] = network(Y)
plt.plot(range(V), X.mean(0))
plt.plot(range(V), torch.exp(p_est), '--')
drawnow()
# %% Imports
import matplotlib.pyplot as plt
import numpy as np
from supr.utils import drawnow, arrange_figs
from supr.layers import *
fig_args = {'x0':2800, 'y0':28, 'x1':3840, 'y1':2160}
# fig_args = {'x0':1400, 'y0':28, 'x1':1920, 'y1':1080}
#%% Create data
N = 1000
V = 10
D = 2
P = np.zeros((V,D))
for d in range(D):
P[:,d] = np.linspace(d/(D-1), 1-d/(D-1), V)
P /= P.sum(1, keepdims=True)
x = [np.random.choice(range(D), size=int(N/2), p=p) for p in P]
y = [np.random.choice(range(D-1,-1,-1), size=int(N/2), p=p) for p in P]
X = torch.tensor(np.concatenate((np.array(x).T, np.array(y).T)))
# Plot data
plt.figure(1).clf()
arrange_figs(**fig_args)
plt.plot(range(V), X.sum(0)/N, '.-')
plt.ylim((0,1))
drawnow()
#%% Make sum-product network
dimensions = D
variables = V
channels = 1
tracks = 2
network = nn.Sequential(
CategoricalLeaf(tracks, variables, channels, dimensions, n=N, alpha0=2),
# Einsum(tracks, variables, channels, channels),
Weightsum(tracks, 1, channels)
)
print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
#%% Fit model
epochs = 10
for r in range(epochs):
network.train()
network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
P = torch.sum(network(X))
print(float(P))
P.backward()
with torch.no_grad():
for m in network:
m.em_batch()
m.em_update()
network.zero_grad(True)
# Plot fit
plt.figure(2).clf()
arrange_figs(**fig_args)
network.eval()
p_est = torch.zeros(variables, dimensions)
for d in range(dimensions):
for v in range(variables):
Y = torch.zeros(1, variables, dtype=int)
marginalize = torch.ones(variables, dtype=torch.bool)
marginalize[v] = False
network[0].marginalize = marginalize
Y[0, v] = d
p_est[v, d] = network(Y)
# plt.plot(range(V), X.sum(0)/N, '.-')
plt.plot(range(V), torch.exp(p_est), '.--')
plt.ylim((0,1))
drawnow()
#%% Plot samples
with torch.no_grad():
network.eval()
network[0].marginalize = torch.ones(variables, dtype=torch.bool)
Z = torch.zeros(1, variables, dtype=torch.long)
Z[0,0] = 1
network[0].marginalize[0] = False
network(Z)
R = 200
smp = []
for r in range(R):
arg = []
for m in reversed(network):
arg = m.sample(*arg)
smp.append(arg)
smp = torch.stack(smp)
plt.figure(3).clf()
arrange_figs(**fig_args)
plt.imshow(smp, aspect='auto', clim=(0,D-1), interpolation='none')
drawnow()
network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
# %% Imports
import matplotlib.pyplot as plt
import numpy as np
import os
from supr.utils import drawnow, arrange_figs, bsgen
from supr.layers import *
fig_args = {'x0':1400, 'y0':28, 'x1':1920, 'y1':1080}
#%%
N = 500
x = []
y = []
for n in range(N):
alpha = np.random.rand()
if alpha<.33:
x.append(np.random.randn()*.8)
y.append(np.random.randn()*.1)
elif alpha<.66:
x.append(np.random.randn()*.1+3)
y.append(np.random.randn()*.1+4)
else:
x.append(np.random.randn()*.1)
y.append(np.random.randn()*.8+4)
X = torch.tensor(np.stack((np.array(x),np.array(y))).T)
T = 2
V = 2
C = 2
network = torch.nn.Sequential(NormalLeaf(T,V,C,n=N), Einsum(T,V,C), Weightsum(T,V,C))
plt.figure(1).clf()
plt.plot(X[:,0], X[:,1], '.')
arrange_figs(**fig_args)
plt.xlim((-2, 4))
plt.ylim((-1, 7))
#%%
network[0].marginalize = torch.zeros(V, dtype=bool)
num_epochs = 50
for epoch in range(num_epochs):
network.train()
P = torch.mean(network(X))
print('P: {}'.format(P))
P.backward()
with torch.no_grad():
for m in network:
m.em_batch()
m.em_update()
network.zero_grad(True)
#%%
with torch.no_grad():
R = N
smp = []
network.eval()
Y = torch.tensor([[0.,0]])
network[0].marginalize = torch.tensor([True, True])
network(Y)
for r in range(R):
arg = []
for m in reversed(network):
arg = m.sample(*arg)
smp.append(arg)
smp = torch.stack(smp)
plt.figure(2)
plt.clf()
plt.plot(smp[:,0], smp[:,1], '.')
arrange_figs(**fig_args)
plt.xlim((-2, 4))
plt.ylim((-1, 7))
# %% Imports
import matplotlib.pyplot as plt
import numpy as np
import os
from supr.utils import drawnow, arrange_figs, bsgen
from supr.layers import *
fig_args = {'x0':2800, 'y0':28, 'x1':3840, 'y1':2160}
# %%
N = 200
V = 10
n_profiles = 2
np.random.seed(0)
profiles = (np.cos(np.pi/4+.5*np.pi*np.arange(V)/V*(1+np.arange(n_profiles)[:,None]))>0).astype(float)
idx_profile = np.arange(N) % n_profiles
std_profile = np.array([.1 + .00 * x for x in range(n_profiles)])
x = (profiles[idx_profile] + ((np.random.rand(N, V))<0.02).astype(float)) % 2
plt.figure(1).clf()
arrange_figs(**fig_args)
plt.plot(x.T, color='r', alpha=0.01)
drawnow()
# %% Device and environment
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
X = torch.tensor(x).float().to(device)
# %%
S = 5
T = 5
network = nn.Sequential(BernoulliLeaf(T, V, S),
ScrambleTracks2d(T, V, 1., (V,1)),
*[Einsum(T, v, S) for v in bsgen(V, 2)],
Einsum(T, 2, S, 1),
Weightsum(T, 1, 1)).to(device)
# %%
num_epochs = 10
for epoch in range(num_epochs):
network.train()
P = torch.mean(network(X))
print('P: {}'.format(P))
P.backward()
with torch.no_grad():
for m in network:
m.em_batch()
m.em_update()
network.zero_grad(True)
with torch.no_grad():
network.eval()
# R = 100
# ylim = [-2, 2]
# yvals = torch.linspace(*ylim, R)
# PM = np.zeros((R, V))
# for v in range(V):
# network[0].marginalize = torch.ones(V, dtype=torch.bool)
# network[0].marginalize[v] = False
# Y = torch.zeros(R, V).to(device)
# Y[:, v] = yvals
# P = network(Y)
# PM[:, v] = P.detach().cpu().numpy()
# network[0].marginalize = torch.zeros(V, dtype=torch.bool)
# plt.figure(2).clf()
# arrange_figs()
# plt.imshow(np.exp(PM) / (np.max(np.exp(PM), axis=0) + 1e-10), cmap='Blues', vmax=2, origin='lower',
# extent=(-.5, V - .5, *ylim), interpolation='none', aspect='auto')
# plt.plot(X.T, color='r', alpha=0.01)
# # plt.plot(profiles.T, '.-', color='k')
# # plt.grid(True)
# drawnow()
with torch.no_grad():
network[0].marginalize = torch.ones(V, dtype=torch.bool).to(device)
Z = torch.zeros(1, V).to(device)
# network[0].marginalize[0:3] = False
# Z[0, 0:3] = torch.tensor([2.1, 2.2, 2.4])
# network[0].marginalize[[0,9]] = False
# Z[0, [0,9]] = torch.tensor([-2, -0.9])
network(Z)
R = 200
Z = []
for epoch in range(R):
arg = []
for m in reversed(network):
arg = m.sample(*arg)
Z.append(arg)
Z = torch.stack(Z)
plt.figure(3).clf()
arrange_figs(**fig_args)
plt.plot(Z.cpu().numpy().T, color='r', alpha=0.01)
# plt.plot(profiles.T, '.-', color='k')
# plt.grid(True)
drawnow()
# plt.figure(4).clf()
# arrange_figs()
# plt.imshow(network[2].weights.reshape(100, 400).detach(), aspect='auto')
# drawnow()
network[0].marginalize = torch.zeros(V, dtype=torch.bool).to(device)
print('Sample P: {}'.format(torch.mean(network(Z))))
......@@ -53,7 +53,7 @@ model = supr.Sequential(
marginalize_y = torch.tensor([False, True])
# %% Fit model and display results
epochs = 20
epochs = 10
for epoch in range(epochs):
model.train()
......@@ -154,4 +154,3 @@ for epoch in range(epochs):
plt.xlim([x_min, x_max])
plt.ylim([y_min, y_max])
drawnow()
# Imports
import supr
import torch
import numpy as np
import matplotlib.pyplot as plt
# Gaussian model
# Data
N = 10000
X = torch.randn(N, 10)*0.2 + torch.arange(10)
X[1:-1:2] = 10-X[1:-1:2]
tracks = 1
variables = 10
channels = 5
model = supr.Sequential(
supr.NormalLeaf(tracks, variables, channels, n=N),
supr.Weightsum(tracks, variables, channels))
marg = torch.ones(10, dtype=torch.bool)
marg[2] = False
model.train()
epochs = 10
for e in range(epochs):
loss = model(X).sum()
model.zero_grad()
loss.backward()
model.em_batch_update()
print(loss)
#%%
plt.figure(1).clf()
with torch.no_grad():
model.eval()
model(X[0:3], marginalize=marg)
plt.plot(X[0], '-.')
for k in range(100):
Y = model.sample()
plt.plot(Y, '.', color='tab:orange')
\ No newline at end of file
......@@ -26,13 +26,13 @@ variables = 2
channels = 4
tracks = 4
network = nn.Sequential(
NormalLeaf(tracks, variables, channels, n=N, alpha0=1, beta0=1),
NormalLeaf(tracks, variables, channels, n=N, alpha0=.01, beta0=.01),
Weightsum(tracks, 1, channels)
)
print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
#%% Fit model
epochs = 20
epochs = 50
for r in range(epochs):
network.train()
P = torch.sum(network(X))
......
# pyproject.toml
[build-system]
requires = ["setuptools>=61.0.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "supr"
version = "0.0.1"
description = "SUm-PRoduct network Pytorch layers"
readme = "README.md"
authors = [{ name = "Mikkel N. Schmidt", email = "mnsc@dtu.dk" }]
license = { file = "LICENSE" }
classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
]
keywords = ["sum-product network"]
dependencies = [
"torch",
]
requires-python = ">=3.9"
#[project.urls]
#Homepage = "https://github.com/"
Metadata-Version: 2.1
Name: supr
Version: 0.0.1
Summary: SUm-PRoduct network Pytorch layers
Author-email: "Mikkel N. Schmidt" <mnsc@dtu.dk>
Keywords: sum-product network
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: torch
# supr
Python toolbox for tensor-based sum product networks
README.md
pyproject.toml
src/supr/__init__.py
src/supr/layers.py
src/supr/utils.py
src/supr.egg-info/PKG-INFO
src/supr.egg-info/SOURCES.txt
src/supr.egg-info/dependency_links.txt
src/supr.egg-info/requires.txt
src/supr.egg-info/top_level.txt
\ No newline at end of file
torch
supr
File moved
......@@ -55,10 +55,10 @@ class Sequential(nn.Sequential):
for module in self:
module.em_batch()
def em_update(self):
def em_update(self, *args, **kwargs):
with torch.no_grad():
for module in self:
module.em_update()
module.em_update(*args, **kwargs)
def sample(self):
value = []
......@@ -308,7 +308,7 @@ class NormalLeaf(SuprLeaf):
# Prior
self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0
# Parametes
self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))
self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))*3
self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5)
# Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(
......@@ -424,9 +424,11 @@ class BernoulliLeaf(SuprLeaf):
def em_update(self, learning_rate: float = 1.):
# Probability
p_update = (self.z_x_acc + self.alpha0 - 1) / \
(self.z_acc + self.alpha0 + self.beta0 - 2)
(self.z_acc + self.alpha0 + self.beta0 - 2 + self.epsilon)
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * p_update
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
......@@ -441,7 +443,7 @@ class BernoulliLeaf(SuprLeaf):
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, marginalize=None):
# Get shape
batch_size = x.shape[0]
# Store the data
......
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment