Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
supr
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mnsc
supr
Commits
8b606b3e
Commit
8b606b3e
authored
3 years ago
by
mnsc
Browse files
Options
Downloads
Patches
Plain Diff
pep8
parent
0c2aae1b
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
supr/layers.py
+78
-51
78 additions, 51 deletions
supr/layers.py
with
78 additions
and
51 deletions
supr/layers.py
+
78
−
51
View file @
8b606b3e
...
@@ -81,7 +81,6 @@ class Sequential(nn.Sequential):
...
@@ -81,7 +81,6 @@ class Sequential(nn.Sequential):
return
value
return
value
class
Parallel
(
SuprLayer
):
class
Parallel
(
SuprLayer
):
def
__init__
(
self
,
nets
:
List
[
SuprLayer
]):
def
__init__
(
self
,
nets
:
List
[
SuprLayer
]):
super
().
__init__
()
super
().
__init__
()
...
@@ -113,7 +112,8 @@ class ScrambleTracks2d(SuprLayer):
...
@@ -113,7 +112,8 @@ class ScrambleTracks2d(SuprLayer):
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
distance
:
float
,
dims
:
tuple
):
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
distance
:
float
,
dims
:
tuple
):
super
().
__init__
()
super
().
__init__
()
# Permutation for each track
# Permutation for each track
perm
=
torch
.
stack
([
local_scramble_2d
(
distance
,
dims
)
for
_
in
range
(
tracks
)])
perm
=
torch
.
stack
([
local_scramble_2d
(
distance
,
dims
)
for
_
in
range
(
tracks
)])
self
.
register_buffer
(
'
perm
'
,
perm
)
self
.
register_buffer
(
'
perm
'
,
perm
)
def
sample
(
self
,
track
,
channel_per_variable
):
def
sample
(
self
,
track
,
channel_per_variable
):
...
@@ -146,7 +146,8 @@ class ProductSumLayer(SuprLayer):
...
@@ -146,7 +146,8 @@ class ProductSumLayer(SuprLayer):
super
().
__init__
()
super
().
__init__
()
# Parameters
# Parameters
self
.
weights
=
nn
.
Parameter
(
torch
.
rand
(
*
weight_shape
))
self
.
weights
=
nn
.
Parameter
(
torch
.
rand
(
*
weight_shape
))
self
.
weights
.
data
/=
torch
.
clamp
(
self
.
weights
.
sum
(
dim
=
normalize_dims
,
keepdim
=
True
),
self
.
epsilon
)
self
.
weights
.
data
/=
torch
.
clamp
(
self
.
weights
.
sum
(
dim
=
normalize_dims
,
keepdim
=
True
),
self
.
epsilon
)
# Normalize dimensions
# Normalize dimensions
self
.
normalize_dims
=
normalize_dims
self
.
normalize_dims
=
normalize_dims
# EM accumulator
# EM accumulator
...
@@ -157,7 +158,8 @@ class ProductSumLayer(SuprLayer):
...
@@ -157,7 +158,8 @@ class ProductSumLayer(SuprLayer):
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
weights_grad
=
torch
.
clamp
(
self
.
weights_acc
,
self
.
epsilon
)
weights_grad
=
torch
.
clamp
(
self
.
weights_acc
,
self
.
epsilon
)
weights_grad
/=
torch
.
clamp
(
weights_grad
.
sum
(
dim
=
self
.
normalize_dims
,
keepdim
=
True
),
self
.
epsilon
)
weights_grad
/=
torch
.
clamp
(
weights_grad
.
sum
(
dim
=
self
.
normalize_dims
,
keepdim
=
True
),
self
.
epsilon
)
if
learning_rate
<
1.
:
if
learning_rate
<
1.
:
self
.
weights
.
data
*=
1.
-
learning_rate
self
.
weights
.
data
*=
1.
-
learning_rate
self
.
weights
.
data
+=
learning_rate
*
weights_grad
self
.
weights
.
data
+=
learning_rate
*
weights_grad
...
@@ -188,7 +190,6 @@ class Einsum(ProductSumLayer):
...
@@ -188,7 +190,6 @@ class Einsum(ProductSumLayer):
self
.
x2_pad
[
-
1
]
=
True
self
.
x2_pad
[
-
1
]
=
True
else
:
else
:
self
.
pad
=
False
self
.
pad
=
False
# TODO: Implement choice of left, right, or both augmentation. Both returns 2 times the number of tracks
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
r
=
[]
r
=
[]
...
@@ -218,9 +219,12 @@ class Einsum(ProductSumLayer):
...
@@ -218,9 +219,12 @@ class Einsum(ProductSumLayer):
# Compute maximum
# Compute maximum
a1
,
a2
=
[
torch
.
max
(
x
,
dim
=
3
,
keepdim
=
True
)[
0
]
for
x
in
[
x1
,
x2
]]
a1
,
a2
=
[
torch
.
max
(
x
,
dim
=
3
,
keepdim
=
True
)[
0
]
for
x
in
[
x1
,
x2
]]
# Subtract maximum and compute exponential
# Subtract maximum and compute exponential
exa1
,
exa2
=
[
torch
.
clamp
(
torch
.
exp
(
x
-
a
),
self
.
epsilon
)
for
x
,
a
in
[(
x1
,
a1
),
(
x2
,
a2
)]]
exa1
,
exa2
=
[
torch
.
clamp
(
torch
.
exp
(
x
-
a
),
self
.
epsilon
)
for
x
,
a
in
[(
x1
,
a1
),
(
x2
,
a2
)]]
# Compute the contraction
# Compute the contraction
y
=
a1
+
a2
+
torch
.
log
(
torch
.
einsum
(
'
ntva,ntvb,tvcab->ntvc
'
,
exa1
,
exa2
,
self
.
weights
))
y
=
a1
+
a2
+
\
torch
.
log
(
torch
.
einsum
(
'
ntva,ntvb,tvcab->ntvc
'
,
exa1
,
exa2
,
self
.
weights
))
return
y
return
y
...
@@ -233,7 +237,8 @@ class Weightsum(ProductSumLayer):
...
@@ -233,7 +237,8 @@ class Weightsum(ProductSumLayer):
super
().
__init__
((
tracks
,
channels
),
(
0
,
1
))
super
().
__init__
((
tracks
,
channels
),
(
0
,
1
))
def
sample
(
self
):
def
sample
(
self
):
prob
=
self
.
weights
*
torch
.
exp
(
self
.
x_sum
[
0
]
-
torch
.
max
(
self
.
x_sum
[
0
]))
prob
=
self
.
weights
*
\
torch
.
exp
(
self
.
x_sum
[
0
]
-
torch
.
max
(
self
.
x_sum
[
0
]))
s
=
discrete_rand
(
prob
)[
0
]
s
=
discrete_rand
(
prob
)[
0
]
return
s
[
0
],
torch
.
full
((
self
.
variables
,),
s
[
1
]).
to
(
self
.
weights
.
device
)
return
s
[
0
],
torch
.
full
((
self
.
variables
,),
s
[
1
]).
to
(
self
.
weights
.
device
)
...
@@ -262,7 +267,8 @@ class TrackSum(ProductSumLayer):
...
@@ -262,7 +267,8 @@ class TrackSum(ProductSumLayer):
super
().
__init__
((
tracks
,
channels
),
(
0
,))
super
().
__init__
((
tracks
,
channels
),
(
0
,))
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
prob
=
self
.
weights
[:,
None
]
*
torch
.
exp
(
self
.
x
[
0
]
-
torch
.
max
(
self
.
x
[
0
],
dim
=
0
)[
0
])
prob
=
self
.
weights
[:,
None
]
*
\
torch
.
exp
(
self
.
x
[
0
]
-
torch
.
max
(
self
.
x
[
0
],
dim
=
0
)[
0
])
s
=
discrete_rand
(
prob
)[
0
]
s
=
discrete_rand
(
prob
)[
0
]
return
s
[
0
],
channel_per_variable
return
s
[
0
],
channel_per_variable
...
@@ -282,15 +288,18 @@ class TrackSum(ProductSumLayer):
...
@@ -282,15 +288,18 @@ class TrackSum(ProductSumLayer):
y
=
y
[:,
None
]
y
=
y
[:,
None
]
return
y
return
y
class
SuprLeaf
(
SuprLayer
):
class
SuprLeaf
(
SuprLayer
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
class
NormalLeaf
(
SuprLeaf
):
class
NormalLeaf
(
SuprLeaf
):
"""
NormalLeaf layer
"""
"""
NormalLeaf layer
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
n
:
int
=
1
,
mu0
:
torch
.
tensor
=
0.
,
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
n
:
int
=
1
,
nu0
:
torch
.
tensor
=
0.
,
alpha0
:
torch
.
tensor
=
0.
,
beta0
:
torch
.
tensor
=
0.
):
mu0
:
torch
.
tensor
=
0.
,
nu0
:
torch
.
tensor
=
0.
,
alpha0
:
torch
.
tensor
=
0.
,
beta0
:
torch
.
tensor
=
0.
):
super
().
__init__
()
super
().
__init__
()
# Dimensions
# Dimensions
self
.
T
,
self
.
V
,
self
.
C
=
tracks
,
variables
,
channels
self
.
T
,
self
.
V
,
self
.
C
=
tracks
,
variables
,
channels
...
@@ -299,12 +308,11 @@ class NormalLeaf(SuprLeaf):
...
@@ -299,12 +308,11 @@ class NormalLeaf(SuprLeaf):
# Prior
# Prior
self
.
mu0
,
self
.
nu0
,
self
.
alpha0
,
self
.
beta0
=
mu0
,
nu0
,
alpha0
,
beta0
self
.
mu0
,
self
.
nu0
,
self
.
alpha0
,
self
.
beta0
=
mu0
,
nu0
,
alpha0
,
beta0
# Parametes
# Parametes
# self.mu = nn.Parameter(torch.randn(self.T, self.V, self.C))
# self.mu = nn.Parameter(torch.linspace(0, 1, self.C)[None, None, :].repeat((self.T, self.V, 1)))
self
.
mu
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
mu
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
sig
=
nn
.
Parameter
(
torch
.
ones
(
self
.
T
,
self
.
V
,
self
.
C
)
*
0.5
)
self
.
sig
=
nn
.
Parameter
(
torch
.
ones
(
self
.
T
,
self
.
V
,
self
.
C
)
*
0.5
)
# Which variables to marginalized
# Which variables to marginalized
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
# Input
# Input
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
# Output
# Output
...
@@ -315,23 +323,26 @@ class NormalLeaf(SuprLeaf):
...
@@ -315,23 +323,26 @@ class NormalLeaf(SuprLeaf):
self
.
register_buffer
(
'
z_x_sq_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
register_buffer
(
'
z_x_sq_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
def
em_batch
(
self
):
def
em_batch
(
self
):
self
.
z_acc
.
data
+=
torch
.
clamp
(
torch
.
sum
(
self
.
z
.
grad
,
dim
=
0
),
self
.
epsilon
)
self
.
z_acc
.
data
+=
torch
.
clamp
(
torch
.
sum
(
self
.
z
.
grad
,
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
],
dim
=
0
)
dim
=
0
),
self
.
epsilon
)
self
.
z_x_sq_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
]
**
2
,
dim
=
0
)
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
],
dim
=
0
)
self
.
z_x_sq_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
]
**
2
,
dim
=
0
)
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
# Sum of weights
# Sum of weights
sum_z
=
torch
.
clamp
(
self
.
z_acc
,
self
.
epsilon
)
sum_z
=
torch
.
clamp
(
self
.
z_acc
,
self
.
epsilon
)
# Mean
# Mean
#
mu_update = (self.nu0 * self.mu0 + self.
n
* (self.z_x_acc / sum_z)
) / (self.nu0 + self.n)
mu_update
=
(
self
.
nu0
*
self
.
mu0
+
self
.
z_acc
*
(
self
.
z_x_acc
/
sum_z
)
mu_update
=
(
self
.
nu0
*
self
.
mu0
+
self
.
z_acc
*
(
self
.
z_x_acc
/
sum_z
)
)
/
(
self
.
nu0
+
self
.
z_acc
)
)
/
(
self
.
nu0
+
self
.
z_acc
)
self
.
mu
.
data
*=
1.
-
learning_rate
self
.
mu
.
data
*=
1.
-
learning_rate
self
.
mu
.
data
+=
learning_rate
*
mu_update
self
.
mu
.
data
+=
learning_rate
*
mu_update
# Standard deviation
# Standard deviation
#
sig_update =
(self.n *
(self.z_x_sq_acc
/ sum_z - self.mu ** 2) + 2 * self.beta0 + self.nu0 * (
sig_update
=
(
self
.
z_x_sq_acc
-
#
self.
mu0 -
self.mu
)
** 2
) / (self.n
+ 2 * self.
alph
a0 +
3)
self
.
z_acc
*
self
.
mu
**
2
+
2
*
self
.
bet
a0
+
sig_update
=
(
self
.
z_x_sq_acc
-
self
.
z_acc
*
self
.
mu
**
2
+
2
*
self
.
beta
0
+
self
.
nu0
*
(
self
.
nu0
*
(
self
.
mu
0
-
self
.
mu
)
*
*
2
self
.
mu0
-
self
.
mu
)
**
2
)
/
(
self
.
z_acc
+
2
*
self
.
alpha0
+
3
)
)
/
(
self
.
z_acc
+
2
*
self
.
alpha0
+
3
)
self
.
sig
.
data
*=
1
-
learning_rate
self
.
sig
.
data
*=
1
-
learning_rate
self
.
sig
.
data
+=
learning_rate
*
sig_update
self
.
sig
.
data
+=
learning_rate
*
sig_update
# Reset accumulators
# Reset accumulators
...
@@ -341,8 +352,10 @@ class NormalLeaf(SuprLeaf):
...
@@ -341,8 +352,10 @@ class NormalLeaf(SuprLeaf):
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
variables_marginalize
=
torch
.
sum
(
self
.
marginalize
).
int
()
variables_marginalize
=
torch
.
sum
(
self
.
marginalize
).
int
()
mu_marginalize
=
self
.
mu
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
mu_marginalize
=
self
.
mu
[
track
,
self
.
marginalize
,
sig_marginalize
=
self
.
sig
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
channel_per_variable
[
self
.
marginalize
]]
sig_marginalize
=
self
.
sig
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r
[
self
.
marginalize
]
=
mu_marginalize
+
torch
.
randn
(
variables_marginalize
).
to
(
self
.
x
.
device
)
*
torch
.
sqrt
(
r
[
self
.
marginalize
]
=
mu_marginalize
+
torch
.
randn
(
variables_marginalize
).
to
(
self
.
x
.
device
)
*
torch
.
sqrt
(
torch
.
clamp
(
sig_marginalize
,
self
.
epsilon
))
torch
.
clamp
(
sig_marginalize
,
self
.
epsilon
))
...
@@ -353,7 +366,8 @@ class NormalLeaf(SuprLeaf):
...
@@ -353,7 +366,8 @@ class NormalLeaf(SuprLeaf):
return
(
torch
.
clamp
(
self
.
z
.
grad
,
self
.
epsilon
)
*
self
.
mu
).
sum
([
1
,
3
])
return
(
torch
.
clamp
(
self
.
z
.
grad
,
self
.
epsilon
)
*
self
.
mu
).
sum
([
1
,
3
])
def
var
(
self
):
def
var
(
self
):
return
(
torch
.
clamp
(
self
.
z
.
grad
,
self
.
epsilon
)
*
(
self
.
mu
**
2
+
self
.
sig
)).
sum
([
1
,
3
])
-
self
.
mean
()
**
2
return
(
torch
.
clamp
(
self
.
z
.
grad
,
self
.
epsilon
)
*
(
self
.
mu
**
2
+
self
.
sig
)).
sum
([
1
,
3
])
-
self
.
mean
()
**
2
def
forward
(
self
,
x
:
torch
.
Tensor
,
marginalize
=
None
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
marginalize
=
None
):
# Get shape
# Get shape
...
@@ -364,7 +378,8 @@ class NormalLeaf(SuprLeaf):
...
@@ -364,7 +378,8 @@ class NormalLeaf(SuprLeaf):
# Store the data
# Store the data
self
.
x
=
x
self
.
x
=
x
# Compute the probability
# Compute the probability
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
# Get non-marginalized parameters and data
# Get non-marginalized parameters and data
mu_valid
=
self
.
mu
[
None
,
:,
~
self
.
marginalize
,
:]
mu_valid
=
self
.
mu
[
None
,
:,
~
self
.
marginalize
,
:]
sig_valid
=
self
.
sig
[
None
,
:,
~
self
.
marginalize
,
:]
sig_valid
=
self
.
sig
[
None
,
:,
~
self
.
marginalize
,
:]
...
@@ -391,7 +406,8 @@ class BernoulliLeaf(SuprLeaf):
...
@@ -391,7 +406,8 @@ class BernoulliLeaf(SuprLeaf):
# Parametes
# Parametes
self
.
p
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
p
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
))
# Which variables to marginalized
# Which variables to marginalized
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
# Input
# Input
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
# Output
# Output
...
@@ -402,13 +418,13 @@ class BernoulliLeaf(SuprLeaf):
...
@@ -402,13 +418,13 @@ class BernoulliLeaf(SuprLeaf):
def
em_batch
(
self
):
def
em_batch
(
self
):
self
.
z_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
,
dim
=
0
)
self
.
z_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
,
dim
=
0
)
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
],
dim
=
0
)
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
*
self
.
x
[:,
None
,
:,
None
],
dim
=
0
)
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
# Probability
# Probability
sum_z
=
torch
.
clamp
(
self
.
z_acc
,
self
.
epsilon
)
p_update
=
(
self
.
z_x_acc
+
self
.
alpha0
-
1
)
/
\
# p_update = (self.n * self.z_x_acc / sum_z + self.alpha0 - 1) / (self.n + self.alpha0 + self.beta0 - 2)
(
self
.
z_acc
+
self
.
alpha0
+
self
.
beta0
-
2
)
p_update
=
(
self
.
z_x_acc
+
self
.
alpha0
-
1
)
/
(
self
.
z_acc
+
self
.
alpha0
+
self
.
beta0
-
2
)
self
.
p
.
data
*=
1.
-
learning_rate
self
.
p
.
data
*=
1.
-
learning_rate
self
.
p
.
data
+=
learning_rate
*
p_update
self
.
p
.
data
+=
learning_rate
*
p_update
# Reset accumulators
# Reset accumulators
...
@@ -417,9 +433,11 @@ class BernoulliLeaf(SuprLeaf):
...
@@ -417,9 +433,11 @@ class BernoulliLeaf(SuprLeaf):
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
variables_marginalize
=
torch
.
sum
(
self
.
marginalize
).
int
()
variables_marginalize
=
torch
.
sum
(
self
.
marginalize
).
int
()
p_marginalize
=
self
.
p
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
p_marginalize
=
self
.
p
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
]]
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r
[
self
.
marginalize
]
=
(
torch
.
rand
(
variables_marginalize
).
to
(
self
.
x
.
device
)
<
p_marginalize
).
float
()
r
[
self
.
marginalize
]
=
(
torch
.
rand
(
variables_marginalize
).
to
(
self
.
x
.
device
)
<
p_marginalize
).
float
()
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
return
r
return
r
...
@@ -429,13 +447,15 @@ class BernoulliLeaf(SuprLeaf):
...
@@ -429,13 +447,15 @@ class BernoulliLeaf(SuprLeaf):
# Store the data
# Store the data
self
.
x
=
x
self
.
x
=
x
# Compute the probability
# Compute the probability
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
# Get non-marginalized parameters and data
# Get non-marginalized parameters and data
p_valid
=
self
.
p
[
None
,
:,
~
self
.
marginalize
,
:]
p_valid
=
self
.
p
[
None
,
:,
~
self
.
marginalize
,
:]
x_valid
=
self
.
x
[:,
None
,
~
self
.
marginalize
,
None
]
x_valid
=
self
.
x
[:,
None
,
~
self
.
marginalize
,
None
]
# Evaluate log probability
# Evaluate log probability
self
.
z
.
data
[:,
:,
~
self
.
marginalize
,
:]
=
\
self
.
z
.
data
[:,
:,
~
self
.
marginalize
,
:]
=
\
torch
.
distributions
.
Bernoulli
(
probs
=
p_valid
).
log_prob
(
x_valid
).
float
()
torch
.
distributions
.
Bernoulli
(
probs
=
p_valid
).
log_prob
(
x_valid
).
float
()
return
self
.
z
return
self
.
z
...
@@ -454,19 +474,22 @@ class CategoricalLeaf(SuprLeaf):
...
@@ -454,19 +474,22 @@ class CategoricalLeaf(SuprLeaf):
# Parametes
# Parametes
self
.
p
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
,
self
.
D
))
self
.
p
=
nn
.
Parameter
(
torch
.
rand
(
self
.
T
,
self
.
V
,
self
.
C
,
self
.
D
))
# Which variables to marginalized
# Which variables to marginalized
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
self
.
register_buffer
(
'
marginalize
'
,
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
))
# Input
# Input
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
self
.
register_buffer
(
'
x
'
,
torch
.
Tensor
())
# Output
# Output
self
.
register_buffer
(
'
z
'
,
torch
.
Tensor
())
self
.
register_buffer
(
'
z
'
,
torch
.
Tensor
())
# EM accumulator
# EM accumulator
self
.
register_buffer
(
'
z_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
register_buffer
(
'
z_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
))
self
.
register_buffer
(
'
z_x_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
,
self
.
D
))
self
.
register_buffer
(
'
z_x_acc
'
,
torch
.
zeros
(
self
.
T
,
self
.
V
,
self
.
C
,
self
.
D
))
def
em_batch
(
self
):
def
em_batch
(
self
):
self
.
z_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
,
dim
=
0
)
self
.
z_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
,
dim
=
0
)
x_onehot
=
torch
.
eye
(
self
.
D
,
dtype
=
bool
)[
self
.
x
]
x_onehot
=
torch
.
eye
(
self
.
D
,
dtype
=
bool
)[
self
.
x
]
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
[:,
:,
:,
:,
None
]
*
x_onehot
[:,
None
,
:,
None
,
:],
dim
=
0
)
self
.
z_x_acc
.
data
+=
torch
.
sum
(
self
.
z
.
grad
[:,
:,
:,
:,
None
]
*
x_onehot
[:,
None
,
:,
None
,
:],
dim
=
0
)
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
def
em_update
(
self
,
learning_rate
:
float
=
1.
):
# Probability
# Probability
...
@@ -482,9 +505,11 @@ class CategoricalLeaf(SuprLeaf):
...
@@ -482,9 +505,11 @@ class CategoricalLeaf(SuprLeaf):
self
.
z_x_acc
.
zero_
()
self
.
z_x_acc
.
zero_
()
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
def
sample
(
self
,
track
:
int
,
channel_per_variable
:
torch
.
Tensor
):
p_marginalize
=
self
.
p
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
],
:]
p_marginalize
=
self
.
p
[
track
,
self
.
marginalize
,
channel_per_variable
[
self
.
marginalize
],
:]
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r
=
torch
.
empty_like
(
self
.
x
[
0
])
r_sample
=
torch
.
distributions
.
Categorical
(
probs
=
p_marginalize
).
sample
()
r_sample
=
torch
.
distributions
.
Categorical
(
probs
=
p_marginalize
).
sample
()
r
[
self
.
marginalize
]
=
r_sample
r
[
self
.
marginalize
]
=
r_sample
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
return
r
return
r
...
@@ -495,11 +520,13 @@ class CategoricalLeaf(SuprLeaf):
...
@@ -495,11 +520,13 @@ class CategoricalLeaf(SuprLeaf):
# Store the data
# Store the data
self
.
x
=
x
self
.
x
=
x
# Compute the probability
# Compute the probability
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
self
.
z
=
torch
.
zeros
(
batch_size
,
self
.
T
,
self
.
V
,
self
.
C
,
requires_grad
=
True
,
device
=
x
.
device
)
# Get non-marginalized parameters and data
# Get non-marginalized parameters and data
p_valid
=
self
.
p
[
None
,
:,
~
self
.
marginalize
,
:,
:]
p_valid
=
self
.
p
[
None
,
:,
~
self
.
marginalize
,
:,
:]
x_valid
=
self
.
x
[:,
None
,
~
self
.
marginalize
,
None
]
x_valid
=
self
.
x
[:,
None
,
~
self
.
marginalize
,
None
]
# Evaluate log probability
# Evaluate log probability
self
.
z
.
data
[:,
:,
~
self
.
marginalize
,
:]
=
\
self
.
z
.
data
[:,
:,
~
self
.
marginalize
,
:]
=
\
torch
.
distributions
.
Categorical
(
probs
=
p_valid
).
log_prob
(
x_valid
).
float
()
torch
.
distributions
.
Categorical
(
probs
=
p_valid
).
log_prob
(
x_valid
).
float
()
return
self
.
z
return
self
.
z
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment