Skip to content
Snippets Groups Projects
Commit ccce14dc authored by mnsc's avatar mnsc
Browse files

initial commit

parent 0bd06f70
No related branches found
No related tags found
No related merge requests found
import torch
import torch
import torch.nn as nn
from torch.nn.functional import pad
import math
from supr.utils import discrete_rand, local_scramble_2d
from typing import List
# Data:
# N x V x C
# └───│───│─ N: Data points
# └───│─ V: Variables
# └─ C: Channels
# Probability:
# N x T x V x C
# └───│───│───│─ N: Data points
# └───│───│─ T: Tracks
# └───│─ V: Variables
# └─ C: Channels
class SuprLayer(nn.Module):
epsilon = 1e-12
def __init__(self):
super().__init__()
def em_batch(self):
pass
def em_update(self, *args, **kwargs):
pass
class Parallel(SuprLayer):
def __init__(self, nets: List[SuprLayer]):
super().__init__()
self.nets = nets
def forward(self, x: torch.Tensor):
return [n(x) for n, x in zip(self.nets, x)]
class ScrambleTracks(SuprLayer):
""" Scrambles the variables in each track """
def __init__(self, tracks: int, variables: int):
super().__init__()
# Permutation for each track
perm = torch.stack([torch.randperm(variables) for _ in range(tracks)])
self.register_buffer('perm', perm)
def sample(self, track, channel_per_variable):
return track, torch.scatter(channel_per_variable, 0, self.perm[track], channel_per_variable)
def forward(self, x):
return x[:, torch.arange(x.shape[1])[:, None], self.perm, :]
class ScrambleTracks2d(SuprLayer):
""" Scrambles the variables in each track """
def __init__(self, tracks: int, variables: int, distance: float, dims: tuple):
super().__init__()
# Permutation for each track
perm = torch.stack([local_scramble_2d(distance, dims) for _ in range(tracks)])
self.register_buffer('perm', perm)
def sample(self, track, channel_per_variable):
return track, torch.scatter(channel_per_variable, 0, self.perm[track], channel_per_variable)
def forward(self, x):
return x[:, torch.arange(x.shape[1])[:, None], self.perm, :]
class VariablesProduct(SuprLayer):
""" Product over all variables """
def __init(self):
super().__init__()
def sample(self, track, channel_per_variable):
return track, torch.full((self.variables, ), channel_per_variable[0]).to(channel_per_variable.device)
def forward(self, x):
if not self.training:
self.variables = x.shape[2]
return torch.sum(x, dim=2, keepdim=True)
class ProductSumLayer(SuprLayer):
""" Base class for product-sum layers """
def __init__(self, weight_shape, normalize_dims):
super().__init__()
# Parameters
self.weights = nn.Parameter(torch.rand(*weight_shape))
self.weights.data /= torch.clamp(self.weights.sum(dim=normalize_dims, keepdim=True), self.epsilon)
# Normalize dimensions
self.normalize_dims = normalize_dims
# EM accumulator
self.register_buffer('weights_acc', torch.zeros(*weight_shape))
def em_batch(self):
self.weights_acc.data += self.weights * self.weights.grad
def em_update(self, learning_rate: float = 1.):
weights_grad = torch.clamp(self.weights_acc, self.epsilon)
weights_grad /= torch.clamp(weights_grad.sum(dim=self.normalize_dims, keepdim=True), self.epsilon)
if learning_rate < 1.:
self.weights.data *= 1. - learning_rate
self.weights.data += learning_rate * weights_grad
else:
self.weights.data = weights_grad
# Reset accumulator
self.weights_acc.zero_()
class Einsum(ProductSumLayer):
""" Einsum layer """
def __init__(self, tracks: int, variables: int, channels: int, channels_out: int = None):
# Dimensions
variables_out = math.ceil(variables / 2)
if channels_out is None:
channels_out = channels
# Initialize super
super().__init__((tracks, variables_out, channels_out, channels, channels), (3, 4))
# Padding
self.x1_pad = torch.zeros(variables_out, dtype=torch.bool)
self.x2_pad = torch.zeros(variables_out, dtype=torch.bool)
# Zero-pad if necessary
if variables % 2 == 1:
# Pad on the right
self.pad = True
self.x2_padding = [0, 0, 0, 1]
self.x2_pad[-1] = True
else:
self.pad = False
# TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks
def sample(self, track: int, channel_per_variable: torch.Tensor):
r = []
for v, c in enumerate(channel_per_variable):
# Probability matrix
px1 = torch.exp(self.x1[0, track, v, :][:, None])
px2 = torch.exp(self.x2[0, track, v, :][None, :])
prob = self.weights[track, v, c] * px1 * px2
# Sample
idx = discrete_rand(prob)[0]
# Remove indices of padding
idx_valid = idx[[not self.x1_pad[v], not self.x2_pad[v]]]
# Store on list
r.append(idx_valid)
# Concatenate and return indices
return track, torch.cat(r)
def forward(self, x: torch.Tensor):
# Split the input variables in two and apply padding if necessary
x1 = x[:, :, 0::2, :]
x2 = x[:, :, 1::2, :]
if self.pad:
x2 = pad(x2, self.x2_padding)
# Store the inputs for use in sampling routine
if not self.training:
self.x1, self.x2 = x1, x2
# Compute maximum
a1, a2 = [torch.max(x, dim=3, keepdim=True)[0] for x in [x1, x2]]
# Subtract maximum and compute exponential
exa1, exa2 = [torch.clamp(torch.exp(x - a), self.epsilon) for x, a in [(x1, a1), (x2, a2)]]
# Compute the contraction
y = a1 + a2 + torch.log(torch.einsum('ntva,ntvb,tvcab->ntvc', exa1, exa2, self.weights))
return y
class Weightsum(ProductSumLayer):
""" Weightsum layer """
# Product over all variables and weighted sum over tracks and channels
def __init__(self, tracks: int, variables: int, channels: int):
# Initialize super
super().__init__((tracks, channels), (0, 1))
def sample(self):
prob = self.weights * torch.exp(self.x_sum[0] - torch.max(self.x_sum[0]))
s = discrete_rand(prob)[0]
return s[0], torch.full((self.variables,), s[1]).to(self.weights.device)
def forward(self, x: torch.Tensor):
# Product over variables
x_sum = torch.sum(x, 2)
# Store the inputs for use in sampling routine
if not self.training:
self.x_sum = x_sum
self.variables = x.shape[2]
# Compute maximum
a = torch.max(torch.max(x_sum, dim=1)[0], dim=1)[0]
# Subtract maximum and compute exponential
exa = torch.clamp(torch.exp(x_sum - a[:, None, None]), self.epsilon)
# Compute the contraction
y = a + torch.log(torch.einsum('ntc,tc->n', exa, self.weights))
return y
class TrackSum(ProductSumLayer):
""" TrackSum layer """
# Weighted sum over tracks
def __init__(self, tracks: int, channels: int):
# Initialize super
super().__init__((tracks, channels), (0, ))
def sample(self, track: int, channel_per_variable: torch.Tensor):
prob = self.weights[:, None] * torch.exp(self.x[0] - torch.max(self.x[0], dim=0)[0])
s = discrete_rand(prob)[0]
return s[0], channel_per_variable
def forward(self, x: torch.Tensor):
# Module is only valid when number of variables is 1
assert x.shape[2] == 1
# Store the inputs for use in sampling routine
if not self.training:
self.x = x
# Compute maximum
a = torch.max(x, dim=1)[0]
# Subtract maximum and compute exponential
exa = torch.clamp(torch.exp(x - a[:, None]), self.epsilon)
# Compute the contraction
y = a + torch.log(torch.einsum('ntvc,tc->nvc', exa, self.weights))
# Insert track dimension
y = y[:, None]
return y
class NormalLeaf(SuprLayer):
""" NormalLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int):
super().__init__()
# Dimensions
self.T, self.V, self.C = tracks, variables, channels
# Parametes
# self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
# self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
self.mu = nn.Parameter(torch.rand(self.T, self.V, self.C))
self.sig = nn.Parameter(torch.ones(self.T, self.V, self.C) * 0.5)
# Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
# Input
self.register_buffer('x', torch.Tensor())
# Output
self.register_buffer('z', torch.Tensor())
# EM accumulator
self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C))
self.register_buffer('z_x_sq_acc', torch.zeros(self.T, self.V, self.C))
def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0)
self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
self.z_x_sq_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None] ** 2, dim=0)
def em_update(self, learning_rate: float = 1.):
# Mean
sum_z = torch.clamp(self.z_acc, self.epsilon)
self.mu.data *= 1. - learning_rate
self.mu.data += learning_rate * self.z_x_acc / sum_z
# Standard deviation
self.sig.data *= 1 - learning_rate
self.sig.data += learning_rate * torch.sqrt(torch.clamp(self.z_x_sq_acc / sum_z - self.mu ** 2, self.epsilon + 0.01))
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
self.z_x_sq_acc.zero_()
def sample(self, track: int, channel_per_variable: torch.Tensor):
variables_marginalize = torch.sum(self.marginalize).int()
mu_marginalize = self.mu[track, self.marginalize, channel_per_variable[self.marginalize]]
sig_marginalize = self.sig[track, self.marginalize, channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0])
r[self.marginalize] = mu_marginalize + torch.randn(variables_marginalize).to(self.x.device) * sig_marginalize
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
def forward(self, x: torch.Tensor):
# Get shape
batch_size = x.shape[0]
# Store the data
self.x = x
# Compute the probability
self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
# Get non-marginalized parameters and data
mu_valid = self.mu[None, :, ~self.marginalize, :]
sig_valid = self.sig[None, :, ~self.marginalize, :]
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
torch.distributions.Normal(mu_valid, sig_valid).log_prob(x_valid).float()
return self.z
class BernoulliLeaf(SuprLayer):
""" BernoulliLeaf layer """
def __init__(self, tracks: int, variables: int, channels: int):
super().__init__()
# Dimensions
self.T, self.V, self.C = tracks, variables, channels
# Parametes
self.p = nn.Parameter(torch.rand(self.T, self.V, self.C))
# Which variables to marginalized
self.register_buffer('marginalize', torch.zeros(variables, dtype=torch.bool))
# Input
self.register_buffer('x', torch.Tensor())
# Output
self.register_buffer('z', torch.Tensor())
# EM accumulator
self.register_buffer('z_acc', torch.zeros(self.T, self.V, self.C))
self.register_buffer('z_x_acc', torch.zeros(self.T, self.V, self.C))
def em_batch(self):
self.z_acc.data += torch.sum(self.z.grad, dim=0)
self.z_x_acc.data += torch.sum(self.z.grad * self.x[:, None, :, None], dim=0)
def em_update(self, learning_rate: float = 1.):
# Probability
sum_z = torch.clamp(self.z_acc, self.epsilon)
self.p.data *= 1. - learning_rate
self.p.data += learning_rate * self.z_x_acc / sum_z
# Reset accumulators
self.z_acc.zero_()
self.z_x_acc.zero_()
def sample(self, track: int, channel_per_variable: torch.Tensor):
variables_marginalize = torch.sum(self.marginalize).int()
p_marginalize = self.p[track, self.marginalize, channel_per_variable[self.marginalize]]
r = torch.empty_like(self.x[0])
r[self.marginalize] = (torch.rand(variables_marginalize).to(self.x.device) < p_marginalize).float()
r[~self.marginalize] = self.x[0][~self.marginalize]
return r
def forward(self, x: torch.Tensor):
# Get shape
batch_size = x.shape[0]
# Store the data
self.x = x
# Compute the probability
self.z = torch.zeros(batch_size, self.T, self.V, self.C, requires_grad=True, device=x.device)
# Get non-marginalized parameters and data
p_valid = self.p[None, :, ~self.marginalize, :]
x_valid = self.x[:, None, ~self.marginalize, None]
# Evaluate log probability
self.z.data[:, :, ~self.marginalize, :] = \
p_valid*(x_valid==1) + (1-p_valid)*(x_valid==0)
return self.z
# %% Imports
import matplotlib.pyplot as plt
from typing import Tuple
import torch
import numpy as np
import math
from PyQt5 import QtWidgets
# %%
def drawnow():
plt.gcf().canvas.draw()
plt.gcf().canvas.flush_events()
def arrange_figs(cols=1, min_rows=3, toolbar=False, x0=1400, y0=28, x1=1920, y1=1200):
try:
current_fig_num = plt.gcf().number
extra = 37
w = x1 - x0
h = y1 - y0
fignums = plt.get_fignums()
n = len(fignums)
rows = np.maximum(math.ceil(n / cols), min_rows)
height = int(h / rows - extra)
width = int(w / cols)
for i, fn in enumerate(fignums):
r = i % rows
c = int(i / rows)
plt.figure(fn)
win = plt.get_current_fig_manager().window
win.findChild(QtWidgets.QToolBar).setVisible(toolbar)
win.setGeometry(x0 + width * c, y0 + int(h / rows * r) + extra, width, height)
plt.figure(current_fig_num)
except:
pass
def unravel_indices(indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.LongTensor:
r"""Converts flat indices into unraveled coordinates in a target shape.
Args:
indices: A tensor of (flat) indices, (*, N).
shape: The targeted shape, (D,).
Returns:
The unraveled coordinates, (*, N, D).
"""
coord = []
for dim in reversed(shape):
coord.append(indices % dim)
indices = indices // dim
coord = torch.stack(coord[::-1], dim=-1)
return coord
def discrete_rand(v: torch.Tensor, n: int = 1):
idx = torch.sum(torch.rand(n)[:, None].to(v.device) > torch.cumsum(v.flatten(), 0)[None, :] / torch.sum(v), dim=1)
return unravel_indices(idx, v.shape)
def local_scramble_2d(dist: float, dim: tuple):
grid = torch.meshgrid(*[torch.arange(d) for d in dim])
n = [torch.argsort(m + torch.randn(dim) * dist, dim=i) for i, m in enumerate(grid)]
idx = torch.reshape(torch.arange(torch.tensor(dim).prod()), dim)
return idx[n[0], grid[1]][grid[0], n[1]].flatten()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment