Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
supr
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mnsc
supr
Commits
0c2aae1b
Commit
0c2aae1b
authored
May 4, 2022
by
mnsc
Browse files
Options
Downloads
Patches
Plain Diff
added mean/var, updated regression, major other changes
parent
82a57cd2
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
demos/regression.py
+80
-21
80 additions, 21 deletions
demos/regression.py
supr/layers.py
+43
-6
43 additions, 6 deletions
supr/layers.py
with
123 additions
and
27 deletions
demos/regression.py
+
80
−
21
View file @
0c2aae1b
...
@@ -4,38 +4,51 @@ import torch
...
@@ -4,38 +4,51 @@ import torch
import
supr
import
supr
from
supr.utils
import
drawnow
from
supr.utils
import
drawnow
from
scipy.stats
import
norm
from
scipy.stats
import
norm
from
math
import
sqrt
import
numpy
as
np
# %% Dataset
# %% Dataset
N
=
2
00
N
=
1
00
x
=
torch
.
linspace
(
0
,
1
,
N
)
x
=
torch
.
linspace
(
0
,
1
,
N
)
y
=
1
-
2
*
x
+
(
torch
.
rand
(
N
)
>
0.5
)
*
(
x
>
0.5
)
+
torch
.
randn
(
N
)
*
0.1
y
=
1
-
2
*
x
+
(
torch
.
rand
(
N
)
>
0.5
)
*
(
x
>
0.5
)
+
torch
.
randn
(
N
)
*
0.1
x
[
x
>
0.5
]
+=
0.25
x
[
x
>
0.5
]
+=
0.25
x
[
x
<
0.5
]
-=
0.25
x
[
x
<
0.5
]
-=
0.25
x
[
0
]
=
-
1.
x
[
0
]
=
-
1.
y
[
0
]
=
0
y
[
0
]
=
-
0.5
X
=
torch
.
stack
((
x
,
y
),
dim
=
1
)
X
=
torch
.
stack
((
x
,
y
),
dim
=
1
)
# %% Grid to evaluate predictive distribution
# %% Grid to evaluate predictive distribution
x_grid
=
torch
.
linspace
(
-
2
,
2
,
200
)
x_res
,
y_res
=
400
,
500
y_grid
=
torch
.
linspace
(
-
2
,
2
,
200
)
x_min
,
x_max
=
-
2
,
2
X_grid
=
torch
.
stack
([
x
.
flatten
()
for
x
in
torch
.
meshgrid
(
x_grid
,
y_grid
,
indexing
=
'
ij
'
)],
dim
=
1
)
y_min
,
y_max
=
-
2
,
2
x_grid
=
torch
.
linspace
(
x_min
,
x_max
,
x_res
)
y_grid
=
torch
.
linspace
(
y_min
,
y_max
,
y_res
)
XY_grid
=
torch
.
stack
([
x
.
flatten
()
for
x
in
torch
.
meshgrid
(
x_grid
,
y_grid
,
indexing
=
'
ij
'
)],
dim
=
1
)
X_grid
=
torch
.
stack
([
x_grid
,
torch
.
zeros
(
x_res
)]).
T
# %% Sum-product network
# %% Sum-product network
# Parameters
tracks
=
1
tracks
=
1
variables
=
2
variables
=
2
channels
=
50
channels
=
50
# Priors for variance of x and y
# Priors for variance of x and y
alpha0
=
torch
.
tensor
([[[
1
],
[
1
]]])
alpha0
=
torch
.
tensor
([[[
1
],
[
1
]]])
beta0
=
torch
.
tensor
([[[.
0
5
],
[
0
.01
]]])
beta0
=
torch
.
tensor
([[[.
0
1
],
[.
01
]]])
# Construct SPN model
model
=
supr
.
Sequential
(
model
=
supr
.
Sequential
(
supr
.
NormalLeaf
(
tracks
,
variables
,
channels
,
n
=
N
,
mu0
=
0.
,
nu0
=
0
,
alpha0
=
alpha0
,
beta0
=
beta0
),
supr
.
NormalLeaf
(
tracks
,
variables
,
channels
,
n
=
N
,
mu0
=
0.
,
nu0
=
0
,
alpha0
=
alpha0
,
beta0
=
beta0
),
supr
.
Weightsum
(
tracks
,
variables
,
channels
)
supr
.
Weightsum
(
tracks
,
variables
,
channels
)
)
)
# Marginalization query
marginalize_y
=
torch
.
tensor
([
False
,
True
])
# %% Fit model and display results
# %% Fit model and display results
epochs
=
20
epochs
=
20
...
@@ -44,28 +57,74 @@ for epoch in range(epochs):
...
@@ -44,28 +57,74 @@ for epoch in range(epochs):
model
[
0
].
marginalize
=
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
)
model
[
0
].
marginalize
=
torch
.
zeros
(
variables
,
dtype
=
torch
.
bool
)
logp
=
model
(
X
).
sum
()
logp
=
model
(
X
).
sum
()
print
(
f
"
Log-posterior ∝
{
logp
:
.
2
f
}
"
)
print
(
f
"
Log-posterior ∝
{
logp
:
.
2
f
}
"
)
model
.
zero_grad
(
True
)
logp
.
backward
()
logp
.
backward
()
with
torch
.
no_grad
():
model
.
eval
()
# swap?
model
.
eval
()
model
.
em_batch_update
()
model
.
em_batch_update
()
model
.
zero_grad
(
True
)
p_xy
=
torch
.
exp
(
model
(
X_grid
).
reshape
(
len
(
x_grid
),
len
(
y_grid
)).
T
)
# Plot data and model
# -------------------------------------------------------------------------
# Evaluate joint distribution on grid
with
torch
.
no_grad
():
log_p_xy
=
model
(
XY_grid
)
p_xy
=
torch
.
exp
(
log_p_xy
).
reshape
(
x_res
,
y_res
)
model
[
0
].
marginalize
=
torch
.
tensor
([
False
,
True
])
# Evaluate marginal distribution on x-grid
p_x
=
torch
.
exp
(
model
(
X_grid
).
reshape
(
len
(
x_grid
),
len
(
y_grid
)).
T
)
log_p_x
=
model
(
X_grid
,
marginalize
=
marginalize_y
)
p_x
=
torch
.
exp
(
log_p_x
)
model
.
zero_grad
(
True
)
log_p_x
.
sum
().
backward
()
with
torch
.
no_grad
():
# Define prior conditional p(y|x)
Ndx
=
1
Ndx
=
1
p_prior
=
norm
(
0
,
0.5
).
pdf
(
y_grid
)[:,
None
]
sig_prior
=
1
p_y
=
norm
(
0
,
sqrt
(
sig_prior
)).
pdf
(
y_grid
)
# Compute normal approximation
m_pred
=
(
N
*
(
model
.
mean
())[:,
1
]
*
p_x
+
Ndx
*
0
)
/
(
N
*
p_x
+
Ndx
)
v_pred
=
(
N
*
p_x
*
(
model
.
var
()[:,
1
]
+
model
.
mean
()[:,
1
]
**
2
)
+
Ndx
*
sig_prior
)
/
(
N
*
p_x
+
Ndx
)
-
m_pred
**
2
std_pred
=
torch
.
sqrt
(
v_pred
)
# Compute predictive distribution
p_predictive
=
(
N
*
p_xy
+
Ndx
*
p_y
[
None
,
:])
/
(
N
*
p_x
[:,
None
]
+
Ndx
)
p_predictive
=
(
N
*
p_xy
+
Ndx
*
p_prior
)
/
(
N
*
p_x
+
Ndx
)
# Compute 95% highest-posterior region
hpr
=
torch
.
ones
((
x_res
,
y_res
),
dtype
=
torch
.
bool
)
for
k
in
range
(
x_res
):
p_sorted
=
-
np
.
sort
(
-
(
p_predictive
[
k
]
*
np
.
gradient
(
y_grid
)))
i
=
np
.
searchsorted
(
np
.
cumsum
(
p_sorted
),
0.95
)
idx
=
(
p_predictive
[
k
]
*
np
.
gradient
(
y_grid
))
<
p_sorted
[
i
]
hpr
[
k
,
idx
]
=
False
# Plot posterior
plt
.
figure
(
1
).
clf
()
plt
.
figure
(
1
).
clf
()
dx
=
(
x_grid
[
1
]
-
x_grid
[
0
])
/
2.
plt
.
title
(
'
Posterior distribution
'
)
dy
=
(
y_grid
[
1
]
-
y_grid
[
0
])
/
2.
dx
=
(
x_max
-
x_min
)
/
x_res
/
2
dy
=
(
y_max
-
y_min
)
/
y_res
/
2
extent
=
[
x_grid
[
0
]
-
dx
,
x_grid
[
-
1
]
+
dx
,
y_grid
[
0
]
-
dy
,
y_grid
[
-
1
]
+
dy
]
extent
=
[
x_grid
[
0
]
-
dx
,
x_grid
[
-
1
]
+
dx
,
y_grid
[
0
]
-
dy
,
y_grid
[
-
1
]
+
dy
]
plt
.
imshow
(
torch
.
log
(
p_predictive
),
extent
=
extent
,
aspect
=
'
auto
'
,
origin
=
'
lower
'
,
vmin
=-
4
,
vmax
=
1
)
plt
.
imshow
(
torch
.
log
(
p_predictive
).
T
,
extent
=
extent
,
plt
.
plot
(
x
,
y
,
'
.
'
,
color
=
'
tab:orange
'
,
alpha
=
.
5
,
markersize
=
4
,
markeredgewidth
=
0
)
aspect
=
'
auto
'
,
origin
=
'
lower
'
,
vmin
=-
4
,
vmax
=
1
,
cmap
=
'
Blues
'
)
plt
.
contour
(
hpr
.
T
,
levels
=
1
,
extent
=
extent
)
plt
.
plot
(
x
,
y
,
'
.
'
,
color
=
'
tab:orange
'
,
alpha
=
.
5
,
markersize
=
15
,
markeredgewidth
=
0
)
plt
.
axis
(
'
square
'
)
plt
.
xlim
([
x_min
,
x_max
])
plt
.
ylim
([
y_min
,
y_max
])
drawnow
()
# Plot normal approximation to posterior
plt
.
figure
(
2
).
clf
()
plt
.
title
(
'
Posterior Normal approximation
'
)
plt
.
plot
(
x
,
y
,
'
.
'
,
color
=
'
tab:orange
'
,
alpha
=
.
5
,
markersize
=
15
,
markeredgewidth
=
0
)
plt
.
plot
(
x_grid
,
m_pred
,
color
=
'
tab:orange
'
)
plt
.
fill_between
(
x_grid
,
m_pred
+
1.96
*
std_pred
,
m_pred
-
1.96
*
std_pred
,
color
=
'
tab:orange
'
,
alpha
=
0.1
)
plt
.
axis
(
'
square
'
)
plt
.
axis
(
'
square
'
)
plt
.
xlim
([
x_min
,
x_max
])
plt
.
ylim
([
y_min
,
y_max
])
drawnow
()
drawnow
()
This diff is collapsed.
Click to expand it.
supr/layers.py
+
43
−
6
View file @
0c2aae1b
...
@@ -50,12 +50,37 @@ class Sequential(nn.Sequential):
...
@@ -50,12 +50,37 @@ class Sequential(nn.Sequential):
module
.
em_batch
()
module
.
em_batch
()
module
.
em_update
()
module
.
em_update
()
def
em_batch
(
self
):
with
torch
.
no_grad
():
for
module
in
self
:
module
.
em_batch
()
def
em_update
(
self
):
with
torch
.
no_grad
():
for
module
in
self
:
module
.
em_update
()
def
sample
(
self
):
def
sample
(
self
):
value
=
[]
value
=
[]
for
module
in
reversed
(
self
):
for
module
in
reversed
(
self
):
value
=
module
.
sample
(
*
value
)
value
=
module
.
sample
(
*
value
)
return
value
return
value
def
mean
(
self
):
return
self
[
0
].
mean
()
def
var
(
self
):
return
self
[
0
].
var
()
def
forward
(
self
,
value
,
marginalize
=
None
):
for
module
in
self
:
if
isinstance
(
module
,
SuprLeaf
):
value
=
module
(
value
,
marginalize
=
marginalize
)
else
:
value
=
module
(
value
)
return
value
class
Parallel
(
SuprLayer
):
class
Parallel
(
SuprLayer
):
def
__init__
(
self
,
nets
:
List
[
SuprLayer
]):
def
__init__
(
self
,
nets
:
List
[
SuprLayer
]):
...
@@ -257,8 +282,11 @@ class TrackSum(ProductSumLayer):
...
@@ -257,8 +282,11 @@ class TrackSum(ProductSumLayer):
y
=
y
[:,
None
]
y
=
y
[:,
None
]
return
y
return
y
class
SuprLeaf
(
SuprLayer
):
def
__init__
(
self
):
super
().
__init__
()
class
NormalLeaf
(
SuprL
ayer
):
class
NormalLeaf
(
SuprL
eaf
):
"""
NormalLeaf layer
"""
"""
NormalLeaf layer
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
n
:
int
=
1
,
mu0
:
torch
.
tensor
=
0.
,
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
n
:
int
=
1
,
mu0
:
torch
.
tensor
=
0.
,
...
@@ -321,9 +349,18 @@ class NormalLeaf(SuprLayer):
...
@@ -321,9 +349,18 @@ class NormalLeaf(SuprLayer):
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
r
[
~
self
.
marginalize
]
=
self
.
x
[
0
][
~
self
.
marginalize
]
return
r
return
r
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
mean
(
self
):
return
(
torch
.
clamp
(
self
.
z
.
grad
,
self
.
epsilon
)
*
self
.
mu
).
sum
([
1
,
3
])
def
var
(
self
):
return
(
torch
.
clamp
(
self
.
z
.
grad
,
self
.
epsilon
)
*
(
self
.
mu
**
2
+
self
.
sig
)).
sum
([
1
,
3
])
-
self
.
mean
()
**
2
def
forward
(
self
,
x
:
torch
.
Tensor
,
marginalize
=
None
):
# Get shape
# Get shape
batch_size
=
x
.
shape
[
0
]
batch_size
=
x
.
shape
[
0
]
# Marginalize variables
if
marginalize
is
not
None
:
self
.
marginalize
=
marginalize
# Store the data
# Store the data
self
.
x
=
x
self
.
x
=
x
# Compute the probability
# Compute the probability
...
@@ -339,7 +376,7 @@ class NormalLeaf(SuprLayer):
...
@@ -339,7 +376,7 @@ class NormalLeaf(SuprLayer):
return
self
.
z
return
self
.
z
class
BernoulliLeaf
(
SuprL
ayer
):
class
BernoulliLeaf
(
SuprL
eaf
):
"""
BernoulliLeaf layer
"""
"""
BernoulliLeaf layer
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
n
:
int
=
1
,
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
n
:
int
=
1
,
...
@@ -402,7 +439,7 @@ class BernoulliLeaf(SuprLayer):
...
@@ -402,7 +439,7 @@ class BernoulliLeaf(SuprLayer):
return
self
.
z
return
self
.
z
class
CategoricalLeaf
(
SuprL
ayer
):
class
CategoricalLeaf
(
SuprL
eaf
):
"""
CategoricalLeaf layer
"""
"""
CategoricalLeaf layer
"""
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
dimensions
:
int
,
n
:
int
=
1
,
def
__init__
(
self
,
tracks
:
int
,
variables
:
int
,
channels
:
int
,
dimensions
:
int
,
n
:
int
=
1
,
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment