diff --git a/supr/layers.py b/supr/layers.py index 1b34df5591518eb2cedd5b6f26b09a954c8c52b3..22c55ce590e16e6283e0c1461e9abc7a821974c8 100644 --- a/supr/layers.py +++ b/supr/layers.py @@ -1,5 +1,4 @@ import torch -import torch import torch.nn as nn from torch.nn.functional import pad import math @@ -8,10 +7,10 @@ from typing import List # Data: -# N x V x C -# └───│───│─ N: Data points -# └───│─ V: Variables -# └─ C: Channels +# N x V +# └───│── N: Data points +# └── V: Variables +# # Probability: # N x T x V x C # └───│───│───│─ N: Data points @@ -30,14 +29,29 @@ class SuprLayer(nn.Module): def em_update(self, *args, **kwargs): pass + +class SuprSequential(nn.Sequential): + def __init__(self, *args): + super().__init__(*args) + + def em_batch_update(self): + with torch.no_grad(): + for module in self: + module.em_batch() + module.em_update() + def sample(self): + value = [] + for module in reversed(self): + value = module.sample(*value) + return value class Parallel(SuprLayer): def __init__(self, nets: List[SuprLayer]): super().__init__() self.nets = nets - def forward(self, x: torch.Tensor): + def forward(self, x: List[torch.Tensor]): return [n(x) for n, x in zip(self.nets, x)] diff --git a/supr/utils.py b/supr/utils.py index 630b357e196514d116218e4c367c378016d260a7..790caf19f7be993df846506cc31cd9922af8a769 100644 --- a/supr/utils.py +++ b/supr/utils.py @@ -63,3 +63,8 @@ def local_scramble_2d(dist: float, dim: tuple): n = [torch.argsort(m + torch.randn(dim) * dist, dim=i) for i, m in enumerate(grid)] idx = torch.reshape(torch.arange(torch.tensor(dim).prod()), dim) return idx[n[0], grid[1]][grid[0], n[1]].flatten() + +def bsgen(v, v0): + while v > v0: + yield v + v = math.ceil(v / 2) \ No newline at end of file