From 88d24bb726b8117f32d33ad38f6f5a7fa2c64208 Mon Sep 17 00:00:00 2001
From: Alessia Saccardo <s212246@dtu.dk>
Date: Tue, 10 Dec 2024 10:24:19 +0100
Subject: [PATCH] add histogram

---
 qim3d/viz/__init__.py |  1 +
 qim3d/viz/explore.py  | 96 ++++++++++++++++++++++---------------------
 2 files changed, 50 insertions(+), 47 deletions(-)

diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py
index 33d94416..dd46be61 100644
--- a/qim3d/viz/__init__.py
+++ b/qim3d/viz/__init__.py
@@ -8,6 +8,7 @@ from .explore import (
     slices,
     chunks,
     histogram,
+    threshold,
 )
 from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError
 from .k3d import vol, mesh
diff --git a/qim3d/viz/explore.py b/qim3d/viz/explore.py
index 49d11630..7e9b16f2 100644
--- a/qim3d/viz/explore.py
+++ b/qim3d/viz/explore.py
@@ -830,6 +830,7 @@ def histogram(
     element="step",
     return_fig=False,
     show=True,
+    ax=None,  # New parameter for target axes
     **sns_kwargs,
 ):
     """
@@ -853,40 +854,26 @@ def histogram(
         element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step".
         return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False.
         show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True.
+        ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None.
         **sns_kwargs: Additional keyword arguments for `seaborn.histplot`.
 
     Returns:
-        Optional[matplotlib.figure.Figure]: If `return_fig` is True, returns the generated figure object. Otherwise, returns None.
+        Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]: 
+            If `return_fig` is True, returns the generated figure object. 
+            If `return_fig` is False and `ax` is provided, returns the `Axes` object.
+            Otherwise, returns None.
 
     Raises:
         ValueError: If `axis` is not a valid axis index (0, 1, or 2).
         ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
-
-    Example:
-        ```python
-        import qim3d
-
-        vol = qim3d.examples.bone_128x128x128
-        qim3d.viz.histogram(vol)
-        ```
-        ![viz histogram](assets/screenshots/viz-histogram-vol.png)
-
-        ```python
-        import qim3d
-
-        vol = qim3d.examples.bone_128x128x128
-        qim3d.viz.histogram(vol, bins=32, slice_idx="middle", axis=1, kde=False, log_scale=True)
-        ```
-        ![viz histogram](assets/screenshots/viz-histogram-slice.png)
     """
-
     if not (0 <= axis < vol.ndim):
         raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.")
 
     if slice_idx == "middle":
         slice_idx = vol.shape[axis] // 2
 
-    if slice_idx:
+    if slice_idx is not None:
         if 0 <= slice_idx < vol.shape[axis]:
             img_slice = np.take(vol, indices=slice_idx, axis=axis)
             data = img_slice.ravel()
@@ -899,10 +886,14 @@ def histogram(
         data = vol.ravel()
         title = f"Intensity histogram for whole volume {vol.shape}"
 
-    fig, ax = plt.subplots(figsize=figsize)
+    # Use provided Axes or create new figure
+    if ax is None:
+        fig, ax = plt.subplots(figsize=figsize)
+    else:
+        fig = None
 
     if log_scale:
-        plt.yscale("log")
+        ax.set_yscale("log")
 
     if color == "qim3d":
         color = qim3d.viz.colormaps.qim(1.0)
@@ -914,42 +905,32 @@ def histogram(
         color=color,
         element=element,
         edgecolor=edgecolor,
+        ax=ax,  # Plot directly on the specified Axes
         **sns_kwargs,
     )
 
     if despine:
-        sns.despine(
-            fig=None,
-            ax=None,
-            top=True,
-            right=True,
-            left=False,
-            bottom=False,
-            offset={"left": 0, "bottom": 18},
-            trim=True,
-        )
+        sns.despine(ax=ax, top=True, right=True)
 
-    plt.xlabel("Voxel Intensity")
-    plt.ylabel("Frequency")
+    ax.set_xlabel("Voxel Intensity")
+    ax.set_ylabel("Frequency")
 
     if show_title:
-        plt.title(title, fontsize=10)
+        ax.set_title(title, fontsize=10)
 
     # Handle show and return
-    if show:
+    if show and fig is not None:
         plt.show()
-    else:
-        plt.close(fig)
 
     if return_fig:
         return fig
-
-
+    elif ax is not None:
+        return ax
 
 def threshold(
         volume: np.ndarray,
         cmap_image: str = 'viridis',
-        cmap_threshold: str = 'gray',
+        cmap_overlay: str = 'gray',
         vmin: float = None,
         vmax: float = None,
 ):
@@ -1014,7 +995,7 @@ def threshold(
     
     # Create the interactive widget
     def _slicer(position, threshold, method):
-        fig, axes = plt.subplots(1, 4, figsize=(9, 3))
+        fig, axes = plt.subplots(1, 4, figsize=(25, 5))
 
         slice_img = volume[position, :, :]
         # If vmin is higher than the highest value in the image ValueError is raised
@@ -1030,12 +1011,12 @@ def threshold(
             else vmax
         )
 
+        # Add original image to the plot
         axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
         axes[0].set_title('Original')
-        axes[0].axis('off')
-
-        
+        axes[0].axis('off')        
 
+        # Compute the threshold value
         if method == 'Manual':
             threshold_slider.disabled = False
         else:
@@ -1050,16 +1031,37 @@ def threshold(
             else:
                 raise ValueError(f"Unsupported thresholding method: {method}")
 
+        # Compute and add the histogram to the plot
+        histogram(
+            vol=volume,
+            bins=32,
+            slice_idx=position,
+            axis=1,
+            kde=False,
+            ax=axes[1],
+            show=False,
+        )
+        axes[1].axvline(
+            x=threshold, 
+            color="red", 
+            linestyle="--",  
+            linewidth=2, 
+            label=f"Threshold = {threshold}", 
+        )
+        axes[1].set_title("Histogram")
         
+        # Compute and add the binary mask to the plot
         mask = slice_img > threshold
-        axes[2].imshow(mask, cmap=cmap_threshold)
+        axes[2].imshow(mask, cmap='grey')
         axes[2].set_title('Binary mask')
         axes[2].axis('off')
 
+        # Compute and add the overlay to the plot
         masked_volume = qim3d.processing.operations.overlay_rgb_images(
             background = slice_img,
             foreground = mask,
         )
+
         # If vmin is higher than the highest value in the image ValueError is raised
         # We don't want to override the values because next slices might be okay
         new_vmin = (
@@ -1072,7 +1074,7 @@ def threshold(
             if (isinstance(vmax, (float, int)) and vmax < np.min(masked_volume))
             else vmax
         )
-        axes[3].imshow(masked_volume, cmap=cmap_threshold, vmin=new_vmin, vmax=new_vmax)
+        axes[3].imshow(masked_volume, cmap=cmap_overlay, vmin=new_vmin, vmax=new_vmax)
         axes[3].set_title('Overlay')
         axes[3].axis('off')
 
-- 
GitLab