diff --git a/demos/grid_xy.py b/demos/grid_xy.py
new file mode 100644
index 0000000000000000000000000000000000000000..85b5c980aee4663735bc5e5ba5186f3fc381aff4
--- /dev/null
+++ b/demos/grid_xy.py
@@ -0,0 +1,91 @@
+# %% Import libraries
+import matplotlib.pyplot as plt
+import torch
+import supr
+from supr.utils import drawnow
+from scipy.stats import norm
+from math import sqrt
+import numpy as np
+
+# %% Default settings
+plt.ion()
+
+# %% Dataset
+Nc = 10
+K = 3
+N = K**2*Nc
+sigma_x = torch.tensor([[[.02], [.04], [.08]], [[.02], [.04], [.08]],[[.02], [.04], [.08]]])
+sigma_y = torch.tensor([[[.02], [.04], [.08]], [[.02], [.04], [.08]],[[.02], [.04], [.08]]])
+mu_x = torch.linspace(0, 1, K)
+mu_y = torch.linspace(0, 1, K)
+x, y = torch.meshgrid((mu_x, mu_y))
+x = torch.flatten(x[:,:,None] + torch.randn((K, K, Nc))*sigma_x)
+y = torch.flatten(y[:,:,None] + torch.randn((K, K, Nc))*sigma_y)
+
+X = torch.stack((x, y), dim=1)
+
+# %% Grid to evaluate predictive distribution
+x_res, y_res = 400, 400
+x_min, x_max = -.25, 1.25
+y_min, y_max = -.25, 1.25
+x_grid = torch.linspace(x_min, x_max, x_res)
+y_grid = torch.linspace(y_min, y_max, y_res)
+XY_grid = torch.stack([x.flatten() for x in torch.meshgrid(
+    x_grid, y_grid, indexing='ij')], dim=1)
+X_grid = torch.stack([x_grid, torch.zeros(x_res)]).T
+
+# %% Sum-product network
+# Parameters
+tracks = 1
+variables = 2
+channels = 3
+
+# Priors for variance of x and y
+alpha0 = torch.tensor([[[1], [1]]])
+beta0 = torch.tensor([[[.01], [.01]]])
+
+# Construct SPN model
+model = supr.Sequential(
+    supr.NormalLeaf(tracks, variables, channels, n=N, mu0=0.,
+                    nu0=0, alpha0=alpha0, beta0=beta0),
+    supr.Einsum(tracks, variables, channels, 1),
+    supr.Weightsum(tracks, variables, 1)
+)
+
+# %% Fit model and display results
+epochs = 20
+
+for epoch in range(epochs):
+    model.train()
+    model[0].marginalize = torch.zeros(variables, dtype=torch.bool)
+    logp = model(X).sum()
+    print(f"Log-posterior ∝ {logp:.2f} ")
+    model.zero_grad(True)
+    logp.backward()
+
+    model.eval()  # swap?
+    model.em_batch_update()
+
+    # Plot data and model
+    # -------------------------------------------------------------------------
+    # Evaluate joint distribution on grid
+    with torch.no_grad():
+        log_p_xy = model(XY_grid)
+        p_xy = torch.exp(log_p_xy).reshape(x_res, y_res)
+
+        # Plot posterior
+        plt.figure(1).clf()
+        plt.title('Posterior distribution')
+        dx = (x_max-x_min)/x_res/2
+        dy = (y_max-y_min)/y_res/2
+        extent = [x_grid[0]-dx, x_grid[-1]+dx, y_grid[0]-dy, y_grid[-1]+dy]
+        plt.imshow(torch.log(p_xy).T, extent=extent,
+                   aspect='auto', origin='lower',
+                   vmin=-4, vmax=1, cmap='Blues')
+        plt.plot(x, y, '.', color='tab:orange', alpha=.5,
+                 markersize=15, markeredgewidth=0)
+        plt.axis('square')
+        plt.xlim([x_min, x_max])
+        plt.ylim([y_min, y_max])
+        drawnow()
+