From 7860f7af4ac2796d744cc3a48f18995504e60a59 Mon Sep 17 00:00:00 2001 From: Brad Nelson <bjnelson@stanford.edu> Date: Thu, 25 Jul 2019 15:02:47 -0700 Subject: [PATCH] move tensors from gpu if appropriate --- README.md | 4 ++++ topologylayer/functional/flag.py | 17 +++++++++++------ topologylayer/functional/sublevel.py | 10 +++++++--- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index a107eff..ec77072 100644 --- a/README.md +++ b/README.md @@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at ` __Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs. +# Note on GPU use + +You can use `topologylayer` with tensors that are on GPUs. However, because the persistent homology calculation takes place on CPU, there will be some overhead from memory movement. + # (Deprecated) Dionysus Drivers diff --git a/topologylayer/functional/flag.py b/topologylayer/functional/flag.py index 61a5466..855bcd9 100644 --- a/topologylayer/functional/flag.py +++ b/topologylayer/functional/flag.py @@ -18,7 +18,10 @@ class FlagDiagram(Function): """ @staticmethod def forward(ctx, X, y, maxdim, alg='hom'): - X.extendFlag(y) + device = y.device + ctx.device = device + ycpu = y.cpu() + X.extendFlag(ycpu) if alg == 'hom': ret = persistenceForwardHom(X, maxdim, 0) elif alg == 'hom2': @@ -26,14 +29,16 @@ class FlagDiagram(Function): elif alg == 'cohom': ret = persistenceForwardCohom(X, maxdim) ctx.X = X - ctx.save_for_backward(y) + ctx.save_for_backward(ycpu) + ret = [r.to(device) for r in ret] return tuple(ret) @staticmethod def backward(ctx, *grad_dgms): # print(grad_dgms) X = ctx.X - y, = ctx.saved_tensors - grad_ret = list(grad_dgms) - grad_y = persistenceBackwardFlag(X, y, grad_ret) - return None, grad_y, None, None + device = ctx.device + ycpu, = ctx.saved_tensors + grad_ret = [gd.cpu() for gd in grad_dgms] + grad_y = persistenceBackwardFlag(X, ycpu, grad_ret) + return None, grad_y.to(device), None, None diff --git a/topologylayer/functional/sublevel.py b/topologylayer/functional/sublevel.py index a17d86e..3d97e9a 100644 --- a/topologylayer/functional/sublevel.py +++ b/topologylayer/functional/sublevel.py @@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function): def forward(ctx, X, f, maxdim, alg='hom'): ctx.retshape = f.shape f = f.view(-1) - X.extendFloat(f) + device = f.device + ctx.device = device + X.extendFloat(f.cpu()) if alg == 'hom': ret = persistenceForwardHom(X, maxdim, 0) elif alg == 'hom2': @@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function): elif alg == 'cohom': ret = persistenceForwardCohom(X, maxdim) ctx.X = X + ret = [r.to(device) for r in ret] return tuple(ret) @staticmethod def backward(ctx, *grad_dgms): # print(grad_dgms) X = ctx.X + device = ctx.device retshape = ctx.retshape - grad_ret = list(grad_dgms) + grad_ret = [gd.cpu() for gd in grad_dgms] grad_f = persistenceBackward(X, grad_ret) - return None, grad_f.view(retshape), None, None + return None, grad_f.view(retshape).to(device), None, None -- GitLab