Skip to content
Snippets Groups Projects
Commit 00e8f833 authored by Brad Nelson's avatar Brad Nelson
Browse files

update unit tests for permuation invariance

parent 1d619848
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ import unittest
import topologylayer
import torch
import numpy as np
from topologylayer.util.process import remove_zero_bars
from topologylayer.util.process import remove_zero_bars, remove_infinite_bars
class AlphaTest(unittest.TestCase):
def test(self):
......@@ -19,8 +19,8 @@ class AlphaTest(unittest.TestCase):
True,
"Expected sublevel set layer")
self.assertEqual(
torch.all(torch.eq(remove_zero_bars(dgms[0]),
torch.tensor([[0., 2.], [0., 2.], [0., 2.], [0., np.inf]]))),
torch.all(torch.eq(remove_infinite_bars(remove_zero_bars(dgms[0]), issub),
torch.tensor([[0., 2.], [0., 2.], [0., 2.]]))),
True,
"unexpected 0-dim barcode")
self.assertEqual(
......@@ -29,12 +29,12 @@ class AlphaTest(unittest.TestCase):
True,
"unexpected 1-dim barcode")
d0 = remove_zero_bars(dgms[0])
p = d0[0, 1] - d0[0, 0]
d0 = remove_infinite_bars(remove_zero_bars(dgms[0]), issub)
p = torch.sum(d0[:, 1] - d0[:, 0])
p.backward()
self.assertEqual(
torch.all(torch.eq(x.grad,
torch.tensor([[0,1],[0,-1],[0,0],[0,0]], dtype=torch.float))),
torch.tensor([[1,1],[1,-1],[-1,0],[-1,0]], dtype=torch.float))),
True,
"unexpected gradient")
......@@ -3,7 +3,7 @@ import unittest
import topologylayer
import torch
import numpy as np
from topologylayer.util.process import remove_zero_bars
from topologylayer.util.process import remove_zero_bars, remove_infinite_bars
class RipsTest(unittest.TestCase):
def test(self):
......@@ -19,8 +19,8 @@ class RipsTest(unittest.TestCase):
True,
"Expected sublevel set layer")
self.assertEqual(
torch.all(torch.eq(remove_zero_bars(dgms[0]),
torch.tensor([[0., 2.], [0., 2.], [0., 2.], [0., np.inf]]))),
torch.all(torch.eq(remove_infinite_bars(remove_zero_bars(dgms[0]), issub),
torch.tensor([[0., 2.], [0., 2.], [0., 2.]]))),
True,
"unexpected 0-dim barcode")
self.assertEqual(
......@@ -29,12 +29,12 @@ class RipsTest(unittest.TestCase):
True,
"unexpected 1-dim barcode")
d0 = remove_zero_bars(dgms[0])
p = d0[0, 1] - d0[0, 0]
d0 = remove_infinite_bars(remove_zero_bars(dgms[0]), issub)
p = torch.sum(d0[:, 1] - d0[:, 0])
p.backward()
self.assertEqual(
torch.all(torch.eq(x.grad,
torch.tensor([[0,1],[0,-1],[0,0],[0,0]], dtype=torch.float))),
torch.tensor([[1,1],[1,-1],[-1,0],[-1,0]], dtype=torch.float))),
True,
"unexpected gradient")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment