Skip to content
Snippets Groups Projects
Unverified Commit 1d626101 authored by Rickard Brüel Gabrielsson's avatar Rickard Brüel Gabrielsson Committed by GitHub
Browse files

Merge pull request #16 from bnels/gpufix

Allow tensors to come from GPU
parents 8a9bad5b 300d82eb
No related branches found
No related tags found
No related merge requests found
......@@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at `
__Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs.
# Note on GPU use
You can use `topologylayer` with tensors that are on GPUs. However, because the persistent homology calculation takes place on CPU, there will be some overhead from memory movement.
# (Deprecated) Dionysus Drivers
......
......@@ -18,7 +18,10 @@ class FlagDiagram(Function):
"""
@staticmethod
def forward(ctx, X, y, maxdim, alg='hom'):
X.extendFlag(y)
device = y.device
ctx.device = device
ycpu = y.cpu()
X.extendFlag(ycpu)
if alg == 'hom':
ret = persistenceForwardHom(X, maxdim, 0)
elif alg == 'hom2':
......@@ -26,14 +29,16 @@ class FlagDiagram(Function):
elif alg == 'cohom':
ret = persistenceForwardCohom(X, maxdim)
ctx.X = X
ctx.save_for_backward(y)
ctx.save_for_backward(ycpu)
ret = [r.to(device) for r in ret]
return tuple(ret)
@staticmethod
def backward(ctx, *grad_dgms):
# print(grad_dgms)
X = ctx.X
y, = ctx.saved_tensors
grad_ret = list(grad_dgms)
grad_y = persistenceBackwardFlag(X, y, grad_ret)
return None, grad_y, None, None
device = ctx.device
ycpu, = ctx.saved_tensors
grad_ret = [gd.cpu() for gd in grad_dgms]
grad_y = persistenceBackwardFlag(X, ycpu, grad_ret)
return None, grad_y.to(device), None, None
......@@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function):
def forward(ctx, X, f, maxdim, alg='hom'):
ctx.retshape = f.shape
f = f.view(-1)
X.extendFloat(f)
device = f.device
ctx.device = device
X.extendFloat(f.cpu())
if alg == 'hom':
ret = persistenceForwardHom(X, maxdim, 0)
elif alg == 'hom2':
......@@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function):
elif alg == 'cohom':
ret = persistenceForwardCohom(X, maxdim)
ctx.X = X
ret = [r.to(device) for r in ret]
return tuple(ret)
@staticmethod
def backward(ctx, *grad_dgms):
# print(grad_dgms)
X = ctx.X
device = ctx.device
retshape = ctx.retshape
grad_ret = list(grad_dgms)
grad_ret = [gd.cpu() for gd in grad_dgms]
grad_f = persistenceBackward(X, grad_ret)
return None, grad_f.view(retshape), None, None
return None, grad_f.view(retshape).to(device), None, None
......@@ -55,7 +55,7 @@ class AlphaLayer(nn.Module):
self.alg = alg
def forward(self, x):
xnp = x.data.numpy()
xnp = x.cpu().detach().numpy()
complex = None
if xnp.shape[1] == 1:
xnp = xnp.flatten()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment