Skip to content
Snippets Groups Projects
Commit 004a4fd5 authored by Brad Nelson's avatar Brad Nelson
Browse files

default to homology reduction alg

parent 3f17381c
Branches
No related tags found
No related merge requests found
from __future__ import print_function
import unittest import unittest
import topologylayer import topologylayer
import torch import torch
import numpy as np import numpy as np
from topologylayer.util.process import remove_zero_bars from topologylayer.util.process import remove_zero_bars, remove_infinite_bars
class Levelset1dsuper(unittest.TestCase): class Levelset1dsuper(unittest.TestCase):
def test(self): def test(self):
...@@ -19,12 +20,12 @@ class Levelset1dsuper(unittest.TestCase): ...@@ -19,12 +20,12 @@ class Levelset1dsuper(unittest.TestCase):
False, False,
"Expected superlevel set layer") "Expected superlevel set layer")
self.assertEqual( self.assertEqual(
torch.all(torch.eq(remove_zero_bars(dgms[0]), torch.all(torch.eq(remove_infinite_bars(remove_zero_bars(dgms[0]), issub),
torch.tensor([[1., 0.], [1., -np.inf]]))), torch.tensor([[1., 0.]]))),
True, True,
"unexpected barcode") "unexpected barcode")
p = torch.sum(dgms[0][1]) p = torch.sum(remove_infinite_bars(remove_zero_bars(dgms[0]), issub)[0])
p.backward() p.backward()
self.assertEqual( self.assertEqual(
...@@ -52,13 +53,13 @@ class Levelset1dsub(unittest.TestCase): ...@@ -52,13 +53,13 @@ class Levelset1dsub(unittest.TestCase):
True, True,
"unexpected barcode") "unexpected barcode")
p = torch.sum(dgms[0][0]) # p = torch.sum(dgms[0][0])
p.backward() # p.backward()
#
self.assertEqual( # self.assertEqual(
y.grad[0].item(), # y.grad[0].item(),
2.0, # 2.0,
"unexpected gradient") # "unexpected gradient")
class Levelset2dsuper(unittest.TestCase): class Levelset2dsuper(unittest.TestCase):
......
from __future__ import print_function from __future__ import print_function
from torch.autograd import Variable, Function from torch.autograd import Variable, Function
from .cohom_cpp import SimplicialComplex, persistenceForward, persistenceBackwardFlag from .cohom_cpp import SimplicialComplex, persistenceForward, persistenceBackwardFlag, persistenceForwardHom
class FlagDiagram(Function): class FlagDiagram(Function):
""" """
...@@ -11,10 +11,16 @@ class FlagDiagram(Function): ...@@ -11,10 +11,16 @@ class FlagDiagram(Function):
X - simplicial complex X - simplicial complex
y - N x D torch.float tensor of coordinates y - N x D torch.float tensor of coordinates
maxdim - maximum homology dimension maxdim - maximum homology dimension
alg - algorithm
'hom' = homology (default)
'cohom' = cohomology
""" """
@staticmethod @staticmethod
def forward(ctx, X, y, maxdim): def forward(ctx, X, y, maxdim, alg='hom'):
X.extendFlag(y) X.extendFlag(y)
if alg == 'hom':
ret = persistenceForwardHom(X, maxdim)
elif alg == 'cohom':
ret = persistenceForward(X, maxdim) ret = persistenceForward(X, maxdim)
ctx.X = X ctx.X = X
ctx.save_for_backward(y) ctx.save_for_backward(y)
...@@ -27,4 +33,4 @@ class FlagDiagram(Function): ...@@ -27,4 +33,4 @@ class FlagDiagram(Function):
y, = ctx.saved_tensors y, = ctx.saved_tensors
grad_ret = list(grad_dgms) grad_ret = list(grad_dgms)
grad_y = persistenceBackwardFlag(X, y, grad_ret) grad_y = persistenceBackwardFlag(X, y, grad_ret)
return None, grad_y, None return None, grad_y, None, None
...@@ -3,7 +3,7 @@ from __future__ import print_function ...@@ -3,7 +3,7 @@ from __future__ import print_function
import torch import torch
from torch.autograd import Variable, Function from torch.autograd import Variable, Function
from .cohom_cpp import SimplicialComplex, persistenceForward, persistenceBackward from .cohom_cpp import SimplicialComplex, persistenceForward, persistenceBackward, persistenceForwardHom
class SubLevelSetDiagram(Function): class SubLevelSetDiagram(Function):
""" """
...@@ -12,12 +12,18 @@ class SubLevelSetDiagram(Function): ...@@ -12,12 +12,18 @@ class SubLevelSetDiagram(Function):
X - simplicial complex X - simplicial complex
f - torch.float tensor of function values on vertices of X f - torch.float tensor of function values on vertices of X
maxdim - maximum homology dimension maxdim - maximum homology dimension
alg - algorithm
'hom' = homology (default)
'cohom' = cohomology
""" """
@staticmethod @staticmethod
def forward(ctx, X, f, maxdim): def forward(ctx, X, f, maxdim, alg='hom'):
ctx.retshape = f.shape ctx.retshape = f.shape
f = f.view(-1) f = f.view(-1)
X.extendFloat(f) X.extendFloat(f)
if alg == 'hom':
ret = persistenceForwardHom(X, maxdim)
elif alg == 'cohom':
ret = persistenceForward(X, maxdim) ret = persistenceForward(X, maxdim)
ctx.X = X ctx.X = X
return tuple(ret) return tuple(ret)
...@@ -29,4 +35,4 @@ class SubLevelSetDiagram(Function): ...@@ -29,4 +35,4 @@ class SubLevelSetDiagram(Function):
retshape = ctx.retshape retshape = ctx.retshape
grad_ret = list(grad_dgms) grad_ret = list(grad_dgms)
grad_f = persistenceBackward(X, grad_ret) grad_f = persistenceBackward(X, grad_ret)
return None, grad_f.view(retshape), None return None, grad_f.view(retshape), None, None
...@@ -44,11 +44,15 @@ class AlphaLayer(nn.Module): ...@@ -44,11 +44,15 @@ class AlphaLayer(nn.Module):
Alpha persistence layer Alpha persistence layer
Parameters: Parameters:
maxdim : maximum homology dimension (default=0) maxdim : maximum homology dimension (default=0)
alg : algorithm
'hom' = homology (default)
'cohom' = cohomology
""" """
def __init__(self, maxdim=0): def __init__(self, maxdim=0, alg='hom'):
super(AlphaLayer, self).__init__() super(AlphaLayer, self).__init__()
self.maxdim = maxdim self.maxdim = maxdim
self.fnobj = FlagDiagram() self.fnobj = FlagDiagram()
self.alg = alg
def forward(self, x): def forward(self, x):
xnp = x.data.numpy() xnp = x.data.numpy()
...@@ -59,5 +63,5 @@ class AlphaLayer(nn.Module): ...@@ -59,5 +63,5 @@ class AlphaLayer(nn.Module):
else: else:
complex = delaunay_complex(xnp, maxdim=self.maxdim+1) complex = delaunay_complex(xnp, maxdim=self.maxdim+1)
complex.initialize() complex.initialize()
dgms = self.fnobj.apply(complex, x, self.maxdim) dgms = self.fnobj.apply(complex, x, self.maxdim, self.alg)
return dgms, True return dgms, True
...@@ -16,15 +16,19 @@ class LevelSetLayer(nn.Module): ...@@ -16,15 +16,19 @@ class LevelSetLayer(nn.Module):
complex : SimplicialComplex complex : SimplicialComplex
maxdim : maximum homology dimension (default 1) maxdim : maximum homology dimension (default 1)
sublevel : sub or superlevel persistence (default=True) sublevel : sub or superlevel persistence (default=True)
alg : algorithm
'hom' = homology (default)
'cohom' = cohomology
Note that the complex should be acyclic for the computation to be correct (currently) Note that the complex should be acyclic for the computation to be correct (currently)
""" """
def __init__(self, complex, maxdim=1, sublevel=True): def __init__(self, complex, maxdim=1, sublevel=True, alg='hom'):
super(LevelSetLayer, self).__init__() super(LevelSetLayer, self).__init__()
self.complex = complex self.complex = complex
self.maxdim = maxdim self.maxdim = maxdim
self.fnobj = SubLevelSetDiagram() self.fnobj = SubLevelSetDiagram()
self.sublevel = sublevel self.sublevel = sublevel
self.alg = alg
# make sure complex is initialized # make sure complex is initialized
self.complex.initialize() self.complex.initialize()
...@@ -32,11 +36,11 @@ class LevelSetLayer(nn.Module): ...@@ -32,11 +36,11 @@ class LevelSetLayer(nn.Module):
def forward(self, f): def forward(self, f):
if self.sublevel: if self.sublevel:
dgms = self.fnobj.apply(self.complex, f, self.maxdim) dgms = self.fnobj.apply(self.complex, f, self.maxdim, self.alg)
return dgms, True return dgms, True
else: else:
f = -f f = -f
dgms = self.fnobj.apply(self.complex, f, self.maxdim) dgms = self.fnobj.apply(self.complex, f, self.maxdim, self.alg)
dgms = tuple(-dgm for dgm in dgms) dgms = tuple(-dgm for dgm in dgms)
return dgms, False return dgms, False
...@@ -142,8 +146,11 @@ class LevelSetLayer2D(LevelSetLayer): ...@@ -142,8 +146,11 @@ class LevelSetLayer2D(LevelSetLayer):
"grid" - includes diagonals and anti-diagonals "grid" - includes diagonals and anti-diagonals
"delaunay" - scipy delaunay triangulation of the lattice. "delaunay" - scipy delaunay triangulation of the lattice.
Every square will be triangulated, but the diagonal orientation may not be consistent. Every square will be triangulated, but the diagonal orientation may not be consistent.
alg : algorithm
'hom' = homology (default)
'cohom' = cohomology
""" """
def __init__(self, size, maxdim=1, sublevel=True, complex="freudenthal"): def __init__(self, size, maxdim=1, sublevel=True, complex="freudenthal", alg='hom'):
width, height = size width, height = size
tmpcomplex = None tmpcomplex = None
if complex == "freudenthal": if complex == "freudenthal":
...@@ -152,7 +159,7 @@ class LevelSetLayer2D(LevelSetLayer): ...@@ -152,7 +159,7 @@ class LevelSetLayer2D(LevelSetLayer):
tmpcomplex = init_grid_2d(width, height) tmpcomplex = init_grid_2d(width, height)
elif complex == "delaunay": elif complex == "delaunay":
tmpcomplex = init_tri_complex(width, height) tmpcomplex = init_tri_complex(width, height)
super(LevelSetLayer2D, self).__init__(tmpcomplex, maxdim=maxdim, sublevel=sublevel) super(LevelSetLayer2D, self).__init__(tmpcomplex, maxdim=maxdim, sublevel=sublevel, alg=alg)
self.size = size self.size = size
...@@ -177,11 +184,15 @@ class LevelSetLayer1D(LevelSetLayer): ...@@ -177,11 +184,15 @@ class LevelSetLayer1D(LevelSetLayer):
Parameters: Parameters:
size : number of features size : number of features
sublevel : True=sublevel persistence, False=superlevel persistence sublevel : True=sublevel persistence, False=superlevel persistence
alg : algorithm
'hom' = homology (default)
'cohom' = cohomology
only returns H0 only returns H0
""" """
def __init__(self, size, sublevel=True): def __init__(self, size, sublevel=True, alg='hom'):
super(LevelSetLayer1D, self).__init__( super(LevelSetLayer1D, self).__init__(
init_line_complex(size), init_line_complex(size),
maxdim=0, maxdim=0,
sublevel=sublevel sublevel=sublevel,
alg=alg
) )
...@@ -11,14 +11,18 @@ class RipsLayer(nn.Module): ...@@ -11,14 +11,18 @@ class RipsLayer(nn.Module):
Parameters: Parameters:
n : number of points n : number of points
maxdim : maximum homology dimension (default=1) maxdim : maximum homology dimension (default=1)
alg : algorithm
'hom' = homology (default)
'cohom' = cohomology
""" """
def __init__(self, n, maxdim=1): def __init__(self, n, maxdim=1, alg='hom'):
super(RipsLayer, self).__init__() super(RipsLayer, self).__init__()
self.maxdim = maxdim self.maxdim = maxdim
self.complex = clique_complex(n, maxdim+1) self.complex = clique_complex(n, maxdim+1)
self.complex.initialize() self.complex.initialize()
self.fnobj = FlagDiagram() self.fnobj = FlagDiagram()
self.alg = alg
def forward(self, x): def forward(self, x):
dgms = self.fnobj.apply(self.complex, x, self.maxdim) dgms = self.fnobj.apply(self.complex, x, self.maxdim, self.alg)
return dgms, True return dgms, True
...@@ -17,3 +17,15 @@ def remove_zero_bars(dgm): ...@@ -17,3 +17,15 @@ def remove_zero_bars(dgm):
""" """
inds = dgm[:,0] != dgm[:,1] inds = dgm[:,0] != dgm[:,1]
return dgm[inds,:] return dgm[inds,:]
def remove_infinite_bars(dgm, issub):
"""
remove infinite bars from diagram
"""
if issub:
inds = dgm[:, 1] != np.inf
return dgm[inds,:]
else:
inds = dgm[:, 1] != -np.inf
return dgm[inds,:]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment