Skip to content
Snippets Groups Projects
Commit 5505b582 authored by fima's avatar fima :beers:
Browse files

data_exploration refactored

parent 8258ac8b
No related branches found
No related tags found
1 merge request!139Refactoring v1.0 prep
This commit is part of merge request !139. Comments created here will be created in the context of that merge request.
from . import colormaps from . import colormaps
from .cc import plot_cc from .cc import plot_cc
from .detection import circles from .detection import circles
from .explore import ( from .data_exploration import (
interactive_fade_mask, fade_mask,
orthogonal,
slicer, slicer,
slices, slicer_orthogonal,
slices_grid,
chunks, chunks,
histogram, histogram,
) )
from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError
from .k3d import vol, mesh from .k3d import volumetric, mesh
from .local_thickness_ import local_thickness from .local_thickness_ import local_thickness
from .structure_tensor import vectors from .structure_tensor import vectors
from .metrics import plot_metrics, grid_overview, grid_pred, vol_masked from .metrics import plot_metrics, grid_overview, grid_pred, vol_masked
......
...@@ -19,145 +19,157 @@ import seaborn as sns ...@@ -19,145 +19,157 @@ import seaborn as sns
import qim3d import qim3d
def slices(
vol: np.ndarray, def slices_grid(
axis: int = 0, volume: np.ndarray,
position: Optional[Union[str, int, List[int]]] = None, slice_axis: int = 0,
n_slices: int = 5, slice_positions: Optional[Union[str, int, List[int]]] = None,
max_cols: int = 5, num_slices: int = 15,
cmap: str = "viridis", max_columns: int = 5,
vmin: float = None, color_map: str = "magma",
vmax: float = None, value_min: float = None,
img_height: int = 2, value_max: float = None,
img_width: int = 2, image_size=None,
show: bool = False, image_height: int = 2,
show_position: bool = True, image_width: int = 2,
display_figure: bool = False,
display_positions: bool = True,
interpolation: Optional[str] = None, interpolation: Optional[str] = None,
img_size=None, color_bar: bool = False,
cbar: bool = False, color_bar_style: str = "small",
cbar_style: str = "small", **matplotlib_imshow_kwargs,
**imshow_kwargs,
) -> plt.Figure: ) -> plt.Figure:
"""Displays one or several slices from a 3d volume. """Displays one or several slices from a 3d volume.
By default if `position` is None, slices plots `n_slices` linearly spaced slices. By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position. If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position.
If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted. If `slice_positions` is given as a list, `num_slices` will be ignored and the slices from `slice_positions` will be plotted.
Args: Args:
vol np.ndarray: The 3D volume to be sliced. vol np.ndarray: The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None. slice_positions (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5. num_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 15.
max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5. max_columns (int, optional): The maximum number of columns to be plotted. Defaults to 5.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure. image_height (int, optional): Height of the figure.
img_width (int, optional): Width of the figure. image_width (int, optional): Width of the figure.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False. display_figure (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
show_position (bool, optional): If True, displays the position of the slices. Defaults to True. display_positions (bool, optional): If True, displays the position of the slices. Defaults to True.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False. color_bar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
cbar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'. color_bar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.
Returns: Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array. fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
Raises: Raises:
ValueError: If the input is not a numpy.ndarray or da.core.Array. ValueError: If the input is not a numpy.ndarray or da.core.Array.
ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1. ValueError: If the slice_axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
ValueError: If the file or array is not a volume with at least 3 dimensions. ValueError: If the file or array is not a volume with at least 3 dimensions.
ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end". ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
ValueError: If the cbar_style keyword argument is not one of the following strings: 'small' or 'large'. ValueError: If the color_bar_style keyword argument is not one of the following strings: 'small' or 'large'.
Example: Example:
```python ```python
import qim3d import qim3d
vol = qim3d.examples.shell_225x128x128 vol = qim3d.examples.shell_225x128x128
qim3d.viz.slices(vol, n_slices=15) qim3d.viz.slices_grid(vol, num_slices=15)
``` ```
![Grid of slices](assets/screenshots/viz-slices.png) ![Grid of slices](assets/screenshots/viz-slices.png)
""" """
if img_size: if image_size:
img_height = img_size image_height = image_size
img_width = img_size image_width = image_size
# If we pass python None to the imshow function, it will set to # If we pass python None to the imshow function, it will set to
# default value 'antialiased' # default value 'antialiased'
if interpolation is None: if interpolation is None:
interpolation = 'none' interpolation = "none"
# Numpy array or Torch tensor input # Numpy array or Torch tensor input
if not isinstance(vol, (np.ndarray, da.core.Array)): if not isinstance(volume, (np.ndarray, da.core.Array)):
raise ValueError("Data type not supported") raise ValueError("Data type not supported")
if vol.ndim < 3: if volume.ndim < 3:
raise ValueError( raise ValueError(
"The provided object is not a volume as it has less than 3 dimensions." "The provided object is not a volume as it has less than 3 dimensions."
) )
cbar_style_options = ["small", "large"] color_bar_style_options = ["small", "large"]
if cbar_style not in cbar_style_options: if color_bar_style not in color_bar_style_options:
raise ValueError(f"Value '{cbar_style}' is not valid for colorbar style. Please select from {cbar_style_options}.") raise ValueError(
f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}."
)
if isinstance(vol, da.core.Array): if isinstance(volume, da.core.Array):
vol = vol.compute() volume = volume.compute()
# Ensure axis is a valid choice # Ensure axis is a valid choice
if not (0 <= axis < vol.ndim): if not (0 <= slice_axis < volume.ndim):
raise ValueError( raise ValueError(
f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}." f"Invalid value for 'slice_axis'. It should be an integer between 0 and {volume.ndim - 1}."
) )
if type(cmap) == matplotlib.colors.LinearSegmentedColormap or cmap == 'objects': # Here we deal with the case that the user wants to use the objects colormap directly
num_labels = len(np.unique(vol)) if (
type(color_map) == matplotlib.colors.LinearSegmentedColormap
or color_map == "objects"
):
num_labels = len(np.unique(volume))
if cmap == 'objects': if color_map == "objects":
cmap = qim3d.viz.colormaps.objects(num_labels) color_map = qim3d.viz.colormaps.objects(num_labels)
# If vmin and vmax are not set like this, then in case the # If value_min and value_max are not set like this, then in case the
# number of objects changes on new slice, objects might change # number of objects changes on new slice, objects might change
# colors. So when using a slider, the same object suddently # colors. So when using a slider, the same object suddently
# changes color (flickers), which is confusing and annoying. # changes color (flickers), which is confusing and annoying.
vmin = 0 value_min = 0
vmax = num_labels value_max = num_labels
# Get total number of slices in the specified dimension # Get total number of slices in the specified dimension
n_total = vol.shape[axis] n_total = volume.shape[slice_axis]
# Position is not provided - will use linearly spaced slices # Position is not provided - will use linearly spaced slices
if position is None: if slice_positions is None:
slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int) slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int)
# Position is a string # Position is a string
elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]: elif isinstance(slice_positions, str) and slice_positions.lower() in [
if position.lower() == "start": "start",
slice_idxs = _get_slice_range(0, n_slices, n_total) "mid",
elif position.lower() == "mid": "end",
slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total) ]:
elif position.lower() == "end": if slice_positions.lower() == "start":
slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total) slice_idxs = _get_slice_range(0, num_slices, n_total)
elif slice_positions.lower() == "mid":
slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total)
elif slice_positions.lower() == "end":
slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total)
# Position is an integer # Position is an integer
elif isinstance(position, int): elif isinstance(slice_positions, int):
slice_idxs = _get_slice_range(position, n_slices, n_total) slice_idxs = _get_slice_range(slice_positions, num_slices, n_total)
# Position is a list of integers # Position is a list of integers
elif isinstance(position, list) and all(isinstance(idx, int) for idx in position): elif isinstance(slice_positions, list) and all(
slice_idxs = position isinstance(idx, int) for idx in slice_positions
):
slice_idxs = slice_positions
else: else:
raise ValueError( raise ValueError(
'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".' 'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
) )
# Make grid # Make grid
nrows = math.ceil(n_slices / max_cols) nrows = math.ceil(num_slices / max_columns)
ncols = min(n_slices, max_cols) ncols = min(num_slices, max_columns)
# Generate figure # Generate figure
fig, axs = plt.subplots( fig, axs = plt.subplots(
nrows=nrows, nrows=nrows,
ncols=ncols, ncols=ncols,
figsize=(ncols * img_height, nrows * img_width), figsize=(ncols * image_height, nrows * image_width),
constrained_layout=True, constrained_layout=True,
) )
...@@ -165,47 +177,53 @@ def slices( ...@@ -165,47 +177,53 @@ def slices(
axs = [axs] # Convert to a list for uniformity axs = [axs] # Convert to a list for uniformity
# Convert to NumPy array in order to use the numpy.take method # Convert to NumPy array in order to use the numpy.take method
if isinstance(vol, da.core.Array): if isinstance(volume, da.core.Array):
vol = vol.compute() volume = volume.compute()
if cbar: if color_bar:
# In this case, we want the vrange to be constant across the # In this case, we want the vrange to be constant across the
# slices, which makes them all comparable to a single cbar. # slices, which makes them all comparable to a single color_bar.
new_vmin = vmin if vmin is not None else np.min(vol) new_value_min = value_min if value_min is not None else np.min(volume)
new_vmax = vmax if vmax is not None else np.max(vol) new_value_max = value_max if value_max is not None else np.max(volume)
# Run through each ax of the grid # Run through each ax of the grid
for i, ax_row in enumerate(axs): for i, ax_row in enumerate(axs):
for j, ax in enumerate(np.atleast_1d(ax_row)): for j, ax in enumerate(np.atleast_1d(ax_row)):
slice_idx = i * max_cols + j slice_idx = i * max_columns + j
try: try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis) slice_img = volume.take(slice_idxs[slice_idx], axis=slice_axis)
if not cbar: if not color_bar:
# If vmin is higher than the highest value in the # If value_min is higher than the highest value in the
# image ValueError is raised. We don't want to # image ValueError is raised. We don't want to
# override the values because next slices might be okay # override the values because next slices might be okay
new_vmin = ( new_value_min = (
None None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) if (
else vmin isinstance(value_min, (float, int))
and value_min > np.max(slice_img)
) )
new_vmax = ( else value_min
)
new_value_max = (
None None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) if (
else vmax isinstance(value_max, (float, int))
and value_max < np.min(slice_img)
)
else value_max
) )
ax.imshow( ax.imshow(
slice_img, slice_img,
cmap=cmap, cmap=color_map,
interpolation=interpolation, interpolation=interpolation,
vmin=new_vmin, vmin=new_value_min,
vmax=new_vmax, vmax=new_value_max,
**imshow_kwargs, **matplotlib_imshow_kwargs,
) )
if show_position: if display_positions:
ax.text( ax.text(
0.0, 0.0,
1.0, 1.0,
...@@ -221,7 +239,7 @@ def slices( ...@@ -221,7 +239,7 @@ def slices(
ax.text( ax.text(
1.0, 1.0,
0.0, 0.0,
f"axis {axis} ", f"axis {slice_axis} ",
transform=ax.transAxes, transform=ax.transAxes,
color="white", color="white",
fontsize=8, fontsize=8,
...@@ -237,33 +255,40 @@ def slices( ...@@ -237,33 +255,40 @@ def slices(
# Hide the axis, so that we have a nice grid # Hide the axis, so that we have a nice grid
ax.axis("off") ax.axis("off")
if cbar: if color_bar:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) warnings.simplefilter("ignore", category=UserWarning)
fig.tight_layout() fig.tight_layout()
norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True) norm = matplotlib.colors.Normalize(
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) vmin=new_value_min, vmax=new_value_max, clip=True
)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map)
if cbar_style =="small": if color_bar_style == "small":
# Figure coordinates of top-right axis # Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position() tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images # The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes( color_bar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height] [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
) )
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical") fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
elif cbar_style == "large": elif color_bar_style == "large":
# Figure coordinates of bottom- and top-right axis # Figure coordinates of bottom- and top-right axis
br_pos = np.atleast_1d(axs[-1])[-1].get_position() br_pos = np.atleast_1d(axs[-1])[-1].get_position()
tr_pos = np.atleast_1d(axs[0])[-1].get_position() tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images # The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes( color_bar_ax = fig.add_axes(
[br_pos.xmax + 0.05 / ncols, br_pos.y0+0.0015, 0.05 / ncols, (tr_pos.y1 - br_pos.y0)-0.0015] [
br_pos.xmax + 0.05 / ncols,
br_pos.y0 + 0.0015,
0.05 / ncols,
(tr_pos.y1 - br_pos.y0) - 0.0015,
]
) )
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical") fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
if show: if display_figure:
plt.show() plt.show()
plt.close() plt.close()
...@@ -271,49 +296,51 @@ def slices( ...@@ -271,49 +296,51 @@ def slices(
return fig return fig
def _get_slice_range(position: int, n_slices: int, n_total): def _get_slice_range(position: int, num_slices: int, n_total):
"""Helper function for `slices`. Returns the range of slices to be displayed around the given position.""" """Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
start_idx = position - n_slices // 2 start_idx = position - num_slices // 2
end_idx = ( end_idx = (
position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1 position + num_slices // 2
if num_slices % 2 == 0
else position + num_slices // 2 + 1
) )
slice_idxs = np.arange(start_idx, end_idx) slice_idxs = np.arange(start_idx, end_idx)
if slice_idxs[0] < 0: if slice_idxs[0] < 0:
slice_idxs = np.arange(0, n_slices) slice_idxs = np.arange(0, num_slices)
elif slice_idxs[-1] > n_total: elif slice_idxs[-1] > n_total:
slice_idxs = np.arange(n_total - n_slices, n_total) slice_idxs = np.arange(n_total - num_slices, n_total)
return slice_idxs return slice_idxs
def slicer( def slicer(
vol: np.ndarray, volume: np.ndarray,
axis: int = 0, slice_axis: int = 0,
cmap: str = "viridis", color_map: str = "magma",
vmin: float = None, value_min: float = None,
vmax: float = None, value_max: float = None,
img_height: int = 3, image_height: int = 3,
img_width: int = 3, image_width: int = 3,
show_position: bool = False, display_positions: bool = False,
interpolation: Optional[str] = None, interpolation: Optional[str] = None,
img_size=None, image_size=None,
cbar: bool = False, color_bar: bool = False,
**imshow_kwargs, **matplotlib_imshow_kwargs,
) -> widgets.interactive: ) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume. """Interactive widget for visualizing slices of a 3D volume.
Args: Args:
vol (np.ndarray): The 3D volume to be sliced. volume (np.ndarray): The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0. slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure. Defaults to 3. image_height (int, optional): Height of the figure. Defaults to 3.
img_width (int, optional): Width of the figure. Defaults to 3. image_width (int, optional): Width of the figure. Defaults to 3.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False. display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False. color_bar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
Returns: Returns:
slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume. slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
...@@ -328,98 +355,98 @@ def slicer( ...@@ -328,98 +355,98 @@ def slicer(
![viz slicer](assets/screenshots/viz-slicer.gif) ![viz slicer](assets/screenshots/viz-slicer.gif)
""" """
if img_size: if image_size:
img_height = img_size image_height = image_size
img_width = img_size image_width = image_size
# Create the interactive widget # Create the interactive widget
def _slicer(position): def _slicer(slice_positions):
fig = slices( fig = slices_grid(
vol, volume,
axis=axis, slice_axis=slice_axis,
cmap=cmap, color_map=color_map,
vmin=vmin, value_min=value_min,
vmax=vmax, value_max=value_max,
img_height=img_height, image_height=image_height,
img_width=img_width, image_width=image_width,
show_position=show_position, display_positions=display_positions,
interpolation=interpolation, interpolation=interpolation,
position=position, slice_positions=slice_positions,
n_slices=1, num_slices=1,
show=True, display_figure=True,
cbar=cbar, color_bar=color_bar,
**imshow_kwargs, **matplotlib_imshow_kwargs,
) )
return fig return fig
position_slider = widgets.IntSlider( position_slider = widgets.IntSlider(
value=vol.shape[axis] // 2, value=volume.shape[slice_axis] // 2,
min=0, min=0,
max=vol.shape[axis] - 1, max=volume.shape[slice_axis] - 1,
description="Slice", description="Slice",
continuous_update=True, continuous_update=True,
) )
slicer_obj = widgets.interactive(_slicer, position=position_slider) slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider)
slicer_obj.layout = widgets.Layout(align_items="flex-start") slicer_obj.layout = widgets.Layout(align_items="flex-start")
return slicer_obj return slicer_obj
def orthogonal( def slicer_orthogonal(
vol: np.ndarray, volume: np.ndarray,
cmap: str = "viridis", color_map: str = 'magma',
vmin: float = None, value_min: float = None,
vmax: float = None, value_max: float = None,
img_height: int = 3, image_height: int = 3,
img_width: int = 3, image_width: int = 3,
show_position: bool = False, display_positions: bool = False,
interpolation: Optional[str] = None, interpolation: Optional[str] = None,
img_size=None, image_size=None,
): ):
"""Interactive widget for visualizing orthogonal slices of a 3D volume. """Interactive widget for visualizing orthogonal slices of a 3D volume.
Args: Args:
vol (np.ndarray): The 3D volume to be sliced. volume (np.ndarray): The 3D volume to be sliced.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure. image_height (int, optional): Height of the figure.
img_width (int, optional): Width of the figure. image_width (int, optional): Width of the figure.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False. display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None. interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns: Returns:
orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume. slicer_orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
Example: Example:
```python ```python
import qim3d import qim3d
vol = qim3d.examples.fly_150x256x256 vol = qim3d.examples.fly_150x256x256
qim3d.viz.orthogonal(vol, cmap="magma") qim3d.viz.slicer_orthogonal(vol, color_map="magma")
``` ```
![viz orthogonal](assets/screenshots/viz-orthogonal.gif) ![viz slicer_orthogonal](assets/screenshots/viz-orthogonal.gif)
""" """
if img_size: if image_size:
img_height = img_size image_height = image_size
img_width = img_size image_width = image_size
get_slicer_for_axis = lambda axis: slicer( get_slicer_for_axis = lambda slice_axis: slicer(
vol, volume,
axis=axis, slice_axis=slice_axis,
cmap=cmap, color_map=color_map,
vmin=vmin, value_min=value_min,
vmax=vmax, value_max=value_max,
img_height=img_height, image_height=image_height,
img_width=img_width, image_width=image_width,
show_position=show_position, display_positions=display_positions,
interpolation=interpolation, interpolation=interpolation,
) )
z_slicer = get_slicer_for_axis(axis=0) z_slicer = get_slicer_for_axis(slice_axis=0)
y_slicer = get_slicer_for_axis(axis=1) y_slicer = get_slicer_for_axis(slice_axis=1)
x_slicer = get_slicer_for_axis(axis=2) x_slicer = get_slicer_for_axis(slice_axis=2)
z_slicer.children[0].description = "Z" z_slicer.children[0].description = "Z"
y_slicer.children[0].description = "Y" y_slicer.children[0].description = "Y"
...@@ -428,29 +455,29 @@ def orthogonal( ...@@ -428,29 +455,29 @@ def orthogonal(
return widgets.HBox([z_slicer, y_slicer, x_slicer]) return widgets.HBox([z_slicer, y_slicer, x_slicer])
def interactive_fade_mask( def fade_mask(
vol: np.ndarray, volume: np.ndarray,
axis: int = 0, axis: int = 0,
cmap: str = "viridis", color_map: str = 'magma',
vmin: float = None, value_min: float = None,
vmax: float = None, value_max: float = None,
): ):
"""Interactive widget for visualizing the effect of edge fading on a 3D volume. """Interactive widget for visualizing the effect of edge fading on a 3D volume.
This can be used to select the best parameters before applying the mask. This can be used to select the best parameters before applying the mask.
Args: Args:
vol (np.ndarray): The volume to apply edge fading to. volume (np.ndarray): The volume to apply edge fading to.
axis (int, optional): The axis along which to apply the fading. Defaults to 0. axis (int, optional): The axis along which to apply the fading. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis". color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
Example: Example:
```python ```python
import qim3d import qim3d
vol = qim3d.examples.cement_128x128x128 vol = qim3d.examples.cement_128x128x128
qim3d.viz.interactive_fade_mask(vol) qim3d.viz.fade_mask(vol)
``` ```
![operations-edge_fade_before](assets/screenshots/viz-fade_mask.gif) ![operations-edge_fade_before](assets/screenshots/viz-fade_mask.gif)
...@@ -460,58 +487,62 @@ def interactive_fade_mask( ...@@ -460,58 +487,62 @@ def interactive_fade_mask(
def _slicer(position, decay_rate, ratio, geometry, invert): def _slicer(position, decay_rate, ratio, geometry, invert):
fig, axes = plt.subplots(1, 3, figsize=(9, 3)) fig, axes = plt.subplots(1, 3, figsize=(9, 3))
slice_img = vol[position, :, :] slice_img = volume[position, :, :]
# If vmin is higher than the highest value in the image ValueError is raised # If value_min is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay # We don't want to override the values because next slices might be okay
new_vmin = ( new_value_min = (
None None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
else vmin else value_min
) )
new_vmax = ( new_value_max = (
None None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
else vmax else value_max
) )
axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax) axes[0].imshow(
slice_img, cmap=color_map, value_min=new_value_min, value_max=new_value_max
)
axes[0].set_title("Original") axes[0].set_title("Original")
axes[0].axis("off") axes[0].axis("off")
mask = qim3d.processing.operations.fade_mask( mask = qim3d.processing.operations.fade_mask(
np.ones_like(vol), np.ones_like(volume),
decay_rate=decay_rate, decay_rate=decay_rate,
ratio=ratio, ratio=ratio,
geometry=geometry, geometry=geometry,
axis=axis, axis=axis,
invert=invert, invert=invert,
) )
axes[1].imshow(mask[position, :, :], cmap=cmap) axes[1].imshow(mask[position, :, :], cmap=color_map)
axes[1].set_title("Mask") axes[1].set_title("Mask")
axes[1].axis("off") axes[1].axis("off")
masked_vol = qim3d.processing.operations.fade_mask( masked_volume = qim3d.processing.operations.fade_mask(
vol, volume,
decay_rate=decay_rate, decay_rate=decay_rate,
ratio=ratio, ratio=ratio,
geometry=geometry, geometry=geometry,
axis=axis, axis=axis,
invert=invert, invert=invert,
) )
# If vmin is higher than the highest value in the image ValueError is raised # If value_min is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay # We don't want to override the values because next slices might be okay
slice_img = masked_vol[position, :, :] slice_img = masked_volume[position, :, :]
new_vmin = ( new_value_min = (
None None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
else vmin else value_min
) )
new_vmax = ( new_value_max = (
None None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
else vmax else value_max
)
axes[2].imshow(
slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max
) )
axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[2].set_title("Masked") axes[2].set_title("Masked")
axes[2].axis("off") axes[2].axis("off")
...@@ -524,9 +555,9 @@ def interactive_fade_mask( ...@@ -524,9 +555,9 @@ def interactive_fade_mask(
) )
position_slider = widgets.IntSlider( position_slider = widgets.IntSlider(
value=vol.shape[0] // 2, value=volume.shape[0] // 2,
min=0, min=0,
max=vol.shape[0] - 1, max=volume.shape[0] - 1,
description="Slice", description="Slice",
continuous_update=False, continuous_update=False,
) )
...@@ -659,13 +690,13 @@ def chunks(zarr_path: str, **kwargs): ...@@ -659,13 +690,13 @@ def chunks(zarr_path: str, **kwargs):
viz_widget = widgets.Output() viz_widget = widgets.Output()
with viz_widget: with viz_widget:
viz_widget.clear_output(wait=True) viz_widget.clear_output(wait=True)
fig = qim3d.viz.slices(chunk, **kwargs) fig = qim3d.viz.slices_grid(chunk, **kwargs)
display(fig) display(fig)
elif visualization_method == "vol": elif visualization_method == "volume":
viz_widget = widgets.Output() viz_widget = widgets.Output()
with viz_widget: with viz_widget:
viz_widget.clear_output(wait=True) viz_widget.clear_output(wait=True)
out = qim3d.viz.vol(chunk, show=False, **kwargs) out = qim3d.viz.volumetric(chunk, show=False, **kwargs)
display(out) display(out)
else: else:
log.info(f"Invalid visualization method: {visualization_method}") log.info(f"Invalid visualization method: {visualization_method}")
...@@ -716,7 +747,7 @@ def chunks(zarr_path: str, **kwargs): ...@@ -716,7 +747,7 @@ def chunks(zarr_path: str, **kwargs):
) )
method_dropdown = widgets.Dropdown( method_dropdown = widgets.Dropdown(
options=["slicer", "slices", "vol"], options=["slicer", "slices", "volume"],
value="slicer", value="slicer",
description="Visualization", description="Visualization",
style={"description_width": description_width, "text_align": "left"}, style={"description_width": description_width, "text_align": "left"},
...@@ -815,7 +846,7 @@ def chunks(zarr_path: str, **kwargs): ...@@ -815,7 +846,7 @@ def chunks(zarr_path: str, **kwargs):
def histogram( def histogram(
vol: np.ndarray, volume: np.ndarray,
bins: Union[int, str] = "auto", bins: Union[int, str] = "auto",
slice_idx: Union[int, str] = None, slice_idx: Union[int, str] = None,
axis: int = 0, axis: int = 0,
...@@ -837,7 +868,7 @@ def histogram( ...@@ -837,7 +868,7 @@ def histogram(
Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization. Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization.
Args: Args:
vol (np.ndarray): A 3D NumPy array representing the volume to be visualized. volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto". bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
axis (int, optional): Axis along which to take a slice. Default is 0. axis (int, optional): Axis along which to take a slice. Default is 0.
slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
...@@ -879,24 +910,24 @@ def histogram( ...@@ -879,24 +910,24 @@ def histogram(
![viz histogram](assets/screenshots/viz-histogram-slice.png) ![viz histogram](assets/screenshots/viz-histogram-slice.png)
""" """
if not (0 <= axis < vol.ndim): if not (0 <= axis < volume.ndim):
raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.") raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
if slice_idx == "middle": if slice_idx == "middle":
slice_idx = vol.shape[axis] // 2 slice_idx = volume.shape[axis] // 2
if slice_idx: if slice_idx:
if 0 <= slice_idx < vol.shape[axis]: if 0 <= slice_idx < volume.shape[axis]:
img_slice = np.take(vol, indices=slice_idx, axis=axis) img_slice = np.take(volume, indices=slice_idx, axis=axis)
data = img_slice.ravel() data = img_slice.ravel()
title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}" title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}"
else: else:
raise ValueError( raise ValueError(
f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}." f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}."
) )
else: else:
data = vol.ravel() data = volume.ravel()
title = f"Intensity histogram for whole volume {vol.shape}" title = f"Intensity histogram for whole volume {volume.shape}"
fig, ax = plt.subplots(figsize=figsize) fig, ax = plt.subplots(figsize=figsize)
......
...@@ -14,13 +14,13 @@ from qim3d.utils.logger import log ...@@ -14,13 +14,13 @@ from qim3d.utils.logger import log
from qim3d.utils.misc import downscale_img, scale_to_float16 from qim3d.utils.misc import downscale_img, scale_to_float16
def vol( def volumetric(
img, img,
aspectmode="data", aspectmode="data",
show=True, show=True,
save=False, save=False,
grid_visible=False, grid_visible=False,
cmap=None, color_map='magma',
constant_opacity=False, constant_opacity=False,
vmin=None, vmin=None,
vmax=None, vmax=None,
...@@ -43,8 +43,8 @@ def vol( ...@@ -43,8 +43,8 @@ def vol(
If a string is provided, it's interpreted as the file path where the HTML If a string is provided, it's interpreted as the file path where the HTML
file will be saved. Defaults to False. file will be saved. Defaults to False.
grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False. grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
cmap (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None. color_map (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None.
constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding cmap; otherwise, the plot may appear poorly. Defaults to False. constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding color_map; otherwise, the plot may appear poorly. Defaults to False.
vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None. vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None
samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512. samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512.
...@@ -60,7 +60,7 @@ def vol( ...@@ -60,7 +60,7 @@ def vol(
ValueError: If `aspectmode` is not `'data'` or `'cube'`. ValueError: If `aspectmode` is not `'data'` or `'cube'`.
Tip: Tip:
The function can be used for object label visualization using a `cmap` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering. The function can be used for object label visualization using a `color_map` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering.
Example: Example:
Display a volume inline: Display a volume inline:
...@@ -69,7 +69,7 @@ def vol( ...@@ -69,7 +69,7 @@ def vol(
import qim3d import qim3d
vol = qim3d.examples.bone_128x128x128 vol = qim3d.examples.bone_128x128x128
qim3d.viz.vol(vol) qim3d.viz.volumetric(vol)
``` ```
<iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe> <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe>
...@@ -78,7 +78,7 @@ def vol( ...@@ -78,7 +78,7 @@ def vol(
```python ```python
import qim3d import qim3d
vol = qim3d.examples.bone_128x128x128 vol = qim3d.examples.bone_128x128x128
plot = qim3d.viz.vol(vol, show=False, save="plot.html") plot = qim3d.viz.volumetric(vol, show=False, save="plot.html")
``` ```
""" """
...@@ -129,21 +129,21 @@ def vol( ...@@ -129,21 +129,21 @@ def vol(
if vmax: if vmax:
color_range[1] = vmax color_range[1] = vmax
# Handle the different formats that cmap can take # Handle the different formats that color_map can take
if cmap: if color_map:
if isinstance(cmap, str): if isinstance(color_map, str):
cmap = plt.get_cmap(cmap) # Convert to Colormap object color_map = plt.get_cmap(color_map) # Convert to Colormap object
if isinstance(cmap, Colormap): if isinstance(color_map, Colormap):
# Convert to the format of cmap required by k3d.volume # Convert to the format of color_map required by k3d.volume
attr_vals = np.linspace(0.0, 1.0, num=cmap.N) attr_vals = np.linspace(0.0, 1.0, num=color_map.N)
RGB_vals = cmap(np.arange(0, cmap.N))[:, :3] RGB_vals = color_map(np.arange(0, color_map.N))[:, :3]
cmap = np.column_stack((attr_vals, RGB_vals)).tolist() color_map = np.column_stack((attr_vals, RGB_vals)).tolist()
# Default k3d.volume settings # Default k3d.volume settings
opacity_function = [] opacity_function = []
interpolation = True interpolation = True
if constant_opacity: if constant_opacity:
# without these settings, the plot will look bad when cmap is created with qim3d.viz.colormaps.objects # without these settings, the plot will look bad when color_map is created with qim3d.viz.colormaps.objects
opacity_function = [0.0, float(constant_opacity), 1.0, float(constant_opacity)] opacity_function = [0.0, float(constant_opacity), 1.0, float(constant_opacity)]
interpolation = False interpolation = False
...@@ -155,7 +155,7 @@ def vol( ...@@ -155,7 +155,7 @@ def vol(
if aspectmode.lower() == "data" if aspectmode.lower() == "data"
else None else None
), ),
color_map=cmap, color_map=color_map,
samples=samples, samples=samples,
color_range=color_range, color_range=color_range,
opacity_function=opacity_function, opacity_function=opacity_function,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment