From a9091451d3bcc043c8bc799d043ac794f74ceaa2 Mon Sep 17 00:00:00 2001
From: Mikkel N Schmidt <mnsc@dtu.dk>
Date: Tue, 22 Feb 2022 12:37:47 +0100
Subject: [PATCH] added two moons demos

---
 demos/two_moons.py            |  90 ++++++++++++++++++++++++++++++
 demos/two_moons_classifier.py | 100 ++++++++++++++++++++++++++++++++++
 2 files changed, 190 insertions(+)
 create mode 100644 demos/two_moons.py
 create mode 100644 demos/two_moons_classifier.py

diff --git a/demos/two_moons.py b/demos/two_moons.py
new file mode 100644
index 0000000..865896f
--- /dev/null
+++ b/demos/two_moons.py
@@ -0,0 +1,90 @@
+# %% Imports
+from sklearn import datasets
+import matplotlib.pyplot as plt
+import numpy as np
+from supr.utils import drawnow, arrange_figs
+from supr.layers import *
+
+#%% Create data
+N = 1000
+x, y = datasets.make_moons(n_samples=N, noise=0.1)
+x = (x - x.mean()) / x.std()
+X = torch.tensor(x).float()
+
+#%% Plot data
+plt.figure(1).clf()
+arrange_figs()
+xlim = (-3., 3.)
+ylim = (-2., 2.)
+plt.plot(x[:, 0], x[:, 1], '.', markersize=1, color='tab:red')
+plt.xlim(*xlim)
+plt.ylim(*ylim)
+drawnow()
+
+#%% Make sum-product network
+variables = 2
+channels = 4
+tracks = 4
+network = nn.Sequential(
+    NormalLeaf(tracks, variables, channels),
+    Weightsum(tracks, 1, channels)
+)
+print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
+
+#%% Fit model
+epochs = 20
+for r in range(epochs):
+    network.train()
+    P = torch.sum(network(X))
+    print(float(P))
+    P.backward()
+
+    with torch.no_grad():
+        for m in network:
+            m.em_batch()
+            m.em_update()
+        network.zero_grad(True)
+
+        # Plot fit
+        R = 150
+        prx = torch.linspace(*xlim, R)
+        pry = torch.linspace(*ylim, R)
+        Xm, Ym = torch.meshgrid(prx, pry)
+        XY = torch.stack((Xm.flatten(), Ym.flatten())).T
+        XY = torch.cat((XY, torch.ones((R ** 2, variables - 2))), dim=1)
+        P = network(XY)
+        p = P.reshape((R, R)).detach().numpy()
+        plt.figure(2).clf()
+        arrange_figs()
+        plt.imshow(p.T, extent=(xlim[0], xlim[1], ylim[0], ylim[1]), origin='lower', clim=[np.max(p) - 10, np.max(p)])
+        plt.plot(X[:, 0], X[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
+        drawnow()
+
+#%% Plot samples
+with torch.no_grad():
+    network.eval()
+    network[0].marginalize = torch.ones(variables, dtype=torch.bool)
+    Z = torch.zeros(1, variables)
+    # network[0].marginalize[1] = False
+    # Z[0, 1] = -0.5
+    network(Z)
+
+    R = 1000
+
+    Z = []
+    for r in range(R):
+        arg = []
+        for m in reversed(network):
+            arg = m.sample(*arg)
+        Z.append(arg)
+    Z = torch.stack(Z)
+
+    plt.figure(3).clf()
+    arrange_figs()
+    plt.plot(Z[:, 0], Z[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
+    plt.grid(True)
+    plt.xlim(*xlim)
+    plt.ylim(*ylim)
+    drawnow()
+
+    network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
diff --git a/demos/two_moons_classifier.py b/demos/two_moons_classifier.py
new file mode 100644
index 0000000..cdcef30
--- /dev/null
+++ b/demos/two_moons_classifier.py
@@ -0,0 +1,100 @@
+# %% Imports
+from sklearn import datasets
+import matplotlib.pyplot as plt
+import numpy as np
+from supr.utils import drawnow, arrange_figs
+from supr.layers import *
+
+#%% Create data
+N = 1000
+x, y = datasets.make_moons(n_samples=N, noise=0.05)
+x = (x - x.mean()) / x.std()
+X = torch.tensor(x).float()
+
+#%% Plot data
+plt.figure(1).clf()
+arrange_figs()
+xlim = (-3., 3.)
+ylim = (-2., 2.)
+plt.plot(x[:, 0], x[:, 1], '.', markersize=1, color='tab:red')
+plt.xlim(*xlim)
+plt.ylim(*ylim)
+drawnow()
+
+#%% Make sum-product network
+variables = 2
+channels = 2
+tracks = 5
+classes = 10
+network = nn.Sequential(NormalLeaf(tracks, variables, channels),
+                        ScrambleTracks(tracks, variables),
+                        Einsum(tracks, variables, channels, classes),
+                        VariablesProduct(),
+                        TrackSum(tracks, classes))
+print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
+
+#%% Fit model
+epochs = 20
+for r in range(epochs):
+    network.train()
+    p = network(X)
+    P = torch.sum(p[np.arange(N),0,0,y])
+    print(float(P))
+    P.backward()
+
+    with torch.no_grad():
+        for m in network:
+            m.em_batch()
+            m.em_update()
+        network.zero_grad(True)
+
+        # Plot fit
+        R = 150
+        prx = torch.linspace(*xlim, R)
+        pry = torch.linspace(*ylim, R)
+        Xm, Ym = torch.meshgrid(prx, pry)
+        XY = torch.stack((Xm.flatten(), Ym.flatten())).T
+        XY = torch.cat((XY, torch.ones((R ** 2, variables - 2))), dim=1)
+        P = network(XY)[:,0,0,0]
+        p = P.reshape((R, R)).detach().numpy()
+        plt.figure(2).clf()
+        arrange_figs()
+        plt.imshow(p.T, extent=(xlim[0], xlim[1], ylim[0], ylim[1]), origin='lower', clim=[np.max(p) - 10, np.max(p)])
+        plt.plot(X[:, 0], X[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
+        drawnow()
+
+        # Compute accuracy
+        p = network(X)
+        train_correct = torch.count_nonzero(torch.argmax(p[:,0,0,:], dim=1) == torch.tensor(y))
+        train_accuracy = train_correct / N
+        print(f'Train accuracy: {train_accuracy} = {train_correct}/{N}')
+
+
+#%% Plot samples
+with torch.no_grad():
+    network.eval()
+    network[0].marginalize = torch.ones(variables, dtype=torch.bool)
+    Z = torch.zeros(1, variables)
+    # network[0].marginalize[1] = False
+    # Z[0, 1] = 0.
+    network(Z)
+
+    R = 1000
+
+    Z = []
+    for r in range(R):
+        arg = [0, torch.tensor([0])]
+        for m in reversed(network):
+            arg = m.sample(*arg)
+        Z.append(arg)
+    Z = torch.stack(Z)
+
+    plt.figure(3).clf()
+    arrange_figs()
+    plt.plot(Z[:, 0], Z[:, 1], '.', markersize=1, color='tab:red', alpha=1.)
+    plt.grid(True)
+    plt.xlim(*xlim)
+    plt.ylim(*ylim)
+    drawnow()
+
+    network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
-- 
GitLab