Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
supr
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mnsc
supr
Commits
ccce14dc
Commit
ccce14dc
authored
3 years ago
by
mnsc
Browse files
Options
Downloads
Patches
Plain Diff
initial commit
parent
0bd06f70
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
supr/layers.py
+352
-0
352 additions, 0 deletions
supr/layers.py
supr/utils.py
+65
-0
65 additions, 0 deletions
supr/utils.py
with
417 additions
and
0 deletions
supr/layers.py
0 → 100644
+
352
−
0
View file @
ccce14dc
import
torch
import
torch
import
torch.nn
as
nn
from
torch.nn.functional
import
pad
import
math
from
supr.utils
import
discrete_rand
,
local_scramble_2d
from
typing
import
List
# Data:
# N x V x C
# └───│───│─ N: Data points
# └───│─ V: Variables
# └─ C: Channels
# Probability:
# N x T x V x C
# └───│───│───│─ N: Data points
# └───│───│─ T: Tracks
# └───│─ V: Variables
# └─ C: Channels
class
SuprLayer
(
nn
.
Module
):
epsilon
=
1e-12
def
__init__
(
self
):
super
().
__init__
()
def
em_batch
(
self
):
pass
def
em_update
(
self
,
*
args
,
**
kwargs
):
pass
class
Parallel
(
SuprLayer
):
def
__init__
(
self
,
nets
:
List
[
SuprLayer
]):
super
().
__init__
()
self
.
nets
=
nets
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
[
n
(
x
)
for
n
,
x
in
zip
(
self
.
nets
,
x
)]
class
ScrambleTracks
(
SuprLayer
):
"""
Scrambles the variables in each track
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
):
super
().
__init__
()
# Permutation for each track
perm
=
torch
.
stack
([
torch
.
randperm
(
variables
)
for
_
in
range
(
tracks
)])
self
.
register_buffer
(
'
perm
'
,
perm
)
def
sample
(
self
,
track
,
channel_per_variable
):
return
track
,
torch
.
scatter
(
channel_per_variable
,
0
,
self
.
perm
[
track
],
channel_per_variable
)
def
forward
(
self
,
x
):
return
x
[:,
torch
.
arange
(
x
.
shape
[
1
])[:,
None
],
self
.
perm
,
:]
class
ScrambleTracks2d
(
SuprLayer
):
"""
Scrambles the variables in each track
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
distance
:
float
,
dims
:
tuple
):
super
().
__init__
()
# Permutation for each track
perm
=
torch
.
stack
([
local_scramble_2d
(
distance
,
dims
)
for
_
in
range
(
tracks
)])
self
.
register_buffer
(
'
perm
'
,
perm
)
def
sample
(
self
,
track
,
channel_per_variable
):
return
track
,
torch
.
scatter
(
channel_per_variable
,
0
,
self
.
perm
[
track
],
channel_per_variable
)
def
forward
(
self
,
x
):
return
x
[:,
torch
.
arange
(
x
.
shape
[
1
])[:,
None
],
self
.
perm
,
:]
class
VariablesProduct
(
SuprLayer
):
"""
Product over all variables
"""
def
__init
(
self
):
super
().
__init__
()
def
sample
(
self
,
track
,
channel_per_variable
):
return
track
,
torch
.
full
((
self
.
variables
,
),
channel_per_variable
[
0
]).
to
(
channel_per_variable
.
device
)
def
forward
(
self
,
x
):
if
not
self
.
training
:
self
.
variables
=
x
.
shape
[
2
]
return
torch
.
sum
(
x
,
dim
=
2
,
keepdim
=
True
)
class
ProductSumLayer
(
SuprLayer
):
"""
Base class for product-sum layers
"""
def
__init__
(
self
,
weight_shape
,
normalize_dims
):
super
().
__init__
()
# Parameters
self
.
weights
=
nn
.
Parameter
(
torch
.
rand
(
*
weight_shape
))
self
.
weights
.
data
/=
torch
.
clamp
(
self
.
weights
.
sum
(
dim
=
normalize_dims
,
keepdim
=
True
),
self
.
epsilon
)
# Normalize dimensions
self
.
normalize_dims
=
normalize_dims
# EM accumulator
self
.
register_buffer
(
'
weights_acc
'
,
torch
.
zeros
(
*
weight_shape
))
def
em_batch
(
self
):
self
.
weights_acc
.
data
+=
self
.
weights
*
self
.
weights
.
grad
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
weights_grad
=
torch
.
clamp
(
self
.
weights_acc
,
self
.
epsilon
)
weights_grad
/=
torch
.
clamp
(
weights_grad
.
sum
(
dim
=
self
.
normalize_dims
,
keepdim
=
True
),
self
.
epsilon
)
if
learning_rate
<
1.
:
self
.
weights
.
data
*=
1.
-
learning_rate
self
.
weights
.
data
+=
learning_rate
*
weights_grad
else
:
self
.
weights
.
data
=
weights_grad
# Reset accumulator
self
.
weights_acc
.
zero_
()
class
Einsum
(
ProductSumLayer
):
"""
Einsum layer
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
channels_out
:
int
=
None
):
# Dimensions
variables_out
=
math
.
ceil
(
variables
/
2
)
if
channels_out
is
None
:
channels_out
=
channels
# Initialize super
super
().
__init__
((
tracks
,
variables_out
,
channels_out
,
channels
,
channels
),
(
3
,
4
))
# Padding
self
.
x1_pad
=
torch
.
zeros
(
variables_out
,
dtype
=
torch
.
bool
)
self
.
x2_pad
=
torch
.
zeros
(
variables_out
,
dtype
=
torch
.
bool
)
# Zero-pad if necessary
if
variables
%
2
==
1
:
# Pad on the right
self
.
pad
=
True
self
.
x2_padding
=
[
0
,
0
,
0
,
1
]
self
.
x2_pad
[
-
1
]
=
True
else
:
self
.
pad
=
False
# TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
r
=
[]
for
v
,
c
in
enumerate
(
channel_per_variable
):
# Probability matrix
px1
=
torch
.
exp
(
self
.
x1
[
0
,
track
,
v
,
:][:,
None
])
px2
=
torch
.
exp
(
self
.
x2
[
0
,
track
,
v
,
:][
None
,
:])
prob
=
self
.
weights
[
track
,
v
,
c
]
*
px1
*
px2
# Sample
idx
=
discrete_rand
(
prob
)[
0
]
# Remove indices of padding
idx_valid
=
idx
[[
not
self
.
x1_pad
[
v
],
not
self
.
x2_pad
[
v
]]]
# Store on list
r
.
append
(
idx_valid
)
# Concatenate and return indices
return
track
,
torch
.
cat
(
r
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Split the input variables in two and apply padding if necessary
x1
=
x
[:,
:,
0
::
2
,
:]
x2
=
x
[:,
:,
1
::
2
,
:]
if
self
.
pad
:
x2
=
pad
(
x2
,
self
.
x2_padding
)
# Store the inputs for use in sampling routine
if
not
self
.
training
:
self
.
x1
,
self
.
x2
=
x1
,
x2
# Compute maximum
a1
,
a2
=
[
torch
.
max
(
x
,
dim
=
3
,
keepdim
=
True
)[
0
]
for
x
in
[
x1
,
x2
]]
# Subtract maximum and compute exponential
exa1
,
exa2
=
[
torch
.
clamp
(
torch
.
exp
(
x
-
a
),
self
.
epsilon
)
for
x
,
a
in
[(
x1
,
a1
),
(
x2
,
a2
)]]
# Compute the contraction
y
=
a1
+
a2
+
torch
.
log
(
torch
.
einsum
(
'
ntva,ntvb,tvcab->ntvc
'
,
exa1
,
exa2
,
self
.
weights
))
return
y
class
Weightsum
(
ProductSumLayer
):
"""
Weightsum layer
"""
# Product over all variables and weighted sum over tracks and channels
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
):
# Initialize super
super
().
__init__
((
tracks
,
channels
),
(
0
,
1
))
def
sample
(
self
):
prob
=
self
.
weights
*
torch
.
exp
(
self
.
x_sum
[
0
]
-
torch
.
max
(
self
.
x_sum
[
0
]))
s
=
discrete_rand
(
prob
)[
0
]
return
s
[
0
],
torch
.
full
((
self
.
variables
,),
s
[
1
]).
to
(
self
.
weights
.
device
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Product over variables
x_sum
=
torch
.
sum
(
x
,
2
)
# Store the inputs for use in sampling routine
if
not
self
.
training
:
self
.
x_sum
=
x_sum
self
.
variables
=
x
.
shape
[
2
]
# Compute maximum
a
=
torch
.
max
(
torch
.
max
(
x_sum
,
dim
=
1
)[
0
],
dim
=
1
)[
0
]
# Subtract maximum and compute exponential
exa
=
torch
.
clamp
(
torch
.
exp
(
x_sum
-
a
[:,
None
,
None
]),
self
.
epsilon
)
# Compute the contraction
y
=
a
+
torch
.
log
(
torch
.
einsum
(
'
ntc,tc->n
'
,
exa
,
self
.
weights
))
return
y
class
TrackSum
(
ProductSumLayer
):
"""
TrackSum layer
"""
# Weighted sum over tracks
def
__init__
(
self
,
tracks
:
int
,
channels
:
int
):
# Initialize super
super
().
__init__
((
tracks
,
channels
),
(
0
,
))
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
prob
=
self
.
weights
[:,
None
]
*
torch
.
exp
(
self
.
x
[
0
]
-
torch
.
max
(
self
.
x
[
0
],
dim
=
0
)[
0
])
s
=
discrete_rand
(
prob
)[
0
]
return
s
[
0
],
channel_per_variable
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Module is only valid when number of variables is 1
assert
x
.
shape
[
2
]
==
1
# Store the inputs for use in sampling routine
if
not
self
.
training
:
self
.
x
=
x
# Compute maximum
a
=
torch
.
max
(
x
,
dim
=
1
)[
0
]
# Subtract maximum and compute exponential
exa
=
torch
.
clamp
(
torch
.
exp
(
x
-
a
[:,
None
]),
self
.
epsilon
)
# Compute the contraction
y
=
a
+
torch
.
log
(
torch
.
einsum
(
'
ntvc,tc->nvc
'
,
exa
,
self
.
weights
))
# Insert track dimension
y
=
y
[:,
None
]
return
y
class
NormalLeaf
(
SuprLayer
):
"""
NormalLeaf layer
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
):
super
().
__init__
()
# Dimensions
self
.
T
,
self
.
V
,
self
.
C
=
tracks
,
variables
,
channels
# Parametes
# self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
# self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
self
.
mu
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
sig
=
nn
.
Parameter
(
torch
.
ones
(
self
.
T
,
self
.
V
,
self
.
C
)
*
0.5
)
# Which variables to marginalized
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
# Input
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
# Output
self
.
register_buffer
(
'
z
'
,
torch
.
Tensor
())
# EM accumulator
self
.
register_buffer
(
'
z_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
register_buffer
(
'
z_x_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
register_buffer
(
'
z_x_sq_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
def
em_batch
(
self
):
self
.
z_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
,
dim
=
0
)
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
],
dim
=
0
)
self
.
z_x_sq_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
]
**
2
,
dim
=
0
)
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
# Mean
sum_z
=
torch
.
clamp
(
self
.
z_acc
,
self
.
epsilon
)
self
.
mu
.
data
*=
1.
-
learning_rate
self
.
mu
.
data
+=
learning_rate
*
self
.
z_x_acc
/
sum_z
# Standard deviation
self
.
sig
.
data
*=
1
-
learning_rate
self
.
sig
.
data
+=
learning_rate
*
torch
.
sqrt
(
torch
.
clamp
(
self
.
z_x_sq_acc
/
sum_z
-
self
.
mu
**
2
,
self
.
epsilon
+
0.01
))
# Reset accumulators
self
.
z_acc
.
zero_
()
self
.
z_x_acc
.
zero_
()
self
.
z_x_sq_acc
.
zero_
()
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
variables_marginalize
=
torch
.
sum
(
self
.
marginalize
).
int
()
mu_marginalize
=
self
.
mu
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
sig_marginalize
=
self
.
sig
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r
[
self
.
marginalize
]
=
mu_marginalize
+
torch
.
randn
(
variables_marginalize
).
to
(
self
.
x
.
device
)
*
sig_marginalize
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
return
r
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Get shape
batch_size
=
x
.
shape
[
0
]
# Store the data
self
.
x
=
x
# Compute the probability
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
# Get non-marginalized parameters and data
mu_valid
=
self
.
mu
[
None
,
:,
~
self
.
marginalize
,
:]
sig_valid
=
self
.
sig
[
None
,
:,
~
self
.
marginalize
,
:]
x_valid
=
self
.
x
[:,
None
,
~
self
.
marginalize
,
None
]
# Evaluate log probability
self
.
z
.
data
[:,
:,
~
self
.
marginalize
,
:]
=
\
torch
.
distributions
.
Normal
(
mu_valid
,
sig_valid
).
log_prob
(
x_valid
).
float
()
return
self
.
z
class
BernoulliLeaf
(
SuprLayer
):
"""
BernoulliLeaf layer
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
):
super
().
__init__
()
# Dimensions
self
.
T
,
self
.
V
,
self
.
C
=
tracks
,
variables
,
channels
# Parametes
self
.
p
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
))
# Which variables to marginalized
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
# Input
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
# Output
self
.
register_buffer
(
'
z
'
,
torch
.
Tensor
())
# EM accumulator
self
.
register_buffer
(
'
z_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
register_buffer
(
'
z_x_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
def
em_batch
(
self
):
self
.
z_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
,
dim
=
0
)
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
],
dim
=
0
)
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
# Probability
sum_z
=
torch
.
clamp
(
self
.
z_acc
,
self
.
epsilon
)
self
.
p
.
data
*=
1.
-
learning_rate
self
.
p
.
data
+=
learning_rate
*
self
.
z_x_acc
/
sum_z
# Reset accumulators
self
.
z_acc
.
zero_
()
self
.
z_x_acc
.
zero_
()
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
variables_marginalize
=
torch
.
sum
(
self
.
marginalize
).
int
()
p_marginalize
=
self
.
p
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r
[
self
.
marginalize
]
=
(
torch
.
rand
(
variables_marginalize
).
to
(
self
.
x
.
device
)
<
p_marginalize
).
float
()
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
return
r
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Get shape
batch_size
=
x
.
shape
[
0
]
# Store the data
self
.
x
=
x
# Compute the probability
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
# Get non-marginalized parameters and data
p_valid
=
self
.
p
[
None
,
:,
~
self
.
marginalize
,
:]
x_valid
=
self
.
x
[:,
None
,
~
self
.
marginalize
,
None
]
# Evaluate log probability
self
.
z
.
data
[:,
:,
~
self
.
marginalize
,
:]
=
\
p_valid
*
(
x_valid
==
1
)
+
(
1
-
p_valid
)
*
(
x_valid
==
0
)
return
self
.
z
This diff is collapsed.
Click to expand it.
supr/utils.py
0 → 100644
+
65
−
0
View file @
ccce14dc
# %% Imports
import
matplotlib.pyplot
as
plt
from
typing
import
Tuple
import
torch
import
numpy
as
np
import
math
from
PyQt5
import
QtWidgets
# %%
def
drawnow
():
plt
.
gcf
().
canvas
.
draw
()
plt
.
gcf
().
canvas
.
flush_events
()
def
arrange_figs
(
cols
=
1
,
min_rows
=
3
,
toolbar
=
False
,
x0
=
1400
,
y0
=
28
,
x1
=
1920
,
y1
=
1200
):
try
:
current_fig_num
=
plt
.
gcf
().
number
extra
=
37
w
=
x1
-
x0
h
=
y1
-
y0
fignums
=
plt
.
get_fignums
()
n
=
len
(
fignums
)
rows
=
np
.
maximum
(
math
.
ceil
(
n
/
cols
),
min_rows
)
height
=
int
(
h
/
rows
-
extra
)
width
=
int
(
w
/
cols
)
for
i
,
fn
in
enumerate
(
fignums
):
r
=
i
%
rows
c
=
int
(
i
/
rows
)
plt
.
figure
(
fn
)
win
=
plt
.
get_current_fig_manager
().
window
win
.
findChild
(
QtWidgets
.
QToolBar
).
setVisible
(
toolbar
)
win
.
setGeometry
(
x0
+
width
*
c
,
y0
+
int
(
h
/
rows
*
r
)
+
extra
,
width
,
height
)
plt
.
figure
(
current_fig_num
)
except
:
pass
def
unravel_indices
(
indices
:
torch
.
LongTensor
,
shape
:
Tuple
[
int
,
...])
->
torch
.
LongTensor
:
r
"""
Converts flat indices into unraveled coordinates in a target shape.
Args:
indices: A tensor of (flat) indices, (*, N).
shape: The targeted shape, (D,).
Returns:
The unraveled coordinates, (*, N, D).
"""
coord
=
[]
for
dim
in
reversed
(
shape
):
coord
.
append
(
indices
%
dim
)
indices
=
indices
//
dim
coord
=
torch
.
stack
(
coord
[::
-
1
],
dim
=-
1
)
return
coord
def
discrete_rand
(
v
:
torch
.
Tensor
,
n
:
int
=
1
):
idx
=
torch
.
sum
(
torch
.
rand
(
n
)[:,
None
].
to
(
v
.
device
)
>
torch
.
cumsum
(
v
.
flatten
(),
0
)[
None
,
:]
/
torch
.
sum
(
v
),
dim
=
1
)
return
unravel_indices
(
idx
,
v
.
shape
)
def
local_scramble_2d
(
dist
:
float
,
dim
:
tuple
):
grid
=
torch
.
meshgrid
(
*
[
torch
.
arange
(
d
)
for
d
in
dim
])
n
=
[
torch
.
argsort
(
m
+
torch
.
randn
(
dim
)
*
dist
,
dim
=
i
)
for
i
,
m
in
enumerate
(
grid
)]
idx
=
torch
.
reshape
(
torch
.
arange
(
torch
.
tensor
(
dim
).
prod
()),
dim
)
return
idx
[
n
[
0
],
grid
[
1
]][
grid
[
0
],
n
[
1
]].
flatten
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment