Skip to content
Snippets Groups Projects
Commit 8844f5c7 authored by OskarK's avatar OskarK
Browse files

removing unecessary steps in training loop.

parent abefb043
No related branches found
No related tags found
1 merge request!23Implementation of Deep Learning unit tests, as well as paths to the 2d data for windows users in the UNet jupyter notebook.
This commit is part of merge request !23. Comments created here will be created in the context of that merge request.
......@@ -21,6 +21,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
Returns:
tuple:
train_loss (dict): dictionary with average losses and batch losses for training loop.
......@@ -65,12 +66,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
for data in train_loader:
inputs, targets = data
inputs = inputs.to(device)
if torch.cuda.is_available():
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
else:
targets = targets.to(device).type(torch.FloatTensor).unsqueeze(1)
targets = targets.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(inputs)
......@@ -100,11 +96,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
for data in val_loader:
inputs, targets = data
inputs = inputs.to(device)
if torch.cuda.is_available():
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
else:
targets = targets.to(device).type(torch.FloatTensor).unsqueeze(1)
targets = targets.to(device).unsqueeze(1)
with torch.no_grad():
outputs = model(inputs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment