Skip to content
Snippets Groups Projects
Commit abefb043 authored by OskarK's avatar OskarK
Browse files

Resolved problem where unit tests for inference/train_model didn't work on GPU

parent 6fde7ea1
No related branches found
No related tags found
1 merge request!23Implementation of Deep Learning unit tests, as well as paths to the 2d data for windows users in the UNet jupyter notebook.
This commit is part of merge request !23. Comments created here will be created in the context of that merge request.
...@@ -66,13 +66,15 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 ...@@ -66,13 +66,15 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
inputs, targets = data inputs, targets = data
inputs = inputs.to(device) inputs = inputs.to(device)
if device == 'cuda': if torch.cuda.is_available():
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1) targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
else: else:
targets = targets.to(device).type(torch.FloatTensor).unsqueeze(1) targets = targets.to(device).type(torch.FloatTensor).unsqueeze(1)
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(inputs) outputs = model(inputs)
loss = criterion(outputs, targets) loss = criterion(outputs, targets)
# Backpropagation # Backpropagation
...@@ -99,7 +101,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1 ...@@ -99,7 +101,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
inputs, targets = data inputs, targets = data
inputs = inputs.to(device) inputs = inputs.to(device)
if device == 'cuda': if torch.cuda.is_available():
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1) targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
else: else:
targets = targets.to(device).type(torch.FloatTensor).unsqueeze(1) targets = targets.to(device).type(torch.FloatTensor).unsqueeze(1)
...@@ -207,6 +209,7 @@ def inference(data,model): ...@@ -207,6 +209,7 @@ def inference(data,model):
else: else:
raise ValueError("Input image must be (C,H,W) format") raise ValueError("Input image must be (C,H,W) format")
model.to(device)
model.eval() model.eval()
# Make new list such that possible augmentations remain identical for all three rows # Make new list such that possible augmentations remain identical for all three rows
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment