diff --git a/src/train.py b/src/train.py index c27895408daa6dc1cd94639a1e64ce022b9bbfdd..84b776995edb1c20d6bbdb78abbd266d31fe5c85 100644 --- a/src/train.py +++ b/src/train.py @@ -142,6 +142,7 @@ def main(): random.seed(RNG_SEED) np.random.seed(RNG_SEED) torch.manual_seed(RNG_SEED) + torch.cuda.manual_seed(RNG_SEED) cudnn.deterministic = True cudnn.benchmark = False