Skip to content
Snippets Groups Projects
Commit 7860f7af authored by Brad Nelson's avatar Brad Nelson
Browse files

move tensors from gpu if appropriate

parent 8a9bad5b
No related branches found
No related tags found
No related merge requests found
...@@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at ` ...@@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at `
__Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs. __Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs.
# Note on GPU use
You can use `topologylayer` with tensors that are on GPUs. However, because the persistent homology calculation takes place on CPU, there will be some overhead from memory movement.
# (Deprecated) Dionysus Drivers # (Deprecated) Dionysus Drivers
......
...@@ -18,7 +18,10 @@ class FlagDiagram(Function): ...@@ -18,7 +18,10 @@ class FlagDiagram(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, X, y, maxdim, alg='hom'): def forward(ctx, X, y, maxdim, alg='hom'):
X.extendFlag(y) device = y.device
ctx.device = device
ycpu = y.cpu()
X.extendFlag(ycpu)
if alg == 'hom': if alg == 'hom':
ret = persistenceForwardHom(X, maxdim, 0) ret = persistenceForwardHom(X, maxdim, 0)
elif alg == 'hom2': elif alg == 'hom2':
...@@ -26,14 +29,16 @@ class FlagDiagram(Function): ...@@ -26,14 +29,16 @@ class FlagDiagram(Function):
elif alg == 'cohom': elif alg == 'cohom':
ret = persistenceForwardCohom(X, maxdim) ret = persistenceForwardCohom(X, maxdim)
ctx.X = X ctx.X = X
ctx.save_for_backward(y) ctx.save_for_backward(ycpu)
ret = [r.to(device) for r in ret]
return tuple(ret) return tuple(ret)
@staticmethod @staticmethod
def backward(ctx, *grad_dgms): def backward(ctx, *grad_dgms):
# print(grad_dgms) # print(grad_dgms)
X = ctx.X X = ctx.X
y, = ctx.saved_tensors device = ctx.device
grad_ret = list(grad_dgms) ycpu, = ctx.saved_tensors
grad_y = persistenceBackwardFlag(X, y, grad_ret) grad_ret = [gd.cpu() for gd in grad_dgms]
return None, grad_y, None, None grad_y = persistenceBackwardFlag(X, ycpu, grad_ret)
return None, grad_y.to(device), None, None
...@@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function): ...@@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function):
def forward(ctx, X, f, maxdim, alg='hom'): def forward(ctx, X, f, maxdim, alg='hom'):
ctx.retshape = f.shape ctx.retshape = f.shape
f = f.view(-1) f = f.view(-1)
X.extendFloat(f) device = f.device
ctx.device = device
X.extendFloat(f.cpu())
if alg == 'hom': if alg == 'hom':
ret = persistenceForwardHom(X, maxdim, 0) ret = persistenceForwardHom(X, maxdim, 0)
elif alg == 'hom2': elif alg == 'hom2':
...@@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function): ...@@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function):
elif alg == 'cohom': elif alg == 'cohom':
ret = persistenceForwardCohom(X, maxdim) ret = persistenceForwardCohom(X, maxdim)
ctx.X = X ctx.X = X
ret = [r.to(device) for r in ret]
return tuple(ret) return tuple(ret)
@staticmethod @staticmethod
def backward(ctx, *grad_dgms): def backward(ctx, *grad_dgms):
# print(grad_dgms) # print(grad_dgms)
X = ctx.X X = ctx.X
device = ctx.device
retshape = ctx.retshape retshape = ctx.retshape
grad_ret = list(grad_dgms) grad_ret = [gd.cpu() for gd in grad_dgms]
grad_f = persistenceBackward(X, grad_ret) grad_f = persistenceBackward(X, grad_ret)
return None, grad_f.view(retshape), None, None return None, grad_f.view(retshape).to(device), None, None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment