From 42c3ba55f6b19b69be5363c5d1a93455ba595718 Mon Sep 17 00:00:00 2001
From: "Mikkel N. Schmidt" <mnsc@dtu.dk>
Date: Fri, 7 Jun 2024 12:59:07 +0200
Subject: [PATCH] move to src and add demos

---
 demos/binary_sweep.py                  |  58 +++++++++++++
 demos/categorical_sweep.py             |  98 +++++++++++++++++++++
 demos/gaussian_mixture.py              |  79 +++++++++++++++++
 demos/profiles_binary.py               | 116 +++++++++++++++++++++++++
 demos/regression.py                    |   3 +-
 demos/sampling.py                      |  42 +++++++++
 demos/two_moons.py                     |   4 +-
 pyproject.toml                         |  27 ++++++
 src/supr.egg-info/PKG-INFO             |  16 ++++
 src/supr.egg-info/SOURCES.txt          |  10 +++
 src/supr.egg-info/dependency_links.txt |   1 +
 src/supr.egg-info/requires.txt         |   1 +
 src/supr.egg-info/top_level.txt        |   1 +
 {supr => src/supr}/__init__.py         |   0
 {supr => src/supr}/layers.py           |  12 +--
 {supr => src/supr}/utils.py            |   0
 16 files changed, 459 insertions(+), 9 deletions(-)
 create mode 100644 demos/binary_sweep.py
 create mode 100644 demos/categorical_sweep.py
 create mode 100644 demos/gaussian_mixture.py
 create mode 100644 demos/profiles_binary.py
 create mode 100644 demos/sampling.py
 create mode 100644 pyproject.toml
 create mode 100644 src/supr.egg-info/PKG-INFO
 create mode 100644 src/supr.egg-info/SOURCES.txt
 create mode 100644 src/supr.egg-info/dependency_links.txt
 create mode 100644 src/supr.egg-info/requires.txt
 create mode 100644 src/supr.egg-info/top_level.txt
 rename {supr => src/supr}/__init__.py (100%)
 rename {supr => src/supr}/layers.py (98%)
 rename {supr => src/supr}/utils.py (100%)

diff --git a/demos/binary_sweep.py b/demos/binary_sweep.py
new file mode 100644
index 0000000..f521983
--- /dev/null
+++ b/demos/binary_sweep.py
@@ -0,0 +1,58 @@
+# %% Imports
+from sklearn import datasets
+import matplotlib.pyplot as plt
+import numpy as np
+from supr.utils import drawnow, arrange_figs
+from supr.layers import *
+
+#%% Create data
+N = 20
+V = 10
+p = np.linspace(0,1,V)
+X = np.random.rand(N, V) < p
+X = torch.tensor(X).float()
+
+# Plot data
+plt.figure(1).clf()
+arrange_figs()
+plt.plot(range(V), X.mean(0))
+drawnow()
+
+#%% Make sum-product network
+variables = V
+channels = 4
+tracks = 4
+network = nn.Sequential(
+    BernoulliLeaf(tracks, variables, channels, n=N, alpha0=2, beta0=2),
+    Weightsum(tracks, 1, channels)
+)
+print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
+
+#%% Fit model
+epochs = 20
+for r in range(epochs):
+    network.train()
+    P = torch.sum(network(X))
+    print(float(P))
+    P.backward()
+
+    with torch.no_grad():
+        for m in network:
+            m.em_batch()
+            m.em_update()
+        network.zero_grad(True)
+
+        # Plot fit
+        plt.figure(2).clf()
+        arrange_figs()
+        network.eval()
+        Y = torch.ones(1, variables)
+        p_est = torch.zeros(variables)
+        for v in range(variables):
+            marginalize = torch.ones(variables, dtype=torch.bool)
+            marginalize[v] = False
+            network[0].marginalize = marginalize
+            p_est[v] = network(Y)
+        plt.plot(range(V), X.mean(0))
+        plt.plot(range(V), torch.exp(p_est), '--')
+        drawnow()
diff --git a/demos/categorical_sweep.py b/demos/categorical_sweep.py
new file mode 100644
index 0000000..c2a7b57
--- /dev/null
+++ b/demos/categorical_sweep.py
@@ -0,0 +1,98 @@
+# %% Imports
+import matplotlib.pyplot as plt
+import numpy as np
+from supr.utils import drawnow, arrange_figs
+from supr.layers import *
+
+fig_args = {'x0':2800, 'y0':28, 'x1':3840, 'y1':2160}
+# fig_args = {'x0':1400, 'y0':28, 'x1':1920, 'y1':1080}
+
+#%% Create data
+N = 1000
+V = 10
+D = 2
+P = np.zeros((V,D))
+for d in range(D):
+    P[:,d] = np.linspace(d/(D-1), 1-d/(D-1), V)
+P /= P.sum(1, keepdims=True)
+x = [np.random.choice(range(D), size=int(N/2), p=p) for p in P]
+y = [np.random.choice(range(D-1,-1,-1), size=int(N/2), p=p) for p in P]
+
+X = torch.tensor(np.concatenate((np.array(x).T, np.array(y).T)))
+
+# Plot data
+plt.figure(1).clf()
+arrange_figs(**fig_args)
+plt.plot(range(V), X.sum(0)/N, '.-')
+plt.ylim((0,1))
+drawnow()
+
+#%% Make sum-product network
+dimensions = D
+variables = V
+channels = 1
+tracks = 2
+network = nn.Sequential(
+    CategoricalLeaf(tracks, variables, channels, dimensions, n=N, alpha0=2),
+    # Einsum(tracks, variables, channels, channels),
+    Weightsum(tracks, 1, channels)
+)
+print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
+
+#%% Fit model
+epochs = 10
+for r in range(epochs):
+    network.train()
+    network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
+    P = torch.sum(network(X))
+    print(float(P))
+    P.backward()
+
+    with torch.no_grad():
+        for m in network:
+            m.em_batch()
+            m.em_update()
+        network.zero_grad(True)
+
+        # Plot fit
+        plt.figure(2).clf()
+        arrange_figs(**fig_args)
+        network.eval()
+        p_est = torch.zeros(variables, dimensions)
+        for d in range(dimensions):
+            for v in range(variables):
+                Y = torch.zeros(1, variables, dtype=int)
+                marginalize = torch.ones(variables, dtype=torch.bool)
+                marginalize[v] = False
+                network[0].marginalize = marginalize
+                Y[0, v] = d
+                p_est[v, d] = network(Y)
+        # plt.plot(range(V), X.sum(0)/N, '.-')
+        plt.plot(range(V), torch.exp(p_est), '.--')
+        plt.ylim((0,1))
+        drawnow()
+
+#%% Plot samples
+with torch.no_grad():
+    network.eval()
+    network[0].marginalize = torch.ones(variables, dtype=torch.bool)
+    Z = torch.zeros(1, variables, dtype=torch.long)
+    Z[0,0] = 1
+    network[0].marginalize[0] = False
+    network(Z)
+
+    R = 200
+    smp = []
+    for r in range(R):
+        arg = []
+        for m in reversed(network):
+            arg = m.sample(*arg)
+        smp.append(arg)
+    smp = torch.stack(smp)
+
+    plt.figure(3).clf()
+    arrange_figs(**fig_args)
+    plt.imshow(smp, aspect='auto', clim=(0,D-1), interpolation='none')
+    drawnow()
+
+    network[0].marginalize = torch.zeros(variables, dtype=torch.bool)
diff --git a/demos/gaussian_mixture.py b/demos/gaussian_mixture.py
new file mode 100644
index 0000000..b72d0c6
--- /dev/null
+++ b/demos/gaussian_mixture.py
@@ -0,0 +1,79 @@
+# %% Imports
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from supr.utils import drawnow, arrange_figs, bsgen
+from supr.layers import *
+
+fig_args = {'x0':1400, 'y0':28, 'x1':1920, 'y1':1080}
+
+#%%
+
+N = 500
+
+x = []
+y = []
+for n in range(N):
+    alpha = np.random.rand()
+    if alpha<.33:
+        x.append(np.random.randn()*.8)
+        y.append(np.random.randn()*.1)        
+    elif alpha<.66:
+        x.append(np.random.randn()*.1+3)
+        y.append(np.random.randn()*.1+4)
+    else:
+        x.append(np.random.randn()*.1)
+        y.append(np.random.randn()*.8+4)
+
+X = torch.tensor(np.stack((np.array(x),np.array(y))).T)        
+
+T = 2
+V = 2
+C = 2
+network = torch.nn.Sequential(NormalLeaf(T,V,C,n=N), Einsum(T,V,C), Weightsum(T,V,C))
+
+plt.figure(1).clf()
+plt.plot(X[:,0], X[:,1], '.')
+arrange_figs(**fig_args)
+plt.xlim((-2, 4))
+plt.ylim((-1, 7))
+
+#%%    
+network[0].marginalize = torch.zeros(V, dtype=bool)
+num_epochs = 50
+for epoch in range(num_epochs):
+    network.train()
+    P = torch.mean(network(X))
+    print('P: {}'.format(P))
+
+    P.backward()
+
+    with torch.no_grad():
+        for m in network:
+            m.em_batch()
+            m.em_update()
+        network.zero_grad(True)
+
+#%%
+
+with torch.no_grad():
+    R = N
+    smp = []
+    network.eval()
+    Y = torch.tensor([[0.,0]])
+    network[0].marginalize = torch.tensor([True, True])
+    network(Y)
+    for r in range(R):
+        arg = []
+        for m in reversed(network):
+            arg = m.sample(*arg)
+        smp.append(arg)
+    smp = torch.stack(smp)
+    
+    plt.figure(2)
+    plt.clf()
+    plt.plot(smp[:,0], smp[:,1], '.')
+    arrange_figs(**fig_args)
+    plt.xlim((-2, 4))
+    plt.ylim((-1, 7))
+    
diff --git a/demos/profiles_binary.py b/demos/profiles_binary.py
new file mode 100644
index 0000000..7a53b24
--- /dev/null
+++ b/demos/profiles_binary.py
@@ -0,0 +1,116 @@
+# %% Imports
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from supr.utils import drawnow, arrange_figs, bsgen
+from supr.layers import *
+
+fig_args = {'x0':2800, 'y0':28, 'x1':3840, 'y1':2160}
+
+# %%
+
+N = 200
+V = 10
+n_profiles = 2
+np.random.seed(0)
+profiles = (np.cos(np.pi/4+.5*np.pi*np.arange(V)/V*(1+np.arange(n_profiles)[:,None]))>0).astype(float)
+
+idx_profile = np.arange(N) % n_profiles
+std_profile = np.array([.1 + .00 * x for x in range(n_profiles)])
+x = (profiles[idx_profile] + ((np.random.rand(N, V))<0.02).astype(float)) % 2
+
+plt.figure(1).clf()
+arrange_figs(**fig_args)
+plt.plot(x.T, color='r', alpha=0.01)
+drawnow()
+
+# %% Device and environment
+os.environ["CUDA_VISIBLE_DEVICES"] = "6"
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+X = torch.tensor(x).float().to(device)
+
+# %%
+S = 5
+T = 5
+network = nn.Sequential(BernoulliLeaf(T, V, S),
+                        ScrambleTracks2d(T, V, 1., (V,1)),
+                        *[Einsum(T, v, S) for v in bsgen(V, 2)],
+                        Einsum(T, 2, S, 1),
+                        Weightsum(T, 1, 1)).to(device)
+
+# %%
+num_epochs = 10
+for epoch in range(num_epochs):
+    network.train()
+    P = torch.mean(network(X))
+    print('P: {}'.format(P))
+
+    P.backward()
+
+    with torch.no_grad():
+        for m in network:
+            m.em_batch()
+            m.em_update()
+        network.zero_grad(True)
+
+    with torch.no_grad():
+        network.eval()
+        # R = 100
+        # ylim = [-2, 2]
+        # yvals = torch.linspace(*ylim, R)
+        # PM = np.zeros((R, V))
+        # for v in range(V):
+        #     network[0].marginalize = torch.ones(V, dtype=torch.bool)
+        #     network[0].marginalize[v] = False
+        #     Y = torch.zeros(R, V).to(device)
+        #     Y[:, v] = yvals
+        #     P = network(Y)
+        #     PM[:, v] = P.detach().cpu().numpy()
+        # network[0].marginalize = torch.zeros(V, dtype=torch.bool)
+
+        # plt.figure(2).clf()
+        # arrange_figs()
+        # plt.imshow(np.exp(PM) / (np.max(np.exp(PM), axis=0) + 1e-10), cmap='Blues', vmax=2, origin='lower',
+        #            extent=(-.5, V - .5, *ylim), interpolation='none', aspect='auto')
+        # plt.plot(X.T, color='r', alpha=0.01)
+        # # plt.plot(profiles.T, '.-', color='k')
+        # # plt.grid(True)
+        # drawnow()
+
+        with torch.no_grad():
+            network[0].marginalize = torch.ones(V, dtype=torch.bool).to(device)
+            Z = torch.zeros(1, V).to(device)
+
+            # network[0].marginalize[0:3] = False
+            # Z[0, 0:3] = torch.tensor([2.1, 2.2, 2.4])
+            # network[0].marginalize[[0,9]] = False
+            # Z[0, [0,9]] = torch.tensor([-2, -0.9])
+            network(Z)
+
+            R = 200
+
+            Z = []
+            for epoch in range(R):
+                arg = []
+                for m in reversed(network):
+                    arg = m.sample(*arg)
+                Z.append(arg)
+            Z = torch.stack(Z)
+
+            plt.figure(3).clf()
+            arrange_figs(**fig_args)
+            plt.plot(Z.cpu().numpy().T, color='r', alpha=0.01)
+            # plt.plot(profiles.T, '.-', color='k')
+            # plt.grid(True)
+            drawnow()
+
+            # plt.figure(4).clf()
+            # arrange_figs()
+            # plt.imshow(network[2].weights.reshape(100, 400).detach(), aspect='auto')
+            # drawnow()
+
+        network[0].marginalize = torch.zeros(V, dtype=torch.bool).to(device)
+        print('Sample P: {}'.format(torch.mean(network(Z))))
+
+
diff --git a/demos/regression.py b/demos/regression.py
index 789e3a6..5642631 100644
--- a/demos/regression.py
+++ b/demos/regression.py
@@ -53,7 +53,7 @@ model = supr.Sequential(
 marginalize_y = torch.tensor([False, True])
 
 # %% Fit model and display results
-epochs = 20
+epochs = 10
 
 for epoch in range(epochs):
     model.train()
@@ -154,4 +154,3 @@ for epoch in range(epochs):
         plt.xlim([x_min, x_max])
         plt.ylim([y_min, y_max])
         drawnow()
-        
diff --git a/demos/sampling.py b/demos/sampling.py
new file mode 100644
index 0000000..55e393f
--- /dev/null
+++ b/demos/sampling.py
@@ -0,0 +1,42 @@
+# Imports
+import supr
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+
+# Gaussian model
+
+# Data
+N = 10000
+X = torch.randn(N, 10)*0.2 + torch.arange(10)
+X[1:-1:2]  = 10-X[1:-1:2]
+ 
+tracks = 1
+variables = 10
+channels = 5
+model = supr.Sequential(
+    supr.NormalLeaf(tracks, variables, channels, n=N),
+    supr.Weightsum(tracks, variables, channels))
+
+marg = torch.ones(10, dtype=torch.bool)
+marg[2] = False
+
+model.train()
+epochs = 10
+for e in range(epochs):
+    loss = model(X).sum()
+    model.zero_grad()
+    loss.backward()
+    model.em_batch_update()
+    print(loss)
+
+#%%
+plt.figure(1).clf()
+with torch.no_grad():
+    model.eval()
+    model(X[0:3], marginalize=marg)
+    plt.plot(X[0], '-.')    
+
+    for k in range(100):
+        Y = model.sample()    
+        plt.plot(Y, '.', color='tab:orange')
\ No newline at end of file
diff --git a/demos/two_moons.py b/demos/two_moons.py
index 95716d5..75d420d 100644
--- a/demos/two_moons.py
+++ b/demos/two_moons.py
@@ -26,13 +26,13 @@ variables = 2
 channels = 4
 tracks = 4
 network = nn.Sequential(
-    NormalLeaf(tracks, variables, channels, n=N, alpha0=1, beta0=1),
+    NormalLeaf(tracks, variables, channels, n=N, alpha0=.01, beta0=.01),
     Weightsum(tracks, 1, channels)
 )
 print(f"Total parameters: {sum(p.numel() for p in network.parameters() if p.requires_grad)}")
 
 #%% Fit model
-epochs = 20
+epochs = 50
 for r in range(epochs):
     network.train()
     P = torch.sum(network(X))
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..c345b90
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,27 @@
+# pyproject.toml
+
+[build-system]
+requires = ["setuptools>=61.0.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "supr"
+version = "0.0.1"
+description = "SUm-PRoduct network Pytorch layers"
+readme = "README.md"
+authors = [{ name = "Mikkel N. Schmidt", email = "mnsc@dtu.dk" }]
+license = { file = "LICENSE" }
+classifiers = [
+    "License :: OSI Approved :: MIT License",
+    "Programming Language :: Python",
+    "Programming Language :: Python :: 3",
+]
+keywords = ["sum-product network"]
+dependencies = [
+    "torch",
+]
+requires-python = ">=3.9"
+
+#[project.urls]
+#Homepage = "https://github.com/"
+
diff --git a/src/supr.egg-info/PKG-INFO b/src/supr.egg-info/PKG-INFO
new file mode 100644
index 0000000..23bdd5d
--- /dev/null
+++ b/src/supr.egg-info/PKG-INFO
@@ -0,0 +1,16 @@
+Metadata-Version: 2.1
+Name: supr
+Version: 0.0.1
+Summary: SUm-PRoduct network Pytorch layers
+Author-email: "Mikkel N. Schmidt" <mnsc@dtu.dk>
+Keywords: sum-product network
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 3
+Requires-Python: >=3.9
+Description-Content-Type: text/markdown
+Requires-Dist: torch
+
+# supr
+
+Python toolbox for tensor-based sum product networks
diff --git a/src/supr.egg-info/SOURCES.txt b/src/supr.egg-info/SOURCES.txt
new file mode 100644
index 0000000..3583f90
--- /dev/null
+++ b/src/supr.egg-info/SOURCES.txt
@@ -0,0 +1,10 @@
+README.md
+pyproject.toml
+src/supr/__init__.py
+src/supr/layers.py
+src/supr/utils.py
+src/supr.egg-info/PKG-INFO
+src/supr.egg-info/SOURCES.txt
+src/supr.egg-info/dependency_links.txt
+src/supr.egg-info/requires.txt
+src/supr.egg-info/top_level.txt
\ No newline at end of file
diff --git a/src/supr.egg-info/dependency_links.txt b/src/supr.egg-info/dependency_links.txt
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/src/supr.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/src/supr.egg-info/requires.txt b/src/supr.egg-info/requires.txt
new file mode 100644
index 0000000..12c6d5d
--- /dev/null
+++ b/src/supr.egg-info/requires.txt
@@ -0,0 +1 @@
+torch
diff --git a/src/supr.egg-info/top_level.txt b/src/supr.egg-info/top_level.txt
new file mode 100644
index 0000000..9e0f0ec
--- /dev/null
+++ b/src/supr.egg-info/top_level.txt
@@ -0,0 +1 @@
+supr
diff --git a/supr/__init__.py b/src/supr/__init__.py
similarity index 100%
rename from supr/__init__.py
rename to src/supr/__init__.py
diff --git a/supr/layers.py b/src/supr/layers.py
similarity index 98%
rename from supr/layers.py
rename to src/supr/layers.py
index 441dc47..ddf864b 100644
--- a/supr/layers.py
+++ b/src/supr/layers.py
@@ -55,10 +55,10 @@ class Sequential(nn.Sequential):
             for module in self:
                 module.em_batch()
 
-    def em_update(self):
+    def em_update(self, *args, **kwargs):
         with torch.no_grad():
             for module in self:
-                module.em_update()
+                module.em_update(*args, **kwargs)
 
     def sample(self):
         value = []
@@ -308,7 +308,7 @@ class NormalLeaf(SuprLeaf):
         # Prior
         self.mu0, self.nu0, self.alpha0, self.beta0 = mu0, nu0, alpha0, beta0
         # Parametes
-        self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))
+        self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))*3
         self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5)
         # Which variables to marginalized
         self.register_buffer('marginalize', torch.zeros(
@@ -424,9 +424,11 @@ class BernoulliLeaf(SuprLeaf):
     def em_update(self, learning_rate: float = 1.):
         # Probability
         p_update = (self.z_x_acc + self.alpha0 - 1) / \
-            (self.z_acc + self.alpha0 + self.beta0 - 2)
+            (self.z_acc + self.alpha0 + self.beta0 - 2 + self.epsilon)
+
         self.p.data *= 1. - learning_rate
         self.p.data += learning_rate * p_update
+        
         # Reset accumulators
         self.z_acc.zero_()
         self.z_x_acc.zero_()
@@ -441,7 +443,7 @@ class BernoulliLeaf(SuprLeaf):
         r[~self.marginalize] = self.x[0][~self.marginalize]
         return r
 
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: torch.Tensor, marginalize=None):
         # Get shape
         batch_size = x.shape[0]
         # Store the data
diff --git a/supr/utils.py b/src/supr/utils.py
similarity index 100%
rename from supr/utils.py
rename to src/supr/utils.py
-- 
GitLab