From 5505b582c5669a7cc6a3c021ed8ee0d9e5799213 Mon Sep 17 00:00:00 2001
From: Felipe <fima@dtu.dk>
Date: Wed, 11 Dec 2024 13:23:38 +0100
Subject: [PATCH] data_exploration refactored

---
 qim3d/viz/__init__.py                         |  10 +-
 qim3d/viz/{explore.py => data_exploration.py} | 517 ++++++++++--------
 qim3d/viz/k3d.py                              |  36 +-
 3 files changed, 297 insertions(+), 266 deletions(-)
 rename qim3d/viz/{explore.py => data_exploration.py} (62%)

diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py
index 33d94416..e736db3e 100644
--- a/qim3d/viz/__init__.py
+++ b/qim3d/viz/__init__.py
@@ -1,16 +1,16 @@
 from . import colormaps
 from .cc import plot_cc
 from .detection import circles
-from .explore import (
-    interactive_fade_mask,
-    orthogonal,
+from .data_exploration import (
+    fade_mask,
     slicer,
-    slices,
+    slicer_orthogonal,
+    slices_grid,
     chunks,
     histogram,
 )
 from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError
-from .k3d import vol, mesh
+from .k3d import volumetric, mesh
 from .local_thickness_ import local_thickness
 from .structure_tensor import vectors
 from .metrics import plot_metrics, grid_overview, grid_pred, vol_masked
diff --git a/qim3d/viz/explore.py b/qim3d/viz/data_exploration.py
similarity index 62%
rename from qim3d/viz/explore.py
rename to qim3d/viz/data_exploration.py
index 8fe2a545..f2b958ac 100644
--- a/qim3d/viz/explore.py
+++ b/qim3d/viz/data_exploration.py
@@ -19,145 +19,157 @@ import seaborn as sns
 
 import qim3d
 
-def slices(
-    vol: np.ndarray,
-    axis: int = 0,
-    position: Optional[Union[str, int, List[int]]] = None,
-    n_slices: int = 5,
-    max_cols: int = 5,
-    cmap: str = "viridis",
-    vmin: float = None,
-    vmax: float = None,
-    img_height: int = 2,
-    img_width: int = 2,
-    show: bool = False,
-    show_position: bool = True,
+
+def slices_grid(
+    volume: np.ndarray,
+    slice_axis: int = 0,
+    slice_positions: Optional[Union[str, int, List[int]]] = None,
+    num_slices: int = 15,
+    max_columns: int = 5,
+    color_map: str = "magma",
+    value_min: float = None,
+    value_max: float = None,
+    image_size=None,
+    image_height: int = 2,
+    image_width: int = 2,
+    display_figure: bool = False,
+    display_positions: bool = True,
     interpolation: Optional[str] = None,
-    img_size=None,
-    cbar: bool = False,
-    cbar_style: str = "small",
-    **imshow_kwargs,
+    color_bar: bool = False,
+    color_bar_style: str = "small",
+    **matplotlib_imshow_kwargs,
 ) -> plt.Figure:
     """Displays one or several slices from a 3d volume.
 
-    By default if `position` is None, slices plots `n_slices` linearly spaced slices.
-    If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position.
-    If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted.
+    By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
+    If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position.
+    If `slice_positions` is given as a list, `num_slices` will be ignored and the slices from `slice_positions` will be plotted.
 
     Args:
         vol np.ndarray: The 3D volume to be sliced.
-        axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
-        position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
-        n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5.
-        max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5.
-        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
-        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
-        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
-        img_height (int, optional): Height of the figure.
-        img_width (int, optional): Width of the figure.
-        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
-        show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
+        slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
+        slice_positions (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
+        num_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 15.
+        max_columns (int, optional): The maximum number of columns to be plotted. Defaults to 5.
+        color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
+        value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
+        value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
+        image_height (int, optional): Height of the figure.
+        image_width (int, optional): Width of the figure.
+        display_figure (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
+        display_positions (bool, optional): If True, displays the position of the slices. Defaults to True.
         interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
-        cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
-        cbar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.
+        color_bar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
+        color_bar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.
 
     Returns:
         fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
 
     Raises:
         ValueError: If the input is not a numpy.ndarray or da.core.Array.
-        ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
+        ValueError: If the slice_axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
         ValueError: If the file or array is not a volume with at least 3 dimensions.
         ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
-        ValueError: If the cbar_style keyword argument is not one of the following strings: 'small' or 'large'.
+        ValueError: If the color_bar_style keyword argument is not one of the following strings: 'small' or 'large'.
 
     Example:
         ```python
         import qim3d
 
         vol = qim3d.examples.shell_225x128x128
-        qim3d.viz.slices(vol, n_slices=15)
+        qim3d.viz.slices_grid(vol, num_slices=15)
         ```
         ![Grid of slices](assets/screenshots/viz-slices.png)
     """
-    if img_size:
-        img_height = img_size
-        img_width = img_size
+    if image_size:
+        image_height = image_size
+        image_width = image_size
 
-    # If we pass python None to the imshow function, it will set to 
+    # If we pass python None to the imshow function, it will set to
     # default value 'antialiased'
     if interpolation is None:
-        interpolation = 'none'
+        interpolation = "none"
 
     # Numpy array or Torch tensor input
-    if not isinstance(vol, (np.ndarray, da.core.Array)):
+    if not isinstance(volume, (np.ndarray, da.core.Array)):
         raise ValueError("Data type not supported")
 
-    if vol.ndim < 3:
+    if volume.ndim < 3:
         raise ValueError(
             "The provided object is not a volume as it has less than 3 dimensions."
         )
 
-    cbar_style_options = ["small", "large"]
-    if cbar_style not in cbar_style_options:
-        raise ValueError(f"Value '{cbar_style}' is not valid for colorbar style. Please select from {cbar_style_options}.")
-    
-    if isinstance(vol, da.core.Array):
-        vol = vol.compute()
+    color_bar_style_options = ["small", "large"]
+    if color_bar_style not in color_bar_style_options:
+        raise ValueError(
+            f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}."
+        )
+
+    if isinstance(volume, da.core.Array):
+        volume = volume.compute()
 
     # Ensure axis is a valid choice
-    if not (0 <= axis < vol.ndim):
+    if not (0 <= slice_axis < volume.ndim):
         raise ValueError(
-            f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
+            f"Invalid value for 'slice_axis'. It should be an integer between 0 and {volume.ndim - 1}."
         )
 
-    if type(cmap) == matplotlib.colors.LinearSegmentedColormap or cmap == 'objects':
-        num_labels = len(np.unique(vol))
+    # Here we deal with the case that the user wants to use the objects colormap directly
+    if (
+        type(color_map) == matplotlib.colors.LinearSegmentedColormap
+        or color_map == "objects"
+    ):
+        num_labels = len(np.unique(volume))
 
-        if cmap == 'objects':
-            cmap = qim3d.viz.colormaps.objects(num_labels)
-        # If vmin and vmax are not set like this, then in case the 
-        # number of objects changes on new slice, objects might change 
-        # colors. So when using a slider, the same object suddently 
+        if color_map == "objects":
+            color_map = qim3d.viz.colormaps.objects(num_labels)
+        # If value_min and value_max are not set like this, then in case the
+        # number of objects changes on new slice, objects might change
+        # colors. So when using a slider, the same object suddently
         # changes color (flickers), which is confusing and annoying.
-        vmin = 0
-        vmax = num_labels
-
+        value_min = 0
+        value_max = num_labels
 
     # Get total number of slices in the specified dimension
-    n_total = vol.shape[axis]
+    n_total = volume.shape[slice_axis]
 
     # Position is not provided - will use linearly spaced slices
-    if position is None:
-        slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int)
+    if slice_positions is None:
+        slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int)
     # Position is a string
-    elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]:
-        if position.lower() == "start":
-            slice_idxs = _get_slice_range(0, n_slices, n_total)
-        elif position.lower() == "mid":
-            slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total)
-        elif position.lower() == "end":
-            slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total)
+    elif isinstance(slice_positions, str) and slice_positions.lower() in [
+        "start",
+        "mid",
+        "end",
+    ]:
+        if slice_positions.lower() == "start":
+            slice_idxs = _get_slice_range(0, num_slices, n_total)
+        elif slice_positions.lower() == "mid":
+            slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total)
+        elif slice_positions.lower() == "end":
+            slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total)
     #  Position is an integer
-    elif isinstance(position, int):
-        slice_idxs = _get_slice_range(position, n_slices, n_total)
+    elif isinstance(slice_positions, int):
+        slice_idxs = _get_slice_range(slice_positions, num_slices, n_total)
     # Position is a list of integers
-    elif isinstance(position, list) and all(isinstance(idx, int) for idx in position):
-        slice_idxs = position
+    elif isinstance(slice_positions, list) and all(
+        isinstance(idx, int) for idx in slice_positions
+    ):
+        slice_idxs = slice_positions
     else:
         raise ValueError(
             'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
         )
 
     # Make grid
-    nrows = math.ceil(n_slices / max_cols)
-    ncols = min(n_slices, max_cols)
+    nrows = math.ceil(num_slices / max_columns)
+    ncols = min(num_slices, max_columns)
 
     # Generate figure
     fig, axs = plt.subplots(
         nrows=nrows,
         ncols=ncols,
-        figsize=(ncols * img_height, nrows * img_width),
+        figsize=(ncols * image_height, nrows * image_width),
         constrained_layout=True,
     )
 
@@ -165,47 +177,53 @@ def slices(
         axs = [axs]  # Convert to a list for uniformity
 
     # Convert to NumPy array in order to use the numpy.take method
-    if isinstance(vol, da.core.Array):
-        vol = vol.compute()
+    if isinstance(volume, da.core.Array):
+        volume = volume.compute()
 
-    if cbar:
-        # In this case, we want the vrange to be constant across the 
-        # slices, which makes them all comparable to a single cbar.
-        new_vmin = vmin if vmin is not None else np.min(vol)
-        new_vmax = vmax if vmax is not None else np.max(vol)
+    if color_bar:
+        # In this case, we want the vrange to be constant across the
+        # slices, which makes them all comparable to a single color_bar.
+        new_value_min = value_min if value_min is not None else np.min(volume)
+        new_value_max = value_max if value_max is not None else np.max(volume)
 
     # Run through each ax of the grid
     for i, ax_row in enumerate(axs):
         for j, ax in enumerate(np.atleast_1d(ax_row)):
-            slice_idx = i * max_cols + j
+            slice_idx = i * max_columns + j
             try:
-                slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
+                slice_img = volume.take(slice_idxs[slice_idx], axis=slice_axis)
 
-                if not cbar:
-                    # If vmin is higher than the highest value in the 
-                    # image ValueError is raised. We don't want to 
+                if not color_bar:
+                    # If value_min is higher than the highest value in the
+                    # image ValueError is raised. We don't want to
                     # override the values because next slices might be okay
-                    new_vmin = (
+                    new_value_min = (
                         None
-                        if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
-                        else vmin
+                        if (
+                            isinstance(value_min, (float, int))
+                            and value_min > np.max(slice_img)
+                        )
+                        else value_min
                     )
-                    new_vmax = (
+                    new_value_max = (
                         None
-                        if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
-                        else vmax
+                        if (
+                            isinstance(value_max, (float, int))
+                            and value_max < np.min(slice_img)
+                        )
+                        else value_max
                     )
 
                 ax.imshow(
                     slice_img,
-                    cmap=cmap,
+                    cmap=color_map,
                     interpolation=interpolation,
-                    vmin=new_vmin,
-                    vmax=new_vmax,
-                    **imshow_kwargs,
+                    vmin=new_value_min,
+                    vmax=new_value_max,
+                    **matplotlib_imshow_kwargs,
                 )
 
-                if show_position:
+                if display_positions:
                     ax.text(
                         0.0,
                         1.0,
@@ -221,7 +239,7 @@ def slices(
                     ax.text(
                         1.0,
                         0.0,
-                        f"axis {axis} ",
+                        f"axis {slice_axis} ",
                         transform=ax.transAxes,
                         color="white",
                         fontsize=8,
@@ -237,33 +255,40 @@ def slices(
             # Hide the axis, so that we have a nice grid
             ax.axis("off")
 
-    if cbar:
+    if color_bar:
         with warnings.catch_warnings():
             warnings.simplefilter("ignore", category=UserWarning)
             fig.tight_layout()
 
-        norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
-        mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
+        norm = matplotlib.colors.Normalize(
+            vmin=new_value_min, vmax=new_value_max, clip=True
+        )
+        mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map)
 
-        if cbar_style =="small":
+        if color_bar_style == "small":
             # Figure coordinates of top-right axis
             tr_pos = np.atleast_1d(axs[0])[-1].get_position()
             # The width is divided by ncols to make it the same relative size to the images
-            cbar_ax = fig.add_axes(
+            color_bar_ax = fig.add_axes(
                 [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
             )
-            fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
-        elif cbar_style == "large":
+            fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
+        elif color_bar_style == "large":
             # Figure coordinates of bottom- and top-right axis
             br_pos = np.atleast_1d(axs[-1])[-1].get_position()
             tr_pos = np.atleast_1d(axs[0])[-1].get_position()
             # The width is divided by ncols to make it the same relative size to the images
-            cbar_ax = fig.add_axes(
-                [br_pos.xmax + 0.05 / ncols, br_pos.y0+0.0015, 0.05 / ncols, (tr_pos.y1 - br_pos.y0)-0.0015]
+            color_bar_ax = fig.add_axes(
+                [
+                    br_pos.xmax + 0.05 / ncols,
+                    br_pos.y0 + 0.0015,
+                    0.05 / ncols,
+                    (tr_pos.y1 - br_pos.y0) - 0.0015,
+                ]
             )
-            fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
+            fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
 
-    if show:
+    if display_figure:
         plt.show()
 
     plt.close()
@@ -271,49 +296,51 @@ def slices(
     return fig
 
 
-def _get_slice_range(position: int, n_slices: int, n_total):
+def _get_slice_range(position: int, num_slices: int, n_total):
     """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
-    start_idx = position - n_slices // 2
+    start_idx = position - num_slices // 2
     end_idx = (
-        position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1
+        position + num_slices // 2
+        if num_slices % 2 == 0
+        else position + num_slices // 2 + 1
     )
     slice_idxs = np.arange(start_idx, end_idx)
 
     if slice_idxs[0] < 0:
-        slice_idxs = np.arange(0, n_slices)
+        slice_idxs = np.arange(0, num_slices)
     elif slice_idxs[-1] > n_total:
-        slice_idxs = np.arange(n_total - n_slices, n_total)
+        slice_idxs = np.arange(n_total - num_slices, n_total)
 
     return slice_idxs
 
 
 def slicer(
-    vol: np.ndarray,
-    axis: int = 0,
-    cmap: str = "viridis",
-    vmin: float = None,
-    vmax: float = None,
-    img_height: int = 3,
-    img_width: int = 3,
-    show_position: bool = False,
+    volume: np.ndarray,
+    slice_axis: int = 0,
+    color_map: str = "magma",
+    value_min: float = None,
+    value_max: float = None,
+    image_height: int = 3,
+    image_width: int = 3,
+    display_positions: bool = False,
     interpolation: Optional[str] = None,
-    img_size=None,
-    cbar: bool = False,
-    **imshow_kwargs,
+    image_size=None,
+    color_bar: bool = False,
+    **matplotlib_imshow_kwargs,
 ) -> widgets.interactive:
     """Interactive widget for visualizing slices of a 3D volume.
 
     Args:
-        vol (np.ndarray): The 3D volume to be sliced.
-        axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
-        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
-        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
-        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
-        img_height (int, optional): Height of the figure. Defaults to 3.
-        img_width (int, optional): Width of the figure. Defaults to 3.
-        show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
+        volume (np.ndarray): The 3D volume to be sliced.
+        slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
+        color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
+        value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
+        value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
+        image_height (int, optional): Height of the figure. Defaults to 3.
+        image_width (int, optional): Width of the figure. Defaults to 3.
+        display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
         interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
-        cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
+        color_bar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
 
     Returns:
         slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
@@ -328,98 +355,98 @@ def slicer(
         ![viz slicer](assets/screenshots/viz-slicer.gif)
     """
 
-    if img_size:
-        img_height = img_size
-        img_width = img_size
+    if image_size:
+        image_height = image_size
+        image_width = image_size
 
     # Create the interactive widget
-    def _slicer(position):
-        fig = slices(
-            vol,
-            axis=axis,
-            cmap=cmap,
-            vmin=vmin,
-            vmax=vmax,
-            img_height=img_height,
-            img_width=img_width,
-            show_position=show_position,
+    def _slicer(slice_positions):
+        fig = slices_grid(
+            volume,
+            slice_axis=slice_axis,
+            color_map=color_map,
+            value_min=value_min,
+            value_max=value_max,
+            image_height=image_height,
+            image_width=image_width,
+            display_positions=display_positions,
             interpolation=interpolation,
-            position=position,
-            n_slices=1,
-            show=True,
-            cbar=cbar,
-            **imshow_kwargs,
+            slice_positions=slice_positions,
+            num_slices=1,
+            display_figure=True,
+            color_bar=color_bar,
+            **matplotlib_imshow_kwargs,
         )
         return fig
 
     position_slider = widgets.IntSlider(
-        value=vol.shape[axis] // 2,
+        value=volume.shape[slice_axis] // 2,
         min=0,
-        max=vol.shape[axis] - 1,
+        max=volume.shape[slice_axis] - 1,
         description="Slice",
         continuous_update=True,
     )
-    slicer_obj = widgets.interactive(_slicer, position=position_slider)
+    slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider)
     slicer_obj.layout = widgets.Layout(align_items="flex-start")
 
     return slicer_obj
 
 
-def orthogonal(
-    vol: np.ndarray,
-    cmap: str = "viridis",
-    vmin: float = None,
-    vmax: float = None,
-    img_height: int = 3,
-    img_width: int = 3,
-    show_position: bool = False,
+def slicer_orthogonal(
+    volume: np.ndarray,
+    color_map: str = 'magma',
+    value_min: float = None,
+    value_max: float = None,
+    image_height: int = 3,
+    image_width: int = 3,
+    display_positions: bool = False,
     interpolation: Optional[str] = None,
-    img_size=None,
+    image_size=None,
 ):
     """Interactive widget for visualizing orthogonal slices of a 3D volume.
 
     Args:
-        vol (np.ndarray): The 3D volume to be sliced.
-        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
-        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
-        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
-        img_height (int, optional): Height of the figure.
-        img_width (int, optional): Width of the figure.
-        show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
+        volume (np.ndarray): The 3D volume to be sliced.
+        color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
+        value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
+        value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
+        image_height (int, optional): Height of the figure.
+        image_width (int, optional): Width of the figure.
+        display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
         interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
 
     Returns:
-        orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
+        slicer_orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
 
     Example:
         ```python
         import qim3d
 
         vol = qim3d.examples.fly_150x256x256
-        qim3d.viz.orthogonal(vol, cmap="magma")
+        qim3d.viz.slicer_orthogonal(vol, color_map="magma")
         ```
-        ![viz orthogonal](assets/screenshots/viz-orthogonal.gif)
+        ![viz slicer_orthogonal](assets/screenshots/viz-orthogonal.gif)
     """
 
-    if img_size:
-        img_height = img_size
-        img_width = img_size
-
-    get_slicer_for_axis = lambda axis: slicer(
-        vol,
-        axis=axis,
-        cmap=cmap,
-        vmin=vmin,
-        vmax=vmax,
-        img_height=img_height,
-        img_width=img_width,
-        show_position=show_position,
+    if image_size:
+        image_height = image_size
+        image_width = image_size
+
+    get_slicer_for_axis = lambda slice_axis: slicer(
+        volume,
+        slice_axis=slice_axis,
+        color_map=color_map,
+        value_min=value_min,
+        value_max=value_max,
+        image_height=image_height,
+        image_width=image_width,
+        display_positions=display_positions,
         interpolation=interpolation,
     )
 
-    z_slicer = get_slicer_for_axis(axis=0)
-    y_slicer = get_slicer_for_axis(axis=1)
-    x_slicer = get_slicer_for_axis(axis=2)
+    z_slicer = get_slicer_for_axis(slice_axis=0)
+    y_slicer = get_slicer_for_axis(slice_axis=1)
+    x_slicer = get_slicer_for_axis(slice_axis=2)
 
     z_slicer.children[0].description = "Z"
     y_slicer.children[0].description = "Y"
@@ -428,29 +455,29 @@ def orthogonal(
     return widgets.HBox([z_slicer, y_slicer, x_slicer])
 
 
-def interactive_fade_mask(
-    vol: np.ndarray,
+def fade_mask(
+    volume: np.ndarray,
     axis: int = 0,
-    cmap: str = "viridis",
-    vmin: float = None,
-    vmax: float = None,
+    color_map: str = 'magma',
+    value_min: float = None,
+    value_max: float = None,
 ):
     """Interactive widget for visualizing the effect of edge fading on a 3D volume.
 
     This can be used to select the best parameters before applying the mask.
 
     Args:
-        vol (np.ndarray): The volume to apply edge fading to.
+        volume (np.ndarray): The volume to apply edge fading to.
         axis (int, optional): The axis along which to apply the fading. Defaults to 0.
-        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
-        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
-        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
+        color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
+        value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
+        value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
 
     Example:
         ```python
         import qim3d
         vol = qim3d.examples.cement_128x128x128
-        qim3d.viz.interactive_fade_mask(vol)
+        qim3d.viz.fade_mask(vol)
         ```
         ![operations-edge_fade_before](assets/screenshots/viz-fade_mask.gif)
 
@@ -460,58 +487,62 @@ def interactive_fade_mask(
     def _slicer(position, decay_rate, ratio, geometry, invert):
         fig, axes = plt.subplots(1, 3, figsize=(9, 3))
 
-        slice_img = vol[position, :, :]
-        # If vmin is higher than the highest value in the image ValueError is raised
+        slice_img = volume[position, :, :]
+        # If value_min is higher than the highest value in the image ValueError is raised
         # We don't want to override the values because next slices might be okay
-        new_vmin = (
+        new_value_min = (
             None
-            if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
-            else vmin
+            if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
+            else value_min
         )
-        new_vmax = (
+        new_value_max = (
             None
-            if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
-            else vmax
+            if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
+            else value_max
         )
 
-        axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
+        axes[0].imshow(
+            slice_img, cmap=color_map, value_min=new_value_min, value_max=new_value_max
+        )
         axes[0].set_title("Original")
         axes[0].axis("off")
 
         mask = qim3d.processing.operations.fade_mask(
-            np.ones_like(vol),
+            np.ones_like(volume),
             decay_rate=decay_rate,
             ratio=ratio,
             geometry=geometry,
             axis=axis,
             invert=invert,
         )
-        axes[1].imshow(mask[position, :, :], cmap=cmap)
+        axes[1].imshow(mask[position, :, :], cmap=color_map)
         axes[1].set_title("Mask")
         axes[1].axis("off")
 
-        masked_vol = qim3d.processing.operations.fade_mask(
-            vol,
+        masked_volume = qim3d.processing.operations.fade_mask(
+            volume,
             decay_rate=decay_rate,
             ratio=ratio,
             geometry=geometry,
             axis=axis,
             invert=invert,
         )
-        # If vmin is higher than the highest value in the image ValueError is raised
+        # If value_min is higher than the highest value in the image ValueError is raised
         # We don't want to override the values because next slices might be okay
-        slice_img = masked_vol[position, :, :]
-        new_vmin = (
+        slice_img = masked_volume[position, :, :]
+        new_value_min = (
             None
-            if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
-            else vmin
+            if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
+            else value_min
         )
-        new_vmax = (
+        new_value_max = (
             None
-            if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
-            else vmax
+            if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
+            else value_max
+        )
+        axes[2].imshow(
+            slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max
         )
-        axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
         axes[2].set_title("Masked")
         axes[2].axis("off")
 
@@ -524,9 +555,9 @@ def interactive_fade_mask(
     )
 
     position_slider = widgets.IntSlider(
-        value=vol.shape[0] // 2,
+        value=volume.shape[0] // 2,
         min=0,
-        max=vol.shape[0] - 1,
+        max=volume.shape[0] - 1,
         description="Slice",
         continuous_update=False,
     )
@@ -659,13 +690,13 @@ def chunks(zarr_path: str, **kwargs):
             viz_widget = widgets.Output()
             with viz_widget:
                 viz_widget.clear_output(wait=True)
-                fig = qim3d.viz.slices(chunk, **kwargs)
+                fig = qim3d.viz.slices_grid(chunk, **kwargs)
                 display(fig)
-        elif visualization_method == "vol":
+        elif visualization_method == "volume":
             viz_widget = widgets.Output()
             with viz_widget:
                 viz_widget.clear_output(wait=True)
-                out = qim3d.viz.vol(chunk, show=False, **kwargs)
+                out = qim3d.viz.volumetric(chunk, show=False, **kwargs)
                 display(out)
         else:
             log.info(f"Invalid visualization method: {visualization_method}")
@@ -716,7 +747,7 @@ def chunks(zarr_path: str, **kwargs):
     )
 
     method_dropdown = widgets.Dropdown(
-        options=["slicer", "slices", "vol"],
+        options=["slicer", "slices", "volume"],
         value="slicer",
         description="Visualization",
         style={"description_width": description_width, "text_align": "left"},
@@ -815,7 +846,7 @@ def chunks(zarr_path: str, **kwargs):
 
 
 def histogram(
-    vol: np.ndarray,
+    volume: np.ndarray,
     bins: Union[int, str] = "auto",
     slice_idx: Union[int, str] = None,
     axis: int = 0,
@@ -833,11 +864,11 @@ def histogram(
 ):
     """
     Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
-    
+
     Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization.
 
     Args:
-        vol (np.ndarray): A 3D NumPy array representing the volume to be visualized.
+        volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
         bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
         axis (int, optional): Axis along which to take a slice. Default is 0.
         slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
@@ -879,24 +910,24 @@ def histogram(
         ![viz histogram](assets/screenshots/viz-histogram-slice.png)
     """
 
-    if not (0 <= axis < vol.ndim):
-        raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.")
+    if not (0 <= axis < volume.ndim):
+        raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
 
     if slice_idx == "middle":
-        slice_idx = vol.shape[axis] // 2
+        slice_idx = volume.shape[axis] // 2
 
     if slice_idx:
-        if 0 <= slice_idx < vol.shape[axis]:
-            img_slice = np.take(vol, indices=slice_idx, axis=axis)
+        if 0 <= slice_idx < volume.shape[axis]:
+            img_slice = np.take(volume, indices=slice_idx, axis=axis)
             data = img_slice.ravel()
             title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}"
         else:
             raise ValueError(
-                f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}."
+                f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}."
             )
     else:
-        data = vol.ravel()
-        title = f"Intensity histogram for whole volume {vol.shape}"
+        data = volume.ravel()
+        title = f"Intensity histogram for whole volume {volume.shape}"
 
     fig, ax = plt.subplots(figsize=figsize)
 
diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py
index c3018e08..853f8ae2 100644
--- a/qim3d/viz/k3d.py
+++ b/qim3d/viz/k3d.py
@@ -14,13 +14,13 @@ from qim3d.utils.logger import log
 from qim3d.utils.misc import downscale_img, scale_to_float16
 
 
-def vol(
+def volumetric(
     img,
     aspectmode="data",
     show=True,
     save=False,
     grid_visible=False,
-    cmap=None,
+    color_map='magma',
     constant_opacity=False,
     vmin=None,
     vmax=None,
@@ -43,8 +43,8 @@ def vol(
             If a string is provided, it's interpreted as the file path where the HTML
             file will be saved. Defaults to False.
         grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
-        cmap (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None.
-        constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding cmap; otherwise, the plot may appear poorly. Defaults to False.
+        color_map (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None.
+        constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding color_map; otherwise, the plot may appear poorly. Defaults to False.
         vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.
         vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None
         samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512.
@@ -60,7 +60,7 @@ def vol(
         ValueError: If `aspectmode` is not `'data'` or `'cube'`.
 
     Tip:
-        The function can be used for object label visualization using a `cmap` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering.
+        The function can be used for object label visualization using a `color_map` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering.
 
     Example:
         Display a volume inline:
@@ -69,7 +69,7 @@ def vol(
         import qim3d
 
         vol = qim3d.examples.bone_128x128x128
-        qim3d.viz.vol(vol)
+        qim3d.viz.volumetric(vol)
         ```
         <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe>
 
@@ -78,7 +78,7 @@ def vol(
         ```python
         import qim3d
         vol = qim3d.examples.bone_128x128x128
-        plot = qim3d.viz.vol(vol, show=False, save="plot.html")
+        plot = qim3d.viz.volumetric(vol, show=False, save="plot.html")
         ```
 
     """
@@ -129,21 +129,21 @@ def vol(
     if vmax:
         color_range[1] = vmax
 
-    # Handle the different formats that cmap can take
-    if cmap:
-        if isinstance(cmap, str):
-            cmap = plt.get_cmap(cmap)  # Convert to Colormap object
-        if isinstance(cmap, Colormap):
-            # Convert to the format of cmap required by k3d.volume
-            attr_vals = np.linspace(0.0, 1.0, num=cmap.N)
-            RGB_vals = cmap(np.arange(0, cmap.N))[:, :3]
-            cmap = np.column_stack((attr_vals, RGB_vals)).tolist()
+    # Handle the different formats that color_map can take
+    if color_map:
+        if isinstance(color_map, str):
+            color_map = plt.get_cmap(color_map)  # Convert to Colormap object
+        if isinstance(color_map, Colormap):
+            # Convert to the format of color_map required by k3d.volume
+            attr_vals = np.linspace(0.0, 1.0, num=color_map.N)
+            RGB_vals = color_map(np.arange(0, color_map.N))[:, :3]
+            color_map = np.column_stack((attr_vals, RGB_vals)).tolist()
 
     # Default k3d.volume settings
     opacity_function = []
     interpolation = True
     if constant_opacity:
-        # without these settings, the plot will look bad when cmap is created with qim3d.viz.colormaps.objects
+        # without these settings, the plot will look bad when color_map is created with qim3d.viz.colormaps.objects
         opacity_function = [0.0, float(constant_opacity), 1.0, float(constant_opacity)]
         interpolation = False
 
@@ -155,7 +155,7 @@ def vol(
             if aspectmode.lower() == "data"
             else None
         ),
-        color_map=cmap,
+        color_map=color_map,
         samples=samples,
         color_range=color_range,
         opacity_function=opacity_function,
-- 
GitLab