Skip to content
Snippets Groups Projects
Commit a9091451 authored by mnsc's avatar mnsc
Browse files

added two moons demos

parent ccce14dc
No related branches found
No related tags found
No related merge requests found
# %% Imports
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from supr.utils import drawnow, arrange_figs
from supr.layers import *
#%% Create data
N = 1000
x, y = datasets.make_moons(n_samples=N, noise=0.1)
x = (x - x.mean()) / x.std()
X = torch.tensor(x).float()
#%% Plot data
plt.figure(1).clf()
arrange_figs()
xlim = (-3., 3.)
ylim = (-2., 2.)
plt.plot(x[:, 0], x[:, 1], '.', markersize=1, color='tab:red')
plt.xlim(*xlim)
plt.ylim(*ylim)
drawnow()
#%% Make sum-product network
variables = 2
channels = 4
tracks = 4
network = nn.Sequential(
NormalLeaf(tracks, variables, channels),
Weightsum(tracks, 1, channels)
)
print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
#%% Fit model
epochs = 20
for r in range(epochs):
network.train()
P = torch.sum(network(X))
print(float(P))
P.backward()
with torch.no_grad():
for m in network:
m.em_batch()
m.em_update()
network.zero_grad(True)
# Plot fit
R = 150
prx = torch.linspace(*xlim, R)
pry = torch.linspace(*ylim, R)
Xm, Ym = torch.meshgrid(prx, pry)
XY = torch.stack((Xm.flatten(), Ym.flatten())).T
XY = torch.cat((XY, torch.ones((R ** 2, variables - 2))), dim=1)
P = network(XY)
p = P.reshape((R, R)).detach().numpy()
plt.figure(2).clf()
arrange_figs()
plt.imshow(p.T, extent=(xlim[0], xlim[1], ylim[0], ylim[1]), origin='lower', clim=[np.max(p) - 10, np.max(p)])
plt.plot(X[:, 0], X[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
drawnow()
#%% Plot samples
with torch.no_grad():
network.eval()
network[0].marginalize = torch.ones(variables, dtype=torch.bool)
Z = torch.zeros(1, variables)
# network[0].marginalize[1] = False
# Z[0, 1] = -0.5
network(Z)
R = 1000
Z = []
for r in range(R):
arg = []
for m in reversed(network):
arg = m.sample(*arg)
Z.append(arg)
Z = torch.stack(Z)
plt.figure(3).clf()
arrange_figs()
plt.plot(Z[:, 0], Z[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
plt.grid(True)
plt.xlim(*xlim)
plt.ylim(*ylim)
drawnow()
network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
# %% Imports
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
from supr.utils import drawnow, arrange_figs
from supr.layers import *
#%% Create data
N = 1000
x, y = datasets.make_moons(n_samples=N, noise=0.05)
x = (x - x.mean()) / x.std()
X = torch.tensor(x).float()
#%% Plot data
plt.figure(1).clf()
arrange_figs()
xlim = (-3., 3.)
ylim = (-2., 2.)
plt.plot(x[:, 0], x[:, 1], '.', markersize=1, color='tab:red')
plt.xlim(*xlim)
plt.ylim(*ylim)
drawnow()
#%% Make sum-product network
variables = 2
channels = 2
tracks = 5
classes = 10
network = nn.Sequential(NormalLeaf(tracks, variables, channels),
ScrambleTracks(tracks, variables),
Einsum(tracks, variables, channels, classes),
VariablesProduct(),
TrackSum(tracks, classes))
print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
#%% Fit model
epochs = 20
for r in range(epochs):
network.train()
p = network(X)
P = torch.sum(p[np.arange(N),0,0,y])
print(float(P))
P.backward()
with torch.no_grad():
for m in network:
m.em_batch()
m.em_update()
network.zero_grad(True)
# Plot fit
R = 150
prx = torch.linspace(*xlim, R)
pry = torch.linspace(*ylim, R)
Xm, Ym = torch.meshgrid(prx, pry)
XY = torch.stack((Xm.flatten(), Ym.flatten())).T
XY = torch.cat((XY, torch.ones((R ** 2, variables - 2))), dim=1)
P = network(XY)[:,0,0,0]
p = P.reshape((R, R)).detach().numpy()
plt.figure(2).clf()
arrange_figs()
plt.imshow(p.T, extent=(xlim[0], xlim[1], ylim[0], ylim[1]), origin='lower', clim=[np.max(p) - 10, np.max(p)])
plt.plot(X[:, 0], X[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
drawnow()
# Compute accuracy
p = network(X)
train_correct = torch.count_nonzero(torch.argmax(p[:,0,0,:], dim=1) == torch.tensor(y))
train_accuracy = train_correct / N
print(f'Train accuracy: {train_accuracy} = {train_correct}/{N}')
#%% Plot samples
with torch.no_grad():
network.eval()
network[0].marginalize = torch.ones(variables, dtype=torch.bool)
Z = torch.zeros(1, variables)
# network[0].marginalize[1] = False
# Z[0, 1] = 0.
network(Z)
R = 1000
Z = []
for r in range(R):
arg = [0, torch.tensor([0])]
for m in reversed(network):
arg = m.sample(*arg)
Z.append(arg)
Z = torch.stack(Z)
plt.figure(3).clf()
arrange_figs()
plt.plot(Z[:, 0], Z[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
plt.grid(True)
plt.xlim(*xlim)
plt.ylim(*ylim)
drawnow()
network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment