From 88d24bb726b8117f32d33ad38f6f5a7fa2c64208 Mon Sep 17 00:00:00 2001 From: Alessia Saccardo <s212246@dtu.dk> Date: Tue, 10 Dec 2024 10:24:19 +0100 Subject: [PATCH] add histogram --- qim3d/viz/__init__.py | 1 + qim3d/viz/explore.py | 96 ++++++++++++++++++++++--------------------- 2 files changed, 50 insertions(+), 47 deletions(-) diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 33d94416..dd46be61 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -8,6 +8,7 @@ from .explore import ( slices, chunks, histogram, + threshold, ) from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError from .k3d import vol, mesh diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py index 49d11630..7e9b16f2 100644 --- a/qim3d/viz/explore.py +++ b/qim3d/viz/explore.py @@ -830,6 +830,7 @@ def histogram( element="step", return_fig=False, show=True, + ax=None, # New parameter for target axes **sns_kwargs, ): """ @@ -853,40 +854,26 @@ def histogram( element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step". return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False. show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True. + ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None. **sns_kwargs: Additional keyword arguments for `seaborn.histplot`. Returns: - Optional[matplotlib.figure.Figure]: If `return_fig` is True, returns the generated figure object. Otherwise, returns None. + Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]: + If `return_fig` is True, returns the generated figure object. + If `return_fig` is False and `ax` is provided, returns the `Axes` object. + Otherwise, returns None. Raises: ValueError: If `axis` is not a valid axis index (0, 1, or 2). ValueError: If `slice_idx` is an integer and is out of range for the specified axis. - - Example: - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol) - ``` -  - - ```python - import qim3d - - vol = qim3d.examples.bone_128x128x128 - qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True) - ``` -  """ - if not (0 <= axis < vol.ndim): raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.") if slice_idx == "middle": slice_idx = vol.shape[axis] // 2 - if slice_idx: + if slice_idx is not None: if 0 <= slice_idx < vol.shape[axis]: img_slice = np.take(vol, indices=slice_idx, axis=axis) data = img_slice.ravel() @@ -899,10 +886,14 @@ def histogram( data = vol.ravel() title = f"Intensity histogram for whole volume {vol.shape}" - fig, ax = plt.subplots(figsize=figsize) + # Use provided Axes or create new figure + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + else: + fig = None if log_scale: - plt.yscale("log") + ax.set_yscale("log") if color == "qim3d": color = qim3d.viz.colormaps.qim(1.0) @@ -914,42 +905,32 @@ def histogram( color=color, element=element, edgecolor=edgecolor, + ax=ax, # Plot directly on the specified Axes **sns_kwargs, ) if despine: - sns.despine( - fig=None, - ax=None, - top=True, - right=True, - left=False, - bottom=False, - offset={"left": 0, "bottom": 18}, - trim=True, - ) + sns.despine(ax=ax, top=True, right=True) - plt.xlabel("Voxel Intensity") - plt.ylabel("Frequency") + ax.set_xlabel("Voxel Intensity") + ax.set_ylabel("Frequency") if show_title: - plt.title(title, fontsize=10) + ax.set_title(title, fontsize=10) # Handle show and return - if show: + if show and fig is not None: plt.show() - else: - plt.close(fig) if return_fig: return fig - - + elif ax is not None: + return ax def threshold( volume: np.ndarray, cmap_image: str = 'viridis', - cmap_threshold: str = 'gray', + cmap_overlay: str = 'gray', vmin: float = None, vmax: float = None, ): @@ -1014,7 +995,7 @@ def threshold( # Create the interactive widget def _slicer(position, threshold, method): - fig, axes = plt.subplots(1, 4, figsize=(9, 3)) + fig, axes = plt.subplots(1, 4, figsize=(25, 5)) slice_img = volume[position, :, :] # If vmin is higher than the highest value in the image ValueError is raised @@ -1030,12 +1011,12 @@ def threshold( else vmax ) + # Add original image to the plot axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax) axes[0].set_title('Original') - axes[0].axis('off') - - + axes[0].axis('off') + # Compute the threshold value if method == 'Manual': threshold_slider.disabled = False else: @@ -1050,16 +1031,37 @@ def threshold( else: raise ValueError(f"Unsupported thresholding method: {method}") + # Compute and add the histogram to the plot + histogram( + vol=volume, + bins=32, + slice_idx=position, + axis=1, + kde=False, + ax=axes[1], + show=False, + ) + axes[1].axvline( + x=threshold, + color="red", + linestyle="--", + linewidth=2, + label=f"Threshold = {threshold}", + ) + axes[1].set_title("Histogram") + # Compute and add the binary mask to the plot mask = slice_img > threshold - axes[2].imshow(mask, cmap=cmap_threshold) + axes[2].imshow(mask, cmap='grey') axes[2].set_title('Binary mask') axes[2].axis('off') + # Compute and add the overlay to the plot masked_volume = qim3d.processing.operations.overlay_rgb_images( background = slice_img, foreground = mask, ) + # If vmin is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay new_vmin = ( @@ -1072,7 +1074,7 @@ def threshold( if (isinstance(vmax, (float, int)) and vmax < np.min(masked_volume)) else vmax ) - axes[3].imshow(masked_volume, cmap=cmap_threshold, vmin=new_vmin, vmax=new_vmax) + axes[3].imshow(masked_volume, cmap=cmap_overlay, vmin=new_vmin, vmax=new_vmax) axes[3].set_title('Overlay') axes[3].axis('off') -- GitLab