Skip to content
Snippets Groups Projects
Commit 1307431c authored by Alessia Saccardo's avatar Alessia Saccardo
Browse files

fix unit tests on ml and models and reorganise test directories

parent fc415439
No related branches found
No related tags found
2 merge requests!145Refactor tests for processing and adapt it to new library structure, plus fix...,!140Fix unit tests
This commit is part of merge request !145. Comments created here will be created in the context of that merge request.
"""Provides a custom Dataset class for building a PyTorch dataset."""
from pathlib import Path
from PIL import Image
from qim3d.utils._logger import log
from qim3d.utils import log
import torch
import numpy as np
......
......@@ -4,8 +4,8 @@ import torch
import numpy as np
from torchinfo import summary
from qim3d.utils._logger import log
from qim3d.viz._metrics import plot_metrics
from qim3d.utils import log
from qim3d.viz import plot_metrics
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
......
......@@ -2,7 +2,7 @@
import torch.nn as nn
from qim3d.utils._logger import log
from qim3d.utils import log
class UNet(nn.Module):
......
......@@ -3,13 +3,13 @@ import torch
# unit tests for UNet()
def test_starting_unet():
unet = qim3d.models.UNet()
unet = qim3d.ml.models.UNet()
assert unet.size == 'medium'
def test_forward_pass():
unet = qim3d.models.UNet()
unet = qim3d.ml.models.UNet()
# Size: B x C x H x W
x = torch.ones([1,1,256,256])
......@@ -19,14 +19,14 @@ def test_forward_pass():
# unit tests for Hyperparameters()
def test_hyper():
unet = qim3d.models.UNet()
hyperparams = qim3d.models.Hyperparameters(unet)
unet = qim3d.ml.models.UNet()
hyperparams = qim3d.ml.models.Hyperparameters(unet)
assert hyperparams.n_epochs == 10
def test_hyper_dict():
unet = qim3d.models.UNet()
hyperparams = qim3d.models.Hyperparameters(unet)
unet = qim3d.ml.models.UNet()
hyperparams = qim3d.ml.models.Hyperparameters(unet)
hyper_dict = hyperparams()
......
......@@ -12,16 +12,16 @@ def test_model_summary():
folder = "folder_data"
temp_data(folder, img_shape=img_shape, n=n)
unet = qim3d.models.UNet(size="small")
augment = qim3d.models.Augmentation(transform_train=None)
train_set, val_set, test_set = qim3d.models.prepare_datasets(
unet = qim3d.ml.models.UNet(size="small")
augment = qim3d.ml.Augmentation(transform_train=None)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
folder, 1 / 3, unet, augment
)
_, val_loader, _ = qim3d.models.prepare_dataloaders(
_, val_loader, _ = qim3d.ml.prepare_dataloaders(
train_set, val_set, test_set, batch_size=1, num_workers=1, pin_memory=False
)
summary = qim3d.models.model_summary(val_loader, unet)
summary = qim3d.ml.model_summary(val_loader, unet)
assert summary.input_size[0] == (1, 1) + img_shape
......@@ -33,11 +33,11 @@ def test_inference():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
augment = qim3d.models.Augmentation(transform_train=None)
train_set, _, _ = qim3d.models.prepare_datasets(folder, 1 / 3, unet, augment)
unet = qim3d.ml.models.UNet(size="small")
augment = qim3d.ml.Augmentation(transform_train=None)
train_set, _, _ = qim3d.ml.prepare_datasets(folder, 1 / 3, unet, augment)
_, targ, _ = qim3d.models.inference(train_set, unet)
_, targ, _ = qim3d.ml.inference(train_set, unet)
assert tuple(targ[0].unique()) == (0, 1)
......@@ -49,11 +49,11 @@ def test_inference_tuple():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
unet = qim3d.ml.models.UNet(size="small")
data = [1, 2, 3]
with pytest.raises(ValueError, match="Data items must be tuples"):
qim3d.models.inference(data, unet)
qim3d.ml.inference(data, unet)
temp_data(folder, remove=True)
......@@ -63,11 +63,11 @@ def test_inference_tensor():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
unet = qim3d.ml.models.UNet(size="small")
data = [(1, 2)]
with pytest.raises(ValueError, match="Data items must consist of tensors"):
qim3d.models.inference(data, unet)
qim3d.ml.inference(data, unet)
temp_data(folder, remove=True)
......@@ -77,12 +77,12 @@ def test_inference_dim():
folder = "folder_data"
temp_data(folder)
unet = qim3d.models.UNet(size="small")
unet = qim3d.ml.models.UNet(size="small")
data = [(ones(1), ones(1))]
# need the r"" for special characters
with pytest.raises(ValueError, match=r"Input image must be \(C,H,W\) format"):
qim3d.models.inference(data, unet)
qim3d.ml.inference(data, unet)
temp_data(folder, remove=True)
......@@ -94,17 +94,17 @@ def test_train_model():
n_epochs = 1
unet = qim3d.models.UNet(size="small")
augment = qim3d.models.Augmentation(transform_train=None)
hyperparams = qim3d.models.Hyperparameters(unet, n_epochs=n_epochs)
train_set, val_set, test_set = qim3d.models.prepare_datasets(
unet = qim3d.ml.models.UNet(size="small")
augment = qim3d.ml.Augmentation(transform_train=None)
hyperparams = qim3d.ml.Hyperparameters(unet, n_epochs=n_epochs)
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
folder, 1 / 3, unet, augment
)
train_loader, val_loader, _ = qim3d.models.prepare_dataloaders(
train_loader, val_loader, _ = qim3d.ml.prepare_dataloaders(
train_set, val_set, test_set, batch_size=1, num_workers=1, pin_memory=False
)
train_loss, _ = qim3d.models.train_model(
train_loss, _ = qim3d.ml.train_model(
unet, hyperparams, train_loader, val_loader, plot=False, return_loss=True
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment