Skip to content
Snippets Groups Projects
Commit d4c1e128 authored by fima's avatar fima :beers:
Browse files

Layout update

parent 3706b4c2
Branches
No related tags found
1 merge request!128Chunk visualization
This commit is part of merge request !128. Comments created here will be created in the context of that merge request.
......@@ -10,7 +10,7 @@ from typing import List, Optional, Union
import dask.array as da
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display
from IPython.display import SVG, display
import matplotlib
import numpy as np
import zarr
......@@ -159,11 +159,24 @@ def slices(
if not cbar:
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
ax.imshow(
slice_img, cmap=cmap, interpolation=interpolation,vmin = new_vmin, vmax = new_vmax, **imshow_kwargs
slice_img,
cmap=cmap,
interpolation=interpolation,
vmin=new_vmin,
vmax=new_vmax,
**imshow_kwargs,
)
if show_position:
......@@ -208,8 +221,10 @@ def slices(
# Figure coordinates of top-right axis
tr_pos = np.atleast_1d(axs[0])[-1].get_position()
# The width is divided by ncols to make it the same relative size to the images
cbar_ax = fig.add_axes([tr_pos.x1 + 0.05/ncols, tr_pos.y0, 0.05/ncols, tr_pos.height])
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation='vertical')
cbar_ax = fig.add_axes(
[tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
)
fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")
if show:
plt.show()
......@@ -376,7 +391,13 @@ def orthogonal(
return widgets.HBox([z_slicer, y_slicer, x_slicer])
def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', vmin:float = None, vmax:float = None):
def interactive_fade_mask(
vol: np.ndarray,
axis: int = 0,
cmap: str = "viridis",
vmin: float = None,
vmax: float = None,
):
"""Interactive widget for visualizing the effect of edge fading on a 3D volume.
This can be used to select the best parameters before applying the mask.
......@@ -405,8 +426,16 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
slice_img = vol[position, :, :]
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[0].set_title("Original")
......@@ -435,8 +464,16 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
# If vmin is higher than the highest value in the image ValueError is raised
# We don't want to override the values because next slices might be okay
slice_img = masked_vol[position, :, :]
new_vmin = None if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img)) else vmin
new_vmax = None if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img)) else vmax
new_vmin = (
None
if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
else vmin
)
new_vmax = (
None
if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
else vmax
)
axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
axes[2].set_title("Masked")
axes[2].axis("off")
......@@ -491,13 +528,12 @@ def interactive_fade_mask(vol: np.ndarray, axis: int = 0,cmap:str = 'viridis', v
return slicer_obj
def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
def chunks(zarr_path: str, **kwargs):
"""
Function to visualize chunks of a Zarr dataset using the specified visualization method.
Args:
zarr_path (str): Path to the Zarr dataset.
visualization_method (str, optional): The visualization method to use ('slicer', 'slices', or 'vol'). Each method leverages the corresponding qim3d visualization function. Defaults to 'slicer'.
**kwargs: Additional keyword arguments to pass to the visualization method.
Example:
......@@ -505,31 +541,42 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
import qim3d
zarr_path = "path/to/zarr/dataset.zarr"
qim3d.viz.chunks(zarr_path, visualization_method='vol')
qim3d.viz.chunks(zarr_path)
```
![chunks-visualization](assets/screenshots/chunks_visualization.gif)
"""
# Load the Zarr dataset
zarr_data = zarr.open(zarr_path, mode='r')
zarr_data = zarr.open(zarr_path, mode="r")
# Save arguments for later use
visualization_method = visualization_method
preserved_kwargs = kwargs
# visualization_method = visualization_method
# preserved_kwargs = kwargs
# Create label to display the chunk coordinates
chunk_info_label = widgets.HTML(value="Chunk coordinates and size will appear here")
widget_title = widgets.HTML("<h2>Chunk Explorer</h2>")
chunk_info_label = widgets.HTML(value="Chunk info will be displayed here")
def load_and_visualize(scale, z_coord, y_coord, x_coord, visualization_method, **kwargs):
def load_and_visualize(
scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
):
# Get chunk shape for the selected scale
chunk_shape = zarr_data[scale].chunks
# Calculate slice indices for the selected chunk
slices = (
slice(z_coord * chunk_shape[0], min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0])),
slice(y_coord * chunk_shape[1], min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1])),
slice(x_coord * chunk_shape[2], min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]))
slice(
z_coord * chunk_shape[0],
min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0]),
),
slice(
y_coord * chunk_shape[1],
min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1]),
),
slice(
x_coord * chunk_shape[2],
min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]),
),
)
# Extract start and stop values from each slice object
......@@ -537,53 +584,40 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
y_start, y_stop = slices[1].start, slices[1].stop
x_start, x_stop = slices[2].start, slices[2].stop
# Extract the chunk
chunk = zarr_data[scale][slices]
# Update the chunk info label with the chunk coordinates
info_string = (
f"<b>shape:</b> {chunk_shape}\n"
+ f"<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n"
+ f"<b>ranges: </b>Z({z_start}-{z_stop}) Y({y_start}-{y_stop}) X({x_start}-{x_stop})\n"
+ f"<b>dtype:</b> {chunk.dtype}\n"
+ f"<b>min value:</b> {np.min(chunk)}\n"
+ f"<b>max value:</b> {np.max(chunk)}\n"
+ f"<b>mean value:</b> {np.mean(chunk)}\n"
)
chunk_info_label.value = f"""
<div style="font-size: 14px; text-align: center;">
<b>Chunk Info</b>
</div>
<div style="font-size: 14px; display: flex; justify-content: space-between;">
<div style="flex: 1; text-align: left;">
<table style="font-size: 13px; border-collapse: collapse; width: 100%;">
<tr style="background-color: #f5f5f5;">
<td colspan="2" style="text-align: center;"><b>Range:</b></td>
</tr>
<tr>
<td style="text-align: right; padding-right: 10px;">Z:</td>
<td>{z_start}-{z_stop}</td>
</tr>
<tr style="background-color: #f5f5f5;">
<td style="text-align: right; padding-right: 10px;">Y:</td>
<td>{y_start}-{y_stop}</td>
</tr>
<tr>
<td style="text-align: right; padding-right: 10px;">X:</td>
<td>{x_start}-{x_stop}</td>
</tr>
</table>
</div>
<div style="flex: 1; text-align: right; padding-left: 20px; padding-top: 3px; font-size: 13px; white-space: nowrap;">
<b>Chunk size:</b> ({z_stop - z_start}, {y_stop - y_start}, {x_stop - x_start})
<div style="font-size: 14px; text-align: left; margin-left:32px">
<h3 style="margin: 0px">Chunk Info</h3>
<div style="font-size: 14px; text-align: left;">
<pre>{info_string}</pre>
</div>
</div>
"""
# Extract the chunk
chunk = zarr_data[scale][slices]
"""
# Prepare chunk visualization based on the selected method
if visualization_method == 'slicer': # return a widget
if visualization_method == "slicer": # return a widget
viz_widget = qim3d.viz.slicer(chunk, **kwargs)
elif visualization_method == 'slices': # return a plt.Figure
elif visualization_method == "slices": # return a plt.Figure
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
fig = qim3d.viz.slices(chunk, **kwargs)
display(fig)
elif visualization_method == 'vol':
elif visualization_method == "vol":
viz_widget = widgets.Output()
with viz_widget:
viz_widget.clear_output(wait=True)
......@@ -598,14 +632,17 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
def get_num_chunks(shape, chunk_size):
return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)]
scale_options = {f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data))} # len(zarr_data) gives number of scales
scale_options = {
f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data))
} # len(zarr_data) gives number of scales
description_width = "128px"
# Create dropdown for scale
scale_dropdown = widgets.Dropdown(
options=scale_options,
value=0, # Default to first scale
description='Scale:',
description="OME-Zarr scale",
style={"description_width": description_width, "text_align": "left"},
)
# Initialize the options for x, y, and z based on the first scale by default
......@@ -616,34 +653,44 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
z_dropdown = widgets.Dropdown(
options=list(range(num_chunks[0])),
value=0,
description='Z:',
description="First dimension (Z)",
style={"description_width": description_width, "text_align": "left"},
)
y_dropdown = widgets.Dropdown(
options=list(range(num_chunks[1])),
value=0,
description='Y:',
description="Second dimension (Y)",
style={"description_width": description_width, "text_align": "left"},
)
x_dropdown = widgets.Dropdown(
options=list(range(num_chunks[2])),
value=0,
description='X:',
description="Third dimension (X)",
style={"description_width": description_width, "text_align": "left"},
)
method_dropdown = widgets.Dropdown(
options=["slicer", "slices", "vol"],
value="slicer",
description="Visualization",
style={"description_width": description_width, "text_align": "left"},
)
# Funtion to temporarily disable observers
def disable_observers():
x_dropdown.unobserve(update_visualization, names='value')
y_dropdown.unobserve(update_visualization, names='value')
z_dropdown.unobserve(update_visualization, names='value')
x_dropdown.unobserve(update_visualization, names="value")
y_dropdown.unobserve(update_visualization, names="value")
z_dropdown.unobserve(update_visualization, names="value")
method_dropdown.unobserve(update_visualization, names="value")
# Funtion to enable observers
def enable_observers():
x_dropdown.observe(update_visualization, names='value')
y_dropdown.observe(update_visualization, names='value')
z_dropdown.observe(update_visualization, names='value')
x_dropdown.observe(update_visualization, names="value")
y_dropdown.observe(update_visualization, names="value")
z_dropdown.observe(update_visualization, names="value")
method_dropdown.observe(update_visualization, names="value")
# Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
def update_coordinate_dropdowns(scale):
......@@ -652,20 +699,28 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
multiscale_shape = zarr_data[scale].shape
chunk_shape = zarr_data[scale].chunks
num_chunks = get_num_chunks(multiscale_shape, chunk_shape) # Calculate new chunk options
num_chunks = get_num_chunks(
multiscale_shape, chunk_shape
) # Calculate new chunk options
# Reset X, Y, Z dropdowns to 0
z_dropdown.options = list(range(num_chunks[0]))
z_dropdown.value = 0 # Reset to 0
z_dropdown.disabled = len(z_dropdown.options) == 1 # Disable if only one option (0) is available
z_dropdown.disabled = (
len(z_dropdown.options) == 1
) # Disable if only one option (0) is available
y_dropdown.options = list(range(num_chunks[1]))
y_dropdown.value = 0 # Reset to 0
y_dropdown.disabled = len(y_dropdown.options) == 1 # Disable if only one option (0) is available
y_dropdown.disabled = (
len(y_dropdown.options) == 1
) # Disable if only one option (0) is available
x_dropdown.options = list(range(num_chunks[2]))
x_dropdown.value = 0 # Reset to 0
x_dropdown.disabled = len(x_dropdown.options) == 1 # Disable if only one option (0) is available
x_dropdown.disabled = (
len(x_dropdown.options) == 1
) # Disable if only one option (0) is available
enable_observers()
......@@ -677,25 +732,39 @@ def chunks(zarr_path: str, visualization_method = 'slicer', **kwargs):
x_coord = x_dropdown.value
y_coord = y_dropdown.value
z_coord = z_dropdown.value
visualization_method = method_dropdown.value
# Clear and update the chunk visualization
slicer_widget = load_and_visualize(scale, z_coord, y_coord, x_coord, visualization_method, **preserved_kwargs)
slicer_widget = load_and_visualize(
scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
)
# Recreate the layout and display the new visualization
vbox_layout.children = [hbox_layout, slicer_widget]
final_layout.children = [widget_title, hbox_layout, slicer_widget]
# Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes
scale_dropdown.observe(lambda change: update_coordinate_dropdowns(scale_dropdown.value), names='value')
scale_dropdown.observe(
lambda change: update_coordinate_dropdowns(scale_dropdown.value), names="value"
)
enable_observers()
# Create first visualization
slicer_widget = load_and_visualize(scale_dropdown.value, z_dropdown.value, y_dropdown.value, x_dropdown.value, visualization_method, **preserved_kwargs)
slicer_widget = load_and_visualize(
scale_dropdown.value,
z_dropdown.value,
y_dropdown.value,
x_dropdown.value,
method_dropdown.value,
**kwargs,
)
# Create the layout
vbox_dropbox = widgets.VBox([scale_dropdown, z_dropdown, y_dropdown, x_dropdown], layout=widgets.Layout(margin='35px 80px 0 0'))
vbox_dropbox = widgets.VBox(
[scale_dropdown, z_dropdown, y_dropdown, x_dropdown, method_dropdown]
)
hbox_layout = widgets.HBox([vbox_dropbox, chunk_info_label])
vbox_layout = widgets.VBox([hbox_layout, slicer_widget])
final_layout = widgets.VBox([widget_title, hbox_layout, slicer_widget])
# Display the VBox
display(vbox_layout)
\ No newline at end of file
display(final_layout)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment