diff --git a/README.md b/README.md index a107eff2a7b53851789f775baa271c7dc350ac59..ec770729b52be85c323785b1c76a961b95ae28b3 100644 --- a/README.md +++ b/README.md @@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at ` __Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs. +# Note on GPU use + +You can use `topologylayer` with tensors that are on GPUs. However, because the persistent homology calculation takes place on CPU, there will be some overhead from memory movement. + # (Deprecated) Dionysus Drivers diff --git a/topologylayer/functional/flag.py b/topologylayer/functional/flag.py index 61a546605be54fd4144ae4f10f2f3b0067dd50c2..855bcd99eacf531e7cff74fa1a96a4b41a9d2b59 100644 --- a/topologylayer/functional/flag.py +++ b/topologylayer/functional/flag.py @@ -18,7 +18,10 @@ class FlagDiagram(Function): """ @staticmethod def forward(ctx, X, y, maxdim, alg='hom'): - X.extendFlag(y) + device = y.device + ctx.device = device + ycpu = y.cpu() + X.extendFlag(ycpu) if alg == 'hom': ret = persistenceForwardHom(X, maxdim, 0) elif alg == 'hom2': @@ -26,14 +29,16 @@ class FlagDiagram(Function): elif alg == 'cohom': ret = persistenceForwardCohom(X, maxdim) ctx.X = X - ctx.save_for_backward(y) + ctx.save_for_backward(ycpu) + ret = [r.to(device) for r in ret] return tuple(ret) @staticmethod def backward(ctx, *grad_dgms): # print(grad_dgms) X = ctx.X - y, = ctx.saved_tensors - grad_ret = list(grad_dgms) - grad_y = persistenceBackwardFlag(X, y, grad_ret) - return None, grad_y, None, None + device = ctx.device + ycpu, = ctx.saved_tensors + grad_ret = [gd.cpu() for gd in grad_dgms] + grad_y = persistenceBackwardFlag(X, ycpu, grad_ret) + return None, grad_y.to(device), None, None diff --git a/topologylayer/functional/sublevel.py b/topologylayer/functional/sublevel.py index a17d86e994da66c0fc849036758f1b34d5ba2b28..3d97e9a7bf626ffa7660202ea5a9bd78c995a415 100644 --- a/topologylayer/functional/sublevel.py +++ b/topologylayer/functional/sublevel.py @@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function): def forward(ctx, X, f, maxdim, alg='hom'): ctx.retshape = f.shape f = f.view(-1) - X.extendFloat(f) + device = f.device + ctx.device = device + X.extendFloat(f.cpu()) if alg == 'hom': ret = persistenceForwardHom(X, maxdim, 0) elif alg == 'hom2': @@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function): elif alg == 'cohom': ret = persistenceForwardCohom(X, maxdim) ctx.X = X + ret = [r.to(device) for r in ret] return tuple(ret) @staticmethod def backward(ctx, *grad_dgms): # print(grad_dgms) X = ctx.X + device = ctx.device retshape = ctx.retshape - grad_ret = list(grad_dgms) + grad_ret = [gd.cpu() for gd in grad_dgms] grad_f = persistenceBackward(X, grad_ret) - return None, grad_f.view(retshape), None, None + return None, grad_f.view(retshape).to(device), None, None