diff --git a/README.md b/README.md
index a107eff2a7b53851789f775baa271c7dc350ac59..ec770729b52be85c323785b1c76a961b95ae28b3 100644
--- a/README.md
+++ b/README.md
@@ -368,6 +368,10 @@ __Warning:__ 0-simplices (vertices) are assumed to start with `[0]` and end at `
 
 __Warning:__ the persistence computation currently assumes that the complex is acyclic at the end of the filtration in order to precompute the number of barcode pairs.
 
+# Note on GPU use
+
+You can use `topologylayer` with tensors that are on GPUs. However, because the persistent homology calculation takes place on CPU, there will be some overhead from memory movement.
+
 
 # (Deprecated) Dionysus Drivers
 
diff --git a/topologylayer/functional/flag.py b/topologylayer/functional/flag.py
index 61a546605be54fd4144ae4f10f2f3b0067dd50c2..855bcd99eacf531e7cff74fa1a96a4b41a9d2b59 100644
--- a/topologylayer/functional/flag.py
+++ b/topologylayer/functional/flag.py
@@ -18,7 +18,10 @@ class FlagDiagram(Function):
     """
     @staticmethod
     def forward(ctx, X, y, maxdim, alg='hom'):
-        X.extendFlag(y)
+        device = y.device
+        ctx.device = device
+        ycpu = y.cpu()
+        X.extendFlag(ycpu)
         if alg == 'hom':
             ret = persistenceForwardHom(X, maxdim, 0)
         elif alg == 'hom2':
@@ -26,14 +29,16 @@ class FlagDiagram(Function):
         elif alg == 'cohom':
             ret = persistenceForwardCohom(X, maxdim)
         ctx.X = X
-        ctx.save_for_backward(y)
+        ctx.save_for_backward(ycpu)
+        ret = [r.to(device) for r in ret]
         return tuple(ret)
 
     @staticmethod
     def backward(ctx, *grad_dgms):
         # print(grad_dgms)
         X = ctx.X
-        y, = ctx.saved_tensors
-        grad_ret = list(grad_dgms)
-        grad_y = persistenceBackwardFlag(X, y, grad_ret)
-        return None, grad_y, None, None
+        device = ctx.device
+        ycpu, = ctx.saved_tensors
+        grad_ret = [gd.cpu() for gd in grad_dgms]
+        grad_y = persistenceBackwardFlag(X, ycpu, grad_ret)
+        return None, grad_y.to(device), None, None
diff --git a/topologylayer/functional/sublevel.py b/topologylayer/functional/sublevel.py
index a17d86e994da66c0fc849036758f1b34d5ba2b28..3d97e9a7bf626ffa7660202ea5a9bd78c995a415 100644
--- a/topologylayer/functional/sublevel.py
+++ b/topologylayer/functional/sublevel.py
@@ -21,7 +21,9 @@ class SubLevelSetDiagram(Function):
     def forward(ctx, X, f, maxdim, alg='hom'):
         ctx.retshape = f.shape
         f = f.view(-1)
-        X.extendFloat(f)
+        device = f.device
+        ctx.device = device
+        X.extendFloat(f.cpu())
         if alg == 'hom':
             ret = persistenceForwardHom(X, maxdim, 0)
         elif alg == 'hom2':
@@ -29,13 +31,15 @@ class SubLevelSetDiagram(Function):
         elif alg == 'cohom':
             ret = persistenceForwardCohom(X, maxdim)
         ctx.X = X
+        ret = [r.to(device) for r in ret]
         return tuple(ret)
 
     @staticmethod
     def backward(ctx, *grad_dgms):
         # print(grad_dgms)
         X = ctx.X
+        device = ctx.device
         retshape = ctx.retshape
-        grad_ret = list(grad_dgms)
+        grad_ret = [gd.cpu() for gd in grad_dgms]
         grad_f = persistenceBackward(X, grad_ret)
-        return None, grad_f.view(retshape), None, None
+        return None, grad_f.view(retshape).to(device), None, None