Skip to content
Snippets Groups Projects
Commit da96baed authored by mnsc's avatar mnsc
Browse files

added sequential layer

parent 8204da3f
Branches
No related tags found
No related merge requests found
import torch
import torch
import torch.nn as nn
from torch.nn.functional import pad
import math
......@@ -8,10 +7,10 @@ from typing import List
# Data:
# N x V x C
# └───│───│─ N: Data points
# └───│─ V: Variables
# └─ C: Channels
# N x V
# └───│── N: Data points
# └── V: Variables
#
# Probability:
# N x T x V x C
# └───│───│───│─ N: Data points
......@@ -31,13 +30,28 @@ class SuprLayer(nn.Module):
def em_update(self, *args, **kwargs):
pass
class SuprSequential(nn.Sequential):
def __init__(self, *args):
super().__init__(*args)
def em_batch_update(self):
with torch.no_grad():
for module in self:
module.em_batch()
module.em_update()
def sample(self):
value = []
for module in reversed(self):
value = module.sample(*value)
return value
class Parallel(SuprLayer):
def __init__(self, nets: List[SuprLayer]):
super().__init__()
self.nets = nets
def forward(self, x: torch.Tensor):
def forward(self, x: List[torch.Tensor]):
return [n(x) for n, x in zip(self.nets, x)]
......
......@@ -63,3 +63,8 @@ def local_scramble_2d(dist: float, dim: tuple):
n = [torch.argsort(m + torch.randn(dim) * dist, dim=i) for i, m in enumerate(grid)]
idx = torch.reshape(torch.arange(torch.tensor(dim).prod()), dim)
return idx[n[0], grid[1]][grid[0], n[1]].flatten()
def bsgen(v, v0):
while v > v0:
yield v
v = math.ceil(v / 2)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment