Skip to content
Snippets Groups Projects
Commit 5bd12b1e authored by Alessia Saccardo's avatar Alessia Saccardo
Browse files

threshold exploration refactoring, add vertical line in histogram function

parent 09ea2352
Branches
No related tags found
1 merge request!135Threshold exploration
This commit is part of merge request !135. Comments created here will be created in the context of that merge request.
import numpy as np import numpy as np
import qim3d.processing.filters as filters import qim3d.processing.filters as filters
from qim3d.utils.logger import log from qim3d.utils.logger import log
import skimage
import scipy
def remove_background( def remove_background(
...@@ -88,8 +90,6 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i ...@@ -88,8 +90,6 @@ def watershed(bin_vol: np.ndarray, min_distance: int = 5) -> tuple[np.ndarray, i
![operations-watershed_after](assets/screenshots/operations-watershed_after.png) ![operations-watershed_after](assets/screenshots/operations-watershed_after.png)
""" """
import skimage
import scipy
if len(np.unique(bin_vol)) > 2: if len(np.unique(bin_vol)) > 2:
raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.") raise ValueError("bin_vol has to be binary volume - it must contain max 2 unique values.")
......
...@@ -816,10 +816,10 @@ def chunks(zarr_path: str, **kwargs): ...@@ -816,10 +816,10 @@ def chunks(zarr_path: str, **kwargs):
display(final_layout) display(final_layout)
def histogram( def histogram(
vol: np.ndarray, volume: np.ndarray,
bins: Union[int, str] = "auto", bins: Union[int, str] = "auto",
slice_idx: Union[int, str, None] = None, slice_idx: Union[int, str, None] = None,
threshold: int = None, vertical_line: int = None,
axis: int = 0, axis: int = 0,
kde: bool = True, kde: bool = True,
log_scale: bool = False, log_scale: bool = False,
...@@ -840,11 +840,12 @@ def histogram( ...@@ -840,11 +840,12 @@ def histogram(
Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization. Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization.
Args: Args:
vol (np.ndarray): A 3D NumPy array representing the volume to be visualized. volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
axis (int, optional): Axis along which to take a slice. Default is 0. axis (int, optional): Axis along which to take a slice. Default is 0.
slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None. If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None.
vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None.
kde (bool, optional): Whether to overlay a kernel density estimate. Default is True. kde (bool, optional): Whether to overlay a kernel density estimate. Default is True.
log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False. log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False.
despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True. despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True.
...@@ -868,24 +869,24 @@ def histogram( ...@@ -868,24 +869,24 @@ def histogram(
ValueError: If `axis` is not a valid axis index (0, 1, or 2). ValueError: If `axis` is not a valid axis index (0, 1, or 2).
ValueError: If `slice_idx` is an integer and is out of range for the specified axis. ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
""" """
if not (0 <= axis < vol.ndim): if not (0 <= axis < volume.ndim):
raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.") raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
if slice_idx == "middle": if slice_idx == "middle":
slice_idx = vol.shape[axis] // 2 slice_idx = volume.shape[axis] // 2
if slice_idx is not None: if slice_idx is not None:
if 0 <= slice_idx < vol.shape[axis]: if 0 <= slice_idx < volume.shape[axis]:
img_slice = np.take(vol, indices=slice_idx, axis=axis) img_slice = np.take(volume, indices=slice_idx, axis=axis)
data = img_slice.ravel() data = img_slice.ravel()
title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}" title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}"
else: else:
raise ValueError( raise ValueError(
f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}." f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}."
) )
else: else:
data = vol.ravel() data = volume.ravel()
title = f"Intensity histogram for whole volume {vol.shape}" title = f"Intensity histogram for whole volume {volume.shape}"
# Use provided Axes or create new figure # Use provided Axes or create new figure
if ax is None: if ax is None:
...@@ -910,15 +911,14 @@ def histogram( ...@@ -910,15 +911,14 @@ def histogram(
**sns_kwargs, **sns_kwargs,
) )
if threshold is not None: if vertical_line is not None:
ax.axvline( ax.axvline(
x=threshold, x=vertical_line,
color='red', color='red',
linestyle="--", linestyle="--",
linewidth=2, linewidth=2,
label=f"Threshold = {round(threshold)}"
) )
ax.legend()
if despine: if despine:
sns.despine( sns.despine(
...@@ -932,7 +932,6 @@ def histogram( ...@@ -932,7 +932,6 @@ def histogram(
trim=True, trim=True,
) )
ax.set_xlabel("Voxel Intensity") ax.set_xlabel("Voxel Intensity")
ax.set_ylabel("Frequency") ax.set_ylabel("Frequency")
...@@ -951,7 +950,6 @@ def histogram( ...@@ -951,7 +950,6 @@ def histogram(
def threshold( def threshold(
volume: np.ndarray, volume: np.ndarray,
cmap_image: str = 'viridis', cmap_image: str = 'viridis',
cmap_overlay: str = 'gray',
vmin: float = None, vmin: float = None,
vmax: float = None, vmax: float = None,
) -> widgets.VBox: ) -> widgets.VBox:
...@@ -1004,6 +1002,13 @@ def threshold( ...@@ -1004,6 +1002,13 @@ def threshold(
![interactive threshold](assets/screenshots/interactive_thresholding.gif) ![interactive threshold](assets/screenshots/interactive_thresholding.gif)
""" """
# Centralized state dictionary to track current parameters
state = {
'position': volume.shape[0] // 2,
'threshold': int((volume.min() + volume.max()) / 2),
'method': 'Manual',
}
threshold_methods = { threshold_methods = {
'Otsu': threshold_otsu, 'Otsu': threshold_otsu,
'Isodata': threshold_isodata, 'Isodata': threshold_isodata,
...@@ -1011,129 +1016,125 @@ def threshold( ...@@ -1011,129 +1016,125 @@ def threshold(
'Mean': threshold_mean, 'Mean': threshold_mean,
'Minimum': threshold_minimum, 'Minimum': threshold_minimum,
'Triangle': threshold_triangle, 'Triangle': threshold_triangle,
'Yen': threshold_yen 'Yen': threshold_yen,
} }
# Create the interactive widget # Create an output widget to display the plot
def _slicer(position, threshold, method): output = widgets.Output()
# Function to update the state and trigger visualization
def update_state(change):
# Update state based on widget values
state['position'] = position_slider.value
state['method'] = method_dropdown.value
if state['method'] == 'Manual':
state['threshold'] = threshold_slider.value
threshold_slider.disabled = False
else:
threshold_func = threshold_methods.get(state['method'])
if threshold_func:
slice_img = volume[state['position'], :, :]
computed_threshold = threshold_func(slice_img)
state['threshold'] = computed_threshold
# Programmatically update the slider without triggering callbacks
threshold_slider.unobserve_all()
threshold_slider.value = computed_threshold
threshold_slider.disabled = True
threshold_slider.observe(update_state, names='value')
else:
raise ValueError(f"Unsupported thresholding method: {state['method']}")
# Trigger visualization
update_visualization()
# Visualization function
def update_visualization():
slice_img = volume[state['position'], :, :]
with output:
output.clear_output(wait=True) # Clear previous plot
fig, axes = plt.subplots(1, 4, figsize=(25, 5)) fig, axes = plt.subplots(1, 4, figsize=(25, 5))
slice_img = volume[position, :, :] # Original image
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = ( new_vmin = (
None None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
) )
new_vmax = ( new_vmax = (
None None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
) )
# Add original image to the plot
axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax) axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title('Original') axes[0].set_title('Original')
axes[0].axis('off') axes[0].axis('off')
# Compute the threshold value # Histogram
if method == 'Manual':
threshold_slider.disabled = False
else:
# Apply the appropriate thresholding function
threshold_func = threshold_methods.get(method)
if threshold_func:
threshold = threshold_func(slice_img)
if threshold_slider.value != threshold:
threshold_slider.unobserve_all()
threshold_slider.value = threshold
threshold_slider.disabled = True
else:
raise ValueError(f"Unsupported thresholding method: {method}")
# Compute and add the histogram to the plot
histogram( histogram(
vol=volume, volume=volume,
bins=32, bins=32,
slice_idx=position, slice_idx=state['position'],
threshold=threshold, vertical_line=state['threshold'],
axis=1, axis=1,
kde=False, kde=False,
ax=axes[1], ax=axes[1],
show=False, show=False,
) )
axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}")
axes[1].set_title(f'Histogram') # Binary mask
mask = slice_img > state['threshold']
# Compute and add the binary mask to the plot axes[2].imshow(mask, cmap='gray')
mask = slice_img > threshold
axes[2].imshow(mask, cmap='grey')
axes[2].set_title('Binary mask') axes[2].set_title('Binary mask')
axes[2].axis('off') axes[2].axis('off')
# both mask and img should be rgb # Overlay
# mask data in first channel and then black the other sure --> no cmap_overlay mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
# Compute and add the overlay to the plot mask_rgb[:, :, 0] = mask
masked_volume = qim3d.processing.operations.overlay_rgb_images( masked_volume = qim3d.processing.operations.overlay_rgb_images(
background=slice_img, background=slice_img,
foreground = mask, foreground=mask_rgb,
) )
axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax)
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(masked_volume))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(masked_volume))
else vmax
)
axes[3].imshow(masked_volume, cmap=cmap_overlay, vmin=new_vmin, vmax=new_vmax)
axes[3].set_title('Overlay') axes[3].set_title('Overlay')
axes[3].axis('off') axes[3].axis('off')
return fig plt.show()
method_dropdown = widgets.Dropdown(
options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
value='Manual', # default value
description='Method',
)
# Widgets
position_slider = widgets.IntSlider( position_slider = widgets.IntSlider(
value=volume.shape[0] // 2, value=state['position'],
min=0, min=0,
max=volume.shape[0] - 1, max=volume.shape[0] - 1,
description='Slice', description='Slice',
continuous_update=False,
) )
threshold_slider = widgets.IntSlider( threshold_slider = widgets.IntSlider(
value=int((volume.min() + volume.max()) / 2), value=state['threshold'],
min=volume.min(), min=volume.min(),
max=volume.max(), max=volume.max(),
description='Threshold', description='Threshold',
continuous_update=False,
) )
slicer_obj = widgets.interactive( method_dropdown = widgets.Dropdown(
_slicer, options=['Manual', 'Otsu', 'Isodata', 'Li', 'Mean', 'Minimum', 'Triangle', 'Yen'],
position=position_slider, value=state['method'],
threshold=threshold_slider, description='Method',
method = method_dropdown,
) )
# Attach the state update function to widgets
position_slider.observe(update_state, names='value')
threshold_slider.observe(update_state, names='value')
method_dropdown.observe(update_state, names='value')
# Layout
controls_left = widgets.VBox([position_slider, threshold_slider]) controls_left = widgets.VBox([position_slider, threshold_slider])
controls_right = widgets.VBox([method_dropdown]) controls_right = widgets.VBox([method_dropdown])
controls_layout = widgets.HBox([controls_left, controls_right], layout=widgets.Layout(justify_content='space-between')) controls_layout = widgets.HBox(
slicer_obj = widgets.VBox([controls_layout, slicer_obj.children[-1]]) [controls_left, controls_right],
slicer_obj.layout.align_items = "flex-start" layout=widgets.Layout(justify_content='flex-start'),
)
interactive_ui = widgets.VBox([controls_layout, output])
update_visualization()
return slicer_obj return interactive_ui
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment