Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
qim3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
QIM
Tools
qim3d
Commits
baf99ab2
Commit
baf99ab2
authored
5 months ago
by
fima
Browse files
Options
Downloads
Patches
Plain Diff
working tool
parent
0402a140
No related branches found
No related tags found
1 merge request
!157
Threshold for v1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
qim3d/viz/_data_exploration.py
+247
-234
247 additions, 234 deletions
qim3d/viz/_data_exploration.py
with
247 additions
and
234 deletions
qim3d/viz/_data_exploration.py
+
247
−
234
View file @
baf99ab2
...
...
@@ -14,6 +14,16 @@ import matplotlib.pyplot as plt
import
numpy
as
np
import
seaborn
as
sns
import
skimage.measure
from
skimage.filters
import
(
threshold_otsu
,
threshold_isodata
,
threshold_li
,
threshold_mean
,
threshold_minimum
,
threshold_triangle
,
threshold_yen
,
)
from
IPython.display
import
clear_output
,
display
import
qim3d
...
...
@@ -26,7 +36,7 @@ def slices_grid(
slice_positions
:
Optional
[
Union
[
str
,
int
,
List
[
int
]]]
=
None
,
num_slices
:
int
=
15
,
max_columns
:
int
=
5
,
color_map
:
str
=
'
magma
'
,
color_map
:
str
=
"
magma
"
,
value_min
:
float
=
None
,
value_max
:
float
=
None
,
image_size
:
int
=
None
,
...
...
@@ -36,7 +46,7 @@ def slices_grid(
display_positions
:
bool
=
True
,
interpolation
:
Optional
[
str
]
=
None
,
color_bar
:
bool
=
False
,
color_bar_style
:
str
=
'
small
'
,
color_bar_style
:
str
=
"
small
"
,
**
matplotlib_imshow_kwargs
,
)
->
matplotlib
.
figure
.
Figure
:
"""
...
...
@@ -90,18 +100,18 @@ def slices_grid(
# If we pass python None to the imshow function, it will set to
# default value 'antialiased'
if
interpolation
is
None
:
interpolation
=
'
none
'
interpolation
=
"
none
"
# Numpy array or Torch tensor input
if
not
isinstance
(
volume
,
(
np
.
ndarray
,
da
.
core
.
Array
)):
raise
ValueError
(
'
Data type not supported
'
)
raise
ValueError
(
"
Data type not supported
"
)
if
volume
.
ndim
<
3
:
raise
ValueError
(
'
The provided object is not a volume as it has less than 3 dimensions.
'
"
The provided object is not a volume as it has less than 3 dimensions.
"
)
color_bar_style_options
=
[
'
small
'
,
'
large
'
]
color_bar_style_options
=
[
"
small
"
,
"
large
"
]
if
color_bar_style
not
in
color_bar_style_options
:
raise
ValueError
(
f
"
Value
'
{
color_bar_style
}
'
is not valid for colorbar style. Please select from
{
color_bar_style_options
}
.
"
...
...
@@ -119,11 +129,11 @@ def slices_grid(
# Here we deal with the case that the user wants to use the objects colormap directly
if
(
type
(
color_map
)
==
matplotlib
.
colors
.
LinearSegmentedColormap
or
color_map
==
'
segmentation
'
or
color_map
==
"
segmentation
"
):
num_labels
=
len
(
np
.
unique
(
volume
))
if
color_map
==
'
segmentation
'
:
if
color_map
==
"
segmentation
"
:
color_map
=
qim3d
.
viz
.
colormaps
.
segmentation
(
num_labels
)
# If value_min and value_max are not set like this, then in case the
# number of objects changes on new slice, objects might change
...
...
@@ -140,15 +150,15 @@ def slices_grid(
slice_idxs
=
np
.
linspace
(
0
,
n_total
-
1
,
num_slices
,
dtype
=
int
)
# Position is a string
elif
isinstance
(
slice_positions
,
str
)
and
slice_positions
.
lower
()
in
[
'
start
'
,
'
mid
'
,
'
end
'
,
"
start
"
,
"
mid
"
,
"
end
"
,
]:
if
slice_positions
.
lower
()
==
'
start
'
:
if
slice_positions
.
lower
()
==
"
start
"
:
slice_idxs
=
_get_slice_range
(
0
,
num_slices
,
n_total
)
elif
slice_positions
.
lower
()
==
'
mid
'
:
elif
slice_positions
.
lower
()
==
"
mid
"
:
slice_idxs
=
_get_slice_range
(
n_total
//
2
,
num_slices
,
n_total
)
elif
slice_positions
.
lower
()
==
'
end
'
:
elif
slice_positions
.
lower
()
==
"
end
"
:
slice_idxs
=
_get_slice_range
(
n_total
-
1
,
num_slices
,
n_total
)
# Position is an integer
elif
isinstance
(
slice_positions
,
int
):
...
...
@@ -229,25 +239,25 @@ def slices_grid(
ax
.
text
(
0.0
,
1.0
,
f
'
slice
{
slice_idxs
[
slice_idx
]
}
'
,
f
"
slice
{
slice_idxs
[
slice_idx
]
}
"
,
transform
=
ax
.
transAxes
,
color
=
'
white
'
,
color
=
"
white
"
,
fontsize
=
8
,
va
=
'
top
'
,
ha
=
'
left
'
,
bbox
=
dict
(
facecolor
=
'
#303030
'
,
linewidth
=
0
,
pad
=
0
),
va
=
"
top
"
,
ha
=
"
left
"
,
bbox
=
dict
(
facecolor
=
"
#303030
"
,
linewidth
=
0
,
pad
=
0
),
)
ax
.
text
(
1.0
,
0.0
,
f
'
axis
{
slice_axis
}
'
,
f
"
axis
{
slice_axis
}
"
,
transform
=
ax
.
transAxes
,
color
=
'
white
'
,
color
=
"
white
"
,
fontsize
=
8
,
va
=
'
bottom
'
,
ha
=
'
right
'
,
bbox
=
dict
(
facecolor
=
'
#303030
'
,
linewidth
=
0
,
pad
=
0
),
va
=
"
bottom
"
,
ha
=
"
right
"
,
bbox
=
dict
(
facecolor
=
"
#303030
"
,
linewidth
=
0
,
pad
=
0
),
)
except
IndexError
:
...
...
@@ -255,11 +265,11 @@ def slices_grid(
pass
# Hide the axis, so that we have a nice grid
ax
.
axis
(
'
off
'
)
ax
.
axis
(
"
off
"
)
if
color_bar
:
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
'
ignore
'
,
category
=
UserWarning
)
warnings
.
simplefilter
(
"
ignore
"
,
category
=
UserWarning
)
fig
.
tight_layout
()
norm
=
matplotlib
.
colors
.
Normalize
(
...
...
@@ -267,15 +277,15 @@ def slices_grid(
)
mappable
=
matplotlib
.
cm
.
ScalarMappable
(
norm
=
norm
,
cmap
=
color_map
)
if
color_bar_style
==
'
small
'
:
if
color_bar_style
==
"
small
"
:
# Figure coordinates of top-right axis
tr_pos
=
np
.
atleast_1d
(
axs
[
0
])[
-
1
].
get_position
()
# The width is divided by ncols to make it the same relative size to the images
color_bar_ax
=
fig
.
add_axes
(
[
tr_pos
.
x1
+
0.05
/
ncols
,
tr_pos
.
y0
,
0.05
/
ncols
,
tr_pos
.
height
]
)
fig
.
colorbar
(
mappable
=
mappable
,
cax
=
color_bar_ax
,
orientation
=
'
vertical
'
)
elif
color_bar_style
==
'
large
'
:
fig
.
colorbar
(
mappable
=
mappable
,
cax
=
color_bar_ax
,
orientation
=
"
vertical
"
)
elif
color_bar_style
==
"
large
"
:
# Figure coordinates of bottom- and top-right axis
br_pos
=
np
.
atleast_1d
(
axs
[
-
1
])[
-
1
].
get_position
()
tr_pos
=
np
.
atleast_1d
(
axs
[
0
])[
-
1
].
get_position
()
...
...
@@ -288,7 +298,7 @@ def slices_grid(
(
tr_pos
.
y1
-
br_pos
.
y0
)
-
0.0015
,
]
)
fig
.
colorbar
(
mappable
=
mappable
,
cax
=
color_bar_ax
,
orientation
=
'
vertical
'
)
fig
.
colorbar
(
mappable
=
mappable
,
cax
=
color_bar_ax
,
orientation
=
"
vertical
"
)
if
display_figure
:
plt
.
show
()
...
...
@@ -319,7 +329,7 @@ def _get_slice_range(position: int, num_slices: int, n_total: int) -> np.ndarray
def
slicer
(
volume
:
np
.
ndarray
,
slice_axis
:
int
=
0
,
color_map
:
str
=
'
magma
'
,
color_map
:
str
=
"
magma
"
,
value_min
:
float
=
None
,
value_max
:
float
=
None
,
image_height
:
int
=
3
,
...
...
@@ -363,14 +373,14 @@ def slicer(
image_height
=
image_size
image_width
=
image_size
color_bar_options
=
[
None
,
'
slices
'
,
'
volume
'
]
color_bar_options
=
[
None
,
"
slices
"
,
"
volume
"
]
if
color_bar
not
in
color_bar_options
:
raise
ValueError
(
f
"
Unrecognized value
'
{
color_bar
}
'
for parameter color_bar.
"
f
'
Expected one of
{
color_bar_options
}
.
'
f
"
Expected one of
{
color_bar_options
}
.
"
)
show_color_bar
=
color_bar
is
not
None
if
color_bar
==
'
slices
'
:
if
color_bar
==
"
slices
"
:
# Precompute the minimum and maximum along each slice for faster widget sliding.
non_slice_axes
=
tuple
(
i
for
i
in
range
(
volume
.
ndim
)
if
i
!=
slice_axis
)
slice_mins
=
np
.
min
(
volume
,
axis
=
non_slice_axes
)
...
...
@@ -378,7 +388,7 @@ def slicer(
# Create the interactive widget
def
_slicer
(
slice_positions
):
if
color_bar
==
'
slices
'
:
if
color_bar
==
"
slices
"
:
dynamic_min
=
slice_mins
[
slice_positions
]
dynamic_max
=
slice_maxs
[
slice_positions
]
else
:
...
...
@@ -407,18 +417,18 @@ def slicer(
value
=
volume
.
shape
[
slice_axis
]
//
2
,
min
=
0
,
max
=
volume
.
shape
[
slice_axis
]
-
1
,
description
=
'
Slice
'
,
description
=
"
Slice
"
,
continuous_update
=
True
,
)
slicer_obj
=
widgets
.
interactive
(
_slicer
,
slice_positions
=
position_slider
)
slicer_obj
.
layout
=
widgets
.
Layout
(
align_items
=
'
flex-start
'
)
slicer_obj
.
layout
=
widgets
.
Layout
(
align_items
=
"
flex-start
"
)
return
slicer_obj
def
slicer_orthogonal
(
volume
:
np
.
ndarray
,
color_map
:
str
=
'
magma
'
,
color_map
:
str
=
"
magma
"
,
value_min
:
float
=
None
,
value_max
:
float
=
None
,
image_height
:
int
=
3
,
...
...
@@ -474,9 +484,9 @@ def slicer_orthogonal(
y_slicer
=
get_slicer_for_axis
(
slice_axis
=
1
)
x_slicer
=
get_slicer_for_axis
(
slice_axis
=
2
)
z_slicer
.
children
[
0
].
description
=
'
Z
'
y_slicer
.
children
[
0
].
description
=
'
Y
'
x_slicer
.
children
[
0
].
description
=
'
X
'
z_slicer
.
children
[
0
].
description
=
"
Z
"
y_slicer
.
children
[
0
].
description
=
"
Y
"
x_slicer
.
children
[
0
].
description
=
"
X
"
return
widgets
.
HBox
([
z_slicer
,
y_slicer
,
x_slicer
])
...
...
@@ -484,7 +494,7 @@ def slicer_orthogonal(
def
fade_mask
(
volume
:
np
.
ndarray
,
axis
:
int
=
0
,
color_map
:
str
=
'
magma
'
,
color_map
:
str
=
"
magma
"
,
value_min
:
float
=
None
,
value_max
:
float
=
None
,
)
->
widgets
.
interactive
:
...
...
@@ -534,8 +544,8 @@ def fade_mask(
axes
[
0
].
imshow
(
slice_img
,
cmap
=
color_map
,
vmin
=
new_value_min
,
vmax
=
new_value_max
)
axes
[
0
].
set_title
(
'
Original
'
)
axes
[
0
].
axis
(
'
off
'
)
axes
[
0
].
set_title
(
"
Original
"
)
axes
[
0
].
axis
(
"
off
"
)
mask
=
qim3d
.
operations
.
fade_mask
(
np
.
ones_like
(
volume
),
...
...
@@ -546,8 +556,8 @@ def fade_mask(
invert
=
invert
,
)
axes
[
1
].
imshow
(
mask
[
position
,
:,
:],
cmap
=
color_map
)
axes
[
1
].
set_title
(
'
Mask
'
)
axes
[
1
].
axis
(
'
off
'
)
axes
[
1
].
set_title
(
"
Mask
"
)
axes
[
1
].
axis
(
"
off
"
)
masked_volume
=
qim3d
.
operations
.
fade_mask
(
volume
,
...
...
@@ -573,22 +583,22 @@ def fade_mask(
axes
[
2
].
imshow
(
slice_img
,
cmap
=
color_map
,
vmin
=
new_value_min
,
vmax
=
new_value_max
)
axes
[
2
].
set_title
(
'
Masked
'
)
axes
[
2
].
axis
(
'
off
'
)
axes
[
2
].
set_title
(
"
Masked
"
)
axes
[
2
].
axis
(
"
off
"
)
return
fig
shape_dropdown
=
widgets
.
Dropdown
(
options
=
[
'
spherical
'
,
'
cylindrical
'
],
value
=
'
spherical
'
,
# default value
description
=
'
Geometry
'
,
options
=
[
"
spherical
"
,
"
cylindrical
"
],
value
=
"
spherical
"
,
# default value
description
=
"
Geometry
"
,
)
position_slider
=
widgets
.
IntSlider
(
value
=
volume
.
shape
[
0
]
//
2
,
min
=
0
,
max
=
volume
.
shape
[
0
]
-
1
,
description
=
'
Slice
'
,
description
=
"
Slice
"
,
continuous_update
=
False
,
)
decay_rate_slider
=
widgets
.
FloatSlider
(
...
...
@@ -596,7 +606,7 @@ def fade_mask(
min
=
1
,
max
=
50
,
step
=
1.0
,
description
=
'
Decay Rate
'
,
description
=
"
Decay Rate
"
,
continuous_update
=
False
,
)
ratio_slider
=
widgets
.
FloatSlider
(
...
...
@@ -604,14 +614,14 @@ def fade_mask(
min
=
0.1
,
max
=
1
,
step
=
0.01
,
description
=
'
Ratio
'
,
description
=
"
Ratio
"
,
continuous_update
=
False
,
)
# Create the Checkbox widget
invert_checkbox
=
widgets
.
Checkbox
(
value
=
False
,
description
=
'
Invert
'
,
# default value
description
=
"
Invert
"
,
# default value
)
slicer_obj
=
widgets
.
interactive
(
...
...
@@ -622,7 +632,7 @@ def fade_mask(
geometry
=
shape_dropdown
,
invert
=
invert_checkbox
,
)
slicer_obj
.
layout
=
widgets
.
Layout
(
align_items
=
'
flex-start
'
)
slicer_obj
.
layout
=
widgets
.
Layout
(
align_items
=
"
flex-start
"
)
return
slicer_obj
...
...
@@ -654,15 +664,15 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
"""
# Load the Zarr dataset
zarr_data
=
zarr
.
open
(
zarr_path
,
mode
=
'
r
'
)
zarr_data
=
zarr
.
open
(
zarr_path
,
mode
=
"
r
"
)
# Save arguments for later use
# visualization_method = visualization_method
# preserved_kwargs = kwargs
# Create label to display the chunk coordinates
widget_title
=
widgets
.
HTML
(
'
<h2>Chunk Explorer</h2>
'
)
chunk_info_label
=
widgets
.
HTML
(
value
=
'
Chunk info will be displayed here
'
)
widget_title
=
widgets
.
HTML
(
"
<h2>Chunk Explorer</h2>
"
)
chunk_info_label
=
widgets
.
HTML
(
value
=
"
Chunk info will be displayed here
"
)
def
load_and_visualize
(
scale
,
z_coord
,
y_coord
,
x_coord
,
visualization_method
,
**
kwargs
...
...
@@ -696,13 +706,13 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
# Update the chunk info label with the chunk coordinates
info_string
=
(
f
'
<b>shape:</b>
{
chunk_shape
}
\n
'
+
f
'
<b>coordinates:</b> (
{
z_coord
}
,
{
y_coord
}
,
{
x_coord
}
)
\n
'
+
f
'
<b>ranges: </b>Z(
{
z_start
}
-
{
z_stop
}
) Y(
{
y_start
}
-
{
y_stop
}
) X(
{
x_start
}
-
{
x_stop
}
)
\n
'
+
f
'
<b>dtype:</b>
{
chunk
.
dtype
}
\n
'
+
f
'
<b>min value:</b>
{
np
.
min
(
chunk
)
}
\n
'
+
f
'
<b>max value:</b>
{
np
.
max
(
chunk
)
}
\n
'
+
f
'
<b>mean value:</b>
{
np
.
mean
(
chunk
)
}
\n
'
f
"
<b>shape:</b>
{
chunk_shape
}
\n
"
+
f
"
<b>coordinates:</b> (
{
z_coord
}
,
{
y_coord
}
,
{
x_coord
}
)
\n
"
+
f
"
<b>ranges: </b>Z(
{
z_start
}
-
{
z_stop
}
) Y(
{
y_start
}
-
{
y_stop
}
) X(
{
x_start
}
-
{
x_stop
}
)
\n
"
+
f
"
<b>dtype:</b>
{
chunk
.
dtype
}
\n
"
+
f
"
<b>min value:</b>
{
np
.
min
(
chunk
)
}
\n
"
+
f
"
<b>max value:</b>
{
np
.
max
(
chunk
)
}
\n
"
+
f
"
<b>mean value:</b>
{
np
.
mean
(
chunk
)
}
\n
"
)
chunk_info_label
.
value
=
f
"""
...
...
@@ -716,22 +726,22 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
"""
# Prepare chunk visualization based on the selected method
if
visualization_method
==
'
slicer
'
:
# return a widget
if
visualization_method
==
"
slicer
"
:
# return a widget
viz_widget
=
qim3d
.
viz
.
slicer
(
chunk
,
**
kwargs
)
elif
visualization_method
==
'
slices
'
:
# return a plt.Figure
elif
visualization_method
==
"
slices
"
:
# return a plt.Figure
viz_widget
=
widgets
.
Output
()
with
viz_widget
:
viz_widget
.
clear_output
(
wait
=
True
)
fig
=
qim3d
.
viz
.
slices_grid
(
chunk
,
**
kwargs
)
display
(
fig
)
elif
visualization_method
==
'
volume
'
:
elif
visualization_method
==
"
volume
"
:
viz_widget
=
widgets
.
Output
()
with
viz_widget
:
viz_widget
.
clear_output
(
wait
=
True
)
out
=
qim3d
.
viz
.
volumetric
(
chunk
,
show
=
False
,
**
kwargs
)
display
(
out
)
else
:
log
.
info
(
f
'
Invalid visualization method:
{
visualization_method
}
'
)
log
.
info
(
f
"
Invalid visualization method:
{
visualization_method
}
"
)
return
viz_widget
...
...
@@ -740,16 +750,16 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
return
[(
s
+
chunk_size
[
i
]
-
1
)
//
chunk_size
[
i
]
for
i
,
s
in
enumerate
(
shape
)]
scale_options
=
{
f
'
{
i
}
{
zarr_data
[
i
].
shape
}
'
:
i
for
i
in
range
(
len
(
zarr_data
))
f
"
{
i
}
{
zarr_data
[
i
].
shape
}
"
:
i
for
i
in
range
(
len
(
zarr_data
))
}
# len(zarr_data) gives number of scales
description_width
=
'
128px
'
description_width
=
"
128px
"
# Create dropdown for scale
scale_dropdown
=
widgets
.
Dropdown
(
options
=
scale_options
,
value
=
0
,
# Default to first scale
description
=
'
OME-Zarr scale
'
,
style
=
{
'
description_width
'
:
description_width
,
'
text_align
'
:
'
left
'
},
description
=
"
OME-Zarr scale
"
,
style
=
{
"
description_width
"
:
description_width
,
"
text_align
"
:
"
left
"
},
)
# Initialize the options for x, y, and z based on the first scale by default
...
...
@@ -760,44 +770,44 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
z_dropdown
=
widgets
.
Dropdown
(
options
=
list
(
range
(
num_chunks
[
0
])),
value
=
0
,
description
=
'
First dimension (Z)
'
,
style
=
{
'
description_width
'
:
description_width
,
'
text_align
'
:
'
left
'
},
description
=
"
First dimension (Z)
"
,
style
=
{
"
description_width
"
:
description_width
,
"
text_align
"
:
"
left
"
},
)
y_dropdown
=
widgets
.
Dropdown
(
options
=
list
(
range
(
num_chunks
[
1
])),
value
=
0
,
description
=
'
Second dimension (Y)
'
,
style
=
{
'
description_width
'
:
description_width
,
'
text_align
'
:
'
left
'
},
description
=
"
Second dimension (Y)
"
,
style
=
{
"
description_width
"
:
description_width
,
"
text_align
"
:
"
left
"
},
)
x_dropdown
=
widgets
.
Dropdown
(
options
=
list
(
range
(
num_chunks
[
2
])),
value
=
0
,
description
=
'
Third dimension (X)
'
,
style
=
{
'
description_width
'
:
description_width
,
'
text_align
'
:
'
left
'
},
description
=
"
Third dimension (X)
"
,
style
=
{
"
description_width
"
:
description_width
,
"
text_align
"
:
"
left
"
},
)
method_dropdown
=
widgets
.
Dropdown
(
options
=
[
'
slicer
'
,
'
slices
'
,
'
volume
'
],
value
=
'
slicer
'
,
description
=
'
Visualization
'
,
style
=
{
'
description_width
'
:
description_width
,
'
text_align
'
:
'
left
'
},
options
=
[
"
slicer
"
,
"
slices
"
,
"
volume
"
],
value
=
"
slicer
"
,
description
=
"
Visualization
"
,
style
=
{
"
description_width
"
:
description_width
,
"
text_align
"
:
"
left
"
},
)
# Funtion to temporarily disable observers
def
disable_observers
():
x_dropdown
.
unobserve
(
update_visualization
,
names
=
'
value
'
)
y_dropdown
.
unobserve
(
update_visualization
,
names
=
'
value
'
)
z_dropdown
.
unobserve
(
update_visualization
,
names
=
'
value
'
)
method_dropdown
.
unobserve
(
update_visualization
,
names
=
'
value
'
)
x_dropdown
.
unobserve
(
update_visualization
,
names
=
"
value
"
)
y_dropdown
.
unobserve
(
update_visualization
,
names
=
"
value
"
)
z_dropdown
.
unobserve
(
update_visualization
,
names
=
"
value
"
)
method_dropdown
.
unobserve
(
update_visualization
,
names
=
"
value
"
)
# Funtion to enable observers
def
enable_observers
():
x_dropdown
.
observe
(
update_visualization
,
names
=
'
value
'
)
y_dropdown
.
observe
(
update_visualization
,
names
=
'
value
'
)
z_dropdown
.
observe
(
update_visualization
,
names
=
'
value
'
)
method_dropdown
.
observe
(
update_visualization
,
names
=
'
value
'
)
x_dropdown
.
observe
(
update_visualization
,
names
=
"
value
"
)
y_dropdown
.
observe
(
update_visualization
,
names
=
"
value
"
)
z_dropdown
.
observe
(
update_visualization
,
names
=
"
value
"
)
method_dropdown
.
observe
(
update_visualization
,
names
=
"
value
"
)
# Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
def
update_coordinate_dropdowns
(
scale
):
...
...
@@ -850,7 +860,7 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
# Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes
scale_dropdown
.
observe
(
lambda
change
:
update_coordinate_dropdowns
(
scale_dropdown
.
value
),
names
=
'
value
'
lambda
change
:
update_coordinate_dropdowns
(
scale_dropdown
.
value
),
names
=
"
value
"
)
enable_observers
()
...
...
@@ -878,21 +888,23 @@ def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
def
histogram
(
volume
:
np
.
ndarray
,
bins
:
Union
[
int
,
str
]
=
'
auto
'
,
slice_idx
:
Union
[
int
,
str
]
=
None
,
bins
:
Union
[
int
,
str
]
=
"
auto
"
,
slice_idx
:
Union
[
int
,
str
,
None
]
=
None
,
vertical_line
:
int
=
None
,
axis
:
int
=
0
,
kde
:
bool
=
True
,
log_scale
:
bool
=
False
,
despine
:
bool
=
True
,
show_title
:
bool
=
True
,
color
:
str
=
'
qim3d
'
,
edgecolor
:
str
|
None
=
None
,
figsize
:
t
uple
[
float
,
float
]
=
(
8
,
4.5
),
element
:
str
=
'
step
'
,
color
:
str
=
"
qim3d
"
,
edgecolor
:
Optional
[
str
]
=
None
,
figsize
:
T
uple
[
float
,
float
]
=
(
8
,
4.5
),
element
:
str
=
"
step
"
,
return_fig
:
bool
=
False
,
show
:
bool
=
True
,
**
sns_kwargs
,
)
->
None
|
matplotlib
.
figure
.
Figure
:
ax
:
Optional
[
plt
.
Axes
]
=
None
,
**
sns_kwargs
:
Union
[
str
,
float
,
int
,
bool
]
)
->
Optional
[
Union
[
plt
.
Figure
,
plt
.
Axes
]]:
"""
Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.
...
...
@@ -900,73 +912,63 @@ def histogram(
Args:
volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
bins (
int or
str, optional): Number of histogram bins or a binning strategy (e.g.,
"
auto
"
). Default is
"
auto
"
.
bins (
Union[int,
str
]
, optional): Number of histogram bins or a binning strategy (e.g.,
"
auto
"
). Default is
"
auto
"
.
axis (int, optional): Axis along which to take a slice. Default is 0.
slice_idx (
int or str or None
, optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
slice_idx (
Union[int, str]
, optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
If
"
middle
"
, the function uses the middle slice. If None, the entire volume is visualized. Default is None.
vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None.
kde (bool, optional): Whether to overlay a kernel density estimate. Default is True.
log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False.
despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True.
show_title (bool, optional): If True, displays a title with slice information. Default is True.
color (str, optional): Color for the histogram bars. If
"
qim3d
"
, defaults to the qim3d color. Default is
"
qim3d
"
.
edgecolor (str, optional): Color for the edges of the histogram bars. Default is None.
figsize (tuple
of floats
, optional): Size of the figure (width, height). Default is (8, 4.5).
figsize (tuple, optional): Size of the figure (width, height). Default is (8, 4.5).
element (str, optional): Type of histogram to draw (
'
bars
'
,
'
step
'
, or
'
poly
'
). Default is
"
step
"
.
return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False.
show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True.
**sns_kwargs (Any): Additional keyword arguments for `seaborn.histplot`.
ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None.
**sns_kwargs: Additional keyword arguments for `seaborn.histplot`.
Returns:
fig (Optional[matplotlib.figure.Figure]): If `return_fig` is True, returns the generated figure object. Otherwise, returns None.
Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]:
If `return_fig` is True, returns the generated figure object.
If `return_fig` is False and `ax` is provided, returns the `Axes` object.
Otherwise, returns None.
Raises:
ValueError: If `axis` is not a valid axis index (0, 1, or 2).
ValueError: If `slice_idx` is an integer and is out of range for the specified axis.
Example:
```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol)
```

```python
import qim3d
vol = qim3d.examples.bone_128x128x128
qim3d.viz.histogram(vol, bins=32, slice_idx=
"
middle
"
, axis=1, kde=False, log_scale=True)
```

"""
if
not
(
0
<=
axis
<
volume
.
ndim
):
raise
ValueError
(
f
'
Axis must be an integer between 0 and
{
volume
.
ndim
-
1
}
.
'
)
raise
ValueError
(
f
"
Axis must be an integer between 0 and
{
volume
.
ndim
-
1
}
.
"
)
if
slice_idx
==
'
middle
'
:
if
slice_idx
==
"
middle
"
:
slice_idx
=
volume
.
shape
[
axis
]
//
2
if
slice_idx
:
if
slice_idx
is
not
None
:
if
0
<=
slice_idx
<
volume
.
shape
[
axis
]:
img_slice
=
np
.
take
(
volume
,
indices
=
slice_idx
,
axis
=
axis
)
data
=
img_slice
.
ravel
()
title
=
f
'
Intensity histogram of slice #
{
slice_idx
}
{
img_slice
.
shape
}
along axis
{
axis
}
'
title
=
f
"
Intensity histogram of slice #
{
slice_idx
}
{
img_slice
.
shape
}
along axis
{
axis
}
"
else
:
raise
ValueError
(
f
'
Slice index out of range. Must be between 0 and
{
volume
.
shape
[
axis
]
-
1
}
.
'
f
"
Slice index out of range. Must be between 0 and
{
volume
.
shape
[
axis
]
-
1
}
.
"
)
else
:
data
=
volume
.
ravel
()
title
=
f
'
Intensity histogram for whole volume
{
volume
.
shape
}
'
title
=
f
"
Intensity histogram for whole volume
{
volume
.
shape
}
"
# Use provided Axes or create new figure
if
ax
is
None
:
fig
,
ax
=
plt
.
subplots
(
figsize
=
figsize
)
else
:
fig
=
None
if
log_scale
:
plt
.
yscale
(
'
log
'
)
ax
.
set_
yscale
(
"
log
"
)
if
color
==
'
qim3d
'
:
if
color
==
"
qim3d
"
:
color
=
qim3d
.
viz
.
colormaps
.
qim
(
1.0
)
sns
.
histplot
(
...
...
@@ -976,35 +978,46 @@ def histogram(
color
=
color
,
element
=
element
,
edgecolor
=
edgecolor
,
ax
=
ax
,
# Plot directly on the specified Axes
**
sns_kwargs
,
)
if
vertical_line
is
not
None
:
ax
.
axvline
(
x
=
vertical_line
,
color
=
'
red
'
,
linestyle
=
"
--
"
,
linewidth
=
2
,
)
if
despine
:
sns
.
despine
(
fig
=
None
,
ax
=
None
,
ax
=
ax
,
top
=
True
,
right
=
True
,
left
=
False
,
bottom
=
False
,
offset
=
{
'
left
'
:
0
,
'
bottom
'
:
18
},
offset
=
{
"
left
"
:
0
,
"
bottom
"
:
18
},
trim
=
True
,
)
plt
.
xlabel
(
'
Voxel Intensity
'
)
plt
.
ylabel
(
'
Frequency
'
)
ax
.
set_
xlabel
(
"
Voxel Intensity
"
)
ax
.
set_
ylabel
(
"
Frequency
"
)
if
show_title
:
plt
.
title
(
title
,
fontsize
=
10
)
ax
.
set_
title
(
title
,
fontsize
=
10
)
# Handle show and return
if
show
:
if
show
and
fig
is
not
None
:
plt
.
show
()
else
:
plt
.
close
(
fig
)
if
return_fig
:
return
fig
elif
ax
is
not
None
:
return
ax
class
_LineProfile
:
...
...
@@ -1045,29 +1058,29 @@ class _LineProfile:
self
.
y_widget
.
value
=
self
.
y_max
//
2
def
initialize_widgets
(
self
):
layout
=
widgets
.
Layout
(
width
=
'
300px
'
,
height
=
'
auto
'
)
layout
=
widgets
.
Layout
(
width
=
"
300px
"
,
height
=
"
auto
"
)
self
.
x_widget
=
widgets
.
IntSlider
(
min
=
self
.
pad
,
step
=
1
,
description
=
''
,
layout
=
layout
min
=
self
.
pad
,
step
=
1
,
description
=
""
,
layout
=
layout
)
self
.
y_widget
=
widgets
.
IntSlider
(
min
=
self
.
pad
,
step
=
1
,
description
=
''
,
layout
=
layout
min
=
self
.
pad
,
step
=
1
,
description
=
""
,
layout
=
layout
)
self
.
angle_widget
=
widgets
.
IntSlider
(
min
=
0
,
max
=
360
,
step
=
1
,
value
=
0
,
description
=
''
,
layout
=
layout
min
=
0
,
max
=
360
,
step
=
1
,
value
=
0
,
description
=
""
,
layout
=
layout
)
self
.
line_fraction_widget
=
widgets
.
FloatRangeSlider
(
min
=
0
,
max
=
1
,
step
=
0.01
,
value
=
[
0
,
1
],
description
=
''
,
layout
=
layout
min
=
0
,
max
=
1
,
step
=
0.01
,
value
=
[
0
,
1
],
description
=
""
,
layout
=
layout
)
self
.
slice_axis_widget
=
widgets
.
Dropdown
(
options
=
[
0
,
1
,
2
],
value
=
self
.
slice_axis
,
description
=
'
Slice axis
'
options
=
[
0
,
1
,
2
],
value
=
self
.
slice_axis
,
description
=
"
Slice axis
"
)
self
.
slice_axis_widget
.
layout
.
width
=
'
250px
'
self
.
slice_axis_widget
.
layout
.
width
=
"
250px
"
self
.
slice_index_widget
=
widgets
.
IntSlider
(
min
=
0
,
step
=
1
,
description
=
'
Slice index
'
,
layout
=
layout
min
=
0
,
step
=
1
,
description
=
"
Slice index
"
,
layout
=
layout
)
self
.
slice_index_widget
.
layout
.
width
=
'
400px
'
self
.
slice_index_widget
.
layout
.
width
=
"
400px
"
def
calculate_line_endpoints
(
self
,
x
,
y
,
angle
):
"""
...
...
@@ -1108,7 +1121,7 @@ class _LineProfile:
image
=
np
.
take
(
self
.
volume
,
slice_index
,
slice_axis
)
angle
=
np
.
radians
(
angle_deg
)
src
,
dst
=
(
np
.
array
(
point
,
dtype
=
'
float32
'
)
np
.
array
(
point
,
dtype
=
"
float32
"
)
for
point
in
self
.
calculate_line_endpoints
(
x
,
y
,
angle
)
)
...
...
@@ -1136,12 +1149,12 @@ class _LineProfile:
colors
=
self
.
cmap
(
norm
(
np
.
arange
(
num_segments
-
1
)))
lc
=
matplotlib
.
collections
.
LineCollection
(
segments
,
colors
=
colors
,
linewidth
=
2
)
ax
[
0
].
imshow
(
image
,
cmap
=
'
gray
'
)
ax
[
0
].
imshow
(
image
,
cmap
=
"
gray
"
)
ax
[
0
].
add_collection
(
lc
)
# pivot point
ax
[
0
].
plot
(
y
,
x
,
marker
=
'
s
'
,
linestyle
=
''
,
color
=
'
cyan
'
,
markersize
=
4
)
ax
[
0
].
set_xlabel
(
f
'
axis
{
np
.
delete
(
np
.
arange
(
3
),
self
.
slice_axis
)[
1
]
}
'
)
ax
[
0
].
set_ylabel
(
f
'
axis
{
np
.
delete
(
np
.
arange
(
3
),
self
.
slice_axis
)[
0
]
}
'
)
ax
[
0
].
plot
(
y
,
x
,
marker
=
"
s
"
,
linestyle
=
""
,
color
=
"
cyan
"
,
markersize
=
4
)
ax
[
0
].
set_xlabel
(
f
"
axis
{
np
.
delete
(
np
.
arange
(
3
),
self
.
slice_axis
)[
1
]
}
"
)
ax
[
0
].
set_ylabel
(
f
"
axis
{
np
.
delete
(
np
.
arange
(
3
),
self
.
slice_axis
)[
0
]
}
"
)
# Profile intensity plot
norm
=
plt
.
Normalize
(
0
,
vmax
=
len
(
y_pline
)
-
1
)
...
...
@@ -1154,7 +1167,7 @@ class _LineProfile:
ax
[
1
].
add_collection
(
lc
)
ax
[
1
].
autoscale
()
ax
[
1
].
set_xlabel
(
'
Distance along line
'
)
ax
[
1
].
set_xlabel
(
"
Distance along line
"
)
ax
[
1
].
grid
(
True
)
plt
.
tight_layout
()
plt
.
show
()
...
...
@@ -1162,7 +1175,7 @@ class _LineProfile:
def
build_interactive
(
self
):
# Group widgets into two columns
title_style
=
(
'
text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;
'
"
text-align:center; font-size:16px; font-weight:bold; margin-bottom:5px;
"
)
title_column1
=
widgets
.
HTML
(
f
"
<div style=
'
{
title_style
}
'
>Line parameterization</div>
"
...
...
@@ -1172,11 +1185,11 @@ class _LineProfile:
)
# Make label widgets instead of descriptions which have different lengths.
label_layout
=
widgets
.
Layout
(
width
=
'
120px
'
)
label_x
=
widgets
.
Label
(
'
Vertical position
'
,
layout
=
label_layout
)
label_y
=
widgets
.
Label
(
'
Horizontal position
'
,
layout
=
label_layout
)
label_angle
=
widgets
.
Label
(
'
Angle (°)
'
,
layout
=
label_layout
)
label_fraction
=
widgets
.
Label
(
'
Fraction range
'
,
layout
=
label_layout
)
label_layout
=
widgets
.
Layout
(
width
=
"
120px
"
)
label_x
=
widgets
.
Label
(
"
Vertical position
"
,
layout
=
label_layout
)
label_y
=
widgets
.
Label
(
"
Horizontal position
"
,
layout
=
label_layout
)
label_angle
=
widgets
.
Label
(
"
Angle (°)
"
,
layout
=
label_layout
)
label_fraction
=
widgets
.
Label
(
"
Fraction range
"
,
layout
=
label_layout
)
row_x
=
widgets
.
HBox
([
label_x
,
self
.
x_widget
])
row_y
=
widgets
.
HBox
([
label_y
,
self
.
y_widget
])
...
...
@@ -1194,12 +1207,12 @@ class _LineProfile:
interactive_plot
=
widgets
.
interactive_output
(
self
.
update
,
{
'
slice_axis
'
:
self
.
slice_axis_widget
,
'
slice_index
'
:
self
.
slice_index_widget
,
'
x
'
:
self
.
x_widget
,
'
y
'
:
self
.
y_widget
,
'
angle_deg
'
:
self
.
angle_widget
,
'
fraction_range
'
:
self
.
line_fraction_widget
,
"
slice_axis
"
:
self
.
slice_axis_widget
,
"
slice_index
"
:
self
.
slice_index_widget
,
"
x
"
:
self
.
x_widget
,
"
y
"
:
self
.
y_widget
,
"
angle_deg
"
:
self
.
angle_widget
,
"
fraction_range
"
:
self
.
line_fraction_widget
,
},
)
...
...
@@ -1209,9 +1222,9 @@ class _LineProfile:
def
line_profile
(
volume
:
np
.
ndarray
,
slice_axis
:
int
=
0
,
slice_index
:
int
|
str
=
'
middle
'
,
vertical_position
:
int
|
str
=
'
middle
'
,
horizontal_position
:
int
|
str
=
'
middle
'
,
slice_index
:
int
|
str
=
"
middle
"
,
vertical_position
:
int
|
str
=
"
middle
"
,
horizontal_position
:
int
|
str
=
"
middle
"
,
angle
:
int
=
0
,
fraction_range
:
Tuple
[
float
,
float
]
=
(
0.00
,
1.00
),
)
->
widgets
.
interactive
:
...
...
@@ -1246,16 +1259,16 @@ def line_profile(
if
isinstance
(
pos
,
int
):
if
not
pos_range
[
0
]
<=
pos
<
pos_range
[
1
]:
raise
ValueError
(
f
'
Value for
{
name
}
must be inside [
{
pos_range
[
0
]
}
,
{
pos_range
[
1
]
}
]
'
f
"
Value for
{
name
}
must be inside [
{
pos_range
[
0
]
}
,
{
pos_range
[
1
]
}
]
"
)
return
pos
elif
isinstance
(
pos
,
str
):
pos
=
pos
.
lower
()
if
pos
==
'
start
'
:
if
pos
==
"
start
"
:
return
pos_range
[
0
]
elif
pos
==
'
middle
'
:
elif
pos
==
"
middle
"
:
return
pos_range
[
0
]
+
(
pos_range
[
1
]
-
pos_range
[
0
])
//
2
elif
pos
==
'
end
'
:
elif
pos
==
"
end
"
:
return
pos_range
[
1
]
else
:
raise
ValueError
(
...
...
@@ -1263,27 +1276,27 @@ def line_profile(
"
Must be
'
start
'
,
'
middle
'
, or
'
end
'
.
"
)
else
:
raise
TypeError
(
'
Axis position must be of type int or str.
'
)
raise
TypeError
(
"
Axis position must be of type int or str.
"
)
if
not
isinstance
(
volume
,
(
np
.
ndarray
,
da
.
core
.
Array
)):
raise
ValueError
(
'
Data type for volume not supported.
'
)
raise
ValueError
(
"
Data type for volume not supported.
"
)
if
volume
.
ndim
!=
3
:
raise
ValueError
(
'
Volume must be 3D.
'
)
raise
ValueError
(
"
Volume must be 3D.
"
)
dims
=
volume
.
shape
slice_index
=
parse_position
(
slice_index
,
(
0
,
dims
[
slice_axis
]
-
1
),
'
slice_index
'
)
slice_index
=
parse_position
(
slice_index
,
(
0
,
dims
[
slice_axis
]
-
1
),
"
slice_index
"
)
# the omission of the ends for the pivot point is due to border issues.
vertical_position
=
parse_position
(
vertical_position
,
(
1
,
np
.
delete
(
dims
,
slice_axis
)[
0
]
-
2
),
'
vertical_position
'
vertical_position
,
(
1
,
np
.
delete
(
dims
,
slice_axis
)[
0
]
-
2
),
"
vertical_position
"
)
horizontal_position
=
parse_position
(
horizontal_position
,
(
1
,
np
.
delete
(
dims
,
slice_axis
)[
1
]
-
2
),
'
horizontal_position
'
,
"
horizontal_position
"
,
)
if
not
isinstance
(
angle
,
int
|
float
):
raise
ValueError
(
'
Invalid type for angle.
'
)
raise
ValueError
(
"
Invalid type for angle.
"
)
angle
=
round
(
angle
)
%
360
if
not
(
...
...
@@ -1291,7 +1304,7 @@ def line_profile(
and
0.0
<=
fraction_range
[
1
]
<=
1.0
and
fraction_range
[
0
]
<=
fraction_range
[
1
]
):
raise
ValueError
(
'
Invalid values for fraction_range.
'
)
raise
ValueError
(
"
Invalid values for fraction_range.
"
)
lp
=
_LineProfile
(
volume
,
...
...
@@ -1307,7 +1320,7 @@ def line_profile(
def
threshold
(
volume
:
np
.
ndarray
,
cmap_image
:
str
=
'
viridis
'
,
cmap_image
:
str
=
'
magma
'
,
vmin
:
float
=
None
,
vmax
:
float
=
None
,
)
->
widgets
.
VBox
:
...
...
@@ -1363,19 +1376,19 @@ def threshold(
# Centralized state dictionary to track current parameters
state
=
{
'
position
'
:
volume
.
shape
[
0
]
//
2
,
'
threshold
'
:
int
((
volume
.
min
()
+
volume
.
max
())
/
2
),
'
method
'
:
'
Manual
'
,
"
position
"
:
volume
.
shape
[
0
]
//
2
,
"
threshold
"
:
int
((
volume
.
min
()
+
volume
.
max
())
/
2
),
"
method
"
:
"
Manual
"
,
}
threshold_methods
=
{
'
Otsu
'
:
threshold_otsu
,
'
Isodata
'
:
threshold_isodata
,
'
Li
'
:
threshold_li
,
'
Mean
'
:
threshold_mean
,
'
Minimum
'
:
threshold_minimum
,
'
Triangle
'
:
threshold_triangle
,
'
Yen
'
:
threshold_yen
,
"
Otsu
"
:
threshold_otsu
,
"
Isodata
"
:
threshold_isodata
,
"
Li
"
:
threshold_li
,
"
Mean
"
:
threshold_mean
,
"
Minimum
"
:
threshold_minimum
,
"
Triangle
"
:
threshold_triangle
,
"
Yen
"
:
threshold_yen
,
}
# Create an output widget to display the plot
...
...
@@ -1384,24 +1397,24 @@ def threshold(
# Function to update the state and trigger visualization
def
update_state
(
change
):
# Update state based on widget values
state
[
'
position
'
]
=
position_slider
.
value
state
[
'
method
'
]
=
method_dropdown
.
value
state
[
"
position
"
]
=
position_slider
.
value
state
[
"
method
"
]
=
method_dropdown
.
value
if
state
[
'
method
'
]
==
'
Manual
'
:
state
[
'
threshold
'
]
=
threshold_slider
.
value
if
state
[
"
method
"
]
==
"
Manual
"
:
state
[
"
threshold
"
]
=
threshold_slider
.
value
threshold_slider
.
disabled
=
False
else
:
threshold_func
=
threshold_methods
.
get
(
state
[
'
method
'
])
threshold_func
=
threshold_methods
.
get
(
state
[
"
method
"
])
if
threshold_func
:
slice_img
=
volume
[
state
[
'
position
'
],
:,
:]
slice_img
=
volume
[
state
[
"
position
"
],
:,
:]
computed_threshold
=
threshold_func
(
slice_img
)
state
[
'
threshold
'
]
=
computed_threshold
state
[
"
threshold
"
]
=
computed_threshold
# Programmatically update the slider without triggering callbacks
threshold_slider
.
unobserve_all
()
threshold_slider
.
value
=
computed_threshold
threshold_slider
.
disabled
=
True
threshold_slider
.
observe
(
update_state
,
names
=
'
value
'
)
threshold_slider
.
observe
(
update_state
,
names
=
"
value
"
)
else
:
raise
ValueError
(
f
"
Unsupported thresholding method:
{
state
[
'
method
'
]
}
"
)
...
...
@@ -1410,7 +1423,7 @@ def threshold(
# Visualization function
def
update_visualization
():
slice_img
=
volume
[
state
[
'
position
'
],
:,
:]
slice_img
=
volume
[
state
[
"
position
"
],
:,
:]
with
output
:
output
.
clear_output
(
wait
=
True
)
# Clear previous plot
fig
,
axes
=
plt
.
subplots
(
1
,
4
,
figsize
=
(
25
,
5
))
...
...
@@ -1427,15 +1440,15 @@ def threshold(
else
vmax
)
axes
[
0
].
imshow
(
slice_img
,
cmap
=
cmap_image
,
vmin
=
new_vmin
,
vmax
=
new_vmax
)
axes
[
0
].
set_title
(
'
Original
'
)
axes
[
0
].
axis
(
'
off
'
)
axes
[
0
].
set_title
(
"
Original
"
)
axes
[
0
].
axis
(
"
off
"
)
# Histogram
histogram
(
volume
=
volume
,
bins
=
32
,
slice_idx
=
state
[
'
position
'
],
vertical_line
=
state
[
'
threshold
'
],
slice_idx
=
state
[
"
position
"
],
vertical_line
=
state
[
"
threshold
"
],
axis
=
1
,
kde
=
False
,
ax
=
axes
[
1
],
...
...
@@ -1444,65 +1457,65 @@ def threshold(
axes
[
1
].
set_title
(
f
"
Histogram with Threshold =
{
int
(
state
[
'
threshold
'
])
}
"
)
# Binary mask
mask
=
slice_img
>
state
[
'
threshold
'
]
axes
[
2
].
imshow
(
mask
,
cmap
=
'
gray
'
)
axes
[
2
].
set_title
(
'
Binary mask
'
)
axes
[
2
].
axis
(
'
off
'
)
mask
=
slice_img
>
state
[
"
threshold
"
]
axes
[
2
].
imshow
(
mask
,
cmap
=
"
gray
"
)
axes
[
2
].
set_title
(
"
Binary mask
"
)
axes
[
2
].
axis
(
"
off
"
)
# Overlay
mask_rgb
=
np
.
zeros
((
mask
.
shape
[
0
],
mask
.
shape
[
1
],
3
),
dtype
=
np
.
uint8
)
mask_rgb
[:,
:,
0
]
=
mask
masked_volume
=
qim3d
.
processing
.
operations
.
overlay_rgb_images
(
masked_volume
=
qim3d
.
operations
.
overlay_rgb_images
(
background
=
slice_img
,
foreground
=
mask_rgb
,
)
axes
[
3
].
imshow
(
masked_volume
,
vmin
=
new_vmin
,
vmax
=
new_vmax
)
axes
[
3
].
set_title
(
'
Overlay
'
)
axes
[
3
].
axis
(
'
off
'
)
axes
[
3
].
set_title
(
"
Overlay
"
)
axes
[
3
].
axis
(
"
off
"
)
plt
.
show
()
# Widgets
position_slider
=
widgets
.
IntSlider
(
value
=
state
[
'
position
'
],
value
=
state
[
"
position
"
],
min
=
0
,
max
=
volume
.
shape
[
0
]
-
1
,
description
=
'
Slice
'
,
description
=
"
Slice
"
,
)
threshold_slider
=
widgets
.
IntSlider
(
value
=
state
[
'
threshold
'
],
value
=
state
[
"
threshold
"
],
min
=
volume
.
min
(),
max
=
volume
.
max
(),
description
=
'
Threshold
'
,
description
=
"
Threshold
"
,
)
method_dropdown
=
widgets
.
Dropdown
(
options
=
[
'
Manual
'
,
'
Otsu
'
,
'
Isodata
'
,
'
Li
'
,
'
Mean
'
,
'
Minimum
'
,
'
Triangle
'
,
'
Yen
'
,
"
Manual
"
,
"
Otsu
"
,
"
Isodata
"
,
"
Li
"
,
"
Mean
"
,
"
Minimum
"
,
"
Triangle
"
,
"
Yen
"
,
],
value
=
state
[
'
method
'
],
description
=
'
Method
'
,
value
=
state
[
"
method
"
],
description
=
"
Method
"
,
)
# Attach the state update function to widgets
position_slider
.
observe
(
update_state
,
names
=
'
value
'
)
threshold_slider
.
observe
(
update_state
,
names
=
'
value
'
)
method_dropdown
.
observe
(
update_state
,
names
=
'
value
'
)
position_slider
.
observe
(
update_state
,
names
=
"
value
"
)
threshold_slider
.
observe
(
update_state
,
names
=
"
value
"
)
method_dropdown
.
observe
(
update_state
,
names
=
"
value
"
)
# Layout
controls_left
=
widgets
.
VBox
([
position_slider
,
threshold_slider
])
controls_right
=
widgets
.
VBox
([
method_dropdown
])
controls_layout
=
widgets
.
HBox
(
[
controls_left
,
controls_right
],
layout
=
widgets
.
Layout
(
justify_content
=
'
flex-start
'
),
layout
=
widgets
.
Layout
(
justify_content
=
"
flex-start
"
),
)
interactive_ui
=
widgets
.
VBox
([
controls_layout
,
output
])
update_visualization
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment