From 7860f7af4ac2796d744cc3a48f18995504e60a59 Mon Sep 17 00:00:00 2001
From: Brad Nelson <bjnelson@stanford.edu>
Date: Thu, 25 Jul 2019 15:02:47 -0700
Subject: [PATCH] move tensors from gpu if appropriate

---
 README.md                            |  4 ++++
 topologylayer/functional/flag.py     | 17 +++++++++++------
 topologylayer/functional/sublevel.py | 10 +++++++---
 3 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/README.md b/README.md
index a107eff..ec77072 100644
--- a/README.md
+++ b/README.md
@@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at `
 
 __Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs.
 
+# Note on GPU use
+
+You can use `topologylayer` with tensors that are on GPUs. However, because the persistent homology calculation takes place on CPU, there will be some overhead from memory movement.
+
 
 # (Deprecated) Dionysus Drivers
 
diff --git a/topologylayer/functional/flag.py b/topologylayer/functional/flag.py
index 61a5466..855bcd9 100644
--- a/topologylayer/functional/flag.py
+++ b/topologylayer/functional/flag.py
@@ -18,7 +18,10 @@ class FlagDiagram(Function):
     """
     @staticmethod
     def forward(ctx, X, y, maxdim, alg='hom'):
-        X.extendFlag(y)
+        device = y.device
+        ctx.device = device
+        ycpu = y.cpu()
+        X.extendFlag(ycpu)
         if alg == 'hom':
             ret = persistenceForwardHom(X, maxdim, 0)
         elif alg == 'hom2':
@@ -26,14 +29,16 @@ class FlagDiagram(Function):
         elif alg == 'cohom':
             ret = persistenceForwardCohom(X, maxdim)
         ctx.X = X
-        ctx.save_for_backward(y)
+        ctx.save_for_backward(ycpu)
+        ret = [r.to(device) for r in ret]
         return tuple(ret)
 
     @staticmethod
     def backward(ctx, *grad_dgms):
         # print(grad_dgms)
         X = ctx.X
-        y, = ctx.saved_tensors
-        grad_ret = list(grad_dgms)
-        grad_y = persistenceBackwardFlag(X, y, grad_ret)
-        return None, grad_y, None, None
+        device = ctx.device
+        ycpu, = ctx.saved_tensors
+        grad_ret = [gd.cpu() for gd in grad_dgms]
+        grad_y = persistenceBackwardFlag(X, ycpu, grad_ret)
+        return None, grad_y.to(device), None, None
diff --git a/topologylayer/functional/sublevel.py b/topologylayer/functional/sublevel.py
index a17d86e..3d97e9a 100644
--- a/topologylayer/functional/sublevel.py
+++ b/topologylayer/functional/sublevel.py
@@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function):
     def forward(ctx, X, f, maxdim, alg='hom'):
         ctx.retshape = f.shape
         f = f.view(-1)
-        X.extendFloat(f)
+        device = f.device
+        ctx.device = device
+        X.extendFloat(f.cpu())
         if alg == 'hom':
             ret = persistenceForwardHom(X, maxdim, 0)
         elif alg == 'hom2':
@@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function):
         elif alg == 'cohom':
             ret = persistenceForwardCohom(X, maxdim)
         ctx.X = X
+        ret = [r.to(device) for r in ret]
         return tuple(ret)
 
     @staticmethod
     def backward(ctx, *grad_dgms):
         # print(grad_dgms)
         X = ctx.X
+        device = ctx.device
         retshape = ctx.retshape
-        grad_ret = list(grad_dgms)
+        grad_ret = [gd.cpu() for gd in grad_dgms]
         grad_f = persistenceBackward(X, grad_ret)
-        return None, grad_f.view(retshape), None, None
+        return None, grad_f.view(retshape).to(device), None, None
-- 
GitLab