diff --git a/supr/layers.py b/supr/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..1b34df5591518eb2cedd5b6f26b09a954c8c52b3 --- /dev/null +++ b/supr/layers.py @@ -0,0 +1,352 @@ +import torch +import torch +import torch.nn as nn +from torch.nn.functional import pad +import math +from supr.utils import discrete_rand, local_scramble_2d +from typing import List + + +# Data: +# N x V x C +# └───│───│─ N: Data points +# └───│─ V: Variables +# └─ C: Channels +# Probability: +# N x T x V x C +# └───│───│───│─ N: Data points +# └───│───│─ T: Tracks +# └───│─ V: Variables +# └─ C: Channels + +class SuprLayer(nn.Module): + epsilon = 1e-12 + + def __init__(self): + super().__init__() + + def em_batch(self): + pass + + def em_update(self, *args, **kwargs): + pass + + +class Parallel(SuprLayer): + def __init__(self, nets: List[SuprLayer]): + super().__init__() + self.nets = nets + + def forward(self, x: torch.Tensor): + return [n(x) for n, x in zip(self.nets, x)] + + +class ScrambleTracks(SuprLayer): + """ Scrambles the variables in each track """ + + def __init__(self, tracks: int, variables: int): + super().__init__() + # Permutation for each track + perm = torch.stack([torch.randperm(variables) for _ in range(tracks)]) + self.register_buffer('perm', perm) + + def sample(self, track, channel_per_variable): + return track, torch.scatter(channel_per_variable, 0, self.perm[track], channel_per_variable) + + def forward(self, x): + return x[:, torch.arange(x.shape[1])[:, None], self.perm, :] + +class ScrambleTracks2d(SuprLayer): + """ Scrambles the variables in each track """ + + def __init__(self, tracks: int, variables: int, distance: float, dims: tuple): + super().__init__() + # Permutation for each track + perm = torch.stack([local_scramble_2d(distance, dims) for _ in range(tracks)]) + self.register_buffer('perm', perm) + + def sample(self, track, channel_per_variable): + return track, torch.scatter(channel_per_variable, 0, self.perm[track], channel_per_variable) + + def forward(self, x): + return x[:, torch.arange(x.shape[1])[:, None], self.perm, :] + + +class VariablesProduct(SuprLayer): + """ Product over all variables """ + + def __init(self): + super().__init__() + + def sample(self, track, channel_per_variable): + return track, torch.full((self.variables, ), channel_per_variable[0]).to(channel_per_variable.device) + + def forward(self, x): + if not self.training: + self.variables = x.shape[2] + return torch.sum(x, dim=2, keepdim=True) + + +class ProductSumLayer(SuprLayer): + """ Base class for product-sum layers """ + def __init__(self, weight_shape, normalize_dims): + super().__init__() + # Parameters + self.weights = nn.Parameter(torch.rand(*weight_shape)) + self.weights.data /= torch.clamp(self.weights.sum(dim=normalize_dims, keepdim=True), self.epsilon) + # Normalize dimensions + self.normalize_dims = normalize_dims + # EM accumulator + self.register_buffer('weights_acc', torch.zeros(*weight_shape)) + + def em_batch(self): + self.weights_acc.data += self.weights * self.weights.grad + + def em_update(self, learning_rate: float = 1.): + weights_grad = torch.clamp(self.weights_acc, self.epsilon) + weights_grad /= torch.clamp(weights_grad.sum(dim=self.normalize_dims, keepdim=True), self.epsilon) + if learning_rate < 1.: + self.weights.data *= 1. - learning_rate + self.weights.data += learning_rate * weights_grad + else: + self.weights.data = weights_grad + # Reset accumulator + self.weights_acc.zero_() + + +class Einsum(ProductSumLayer): + """ Einsum layer """ + + def __init__(self, tracks: int, variables: int, channels: int, channels_out: int = None): + # Dimensions + variables_out = math.ceil(variables / 2) + if channels_out is None: + channels_out = channels + # Initialize super + super().__init__((tracks, variables_out, channels_out, channels, channels), (3, 4)) + # Padding + self.x1_pad = torch.zeros(variables_out, dtype=torch.bool) + self.x2_pad = torch.zeros(variables_out, dtype=torch.bool) + # Zero-pad if necessary + if variables % 2 == 1: + # Pad on the right + self.pad = True + self.x2_padding = [0, 0, 0, 1] + self.x2_pad[-1] = True + else: + self.pad = False + # TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks + + def sample(self, track: int, channel_per_variable: torch.Tensor): + r = [] + for v, c in enumerate(channel_per_variable): + # Probability matrix + px1 = torch.exp(self.x1[0, track, v, :][:, None]) + px2 = torch.exp(self.x2[0, track, v, :][None, :]) + prob = self.weights[track, v, c] * px1 * px2 + # Sample + idx = discrete_rand(prob)[0] + # Remove indices of padding + idx_valid = idx[[not self.x1_pad[v], not self.x2_pad[v]]] + # Store on list + r.append(idx_valid) + # Concatenate and return indices + return track, torch.cat(r) + + def forward(self, x: torch.Tensor): + # Split the input variables in two and apply padding if necessary + x1 = x[:, :, 0::2, :] + x2 = x[:, :, 1::2, :] + if self.pad: + x2 = pad(x2, self.x2_padding) + # Store the inputs for use in sampling routine + if not self.training: + self.x1, self.x2 = x1, x2 + # Compute maximum + a1, a2 = [torch.max(x, dim=3, keepdim=True)[0] for x in [x1, x2]] + # Subtract maximum and compute exponential + exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon) for x, a in [(x1, a1), (x2, a2)]] + # Compute the contraction + y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights)) + return y + +class Weightsum(ProductSumLayer): + """ Weightsum layer """ + + # Product over all variables and weighted sum over tracks and channels + def __init__(self, tracks: int, variables: int, channels: int): + # Initialize super + super().__init__((tracks, channels), (0, 1)) + + def sample(self): + prob = self.weights * torch.exp(self.x_sum[0] - torch.max(self.x_sum[0])) + s = discrete_rand(prob)[0] + return s[0], torch.full((self.variables,), s[1]).to(self.weights.device) + + def forward(self, x: torch.Tensor): + # Product over variables + x_sum = torch.sum(x, 2) + # Store the inputs for use in sampling routine + if not self.training: + self.x_sum = x_sum + self.variables = x.shape[2] + # Compute maximum + a = torch.max(torch.max(x_sum, dim=1)[0], dim=1)[0] + # Subtract maximum and compute exponential + exa = torch.clamp(torch.exp(x_sum - a[:, None, None]), self.epsilon) + # Compute the contraction + y = a + torch.log(torch.einsum('ntc,tc->n', exa, self.weights)) + return y + + +class TrackSum(ProductSumLayer): + """ TrackSum layer """ + + # Weighted sum over tracks + def __init__(self, tracks: int, channels: int): + # Initialize super + super().__init__((tracks, channels), (0, )) + + def sample(self, track: int, channel_per_variable: torch.Tensor): + prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0]) + s = discrete_rand(prob)[0] + return s[0], channel_per_variable + + def forward(self, x: torch.Tensor): + # Module is only valid when number of variables is 1 + assert x.shape[2] == 1 + # Store the inputs for use in sampling routine + if not self.training: + self.x = x + # Compute maximum + a = torch.max(x, dim=1)[0] + # Subtract maximum and compute exponential + exa = torch.clamp(torch.exp(x - a[:, None]), self.epsilon) + # Compute the contraction + y = a + torch.log(torch.einsum('ntvc,tc->nvc', exa, self.weights)) + # Insert track dimension + y = y[:, None] + return y + + +class NormalLeaf(SuprLayer): + """ NormalLeaf layer """ + + def __init__(self, tracks: int, variables: int, channels: int): + super().__init__() + # Dimensions + self.T, self.V, self.C = tracks, variables, channels + # Parametes + # self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C)) + # self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1))) + self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C)) + self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5) + # Which variables to marginalized + self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) + # Input + self.register_buffer('x', torch.Tensor()) + # Output + self.register_buffer('z', torch.Tensor()) + # EM accumulator + self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C)) + self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C)) + self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C)) + + def em_batch(self): + self.z_acc.data += torch.sum(self.z.grad, dim=0) + self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) + self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0) + + def em_update(self, learning_rate: float = 1.): + # Mean + sum_z = torch.clamp(self.z_acc, self.epsilon) + self.mu.data *= 1. - learning_rate + self.mu.data += learning_rate * self.z_x_acc / sum_z + # Standard deviation + self.sig.data *= 1 - learning_rate + self.sig.data += learning_rate * torch.sqrt(torch.clamp(self.z_x_sq_acc / sum_z - self.mu ** 2, self.epsilon + 0.01)) + # Reset accumulators + self.z_acc.zero_() + self.z_x_acc.zero_() + self.z_x_sq_acc.zero_() + + def sample(self, track: int, channel_per_variable: torch.Tensor): + variables_marginalize = torch.sum(self.marginalize).int() + mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]] + sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]] + r = torch.empty_like(self.x[0]) + r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * sig_marginalize + r[~self.marginalize] = self.x[0][~self.marginalize] + return r + + def forward(self, x: torch.Tensor): + # Get shape + batch_size = x.shape[0] + # Store the data + self.x = x + # Compute the probability + self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) + # Get non-marginalized parameters and data + mu_valid = self.mu[None, :, ~self.marginalize, :] + sig_valid = self.sig[None, :, ~self.marginalize, :] + x_valid = self.x[:, None, ~self.marginalize, None] + # Evaluate log probability + self.z.data[:, :, ~self.marginalize, :] = \ + torch.distributions.Normal(mu_valid, sig_valid).log_prob(x_valid).float() + return self.z + +class BernoulliLeaf(SuprLayer): + """ BernoulliLeaf layer """ + + def __init__(self, tracks: int, variables: int, channels: int): + super().__init__() + # Dimensions + self.T, self.V, self.C = tracks, variables, channels + # Parametes + self.p = nn.Parameter(torch.rand(self.T, self.V, self.C)) + # Which variables to marginalized + self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool)) + # Input + self.register_buffer('x', torch.Tensor()) + # Output + self.register_buffer('z', torch.Tensor()) + # EM accumulator + self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C)) + self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C)) + + def em_batch(self): + self.z_acc.data += torch.sum(self.z.grad, dim=0) + self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0) + + def em_update(self, learning_rate: float = 1.): + # Probability + sum_z = torch.clamp(self.z_acc, self.epsilon) + self.p.data *= 1. - learning_rate + self.p.data += learning_rate * self.z_x_acc / sum_z + # Reset accumulators + self.z_acc.zero_() + self.z_x_acc.zero_() + + def sample(self, track: int, channel_per_variable: torch.Tensor): + variables_marginalize = torch.sum(self.marginalize).int() + p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize]] + r = torch.empty_like(self.x[0]) + r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float() + r[~self.marginalize] = self.x[0][~self.marginalize] + return r + + def forward(self, x: torch.Tensor): + # Get shape + batch_size = x.shape[0] + # Store the data + self.x = x + # Compute the probability + self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device) + # Get non-marginalized parameters and data + p_valid = self.p[None, :, ~self.marginalize, :] + x_valid = self.x[:, None, ~self.marginalize, None] + # Evaluate log probability + self.z.data[:, :, ~self.marginalize, :] = \ + p_valid*(x_valid==1) + (1-p_valid)*(x_valid==0) + return self.z + diff --git a/supr/utils.py b/supr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..18c24317bcab9802dbc09b715cff2917c2d5071c --- /dev/null +++ b/supr/utils.py @@ -0,0 +1,65 @@ +# %% Imports +import matplotlib.pyplot as plt +from typing import Tuple +import torch +import numpy as np +import math +from PyQt5 import QtWidgets + + +# %% +def drawnow(): + plt.gcf().canvas.draw() + plt.gcf().canvas.flush_events() + + +def arrange_figs(cols=1, min_rows=3, toolbar=False, x0=1400, y0=28, x1=1920, y1=1200): + try: + current_fig_num = plt.gcf().number + extra = 37 + w = x1 - x0 + h = y1 - y0 + fignums = plt.get_fignums() + n = len(fignums) + rows = np.maximum(math.ceil(n / cols), min_rows) + height = int(h / rows - extra) + width = int(w / cols) + for i, fn in enumerate(fignums): + r = i % rows + c = int(i / rows) + plt.figure(fn) + win = plt.get_current_fig_manager().window + win.findChild(QtWidgets.QToolBar).setVisible(toolbar) + win.setGeometry(x0 + width * c, y0 + int(h / rows * r) + extra, width, height) + plt.figure(current_fig_num) + except: + pass + + +def unravel_indices(indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.LongTensor: + r"""Converts flat indices into unraveled coordinates in a target shape. + Args: + indices: A tensor of (flat) indices, (*, N). + shape: The targeted shape, (D,). + + Returns: + The unraveled coordinates, (*, N, D). + """ + coord = [] + for dim in reversed(shape): + coord.append(indices % dim) + indices = indices // dim + coord = torch.stack(coord[::-1], dim=-1) + return coord + + +def discrete_rand(v: torch.Tensor, n: int = 1): + idx = torch.sum(torch.rand(n)[:, None].to(v.device) > torch.cumsum(v.flatten(), 0)[None, :] / torch.sum(v), dim=1) + return unravel_indices(idx, v.shape) + + +def local_scramble_2d(dist: float, dim: tuple): + grid = torch.meshgrid(*[torch.arange(d) for d in dim]) + n = [torch.argsort(m + torch.randn(dim) * dist, dim=i) for i, m in enumerate(grid)] + idx = torch.reshape(torch.arange(torch.tensor(dim).prod()), dim) + return idx[n[0], grid[1]][grid[0], n[1]].flatten()