Skip to content
Snippets Groups Projects
Commit d6c656f7 authored by Alessia Saccardo's avatar Alessia Saccardo
Browse files

fix unit-test segmentation and utils

parent 569b8d86
No related branches found
No related tags found
2 merge requests!145Refactor tests for processing and adapt it to new library structure, plus fix...,!140Fix unit tests
This commit is part of merge request !145. Comments created here will be created in the context of that merge request.
import numpy as np
import qim3d.filters as filters
from qim3d.utils._logger import log
from qim3d.utils import log
__all__ = ["remove_background", "fade_mask", "overlay_rgb_images"]
......
......@@ -5,13 +5,13 @@ import pytest
# unit tests for Augmentation()
def test_augmentation():
augment_class = qim3d.models.Augmentation()
augment_class = qim3d.ml.Augmentation()
assert augment_class.resize == "crop"
def test_augment():
augment_class = qim3d.models.Augmentation()
augment_class = qim3d.ml.Augmentation()
album_augment = augment_class.augment(256, 256)
......@@ -26,11 +26,11 @@ def test_resize():
ValueError,
match=f"Invalid resize type: {resize_str}. Use either 'crop', 'resize' or 'padding'.",
):
augment_class = qim3d.models.Augmentation(resize=resize_str)
augment_class = qim3d.ml.Augmentation(resize=resize_str)
def test_levels():
augment_class = qim3d.models.Augmentation()
augment_class = qim3d.ml.Augmentation()
level = "Not a valid level"
......
......@@ -10,7 +10,7 @@ def test_dataset():
folder = 'folder_data'
temp_data(folder, img_shape = img_shape)
images = qim3d.models.Dataset(folder)
images = qim3d.ml.Dataset(folder)
assert images[0][0].shape == img_shape
......@@ -19,19 +19,19 @@ def test_dataset():
# unit tests for check_resize()
def test_check_resize():
h_adjust,w_adjust = qim3d.models.data.check_resize(240,240,resize = 'crop',n_channels = 6)
h_adjust,w_adjust = qim3d.ml._data.check_resize(240,240,resize = 'crop',n_channels = 6)
assert (h_adjust,w_adjust) == (192,192)
def test_check_resize_pad():
h_adjust,w_adjust = qim3d.models.data.check_resize(16,16,resize = 'padding',n_channels = 6)
h_adjust,w_adjust = qim3d.ml._data.check_resize(16,16,resize = 'padding',n_channels = 6)
assert (h_adjust,w_adjust) == (64,64)
def test_check_resize_fail():
with pytest.raises(ValueError,match="The size of the image is too small compared to the depth of the UNet. Choose a different 'resize' and/or a smaller model."):
h_adjust,w_adjust = qim3d.models.data.check_resize(16,16,resize = 'crop',n_channels = 6)
h_adjust,w_adjust = qim3d.ml._data.check_resize(16,16,resize = 'crop',n_channels = 6)
# unit tests for prepare_datasets()
......@@ -42,9 +42,9 @@ def test_prepare_datasets():
folder = 'folder_data'
img = temp_data(folder,n = n)
my_model = qim3d.models.UNet()
my_augmentation = qim3d.models.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.models.prepare_datasets(folder,validation,my_model,my_augmentation)
my_model = qim3d.ml.models.UNet()
my_augmentation = qim3d.ml.Augmentation(transform_test='light')
train_set, val_set, test_set = qim3d.ml.prepare_datasets(folder,validation,my_model,my_augmentation)
assert (len(train_set),len(val_set),len(test_set)) == (int((1-validation)*n), int(n*validation), n)
......@@ -56,7 +56,7 @@ def test_validation():
validation = 10
with pytest.raises(ValueError,match = "The validation fraction must be a float between 0 and 1."):
augment_class = qim3d.models.prepare_datasets('folder',validation,'my_model','my_augmentation')
augment_class = qim3d.ml.prepare_datasets('folder',validation,'my_model','my_augmentation')
# unit test for prepare_dataloaders()
......@@ -65,11 +65,11 @@ def test_prepare_dataloaders():
temp_data(folder)
batch_size = 1
my_model = qim3d.models.UNet()
my_augmentation = qim3d.models.Augmentation()
train_set, val_set, test_set = qim3d.models.prepare_datasets(folder,1/3,my_model,my_augmentation)
my_model = qim3d.ml.models.UNet()
my_augmentation = qim3d.ml.Augmentation()
train_set, val_set, test_set = qim3d.ml.prepare_datasets(folder,1/3,my_model,my_augmentation)
_,val_loader,_ = qim3d.models.prepare_dataloaders(train_set,val_set,test_set,
_,val_loader,_ = qim3d.ml.prepare_dataloaders(train_set,val_set,test_set,
batch_size,num_workers = 1,
pin_memory = False)
......
......@@ -4,12 +4,12 @@ doi = "https://doi.org/10.1007/s10851-021-01041-3"
def test_get_bibtex():
bibtext = qim3d.utils.doi.get_bibtex(doi)
bibtext = qim3d.utils._doi.get_bibtex(doi)
assert "Measuring Shape Relations Using r-Parallel Sets" in bibtext
def test_get_reference():
reference = qim3d.utils.doi.get_reference(doi)
reference = qim3d.utils._doi.get_reference(doi)
assert "Stephensen" in reference
......@@ -33,7 +33,7 @@ def test_get_local_ip():
else:
return False
local_ip = qim3d.utils.misc.get_local_ip()
local_ip = qim3d.utils._misc.get_local_ip()
assert validate_ip(local_ip) == True
......@@ -42,7 +42,7 @@ def test_stringify_path1():
"""Test that the function converts os.PathLike objects to strings"""
blobs_path = Path(qim3d.__file__).parents[0] / "img_examples" / "blobs_256x256.tif"
assert str(blobs_path) == qim3d.utils.misc.stringify_path(blobs_path)
assert str(blobs_path) == qim3d.utils._misc.stringify_path(blobs_path)
def test_stringify_path2():
......@@ -50,4 +50,4 @@ def test_stringify_path2():
# Create test_path
test_path = os.path.join("this", "path", "doesnt", "exist.tif")
assert test_path == qim3d.utils.misc.stringify_path(test_path)
assert test_path == qim3d.utils._misc.stringify_path(test_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment