From a9091451d3bcc043c8bc799d043ac794f74ceaa2 Mon Sep 17 00:00:00 2001 From: Mikkel N Schmidt <mnsc@dtu.dk> Date: Tue, 22 Feb 2022 12:37:47 +0100 Subject: [PATCH] added two moons demos --- demos/two_moons.py | 90 ++++++++++++++++++++++++++++++ demos/two_moons_classifier.py | 100 ++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 demos/two_moons.py create mode 100644 demos/two_moons_classifier.py diff --git a/demos/two_moons.py b/demos/two_moons.py new file mode 100644 index 0000000..865896f --- /dev/null +++ b/demos/two_moons.py @@ -0,0 +1,90 @@ +# %% Imports +from sklearn import datasets +import matplotlib.pyplot as plt +import numpy as np +from supr.utils import drawnow, arrange_figs +from supr.layers import * + +#%% Create data +N = 1000 +x, y = datasets.make_moons(n_samples=N, noise=0.1) +x = (x - x.mean()) / x.std() +X = torch.tensor(x).float() + +#%% Plot data +plt.figure(1).clf() +arrange_figs() +xlim = (-3., 3.) +ylim = (-2., 2.) +plt.plot(x[:, 0], x[:, 1], '.', markersize=1, color='tab:red') +plt.xlim(*xlim) +plt.ylim(*ylim) +drawnow() + +#%% Make sum-product network +variables = 2 +channels = 4 +tracks = 4 +network = nn.Sequential( + NormalLeaf(tracks, variables, channels), + Weightsum(tracks, 1, channels) +) +print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}") + +#%% Fit model +epochs = 20 +for r in range(epochs): + network.train() + P = torch.sum(network(X)) + print(float(P)) + P.backward() + + with torch.no_grad(): + for m in network: + m.em_batch() + m.em_update() + network.zero_grad(True) + + # Plot fit + R = 150 + prx = torch.linspace(*xlim, R) + pry = torch.linspace(*ylim, R) + Xm, Ym = torch.meshgrid(prx, pry) + XY = torch.stack((Xm.flatten(), Ym.flatten())).T + XY = torch.cat((XY, torch.ones((R ** 2, variables - 2))), dim=1) + P = network(XY) + p = P.reshape((R, R)).detach().numpy() + plt.figure(2).clf() + arrange_figs() + plt.imshow(p.T, extent=(xlim[0], xlim[1], ylim[0], ylim[1]), origin='lower', clim=[np.max(p) - 10, np.max(p)]) + plt.plot(X[:, 0], X[:, 1], '.', markersize=1, color='tab:red', alpha=1.) + drawnow() + +#%% Plot samples +with torch.no_grad(): + network.eval() + network[0].marginalize = torch.ones(variables, dtype=torch.bool) + Z = torch.zeros(1, variables) + # network[0].marginalize[1] = False + # Z[0, 1] = -0.5 + network(Z) + + R = 1000 + + Z = [] + for r in range(R): + arg = [] + for m in reversed(network): + arg = m.sample(*arg) + Z.append(arg) + Z = torch.stack(Z) + + plt.figure(3).clf() + arrange_figs() + plt.plot(Z[:, 0], Z[:, 1], '.', markersize=1, color='tab:red', alpha=1.) + plt.grid(True) + plt.xlim(*xlim) + plt.ylim(*ylim) + drawnow() + + network[0].marginalize = torch.zeros(variables, dtype=torch.bool) diff --git a/demos/two_moons_classifier.py b/demos/two_moons_classifier.py new file mode 100644 index 0000000..cdcef30 --- /dev/null +++ b/demos/two_moons_classifier.py @@ -0,0 +1,100 @@ +# %% Imports +from sklearn import datasets +import matplotlib.pyplot as plt +import numpy as np +from supr.utils import drawnow, arrange_figs +from supr.layers import * + +#%% Create data +N = 1000 +x, y = datasets.make_moons(n_samples=N, noise=0.05) +x = (x - x.mean()) / x.std() +X = torch.tensor(x).float() + +#%% Plot data +plt.figure(1).clf() +arrange_figs() +xlim = (-3., 3.) +ylim = (-2., 2.) +plt.plot(x[:, 0], x[:, 1], '.', markersize=1, color='tab:red') +plt.xlim(*xlim) +plt.ylim(*ylim) +drawnow() + +#%% Make sum-product network +variables = 2 +channels = 2 +tracks = 5 +classes = 10 +network = nn.Sequential(NormalLeaf(tracks, variables, channels), + ScrambleTracks(tracks, variables), + Einsum(tracks, variables, channels, classes), + VariablesProduct(), + TrackSum(tracks, classes)) +print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}") + +#%% Fit model +epochs = 20 +for r in range(epochs): + network.train() + p = network(X) + P = torch.sum(p[np.arange(N),0,0,y]) + print(float(P)) + P.backward() + + with torch.no_grad(): + for m in network: + m.em_batch() + m.em_update() + network.zero_grad(True) + + # Plot fit + R = 150 + prx = torch.linspace(*xlim, R) + pry = torch.linspace(*ylim, R) + Xm, Ym = torch.meshgrid(prx, pry) + XY = torch.stack((Xm.flatten(), Ym.flatten())).T + XY = torch.cat((XY, torch.ones((R ** 2, variables - 2))), dim=1) + P = network(XY)[:,0,0,0] + p = P.reshape((R, R)).detach().numpy() + plt.figure(2).clf() + arrange_figs() + plt.imshow(p.T, extent=(xlim[0], xlim[1], ylim[0], ylim[1]), origin='lower', clim=[np.max(p) - 10, np.max(p)]) + plt.plot(X[:, 0], X[:, 1], '.', markersize=1, color='tab:red', alpha=1.) + drawnow() + + # Compute accuracy + p = network(X) + train_correct = torch.count_nonzero(torch.argmax(p[:,0,0,:], dim=1) == torch.tensor(y)) + train_accuracy = train_correct / N + print(f'Train accuracy: {train_accuracy} = {train_correct}/{N}') + + +#%% Plot samples +with torch.no_grad(): + network.eval() + network[0].marginalize = torch.ones(variables, dtype=torch.bool) + Z = torch.zeros(1, variables) + # network[0].marginalize[1] = False + # Z[0, 1] = 0. + network(Z) + + R = 1000 + + Z = [] + for r in range(R): + arg = [0, torch.tensor([0])] + for m in reversed(network): + arg = m.sample(*arg) + Z.append(arg) + Z = torch.stack(Z) + + plt.figure(3).clf() + arrange_figs() + plt.plot(Z[:, 0], Z[:, 1], '.', markersize=1, color='tab:red', alpha=1.) + plt.grid(True) + plt.xlim(*xlim) + plt.ylim(*ylim) + drawnow() + + network[0].marginalize = torch.zeros(variables, dtype=torch.bool) -- GitLab