Skip to content
Snippets Groups Projects
Commit ae7f91a1 authored by fima's avatar fima :beers:
Browse files

Changes missing from last unit tests branch

parent 3403d3e5
No related branches found
No related tags found
1 merge request!145Refactor tests for processing and adapt it to new library structure, plus fix...
This commit is part of merge request !145. Comments created here will be created in the context of that merge request.
......@@ -27,10 +27,7 @@ __all__ = [
class FilterBase:
def __init__(self,
dask: bool = False,
chunks: str = "auto",
*args, **kwargs):
def __init__(self, *args, dask: bool = False, chunks: str = "auto", **kwargs):
"""
Base class for image filters.
......@@ -43,6 +40,7 @@ class FilterBase:
self.chunks = chunks
self.kwargs = kwargs
class Gaussian(FilterBase):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
......@@ -54,7 +52,9 @@ class Gaussian(FilterBase):
Returns:
The filtered image or volume.
"""
return gaussian(input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs)
return gaussian(
input, dask=self.dask, chunks=self.chunks, *self.args, **self.kwargs
)
class Median(FilterBase):
......@@ -98,6 +98,7 @@ class Minimum(FilterBase):
"""
return minimum(input, dask=self.dask, chunks=self.chunks, **self.kwargs)
class Tophat(FilterBase):
def __call__(self, input: np.ndarray) -> np.ndarray:
"""
......@@ -144,6 +145,7 @@ class Pipeline:
![filtered volume](assets/screenshots/filter_processed.png)
"""
def __init__(self, *args: Type[FilterBase]):
"""
Represents a sequence of image filters.
......@@ -214,10 +216,9 @@ class Pipeline:
return input
def gaussian(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
*args, **kwargs) -> np.ndarray:
def gaussian(
vol: np.ndarray, dask: bool = False, chunks: str = "auto", *args, **kwargs
) -> np.ndarray:
"""
Applies a Gaussian filter to the input volume using scipy.ndimage.gaussian_filter or dask_image.ndfilters.gaussian_filter.
......@@ -243,10 +244,9 @@ def gaussian(vol: np.ndarray,
return res
def median(vol: np.ndarray,
dask: bool = False,
chunks: str ='auto',
**kwargs) -> np.ndarray:
def median(
vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
) -> np.ndarray:
"""
Applies a median filter to the input volume using scipy.ndimage.median_filter or dask_image.ndfilters.median_filter.
......@@ -270,10 +270,9 @@ def median(vol: np.ndarray,
return res
def maximum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
def maximum(
vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
) -> np.ndarray:
"""
Applies a maximum filter to the input volume using scipy.ndimage.maximum_filter or dask_image.ndfilters.maximum_filter.
......@@ -297,10 +296,9 @@ def maximum(vol: np.ndarray,
return res
def minimum(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
def minimum(
vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
) -> np.ndarray:
"""
Applies a minimum filter to the input volume using scipy.ndimage.minimum_filter or dask_image.ndfilters.minimum_filter.
......@@ -323,10 +321,10 @@ def minimum(vol: np.ndarray,
res = ndimage.minimum_filter(vol, **kwargs)
return res
def tophat(vol: np.ndarray,
dask: bool = False,
chunks: str = 'auto',
**kwargs) -> np.ndarray:
def tophat(
vol: np.ndarray, dask: bool = False, chunks: str = "auto", **kwargs
) -> np.ndarray:
"""
Remove background from the volume.
......@@ -349,7 +347,9 @@ def tophat(vol: np.ndarray,
log.info("Dask not supported for tophat filter, switching to scipy.")
if background == "bright":
log.info("Bright background selected, volume will be temporarily inverted when applying white_tophat")
log.info(
"Bright background selected, volume will be temporarily inverted when applying white_tophat"
)
vol = np.invert(vol)
selem = morphology.ball(radius)
......
......@@ -28,6 +28,7 @@ import gradio as gr
import numpy as np
from PIL import Image
import qim3d
from qim3d.gui.interface import BaseInterface
# TODO: img in launch should be self.img
......
......@@ -6,7 +6,7 @@ import torch
import numpy as np
from typing import Optional, Callable
import torch.nn as nn
from ._data import Augmentation
from ._augmentations import Augmentation
class Dataset(torch.utils.data.Dataset):
"""
......
......@@ -9,7 +9,7 @@ from qim3d.viz._metrics import plot_metrics
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
from models._unet import Hyperparameters
from .models._unet import Hyperparameters
def train_model(
model: torch.nn.Module,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment