Skip to content
Snippets Groups Projects
Commit 406a5c49 authored by Alessia Saccardo's avatar Alessia Saccardo
Browse files

fix unit tests test_img, need to fix on visualization

parent d6c656f7
No related branches found
No related tags found
2 merge requests!145Refactor tests for processing and adapt it to new library structure, plus fix...,!140Fix unit tests
This commit is part of merge request !140. Comments created here will be created in the context of that merge request.
......@@ -33,7 +33,7 @@ from skimage.transform import (
from qim3d.utils import log
from qim3d.utils import OmeZarrExportProgressBar
from qim3d.utils import get_n_chunks
from qim3d.utils._ome_zarr import get_n_chunks
ListOfArrayLike = Union[List[da.Array], List[np.ndarray]]
......
......@@ -36,11 +36,11 @@ def test_grid_pred():
n = 4
temp_data(folder, n=n)
model = qim3d.models.UNet()
augmentation = qim3d.models.Augmentation()
train_set, _, _ = qim3d.models.prepare_datasets(folder, 0.1, model, augmentation)
model = qim3d.ml.models.UNet()
augmentation = qim3d.ml.Augmentation()
train_set, _, _ = qim3d.ml.prepare_datasets(folder, 0.1, model, augmentation)
in_targ_pred = qim3d.models.inference(train_set, model)
in_targ_pred = qim3d.ml.inference(train_set, model)
fig = qim3d.viz.grid_pred(in_targ_pred)
......@@ -52,7 +52,7 @@ def test_grid_pred():
# unit tests for slices function
def test_slices_numpy_array_input():
example_volume = np.ones((10, 10, 10))
fig = qim3d.viz.slices_grid(example_volume, n_slices=1)
fig = qim3d.viz.slices_grid(example_volume, num_slices=1)
assert isinstance(fig, plt.Figure)
......@@ -77,7 +77,7 @@ def test_slices_wrong_position_format1():
ValueError,
match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".',
):
qim3d.viz.slices_grid(example_volume, position="invalid_slice")
qim3d.viz.slices_grid(example_volume, slice_positions="invalid_slice")
def test_slices_wrong_position_format2():
......@@ -86,7 +86,7 @@ def test_slices_wrong_position_format2():
ValueError,
match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".',
):
qim3d.viz.slices_grid(example_volume, position=1.5)
qim3d.viz.slices_grid(example_volume, slice_positions=1.5)
def test_slices_wrong_position_format3():
......@@ -95,16 +95,16 @@ def test_slices_wrong_position_format3():
ValueError,
match='Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".',
):
qim3d.viz.slices_grid(example_volume, position=[1, 2, 3.5])
qim3d.viz.slices_grid(example_volume, slice_positions=[1, 2, 3.5])
def test_slices_invalid_axis_value():
example_volume = np.ones((10, 10, 10))
with pytest.raises(
ValueError,
match="Invalid value for 'axis'. It should be an integer between 0 and 2",
match=f"Invalid value for 'slice_axis'. It should be an integer between 0 and {example_volume.ndim - 1}.",
):
qim3d.viz.slices_grid(example_volume, axis=3)
qim3d.viz.slices_grid(example_volume, slice_axis=3)
def test_slices_interpolation_option():
......@@ -113,8 +113,8 @@ def test_slices_interpolation_option():
interpolation_method = "bilinear"
fig = qim3d.viz.slices_grid(
example_volume,
n_slices=1,
img_width=img_width,
num_slices=1,
image_width=img_width,
interpolation=interpolation_method,
)
......@@ -128,27 +128,27 @@ def test_slices_interpolation_option():
def test_slices_multiple_slices():
example_volume = np.ones((10, 10, 10))
img_width = 3
n_slices = 3
fig = qim3d.viz.slices_grid(example_volume, n_slices=n_slices, img_width=img_width)
image_width = 3
num_slices = 3
fig = qim3d.viz.slices_grid(example_volume, num_slices=num_slices, image_width=image_width)
# Add assertions for the expected number of subplots in the figure
assert len(fig.get_axes()) == n_slices
assert len(fig.get_axes()) == num_slices
def test_slices_axis_argument():
# Non-symmetric input
example_volume = np.arange(1000).reshape((10, 10, 10))
img_width = 3
image_width = 3
# Call the function with different values of the axis
fig_axis_0 = qim3d.viz.slices_grid(
example_volume, n_slices=1, img_width=img_width, axis=0
example_volume, num_slices=1, image_width=image_width, slice_axis=0
)
fig_axis_1 = qim3d.viz.slices_grid(
example_volume, n_slices=1, img_width=img_width, axis=1
example_volume, num_slices=1, image_width=image_width, slice_axis=1
)
fig_axis_2 = qim3d.viz.slices_grid(
example_volume, n_slices=1, img_width=img_width, axis=2
example_volume, num_slices=1, image_width=image_width, slice_axis=2
)
# Ensure that different axes result in different plots
......@@ -216,7 +216,7 @@ def test_orthogonal_with_numpy_array():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
orthogonal_obj = qim3d.viz.slicer_orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
......@@ -225,28 +225,28 @@ def test_orthogonal_with_torch_tensor():
# Create a sample PyTorch tensor
vol = torch.rand(10, 10, 10)
# Call the orthogonal function with the PyTorch tensor
orthogonal_obj = qim3d.viz.orthogonal(vol)
orthogonal_obj = qim3d.viz.slicer_orthogonal(vol)
# Assert that the orthogonal object is created successfully
assert isinstance(orthogonal_obj, widgets.HBox)
def test_orthogonal_with_different_parameters():
# Test with different colormaps
for cmap in ["viridis", "gray", "plasma"]:
orthogonal_obj = qim3d.viz.orthogonal(np.random.rand(10, 10, 10), cmap=cmap)
for color_map in ["viridis", "gray", "plasma"]:
orthogonal_obj = qim3d.viz.slicer_orthogonal(np.random.rand(10, 10, 10), color_map=color_map)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with different image sizes
for img_height, img_width in [(2, 2), (4, 4)]:
orthogonal_obj = qim3d.viz.orthogonal(
np.random.rand(10, 10, 10), img_height=img_height, img_width=img_width
for image_height, image_width in [(2, 2), (4, 4)]:
orthogonal_obj = qim3d.viz.slicer_orthogonal(
np.random.rand(10, 10, 10), image_height=image_height, image_width=image_width
)
assert isinstance(orthogonal_obj, widgets.HBox)
# Test with show_position set to True and False
for show_position in [True, False]:
orthogonal_obj = qim3d.viz.orthogonal(
np.random.rand(10, 10, 10), show_position=show_position
for display_positions in [True, False]:
orthogonal_obj = qim3d.viz.slicer_orthogonal(
np.random.rand(10, 10, 10), display_positions=display_positions
)
assert isinstance(orthogonal_obj, widgets.HBox)
......@@ -255,7 +255,7 @@ def test_orthogonal_initial_slider_value():
# Create a sample NumPy array
vol = np.random.rand(10, 7, 19)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
orthogonal_obj = qim3d.viz.slicer_orthogonal(vol)
for idx, slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].value == vol.shape[idx] // 2
......@@ -264,7 +264,7 @@ def test_orthogonal_slider_description():
# Create a sample NumPy array
vol = np.random.rand(10, 10, 10)
# Call the orthogonal function with the NumPy array
orthogonal_obj = qim3d.viz.orthogonal(vol)
orthogonal_obj = qim3d.viz.slicer_orthogonal(vol)
for idx, slicer in enumerate(orthogonal_obj.children):
assert slicer.children[0].description == ["Z", "Y", "X"][idx]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment