Skip to content
Snippets Groups Projects
Commit 208b6a22 authored by Brad Nelson's avatar Brad Nelson
Browse files

check maxdim in cohom

parent a0ced1b1
No related branches found
No related tags found
No related merge requests found
from __future__ import print_function
from topologylayer.nn import LevelSetLayer2D
import matplotlib.pyplot as plt
import torch
import time
import numpy as np
def sum_finite(d):
diff = d[:,0] - d[:,1]
inds = diff < np.inf
return torch.sum(diff[inds])
# apparently there is some overhead the first time backward is called.
# we'll just get it over with now.
n = 16
y = torch.rand(n, n, dtype=torch.float).requires_grad_(True)
layer1 = LevelSetLayer2D((n, n), False)
dgm, issublevel = layer1(y)
p = sum_finite(dgm[0])
p.backward()
algs = ['hom', 'cohom']
tcs = {}
tfs = {}
tbs = {}
for alg in algs:
tcs[alg] = []
tfs[alg] = []
tbs[alg] = []
ns = [28, 64, 128]
for alg in algs:
for n in ns:
y = torch.rand(n, n, dtype=torch.float).requires_grad_(True)
t0 = time.time()
layer = LevelSetLayer2D((n, n), sublevel=False, alg=alg)
ta = time.time() - t0
tcs[alg].append(ta)
t0 = time.time()
dgm, issublevel = layer(y)
ta = time.time() - t0
tfs[alg].append(ta)
p = sum_finite(dgm[0])
t0 = time.time()
p.backward()
ta = time.time() - t0
tbs[alg].append(ta)
for alg in algs:
plt.loglog([n**2 for n in ns], tfs[alg], label=alg)
plt.legend()
plt.xlabel("n")
plt.ylabel("forward time")
plt.savefig("alg_time_forward_2d.png")
...@@ -19,8 +19,8 @@ class Cocycle{ ...@@ -19,8 +19,8 @@ class Cocycle{
Cocycle() : index(-1){} Cocycle() : index(-1){}
// initializations // initializations
Cocycle(int x) : index(x) , cochain(x) {} Cocycle(size_t x) : index(x) , cochain((int) x) {}
Cocycle(int x, std::vector<int> y) : index(x) , cochain(y) {} Cocycle(size_t x, std::vector<int> y) : index(x) , cochain(y) {}
// for debug purposes // for debug purposes
void insert(int x); void insert(int x);
......
...@@ -10,7 +10,8 @@ void reduction_step(SimplicialComplex &X,\ ...@@ -10,7 +10,8 @@ void reduction_step(SimplicialComplex &X,\
const size_t i,\ const size_t i,\
std::vector<Cocycle> &Z,\ std::vector<Cocycle> &Z,\
std::vector<torch::Tensor> &diagram,\ std::vector<torch::Tensor> &diagram,\
std::vector<int> &nbars) { std::vector<int> &nbars,
const size_t MAXDIM) {
// get cocycle // get cocycle
Cocycle c = X.bdr[i]; Cocycle c = X.bdr[i];
...@@ -20,13 +21,12 @@ void reduction_step(SimplicialComplex &X,\ ...@@ -20,13 +21,12 @@ void reduction_step(SimplicialComplex &X,\
auto pivot = Z.rbegin(); auto pivot = Z.rbegin();
for(auto x = Z.rbegin(); x != Z.rend(); ++x){ for(auto x = Z.rbegin(); x != Z.rend(); ++x){
// see if inner product is non-zero // see if inner product is non-zero
if(x->dot(c)){ if(x->dot(c) > 0){
if(flag==false){ if(flag==false){
// save as column that will be used for schur complement // save as column that will be used for schur complement
pivot = x; pivot = x;
flag=true; flag=true;
} } else {
else{
// schur complement // schur complement
x->add(*pivot); x->add(*pivot);
} }
...@@ -42,6 +42,14 @@ void reduction_step(SimplicialComplex &X,\ ...@@ -42,6 +42,14 @@ void reduction_step(SimplicialComplex &X,\
size_t dindx = c.index; size_t dindx = c.index;
// get birth dimension // get birth dimension
size_t hdim = X.dim(bindx); size_t hdim = X.dim(bindx);
//py::print("bindx: ", bindx, " dindx: ", dindx, " hdim: ", hdim);
// delete reduced column from active cocycles
// stupid translation from reverse to iterator
Z.erase(std::next(pivot).base());
// check if we want this bar
if (hdim > MAXDIM) { return; }
// get location in diagram // get location in diagram
size_t j = nbars[hdim]++; size_t j = nbars[hdim]++;
...@@ -52,10 +60,6 @@ void reduction_step(SimplicialComplex &X,\ ...@@ -52,10 +60,6 @@ void reduction_step(SimplicialComplex &X,\
// put birth/death indices of bar in X.backprop_lookup // put birth/death indices of bar in X.backprop_lookup
X.backprop_lookup[hdim][j] = {(int) bindx, (int) dindx}; X.backprop_lookup[hdim][j] = {(int) bindx, (int) dindx};
// delete reduced column from active cocycles
// stupid translation from reverse to iterator
Z.erase(std::next(pivot).base());
} else { } else {
// cocycle opened // cocycle opened
size_t bindx = c.index; size_t bindx = c.index;
...@@ -73,7 +77,7 @@ void reduction_step(SimplicialComplex &X,\ ...@@ -73,7 +77,7 @@ void reduction_step(SimplicialComplex &X,\
OUTPUTS: vector of tensors - t OUTPUTS: vector of tensors - t
t[k] is float32 tensor with barcode for dimension k t[k] is float32 tensor with barcode for dimension k
*/ */
std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM) { std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, size_t MAXDIM) {
// produce sort permutation on X // produce sort permutation on X
X.sortedOrder(); X.sortedOrder();
...@@ -83,7 +87,7 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM) ...@@ -83,7 +87,7 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM)
// initialize reutrn diagram // initialize reutrn diagram
std::vector<torch::Tensor> diagram(MAXDIM+1); // return array std::vector<torch::Tensor> diagram(MAXDIM+1); // return array
for (int k = 0; k < MAXDIM+1; k++) { for (size_t k = 0; k < MAXDIM+1; k++) {
int Nk = X.numPairs(k); // number of bars in dimension k int Nk = X.numPairs(k); // number of bars in dimension k
// allocate return tensor // allocate return tensor
diagram[k] = torch::empty({Nk,2}, diagram[k] = torch::empty({Nk,2},
...@@ -94,13 +98,13 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM) ...@@ -94,13 +98,13 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM)
} }
// keep track of how many pairs we've put in diagram // keep track of how many pairs we've put in diagram
std::vector<int> nbars(MAXDIM+1); std::vector<int> nbars(MAXDIM+1);
for (int k = 0; k < MAXDIM+1; k++) { for (size_t k = 0; k < MAXDIM+1; k++) {
nbars[k] = 0; nbars[k] = 0;
} }
// go through reduction algorithm // go through reduction algorithm
for (size_t i : X.filtration_perm ) { for (size_t i : X.filtration_perm ) {
reduction_step(X, i, Z, diagram, nbars); reduction_step(X, i, Z, diagram, nbars, MAXDIM);
} }
// add infinite bars using removing columns in Z // add infinite bars using removing columns in Z
...@@ -112,7 +116,7 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM) ...@@ -112,7 +116,7 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM)
// get birth index // get birth index
size_t bindx = pivot->index; size_t bindx = pivot->index;
// get birth dimension // get birth dimension
int hdim = X.bdr[bindx].dim(); size_t hdim = X.bdr[bindx].dim();
if (hdim > MAXDIM) { continue; } if (hdim > MAXDIM) { continue; }
// get location in diagram // get location in diagram
...@@ -120,7 +124,7 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM) ...@@ -120,7 +124,7 @@ std::vector<torch::Tensor> persistence_forward(SimplicialComplex &X, int MAXDIM)
// put births and deaths in diagram. // put births and deaths in diagram.
(diagram[hdim][j].data<float>())[0] = (float) X.full_function[bindx].first; (diagram[hdim][j].data<float>())[0] = (float) X.full_function[bindx].first;
(diagram[hdim][j].data<float>())[1] = std::numeric_limits<float>::infinity(); (diagram[hdim][j].data<float>())[1] = (float) std::numeric_limits<float>::infinity();
// put birth/death indices of bar in X.backprop_lookup // put birth/death indices of bar in X.backprop_lookup
X.backprop_lookup[hdim][j] = {(int) bindx, -1}; X.backprop_lookup[hdim][j] = {(int) bindx, -1};
......
...@@ -10,17 +10,9 @@ ...@@ -10,17 +10,9 @@
// typedef std::map<int,Interval> Barcode; // typedef std::map<int,Interval> Barcode;
// perform reduction step on active cocycles Z
// with cocycle x
void reuction_step(SimplicialComplex &X,\
const size_t i,\
std::vector<Cocycle> &Z,\
std::vector<torch::Tensor> &diagram,\
std::vector<int> &nbars);
// forward function for any filtration // forward function for any filtration
std::vector<torch::Tensor> persistence_forward( std::vector<torch::Tensor> persistence_forward(
SimplicialComplex &X, int MAXDIM); SimplicialComplex &X, size_t MAXDIM);
// backward function for lower-star // backward function for lower-star
torch::Tensor persistence_backward( torch::Tensor persistence_backward(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment