Skip to content
Snippets Groups Projects

Implemented 3D connected components as wrapper class for scipy.ndimage.label

1 file
+ 19
11
Compare changes
  • Side-by-side
  • Inline
+ 19
11
@@ -2,6 +2,8 @@ import numpy as np
@@ -2,6 +2,8 @@ import numpy as np
import torch
import torch
from scipy.ndimage import find_objects, label
from scipy.ndimage import find_objects, label
 
from qim3d.io.logger import log
 
class CC:
class CC:
def __init__(self, connected_components, num_connected_components):
def __init__(self, connected_components, num_connected_components):
@@ -12,27 +14,31 @@ class CC:
@@ -12,27 +14,31 @@ class CC:
connected_components (np.ndarray): The connected components.
connected_components (np.ndarray): The connected components.
num_connected_components (int): The number of connected components.
num_connected_components (int): The number of connected components.
"""
"""
self.connected_components = connected_components
self._connected_components = connected_components
self.num_connected_components = num_connected_components
self.cc_count = num_connected_components
 
 
def __len__(self):
 
"""
 
Returns the number of connected components in the object.
 
"""
 
return self.cc_count
def get_cc(self, index=None, crop=False):
def get_cc(self, index=None, crop=False):
"""
"""
Get the connected component with the given index, if index is None selects a random component.
Get the connected component with the given index, if index is None selects a random component.
Args:
Args:
index (int): The index of the connected component. If none selects a random component.
index (int): The index of the connected component. If none returns all components.
crop (bool): If True, the volume is cropped to the bounding box of the connected component.
crop (bool): If True, the volume is cropped to the bounding box of the connected component.
Returns:
Returns:
np.ndarray: The connected component as a binary mask.
np.ndarray: The connected component as a binary mask.
"""
"""
if index is None:
if index is None:
volume = self.connected_components == np.random.randint(
volume = self._connected_components
1, self.num_connected_components + 1
)
else:
else:
assert 1 <= index <= self.num_connected_components, "Index out of range."
assert 1 <= index <= self.cc_count, "Index out of range."
volume = self.connected_components == index
volume = self._connected_components == index
if crop:
if crop:
# As we index get_bounding_box element 0 will be the bounding box for the connected component at index
# As we index get_bounding_box element 0 will be the bounding box for the connected component at index
@@ -52,10 +58,10 @@ class CC:
@@ -52,10 +58,10 @@ class CC:
"""
"""
if index:
if index:
assert 1 <= index <= self.num_connected_components, "Index out of range."
assert 1 <= index <= self.cc_count, "Index out of range."
return find_objects(self.connected_components == index)
return find_objects(self._connected_components == index)
else:
else:
return find_objects(self.connected_components)
return find_objects(self._connected_components)
def get_3d_cc(image: np.ndarray | torch.Tensor):
def get_3d_cc(image: np.ndarray | torch.Tensor):
@@ -69,4 +75,6 @@ def get_3d_cc(image: np.ndarray | torch.Tensor):
@@ -69,4 +75,6 @@ def get_3d_cc(image: np.ndarray | torch.Tensor):
class: Returns class object of the connected components.
class: Returns class object of the connected components.
"""
"""
connected_components, num_connected_components = label(image)
connected_components, num_connected_components = label(image)
 
log.info(f"Total number of connected components found: {num_connected_components}")
 
return CC(connected_components, num_connected_components)
return CC(connected_components, num_connected_components)
Loading