Skip to content
Snippets Groups Projects
Commit 7860f7af authored by Brad Nelson's avatar Brad Nelson
Browse files

move tensors from gpu if appropriate

parent 8a9bad5b
Branches
No related tags found
No related merge requests found
......@@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at `
__Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs.
# Note on GPU use
You can use `topologylayer` with tensors that are on GPUs. However, because the persistent homology calculation takes place on CPU, there will be some overhead from memory movement.
# (Deprecated) Dionysus Drivers
......
......@@ -18,7 +18,10 @@ class FlagDiagram(Function):
"""
@staticmethod
def forward(ctx, X, y, maxdim, alg='hom'):
X.extendFlag(y)
device = y.device
ctx.device = device
ycpu = y.cpu()
X.extendFlag(ycpu)
if alg == 'hom':
ret = persistenceForwardHom(X, maxdim, 0)
elif alg == 'hom2':
......@@ -26,14 +29,16 @@ class FlagDiagram(Function):
elif alg == 'cohom':
ret = persistenceForwardCohom(X, maxdim)
ctx.X = X
ctx.save_for_backward(y)
ctx.save_for_backward(ycpu)
ret = [r.to(device) for r in ret]
return tuple(ret)
@staticmethod
def backward(ctx, *grad_dgms):
# print(grad_dgms)
X = ctx.X
y, = ctx.saved_tensors
grad_ret = list(grad_dgms)
grad_y = persistenceBackwardFlag(X, y, grad_ret)
return None, grad_y, None, None
device = ctx.device
ycpu, = ctx.saved_tensors
grad_ret = [gd.cpu() for gd in grad_dgms]
grad_y = persistenceBackwardFlag(X, ycpu, grad_ret)
return None, grad_y.to(device), None, None
......@@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function):
def forward(ctx, X, f, maxdim, alg='hom'):
ctx.retshape = f.shape
f = f.view(-1)
X.extendFloat(f)
device = f.device
ctx.device = device
X.extendFloat(f.cpu())
if alg == 'hom':
ret = persistenceForwardHom(X, maxdim, 0)
elif alg == 'hom2':
......@@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function):
elif alg == 'cohom':
ret = persistenceForwardCohom(X, maxdim)
ctx.X = X
ret = [r.to(device) for r in ret]
return tuple(ret)
@staticmethod
def backward(ctx, *grad_dgms):
# print(grad_dgms)
X = ctx.X
device = ctx.device
retshape = ctx.retshape
grad_ret = list(grad_dgms)
grad_ret = [gd.cpu() for gd in grad_dgms]
grad_f = persistenceBackward(X, grad_ret)
return None, grad_f.view(retshape), None, None
return None, grad_f.view(retshape).to(device), None, None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment