From 5505b582c5669a7cc6a3c021ed8ee0d9e5799213 Mon Sep 17 00:00:00 2001 From: Felipe <fima@dtu.dk> Date: Wed, 11 Dec 2024 13:23:38 +0100 Subject: [PATCH] data_exploration refactored --- qim3d/viz/__init__.py | 10 +- qim3d/viz/{explore.py => data_exploration.py} | 517 ++++++++++-------- qim3d/viz/k3d.py | 36 +- 3 files changed, 297 insertions(+), 266 deletions(-) rename qim3d/viz/{explore.py => data_exploration.py} (62%) diff --git a/qim3d/viz/__init__.py b/qim3d/viz/__init__.py index 33d94416..e736db3e 100644 --- a/qim3d/viz/__init__.py +++ b/qim3d/viz/__init__.py @@ -1,16 +1,16 @@ from . import colormaps from .cc import plot_cc from .detection import circles -from .explore import ( - interactive_fade_mask, - orthogonal, +from .data_exploration import ( + fade_mask, slicer, - slices, + slicer_orthogonal, + slices_grid, chunks, histogram, ) from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError -from .k3d import vol, mesh +from .k3d import volumetric, mesh from .local_thickness_ import local_thickness from .structure_tensor import vectors from .metrics import plot_metrics, grid_overview, grid_pred, vol_masked diff --git a/qim3d/viz/explore.py b/qim3d/viz/data_exploration.py similarity index 62% rename from qim3d/viz/explore.py rename to qim3d/viz/data_exploration.py index 8fe2a545..f2b958ac 100644 --- a/qim3d/viz/explore.py +++ b/qim3d/viz/data_exploration.py @@ -19,145 +19,157 @@ import seaborn as sns import qim3d -def slices( - vol: np.ndarray, - axis: int = 0, - position: Optional[Union[str, int, List[int]]] = None, - n_slices: int = 5, - max_cols: int = 5, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, - img_height: int = 2, - img_width: int = 2, - show: bool = False, - show_position: bool = True, + +def slices_grid( + volume: np.ndarray, + slice_axis: int = 0, + slice_positions: Optional[Union[str, int, List[int]]] = None, + num_slices: int = 15, + max_columns: int = 5, + color_map: str = "magma", + value_min: float = None, + value_max: float = None, + image_size=None, + image_height: int = 2, + image_width: int = 2, + display_figure: bool = False, + display_positions: bool = True, interpolation: Optional[str] = None, - img_size=None, - cbar: bool = False, - cbar_style: str = "small", - **imshow_kwargs, + color_bar: bool = False, + color_bar_style: str = "small", + **matplotlib_imshow_kwargs, ) -> plt.Figure: """Displays one or several slices from a 3d volume. - By default if `position` is None, slices plots `n_slices` linearly spaced slices. - If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position. - If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted. + By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices. + If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position. + If `slice_positions` is given as a list, `num_slices` will be ignored and the slices from `slice_positions` will be plotted. Args: vol np.ndarray: The 3D volume to be sliced. - axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. - position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None. - n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5. - max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None - img_height (int, optional): Height of the figure. - img_width (int, optional): Width of the figure. - show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. - show_position (bool, optional): If True, displays the position of the slices. Defaults to True. + slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. + slice_positions (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None. + num_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 15. + max_columns (int, optional): The maximum number of columns to be plotted. Defaults to 5. + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None + image_height (int, optional): Height of the figure. + image_width (int, optional): Width of the figure. + display_figure (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. + display_positions (bool, optional): If True, displays the position of the slices. Defaults to True. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. - cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False. - cbar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'. + color_bar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False. + color_bar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'. Returns: fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. Raises: ValueError: If the input is not a numpy.ndarray or da.core.Array. - ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1. + ValueError: If the slice_axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1. ValueError: If the file or array is not a volume with at least 3 dimensions. ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end". - ValueError: If the cbar_style keyword argument is not one of the following strings: 'small' or 'large'. + ValueError: If the color_bar_style keyword argument is not one of the following strings: 'small' or 'large'. Example: ```python import qim3d vol = qim3d.examples.shell_225x128x128 - qim3d.viz.slices(vol, n_slices=15) + qim3d.viz.slices_grid(vol, num_slices=15) ```  """ - if img_size: - img_height = img_size - img_width = img_size + if image_size: + image_height = image_size + image_width = image_size - # If we pass python None to the imshow function, it will set to + # If we pass python None to the imshow function, it will set to # default value 'antialiased' if interpolation is None: - interpolation = 'none' + interpolation = "none" # Numpy array or Torch tensor input - if not isinstance(vol, (np.ndarray, da.core.Array)): + if not isinstance(volume, (np.ndarray, da.core.Array)): raise ValueError("Data type not supported") - if vol.ndim < 3: + if volume.ndim < 3: raise ValueError( "The provided object is not a volume as it has less than 3 dimensions." ) - cbar_style_options = ["small", "large"] - if cbar_style not in cbar_style_options: - raise ValueError(f"Value '{cbar_style}' is not valid for colorbar style. Please select from {cbar_style_options}.") - - if isinstance(vol, da.core.Array): - vol = vol.compute() + color_bar_style_options = ["small", "large"] + if color_bar_style not in color_bar_style_options: + raise ValueError( + f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}." + ) + + if isinstance(volume, da.core.Array): + volume = volume.compute() # Ensure axis is a valid choice - if not (0 <= axis < vol.ndim): + if not (0 <= slice_axis < volume.ndim): raise ValueError( - f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}." + f"Invalid value for 'slice_axis'. It should be an integer between 0 and {volume.ndim - 1}." ) - if type(cmap) == matplotlib.colors.LinearSegmentedColormap or cmap == 'objects': - num_labels = len(np.unique(vol)) + # Here we deal with the case that the user wants to use the objects colormap directly + if ( + type(color_map) == matplotlib.colors.LinearSegmentedColormap + or color_map == "objects" + ): + num_labels = len(np.unique(volume)) - if cmap == 'objects': - cmap = qim3d.viz.colormaps.objects(num_labels) - # If vmin and vmax are not set like this, then in case the - # number of objects changes on new slice, objects might change - # colors. So when using a slider, the same object suddently + if color_map == "objects": + color_map = qim3d.viz.colormaps.objects(num_labels) + # If value_min and value_max are not set like this, then in case the + # number of objects changes on new slice, objects might change + # colors. So when using a slider, the same object suddently # changes color (flickers), which is confusing and annoying. - vmin = 0 - vmax = num_labels - + value_min = 0 + value_max = num_labels # Get total number of slices in the specified dimension - n_total = vol.shape[axis] + n_total = volume.shape[slice_axis] # Position is not provided - will use linearly spaced slices - if position is None: - slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int) + if slice_positions is None: + slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int) # Position is a string - elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]: - if position.lower() == "start": - slice_idxs = _get_slice_range(0, n_slices, n_total) - elif position.lower() == "mid": - slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total) - elif position.lower() == "end": - slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total) + elif isinstance(slice_positions, str) and slice_positions.lower() in [ + "start", + "mid", + "end", + ]: + if slice_positions.lower() == "start": + slice_idxs = _get_slice_range(0, num_slices, n_total) + elif slice_positions.lower() == "mid": + slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total) + elif slice_positions.lower() == "end": + slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total) # Position is an integer - elif isinstance(position, int): - slice_idxs = _get_slice_range(position, n_slices, n_total) + elif isinstance(slice_positions, int): + slice_idxs = _get_slice_range(slice_positions, num_slices, n_total) # Position is a list of integers - elif isinstance(position, list) and all(isinstance(idx, int) for idx in position): - slice_idxs = position + elif isinstance(slice_positions, list) and all( + isinstance(idx, int) for idx in slice_positions + ): + slice_idxs = slice_positions else: raise ValueError( 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".' ) # Make grid - nrows = math.ceil(n_slices / max_cols) - ncols = min(n_slices, max_cols) + nrows = math.ceil(num_slices / max_columns) + ncols = min(num_slices, max_columns) # Generate figure fig, axs = plt.subplots( nrows=nrows, ncols=ncols, - figsize=(ncols * img_height, nrows * img_width), + figsize=(ncols * image_height, nrows * image_width), constrained_layout=True, ) @@ -165,47 +177,53 @@ def slices( axs = [axs] # Convert to a list for uniformity # Convert to NumPy array in order to use the numpy.take method - if isinstance(vol, da.core.Array): - vol = vol.compute() + if isinstance(volume, da.core.Array): + volume = volume.compute() - if cbar: - # In this case, we want the vrange to be constant across the - # slices, which makes them all comparable to a single cbar. - new_vmin = vmin if vmin is not None else np.min(vol) - new_vmax = vmax if vmax is not None else np.max(vol) + if color_bar: + # In this case, we want the vrange to be constant across the + # slices, which makes them all comparable to a single color_bar. + new_value_min = value_min if value_min is not None else np.min(volume) + new_value_max = value_max if value_max is not None else np.max(volume) # Run through each ax of the grid for i, ax_row in enumerate(axs): for j, ax in enumerate(np.atleast_1d(ax_row)): - slice_idx = i * max_cols + j + slice_idx = i * max_columns + j try: - slice_img = vol.take(slice_idxs[slice_idx], axis=axis) + slice_img = volume.take(slice_idxs[slice_idx], axis=slice_axis) - if not cbar: - # If vmin is higher than the highest value in the - # image ValueError is raised. We don't want to + if not color_bar: + # If value_min is higher than the highest value in the + # image ValueError is raised. We don't want to # override the values because next slices might be okay - new_vmin = ( + new_value_min = ( None - if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) - else vmin + if ( + isinstance(value_min, (float, int)) + and value_min > np.max(slice_img) + ) + else value_min ) - new_vmax = ( + new_value_max = ( None - if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) - else vmax + if ( + isinstance(value_max, (float, int)) + and value_max < np.min(slice_img) + ) + else value_max ) ax.imshow( slice_img, - cmap=cmap, + cmap=color_map, interpolation=interpolation, - vmin=new_vmin, - vmax=new_vmax, - **imshow_kwargs, + vmin=new_value_min, + vmax=new_value_max, + **matplotlib_imshow_kwargs, ) - if show_position: + if display_positions: ax.text( 0.0, 1.0, @@ -221,7 +239,7 @@ def slices( ax.text( 1.0, 0.0, - f"axis {axis} ", + f"axis {slice_axis} ", transform=ax.transAxes, color="white", fontsize=8, @@ -237,33 +255,40 @@ def slices( # Hide the axis, so that we have a nice grid ax.axis("off") - if cbar: + if color_bar: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) fig.tight_layout() - norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True) - mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) + norm = matplotlib.colors.Normalize( + vmin=new_value_min, vmax=new_value_max, clip=True + ) + mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map) - if cbar_style =="small": + if color_bar_style == "small": # Figure coordinates of top-right axis tr_pos = np.atleast_1d(axs[0])[-1].get_position() # The width is divided by ncols to make it the same relative size to the images - cbar_ax = fig.add_axes( + color_bar_ax = fig.add_axes( [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height] ) - fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical") - elif cbar_style == "large": + fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical") + elif color_bar_style == "large": # Figure coordinates of bottom- and top-right axis br_pos = np.atleast_1d(axs[-1])[-1].get_position() tr_pos = np.atleast_1d(axs[0])[-1].get_position() # The width is divided by ncols to make it the same relative size to the images - cbar_ax = fig.add_axes( - [br_pos.xmax + 0.05 / ncols, br_pos.y0+0.0015, 0.05 / ncols, (tr_pos.y1 - br_pos.y0)-0.0015] + color_bar_ax = fig.add_axes( + [ + br_pos.xmax + 0.05 / ncols, + br_pos.y0 + 0.0015, + 0.05 / ncols, + (tr_pos.y1 - br_pos.y0) - 0.0015, + ] ) - fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical") + fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical") - if show: + if display_figure: plt.show() plt.close() @@ -271,49 +296,51 @@ def slices( return fig -def _get_slice_range(position: int, n_slices: int, n_total): +def _get_slice_range(position: int, num_slices: int, n_total): """Helper function for `slices`. Returns the range of slices to be displayed around the given position.""" - start_idx = position - n_slices // 2 + start_idx = position - num_slices // 2 end_idx = ( - position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1 + position + num_slices // 2 + if num_slices % 2 == 0 + else position + num_slices // 2 + 1 ) slice_idxs = np.arange(start_idx, end_idx) if slice_idxs[0] < 0: - slice_idxs = np.arange(0, n_slices) + slice_idxs = np.arange(0, num_slices) elif slice_idxs[-1] > n_total: - slice_idxs = np.arange(n_total - n_slices, n_total) + slice_idxs = np.arange(n_total - num_slices, n_total) return slice_idxs def slicer( - vol: np.ndarray, - axis: int = 0, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, - img_height: int = 3, - img_width: int = 3, - show_position: bool = False, + volume: np.ndarray, + slice_axis: int = 0, + color_map: str = "magma", + value_min: float = None, + value_max: float = None, + image_height: int = 3, + image_width: int = 3, + display_positions: bool = False, interpolation: Optional[str] = None, - img_size=None, - cbar: bool = False, - **imshow_kwargs, + image_size=None, + color_bar: bool = False, + **matplotlib_imshow_kwargs, ) -> widgets.interactive: """Interactive widget for visualizing slices of a 3D volume. Args: - vol (np.ndarray): The 3D volume to be sliced. - axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None - img_height (int, optional): Height of the figure. Defaults to 3. - img_width (int, optional): Width of the figure. Defaults to 3. - show_position (bool, optional): If True, displays the position of the slices. Defaults to False. + volume (np.ndarray): The 3D volume to be sliced. + slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None + image_height (int, optional): Height of the figure. Defaults to 3. + image_width (int, optional): Width of the figure. Defaults to 3. + display_positions (bool, optional): If True, displays the position of the slices. Defaults to False. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. - cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False. + color_bar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False. Returns: slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume. @@ -328,98 +355,98 @@ def slicer(  """ - if img_size: - img_height = img_size - img_width = img_size + if image_size: + image_height = image_size + image_width = image_size # Create the interactive widget - def _slicer(position): - fig = slices( - vol, - axis=axis, - cmap=cmap, - vmin=vmin, - vmax=vmax, - img_height=img_height, - img_width=img_width, - show_position=show_position, + def _slicer(slice_positions): + fig = slices_grid( + volume, + slice_axis=slice_axis, + color_map=color_map, + value_min=value_min, + value_max=value_max, + image_height=image_height, + image_width=image_width, + display_positions=display_positions, interpolation=interpolation, - position=position, - n_slices=1, - show=True, - cbar=cbar, - **imshow_kwargs, + slice_positions=slice_positions, + num_slices=1, + display_figure=True, + color_bar=color_bar, + **matplotlib_imshow_kwargs, ) return fig position_slider = widgets.IntSlider( - value=vol.shape[axis] // 2, + value=volume.shape[slice_axis] // 2, min=0, - max=vol.shape[axis] - 1, + max=volume.shape[slice_axis] - 1, description="Slice", continuous_update=True, ) - slicer_obj = widgets.interactive(_slicer, position=position_slider) + slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider) slicer_obj.layout = widgets.Layout(align_items="flex-start") return slicer_obj -def orthogonal( - vol: np.ndarray, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, - img_height: int = 3, - img_width: int = 3, - show_position: bool = False, +def slicer_orthogonal( + volume: np.ndarray, + color_map: str = 'magma', + value_min: float = None, + value_max: float = None, + image_height: int = 3, + image_width: int = 3, + display_positions: bool = False, interpolation: Optional[str] = None, - img_size=None, + image_size=None, ): """Interactive widget for visualizing orthogonal slices of a 3D volume. Args: - vol (np.ndarray): The 3D volume to be sliced. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None - img_height (int, optional): Height of the figure. - img_width (int, optional): Width of the figure. - show_position (bool, optional): If True, displays the position of the slices. Defaults to False. + volume (np.ndarray): The 3D volume to be sliced. + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None + image_height (int, optional): Height of the figure. + image_width (int, optional): Width of the figure. + display_positions (bool, optional): If True, displays the position of the slices. Defaults to False. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. Returns: - orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume. + slicer_orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume. Example: ```python import qim3d vol = qim3d.examples.fly_150x256x256 - qim3d.viz.orthogonal(vol, cmap="magma") + qim3d.viz.slicer_orthogonal(vol, color_map="magma") ``` -  +  """ - if img_size: - img_height = img_size - img_width = img_size - - get_slicer_for_axis = lambda axis: slicer( - vol, - axis=axis, - cmap=cmap, - vmin=vmin, - vmax=vmax, - img_height=img_height, - img_width=img_width, - show_position=show_position, + if image_size: + image_height = image_size + image_width = image_size + + get_slicer_for_axis = lambda slice_axis: slicer( + volume, + slice_axis=slice_axis, + color_map=color_map, + value_min=value_min, + value_max=value_max, + image_height=image_height, + image_width=image_width, + display_positions=display_positions, interpolation=interpolation, ) - z_slicer = get_slicer_for_axis(axis=0) - y_slicer = get_slicer_for_axis(axis=1) - x_slicer = get_slicer_for_axis(axis=2) + z_slicer = get_slicer_for_axis(slice_axis=0) + y_slicer = get_slicer_for_axis(slice_axis=1) + x_slicer = get_slicer_for_axis(slice_axis=2) z_slicer.children[0].description = "Z" y_slicer.children[0].description = "Y" @@ -428,29 +455,29 @@ def orthogonal( return widgets.HBox([z_slicer, y_slicer, x_slicer]) -def interactive_fade_mask( - vol: np.ndarray, +def fade_mask( + volume: np.ndarray, axis: int = 0, - cmap: str = "viridis", - vmin: float = None, - vmax: float = None, + color_map: str = 'magma', + value_min: float = None, + value_max: float = None, ): """Interactive widget for visualizing the effect of edge fading on a 3D volume. This can be used to select the best parameters before applying the mask. Args: - vol (np.ndarray): The volume to apply edge fading to. + volume (np.ndarray): The volume to apply edge fading to. axis (int, optional): The axis along which to apply the fading. Defaults to 0. - cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". - vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. - vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None + color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". + value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None. + value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None Example: ```python import qim3d vol = qim3d.examples.cement_128x128x128 - qim3d.viz.interactive_fade_mask(vol) + qim3d.viz.fade_mask(vol) ```  @@ -460,58 +487,62 @@ def interactive_fade_mask( def _slicer(position, decay_rate, ratio, geometry, invert): fig, axes = plt.subplots(1, 3, figsize=(9, 3)) - slice_img = vol[position, :, :] - # If vmin is higher than the highest value in the image ValueError is raised + slice_img = volume[position, :, :] + # If value_min is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay - new_vmin = ( + new_value_min = ( None - if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) - else vmin + if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img)) + else value_min ) - new_vmax = ( + new_value_max = ( None - if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) - else vmax + if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img)) + else value_max ) - axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax) + axes[0].imshow( + slice_img, cmap=color_map, value_min=new_value_min, value_max=new_value_max + ) axes[0].set_title("Original") axes[0].axis("off") mask = qim3d.processing.operations.fade_mask( - np.ones_like(vol), + np.ones_like(volume), decay_rate=decay_rate, ratio=ratio, geometry=geometry, axis=axis, invert=invert, ) - axes[1].imshow(mask[position, :, :], cmap=cmap) + axes[1].imshow(mask[position, :, :], cmap=color_map) axes[1].set_title("Mask") axes[1].axis("off") - masked_vol = qim3d.processing.operations.fade_mask( - vol, + masked_volume = qim3d.processing.operations.fade_mask( + volume, decay_rate=decay_rate, ratio=ratio, geometry=geometry, axis=axis, invert=invert, ) - # If vmin is higher than the highest value in the image ValueError is raised + # If value_min is higher than the highest value in the image ValueError is raised # We don't want to override the values because next slices might be okay - slice_img = masked_vol[position, :, :] - new_vmin = ( + slice_img = masked_volume[position, :, :] + new_value_min = ( None - if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) - else vmin + if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img)) + else value_min ) - new_vmax = ( + new_value_max = ( None - if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) - else vmax + if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img)) + else value_max + ) + axes[2].imshow( + slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max ) - axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax) axes[2].set_title("Masked") axes[2].axis("off") @@ -524,9 +555,9 @@ def interactive_fade_mask( ) position_slider = widgets.IntSlider( - value=vol.shape[0] // 2, + value=volume.shape[0] // 2, min=0, - max=vol.shape[0] - 1, + max=volume.shape[0] - 1, description="Slice", continuous_update=False, ) @@ -659,13 +690,13 @@ def chunks(zarr_path: str, **kwargs): viz_widget = widgets.Output() with viz_widget: viz_widget.clear_output(wait=True) - fig = qim3d.viz.slices(chunk, **kwargs) + fig = qim3d.viz.slices_grid(chunk, **kwargs) display(fig) - elif visualization_method == "vol": + elif visualization_method == "volume": viz_widget = widgets.Output() with viz_widget: viz_widget.clear_output(wait=True) - out = qim3d.viz.vol(chunk, show=False, **kwargs) + out = qim3d.viz.volumetric(chunk, show=False, **kwargs) display(out) else: log.info(f"Invalid visualization method: {visualization_method}") @@ -716,7 +747,7 @@ def chunks(zarr_path: str, **kwargs): ) method_dropdown = widgets.Dropdown( - options=["slicer", "slices", "vol"], + options=["slicer", "slices", "volume"], value="slicer", description="Visualization", style={"description_width": description_width, "text_align": "left"}, @@ -815,7 +846,7 @@ def chunks(zarr_path: str, **kwargs): def histogram( - vol: np.ndarray, + volume: np.ndarray, bins: Union[int, str] = "auto", slice_idx: Union[int, str] = None, axis: int = 0, @@ -833,11 +864,11 @@ def histogram( ): """ Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume. - + Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization. Args: - vol (np.ndarray): A 3D NumPy array representing the volume to be visualized. + volume (np.ndarray): A 3D NumPy array representing the volume to be visualized. bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". axis (int, optional): Axis along which to take a slice. Default is 0. slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. @@ -879,24 +910,24 @@ def histogram(  """ - if not (0 <= axis < vol.ndim): - raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.") + if not (0 <= axis < volume.ndim): + raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.") if slice_idx == "middle": - slice_idx = vol.shape[axis] // 2 + slice_idx = volume.shape[axis] // 2 if slice_idx: - if 0 <= slice_idx < vol.shape[axis]: - img_slice = np.take(vol, indices=slice_idx, axis=axis) + if 0 <= slice_idx < volume.shape[axis]: + img_slice = np.take(volume, indices=slice_idx, axis=axis) data = img_slice.ravel() title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}" else: raise ValueError( - f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}." + f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}." ) else: - data = vol.ravel() - title = f"Intensity histogram for whole volume {vol.shape}" + data = volume.ravel() + title = f"Intensity histogram for whole volume {volume.shape}" fig, ax = plt.subplots(figsize=figsize) diff --git a/qim3d/viz/k3d.py b/qim3d/viz/k3d.py index c3018e08..853f8ae2 100644 --- a/qim3d/viz/k3d.py +++ b/qim3d/viz/k3d.py @@ -14,13 +14,13 @@ from qim3d.utils.logger import log from qim3d.utils.misc import downscale_img, scale_to_float16 -def vol( +def volumetric( img, aspectmode="data", show=True, save=False, grid_visible=False, - cmap=None, + color_map='magma', constant_opacity=False, vmin=None, vmax=None, @@ -43,8 +43,8 @@ def vol( If a string is provided, it's interpreted as the file path where the HTML file will be saved. Defaults to False. grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False. - cmap (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None. - constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding cmap; otherwise, the plot may appear poorly. Defaults to False. + color_map (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None. + constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding color_map; otherwise, the plot may appear poorly. Defaults to False. vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None. vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512. @@ -60,7 +60,7 @@ def vol( ValueError: If `aspectmode` is not `'data'` or `'cube'`. Tip: - The function can be used for object label visualization using a `cmap` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering. + The function can be used for object label visualization using a `color_map` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering. Example: Display a volume inline: @@ -69,7 +69,7 @@ def vol( import qim3d vol = qim3d.examples.bone_128x128x128 - qim3d.viz.vol(vol) + qim3d.viz.volumetric(vol) ``` <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe> @@ -78,7 +78,7 @@ def vol( ```python import qim3d vol = qim3d.examples.bone_128x128x128 - plot = qim3d.viz.vol(vol, show=False, save="plot.html") + plot = qim3d.viz.volumetric(vol, show=False, save="plot.html") ``` """ @@ -129,21 +129,21 @@ def vol( if vmax: color_range[1] = vmax - # Handle the different formats that cmap can take - if cmap: - if isinstance(cmap, str): - cmap = plt.get_cmap(cmap) # Convert to Colormap object - if isinstance(cmap, Colormap): - # Convert to the format of cmap required by k3d.volume - attr_vals = np.linspace(0.0, 1.0, num=cmap.N) - RGB_vals = cmap(np.arange(0, cmap.N))[:, :3] - cmap = np.column_stack((attr_vals, RGB_vals)).tolist() + # Handle the different formats that color_map can take + if color_map: + if isinstance(color_map, str): + color_map = plt.get_cmap(color_map) # Convert to Colormap object + if isinstance(color_map, Colormap): + # Convert to the format of color_map required by k3d.volume + attr_vals = np.linspace(0.0, 1.0, num=color_map.N) + RGB_vals = color_map(np.arange(0, color_map.N))[:, :3] + color_map = np.column_stack((attr_vals, RGB_vals)).tolist() # Default k3d.volume settings opacity_function = [] interpolation = True if constant_opacity: - # without these settings, the plot will look bad when cmap is created with qim3d.viz.colormaps.objects + # without these settings, the plot will look bad when color_map is created with qim3d.viz.colormaps.objects opacity_function = [0.0, float(constant_opacity), 1.0, float(constant_opacity)] interpolation = False @@ -155,7 +155,7 @@ def vol( if aspectmode.lower() == "data" else None ), - color_map=cmap, + color_map=color_map, samples=samples, color_range=color_range, opacity_function=opacity_function, -- GitLab