Skip to content
Snippets Groups Projects
Commit 5505b582 authored by fima's avatar fima :beers:
Browse files

data_exploration refactored

parent 8258ac8b
No related branches found
No related tags found
1 merge request!139Refactoring v1.0 prep
This commit is part of merge request !139. Comments created here will be created in the context of that merge request.
from . import colormaps
from .cc import plot_cc
from .detection import circles
from .explore import (
interactive_fade_mask,
orthogonal,
from .data_exploration import (
fade_mask,
slicer,
slices,
slicer_orthogonal,
slices_grid,
chunks,
histogram,
)
from .itk_vtk_viewer import itk_vtk, Installer, NotInstalledError
from .k3d import vol, mesh
from .k3d import volumetric, mesh
from .local_thickness_ import local_thickness
from .structure_tensor import vectors
from .metrics import plot_metrics, grid_overview, grid_pred, vol_masked
......
......@@ -19,145 +19,157 @@ import seaborn as sns
import qim3d
def slices(
vol: np.ndarray,
axis: int = 0,
position: Optional[Union[str, int, List[int]]] = None,
n_slices: int = 5,
max_cols: int = 5,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
img_height: int = 2,
img_width: int = 2,
show: bool = False,
show_position: bool = True,
def slices_grid(
volume: np.ndarray,
slice_axis: int = 0,
slice_positions: Optional[Union[str, int, List[int]]] = None,
num_slices: int = 15,
max_columns: int = 5,
color_map: str = "magma",
value_min: float = None,
value_max: float = None,
image_size=None,
image_height: int = 2,
image_width: int = 2,
display_figure: bool = False,
display_positions: bool = True,
interpolation: Optional[str] = None,
img_size=None,
cbar: bool = False,
cbar_style: str = "small",
**imshow_kwargs,
color_bar: bool = False,
color_bar_style: str = "small",
**matplotlib_imshow_kwargs,
) -> plt.Figure:
"""Displays one or several slices from a 3d volume.
By default if `position` is None, slices plots `n_slices` linearly spaced slices.
If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position.
If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted.
By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position.
If `slice_positions` is given as a list, `num_slices` will be ignored and the slices from `slice_positions` will be plotted.
Args:
vol np.ndarray: The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5.
max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure.
img_width (int, optional): Width of the figure.
show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
slice_positions (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
num_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 15.
max_columns (int, optional): The maximum number of columns to be plotted. Defaults to 5.
color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
image_height (int, optional): Height of the figure.
image_width (int, optional): Width of the figure.
display_figure (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
display_positions (bool, optional): If True, displays the position of the slices. Defaults to True.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
cbar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.
color_bar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
color_bar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.
Returns:
fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.
Raises:
ValueError: If the input is not a numpy.ndarray or da.core.Array.
ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
ValueError: If the slice_axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
ValueError: If the file or array is not a volume with at least 3 dimensions.
ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
ValueError: If the cbar_style keyword argument is not one of the following strings: 'small' or 'large'.
ValueError: If the color_bar_style keyword argument is not one of the following strings: 'small' or 'large'.
Example:
```python
import qim3d
vol = qim3d.examples.shell_225x128x128
qim3d.viz.slices(vol, n_slices=15)
qim3d.viz.slices_grid(vol, num_slices=15)
```
![Grid of slices](assets/screenshots/viz-slices.png)
"""
if img_size:
img_height = img_size
img_width = img_size
if image_size:
image_height = image_size
image_width = image_size
# If we pass python None to the imshow function, it will set to
# default value 'antialiased'
if interpolation is None:
interpolation = 'none'
interpolation = "none"
# Numpy array or Torch tensor input
if not isinstance(vol, (np.ndarray, da.core.Array)):
if not isinstance(volume, (np.ndarray, da.core.Array)):
raise ValueError("Data type not supported")
if vol.ndim < 3:
if volume.ndim < 3:
raise ValueError(
"The provided object is not a volume as it has less than 3 dimensions."
)
cbar_style_options = ["small", "large"]
if cbar_style not in cbar_style_options:
raise ValueError(f"Value '{cbar_style}' is not valid for colorbar style. Please select from {cbar_style_options}.")
color_bar_style_options = ["small", "large"]
if color_bar_style not in color_bar_style_options:
raise ValueError(
f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}."
)
if isinstance(vol, da.core.Array):
vol = vol.compute()
if isinstance(volume, da.core.Array):
volume = volume.compute()
# Ensure axis is a valid choice
if not (0 <= axis < vol.ndim):
if not (0 <= slice_axis < volume.ndim):
raise ValueError(
f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
f"Invalid value for 'slice_axis'. It should be an integer between 0 and {volume.ndim - 1}."
)
if type(cmap) == matplotlib.colors.LinearSegmentedColormap or cmap == 'objects':
num_labels = len(np.unique(vol))
# Here we deal with the case that the user wants to use the objects colormap directly
if (
type(color_map) == matplotlib.colors.LinearSegmentedColormap
or color_map == "objects"
):
num_labels = len(np.unique(volume))
if cmap == 'objects':
cmap = qim3d.viz.colormaps.objects(num_labels)
# If vmin and vmax are not set like this, then in case the
if color_map == "objects":
color_map = qim3d.viz.colormaps.objects(num_labels)
# If value_min and value_max are not set like this, then in case the
# number of objects changes on new slice, objects might change
# colors. So when using a slider, the same object suddently
# changes color (flickers), which is confusing and annoying.
vmin = 0
vmax = num_labels
value_min = 0
value_max = num_labels
# Get total number of slices in the specified dimension
n_total = vol.shape[axis]
n_total = volume.shape[slice_axis]
# Position is not provided - will use linearly spaced slices
if position is None:
slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int)
if slice_positions is None:
slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int)
# Position is a string
elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]:
if position.lower() == "start":
slice_idxs = _get_slice_range(0, n_slices, n_total)
elif position.lower() == "mid":
slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total)
elif position.lower() == "end":
slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total)
elif isinstance(slice_positions, str) and slice_positions.lower() in [
"start",
"mid",
"end",
]:
if slice_positions.lower() == "start":
slice_idxs = _get_slice_range(0, num_slices, n_total)
elif slice_positions.lower() == "mid":
slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total)
elif slice_positions.lower() == "end":
slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total)
# Position is an integer
elif isinstance(position, int):
slice_idxs = _get_slice_range(position, n_slices, n_total)
elif isinstance(slice_positions, int):
slice_idxs = _get_slice_range(slice_positions, num_slices, n_total)
# Position is a list of integers
elif isinstance(position, list) and all(isinstance(idx, int) for idx in position):
slice_idxs = position
elif isinstance(slice_positions, list) and all(
isinstance(idx, int) for idx in slice_positions
):
slice_idxs = slice_positions
else:
raise ValueError(
'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
)
# Make grid
nrows = math.ceil(n_slices / max_cols)
ncols = min(n_slices, max_cols)
nrows = math.ceil(num_slices / max_columns)
ncols = min(num_slices, max_columns)
# Generate figure
fig, axs = plt.subplots(
nrows=nrows,
ncols=ncols,
figsize=(ncols * img_height, nrows * img_width),
figsize=(ncols * image_height, nrows * image_width),
constrained_layout=True,
)
......@@ -165,47 +177,53 @@ def slices(
axs = [axs] # Convert to a list for uniformity
# Convert to NumPy array in order to use the numpy.take method
if isinstance(vol, da.core.Array):
vol = vol.compute()
if isinstance(volume, da.core.Array):
volume = volume.compute()
if cbar:
if color_bar:
# In this case, we want the vrange to be constant across the
# slices, which makes them all comparable to a single cbar.
new_vmin = vmin if vmin is not None else np.min(vol)
new_vmax = vmax if vmax is not None else np.max(vol)
# slices, which makes them all comparable to a single color_bar.
new_value_min = value_min if value_min is not None else np.min(volume)
new_value_max = value_max if value_max is not None else np.max(volume)
# Run through each ax of the grid
for i, ax_row in enumerate(axs):
for j, ax in enumerate(np.atleast_1d(ax_row)):
slice_idx = i * max_cols + j
slice_idx = i * max_columns + j
try:
slice_img = vol.take(slice_idxs[slice_idx], axis=axis)
slice_img = volume.take(slice_idxs[slice_idx], axis=slice_axis)
if not cbar:
# If vmin is higher than the highest value in the
if not color_bar:
# If value_min is higher than the highest value in the
# image ValueError is raised. We don't want to
# override the values because next slices might be okay
new_vmin = (
new_value_min = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
if (
isinstance(value_min, (float, int))
and value_min > np.max(slice_img)
)
new_vmax = (
else value_min
)
new_value_max = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
if (
isinstance(value_max, (float, int))
and value_max < np.min(slice_img)
)
else value_max
)
ax.imshow(
slice_img,
cmap=cmap,
cmap=color_map,
interpolation=interpolation,
vmin=new_vmin,
vmax=new_vmax,
**imshow_kwargs,
vmin=new_value_min,
vmax=new_value_max,
**matplotlib_imshow_kwargs,
)
if show_position:
if display_positions:
ax.text(
0.0,
1.0,
......@@ -221,7 +239,7 @@ def slices(
ax.text(
1.0,
0.0,
f"axis {axis} ",
f"axis {slice_axis} ",
transform=ax.transAxes,
color="white",
fontsize=8,
......@@ -237,33 +255,40 @@ def slices(
# Hide the axis, so that we have a nice grid
ax.axis("off")
if cbar:
if color_bar:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
fig.tight_layout()
norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
norm = matplotlib.colors.Normalize(
vmin=new_value_min, vmax=new_value_max, clip=True
)
mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map)
if cbar_style =="small":
if color_bar_style == "small":
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes(
color_bar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
elif cbar_style == "large":
fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
elif color_bar_style == "large":
# Figure coordinates of bottom- and top-right axis
br_pos = np.atleast_1d(axs[-1])[-1].get_position()
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes(
[br_pos.xmax + 0.05 / ncols, br_pos.y0+0.0015, 0.05 / ncols, (tr_pos.y1 - br_pos.y0)-0.0015]
color_bar_ax = fig.add_axes(
[
br_pos.xmax + 0.05 / ncols,
br_pos.y0 + 0.0015,
0.05 / ncols,
(tr_pos.y1 - br_pos.y0) - 0.0015,
]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation="vertical")
if show:
if display_figure:
plt.show()
plt.close()
......@@ -271,49 +296,51 @@ def slices(
return fig
def _get_slice_range(position: int, n_slices: int, n_total):
def _get_slice_range(position: int, num_slices: int, n_total):
"""Helper function for `slices`. Returns the range of slices to be displayed around the given position."""
start_idx = position - n_slices // 2
start_idx = position - num_slices // 2
end_idx = (
position + n_slices // 2 if n_slices % 2 == 0 else position + n_slices // 2 + 1
position + num_slices // 2
if num_slices % 2 == 0
else position + num_slices // 2 + 1
)
slice_idxs = np.arange(start_idx, end_idx)
if slice_idxs[0] < 0:
slice_idxs = np.arange(0, n_slices)
slice_idxs = np.arange(0, num_slices)
elif slice_idxs[-1] > n_total:
slice_idxs = np.arange(n_total - n_slices, n_total)
slice_idxs = np.arange(n_total - num_slices, n_total)
return slice_idxs
def slicer(
vol: np.ndarray,
axis: int = 0,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
volume: np.ndarray,
slice_axis: int = 0,
color_map: str = "magma",
value_min: float = None,
value_max: float = None,
image_height: int = 3,
image_width: int = 3,
display_positions: bool = False,
interpolation: Optional[str] = None,
img_size=None,
cbar: bool = False,
**imshow_kwargs,
image_size=None,
color_bar: bool = False,
**matplotlib_imshow_kwargs,
) -> widgets.interactive:
"""Interactive widget for visualizing slices of a 3D volume.
Args:
vol (np.ndarray): The 3D volume to be sliced.
axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure. Defaults to 3.
img_width (int, optional): Width of the figure. Defaults to 3.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
volume (np.ndarray): The 3D volume to be sliced.
slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
image_height (int, optional): Height of the figure. Defaults to 3.
image_width (int, optional): Width of the figure. Defaults to 3.
display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
color_bar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.
Returns:
slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.
......@@ -328,98 +355,98 @@ def slicer(
![viz slicer](assets/screenshots/viz-slicer.gif)
"""
if img_size:
img_height = img_size
img_width = img_size
if image_size:
image_height = image_size
image_width = image_size
# Create the interactive widget
def _slicer(position):
fig = slices(
vol,
axis=axis,
cmap=cmap,
vmin=vmin,
vmax=vmax,
img_height=img_height,
img_width=img_width,
show_position=show_position,
def _slicer(slice_positions):
fig = slices_grid(
volume,
slice_axis=slice_axis,
color_map=color_map,
value_min=value_min,
value_max=value_max,
image_height=image_height,
image_width=image_width,
display_positions=display_positions,
interpolation=interpolation,
position=position,
n_slices=1,
show=True,
cbar=cbar,
**imshow_kwargs,
slice_positions=slice_positions,
num_slices=1,
display_figure=True,
color_bar=color_bar,
**matplotlib_imshow_kwargs,
)
return fig
position_slider = widgets.IntSlider(
value=vol.shape[axis] // 2,
value=volume.shape[slice_axis] // 2,
min=0,
max=vol.shape[axis] - 1,
max=volume.shape[slice_axis] - 1,
description="Slice",
continuous_update=True,
)
slicer_obj = widgets.interactive(_slicer, position=position_slider)
slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider)
slicer_obj.layout = widgets.Layout(align_items="flex-start")
return slicer_obj
def orthogonal(
vol: np.ndarray,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
img_height: int = 3,
img_width: int = 3,
show_position: bool = False,
def slicer_orthogonal(
volume: np.ndarray,
color_map: str = 'magma',
value_min: float = None,
value_max: float = None,
image_height: int = 3,
image_width: int = 3,
display_positions: bool = False,
interpolation: Optional[str] = None,
img_size=None,
image_size=None,
):
"""Interactive widget for visualizing orthogonal slices of a 3D volume.
Args:
vol (np.ndarray): The 3D volume to be sliced.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
img_height (int, optional): Height of the figure.
img_width (int, optional): Width of the figure.
show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
volume (np.ndarray): The 3D volume to be sliced.
color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
image_height (int, optional): Height of the figure.
image_width (int, optional): Width of the figure.
display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
Returns:
orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
slicer_orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.
Example:
```python
import qim3d
vol = qim3d.examples.fly_150x256x256
qim3d.viz.orthogonal(vol, cmap="magma")
qim3d.viz.slicer_orthogonal(vol, color_map="magma")
```
![viz orthogonal](assets/screenshots/viz-orthogonal.gif)
![viz slicer_orthogonal](assets/screenshots/viz-orthogonal.gif)
"""
if img_size:
img_height = img_size
img_width = img_size
get_slicer_for_axis = lambda axis: slicer(
vol,
axis=axis,
cmap=cmap,
vmin=vmin,
vmax=vmax,
img_height=img_height,
img_width=img_width,
show_position=show_position,
if image_size:
image_height = image_size
image_width = image_size
get_slicer_for_axis = lambda slice_axis: slicer(
volume,
slice_axis=slice_axis,
color_map=color_map,
value_min=value_min,
value_max=value_max,
image_height=image_height,
image_width=image_width,
display_positions=display_positions,
interpolation=interpolation,
)
z_slicer = get_slicer_for_axis(axis=0)
y_slicer = get_slicer_for_axis(axis=1)
x_slicer = get_slicer_for_axis(axis=2)
z_slicer = get_slicer_for_axis(slice_axis=0)
y_slicer = get_slicer_for_axis(slice_axis=1)
x_slicer = get_slicer_for_axis(slice_axis=2)
z_slicer.children[0].description = "Z"
y_slicer.children[0].description = "Y"
......@@ -428,29 +455,29 @@ def orthogonal(
return widgets.HBox([z_slicer, y_slicer, x_slicer])
def interactive_fade_mask(
vol: np.ndarray,
def fade_mask(
volume: np.ndarray,
axis: int = 0,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
color_map: str = 'magma',
value_min: float = None,
value_max: float = None,
):
"""Interactive widget for visualizing the effect of edge fading on a 3D volume.
This can be used to select the best parameters before applying the mask.
Args:
vol (np.ndarray): The volume to apply edge fading to.
volume (np.ndarray): The volume to apply edge fading to.
axis (int, optional): The axis along which to apply the fading. Defaults to 0.
cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
Example:
```python
import qim3d
vol = qim3d.examples.cement_128x128x128
qim3d.viz.interactive_fade_mask(vol)
qim3d.viz.fade_mask(vol)
```
![operations-edge_fade_before](assets/screenshots/viz-fade_mask.gif)
......@@ -460,58 +487,62 @@ def interactive_fade_mask(
def _slicer(position, decay_rate, ratio, geometry, invert):
fig, axes = plt.subplots(1, 3, figsize=(9, 3))
slice_img = vol[position, :, :]
# If vmin is higher than the highest value in the image ValueError is raised
slice_img = volume[position, :, :]
# If value_min is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = (
new_value_min = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
else value_min
)
new_vmax = (
new_value_max = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
else value_max
)
axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[0].imshow(
slice_img, cmap=color_map, value_min=new_value_min, value_max=new_value_max
)
axes[0].set_title("Original")
axes[0].axis("off")
mask = qim3d.processing.operations.fade_mask(
np.ones_like(vol),
np.ones_like(volume),
decay_rate=decay_rate,
ratio=ratio,
geometry=geometry,
axis=axis,
invert=invert,
)
axes[1].imshow(mask[position, :, :], cmap=cmap)
axes[1].imshow(mask[position, :, :], cmap=color_map)
axes[1].set_title("Mask")
axes[1].axis("off")
masked_vol = qim3d.processing.operations.fade_mask(
vol,
masked_volume = qim3d.processing.operations.fade_mask(
volume,
decay_rate=decay_rate,
ratio=ratio,
geometry=geometry,
axis=axis,
invert=invert,
)
# If vmin is higher than the highest value in the image ValueError is raised
# If value_min is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
slice_img = masked_vol[position, :, :]
new_vmin = (
slice_img = masked_volume[position, :, :]
new_value_min = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
else value_min
)
new_vmax = (
new_value_max = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
else value_max
)
axes[2].imshow(
slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max
)
axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[2].set_title("Masked")
axes[2].axis("off")
......@@ -524,9 +555,9 @@ def interactive_fade_mask(
)
position_slider = widgets.IntSlider(
value=vol.shape[0] // 2,
value=volume.shape[0] // 2,
min=0,
max=vol.shape[0] - 1,
max=volume.shape[0] - 1,
description="Slice",
continuous_update=False,
)
......@@ -659,13 +690,13 @@ def chunks(zarr_path: str, **kwargs):
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
fig = qim3d.viz.slices(chunk, **kwargs)
fig = qim3d.viz.slices_grid(chunk, **kwargs)
display(fig)
elif visualization_method == "vol":
elif visualization_method == "volume":
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
out = qim3d.viz.vol(chunk, show=False, **kwargs)
out = qim3d.viz.volumetric(chunk, show=False, **kwargs)
display(out)
else:
log.info(f"Invalid visualization method: {visualization_method}")
......@@ -716,7 +747,7 @@ def chunks(zarr_path: str, **kwargs):
)
method_dropdown = widgets.Dropdown(
options=["slicer", "slices", "vol"],
options=["slicer", "slices", "volume"],
value="slicer",
description="Visualization",
style={"description_width": description_width, "text_align": "left"},
......@@ -815,7 +846,7 @@ def chunks(zarr_path: str, **kwargs):
def histogram(
vol: np.ndarray,
volume: np.ndarray,
bins: Union[int, str] = "auto",
slice_idx: Union[int, str] = None,
axis: int = 0,
......@@ -837,7 +868,7 @@ def histogram(
Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization.
Args:
vol (np.ndarray): A 3D NumPy array representing the volume to be visualized.
volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
axis (int, optional): Axis along which to take a slice. Default is 0.
slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
......@@ -879,24 +910,24 @@ def histogram(
![viz histogram](assets/screenshots/viz-histogram-slice.png)
"""
if not (0 <= axis < vol.ndim):
raise ValueError(f"Axis must be an integer between 0 and {vol.ndim - 1}.")
if not (0 <= axis < volume.ndim):
raise ValueError(f"Axis must be an integer between 0 and {volume.ndim - 1}.")
if slice_idx == "middle":
slice_idx = vol.shape[axis] // 2
slice_idx = volume.shape[axis] // 2
if slice_idx:
if 0 <= slice_idx < vol.shape[axis]:
img_slice = np.take(vol, indices=slice_idx, axis=axis)
if 0 <= slice_idx < volume.shape[axis]:
img_slice = np.take(volume, indices=slice_idx, axis=axis)
data = img_slice.ravel()
title = f"Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}"
else:
raise ValueError(
f"Slice index out of range. Must be between 0 and {vol.shape[axis] - 1}."
f"Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}."
)
else:
data = vol.ravel()
title = f"Intensity histogram for whole volume {vol.shape}"
data = volume.ravel()
title = f"Intensity histogram for whole volume {volume.shape}"
fig, ax = plt.subplots(figsize=figsize)
......
......@@ -14,13 +14,13 @@ from qim3d.utils.logger import log
from qim3d.utils.misc import downscale_img, scale_to_float16
def vol(
def volumetric(
img,
aspectmode="data",
show=True,
save=False,
grid_visible=False,
cmap=None,
color_map='magma',
constant_opacity=False,
vmin=None,
vmax=None,
......@@ -43,8 +43,8 @@ def vol(
If a string is provided, it's interpreted as the file path where the HTML
file will be saved. Defaults to False.
grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
cmap (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None.
constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding cmap; otherwise, the plot may appear poorly. Defaults to False.
color_map (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to None.
constant_opacity (bool, float): Set to True if doing an object label visualization with a corresponding color_map; otherwise, the plot may appear poorly. Defaults to False.
vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.
vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None
samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512.
......@@ -60,7 +60,7 @@ def vol(
ValueError: If `aspectmode` is not `'data'` or `'cube'`.
Tip:
The function can be used for object label visualization using a `cmap` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering.
The function can be used for object label visualization using a `color_map` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering.
Example:
Display a volume inline:
......@@ -69,7 +69,7 @@ def vol(
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.vol(vol)
qim3d.viz.volumetric(vol)
```
<iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe>
......@@ -78,7 +78,7 @@ def vol(
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
plot = qim3d.viz.vol(vol, show=False, save="plot.html")
plot = qim3d.viz.volumetric(vol, show=False, save="plot.html")
```
"""
......@@ -129,21 +129,21 @@ def vol(
if vmax:
color_range[1] = vmax
# Handle the different formats that cmap can take
if cmap:
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap) # Convert to Colormap object
if isinstance(cmap, Colormap):
# Convert to the format of cmap required by k3d.volume
attr_vals = np.linspace(0.0, 1.0, num=cmap.N)
RGB_vals = cmap(np.arange(0, cmap.N))[:, :3]
cmap = np.column_stack((attr_vals, RGB_vals)).tolist()
# Handle the different formats that color_map can take
if color_map:
if isinstance(color_map, str):
color_map = plt.get_cmap(color_map) # Convert to Colormap object
if isinstance(color_map, Colormap):
# Convert to the format of color_map required by k3d.volume
attr_vals = np.linspace(0.0, 1.0, num=color_map.N)
RGB_vals = color_map(np.arange(0, color_map.N))[:, :3]
color_map = np.column_stack((attr_vals, RGB_vals)).tolist()
# Default k3d.volume settings
opacity_function = []
interpolation = True
if constant_opacity:
# without these settings, the plot will look bad when cmap is created with qim3d.viz.colormaps.objects
# without these settings, the plot will look bad when color_map is created with qim3d.viz.colormaps.objects
opacity_function = [0.0, float(constant_opacity), 1.0, float(constant_opacity)]
interpolation = False
......@@ -155,7 +155,7 @@ def vol(
if aspectmode.lower() == "data"
else None
),
color_map=cmap,
color_map=color_map,
samples=samples,
color_range=color_range,
opacity_function=opacity_function,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment