Skip to content
Snippets Groups Projects
Commit f8facfae authored by David Johansen's avatar David Johansen
Browse files

Refactored chunking into a class. Added volume chunking functionality.

parent a58f87e4
No related branches found
No related tags found
No related merge requests found
import os
from pathlib import Path
import argparse
import re
from multiprocessing import Pool
import zarr import zarr
import numpy as np import numpy as np
import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import skimage as ski
from cil.io import ZEISSDataReader, NikonDataReader from cil.io import ZEISSDataReader, NikonDataReader
from cil.framework import AcquisitionData from cil.framework import AcquisitionData, ImageGeometry
from cil.processors import TransmissionAbsorptionConverter from cil.processors import TransmissionAbsorptionConverter
from cil.recon import FDK from cil.recon import FDK
from cil.utilities.display import show2D, show_geometry
def create_reader(file_name, roi=None): def create_reader(file_name, roi=None):
if file_name.endswith('txrm'): if file_name.endswith('txrm'):
...@@ -21,96 +28,180 @@ def create_reader(file_name, roi=None): ...@@ -21,96 +28,180 @@ def create_reader(file_name, roi=None):
else: else:
return DataReader(file_name=file_name, roi=roi) return DataReader(file_name=file_name, roi=roi)
def write_projection_chunks(config_path, zarr_path, chunk_size, proj_shape): def float_to_uint16(arr):
min_val, max_val = np.min(arr), np.max(arr)
norm_arr = (arr - min_val) / (max_val - min_val)
return (norm_arr * (2**16 - 1)).astype(np.uint16)
def save_center_slice(arr, path):
plt.figure()
plt.imshow(arr[arr.shape[0]//2, :, :], cmap='gray')
plt.colorbar()
plt.tight_layout()
plt.savefig(path)
plt.close()
class ChunkedCT:
def __init__(self, config_path, tiff_dir, proj_chunk_size=100, vol_chunks=1, processes=None, verbose=1):
""" """
# (ignore for now, meant for custom reader) input_path: Directory containing .tiff projection data. vol_chunks: Number of chunks along the height (only chunked along this axis). Must be a divisor of the default ig height.
out_path: Directory to write zarr file.
""" """
nchunks = (proj_shape[0]-1)//chunk_size + 1 self.tiff_dir = tiff_dir
# compressor = zarr.codecs.Blosc(cname='lz4') self.vol_chunks = vol_chunks
compressor = zarr.codecs.Blosc(cname='zstd') self.proj_chunk_size = proj_chunk_size
zarr_array = zarr.open_array( self.processes = processes
store=zarr_path, self.recon_zarr_path = 'out/recon.zarr'
mode='w', self.proj_zarr_path = 'out/projections.zarr'
dtype=np.float32, self.verbose = verbose
shape=proj_shape, self.tiff_paths = self.get_tiff_paths()
chunks=(chunk_size, *proj_shape[1:]),
compressor=compressor self.ag = create_reader(config_path).get_geometry()
) self.ig = self.ag.get_ImageGeometry()
self.proj_shape = self.ag.shape
self.vol_chunk_height = self.ig.voxel_num_z // self.vol_chunks
def get_tiff_paths(self):
"""
Return a list of all tiff files in self.tiff_dir sorted.
"""
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
"""
https://stackoverflow.com/questions/5967500/how-to-correctly-sort-a-string-with-a-number-inside
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
"""
return [atoi(c) for c in re.split(r'(\d+)', text)]
tiff_paths = [f for f in os.listdir(self.tiff_dir) if f.endswith(".tif")]
tiff_paths.sort(key=natural_keys)
return [os.path.join(self.tiff_dir, f) for f in tiff_paths]
@staticmethod
def read_tiffs(tiff_paths):
num_tiffs = len(tiff_paths)
im = ski.io.imread(tiff_paths[0])
im_stack = np.zeros((num_tiffs, *im.shape), dtype=np.uint16)
im_stack[0] = im
for i in range(nchunks): for i in range(1, num_tiffs):
start = i*chunk_size im_stack[i] = ski.io.imread(tiff_paths[i])
end = min((i+1)*chunk_size, proj_shape[0])
reader = create_reader(config_path, roi={'angle': (start, end, 1)}) return im_stack
chunk = reader.read().as_array()
print(f'Write {i}: chunk.min(),chunk.mean()={chunk.min()},{chunk.mean()}') @staticmethod
def process_chunk(tiff_paths, proj_zarr_path, proj_chunk_size, proj_shape, chunk_index):
"""Static method so it works with multiprocessing.Pool"""
start = chunk_index * proj_chunk_size
end = min((chunk_index + 1) * proj_chunk_size, proj_shape[0])
chunk = ChunkedCT.read_tiffs(tiff_paths[start:end])
expected_shape = (end - start, *proj_shape[1:]) expected_shape = (end - start, *proj_shape[1:])
if chunk.shape != expected_shape: if chunk.shape != expected_shape:
raise ValueError(f'Chunk shape mismatch: expected {expected_shape}, got {chunk.shape}') raise ValueError(f'Chunk shape mismatch: expected {expected_shape}, got {chunk.shape}')
zarr_array = zarr.open_array(store=proj_zarr_path, mode='a')
zarr_array[start:end, :, :] = chunk zarr_array[start:end, :, :] = chunk
print(f'Finished writing chunk {i}.') print(f'Finished writing projection chunk {chunk_index}.')
def read_projection_chunk(zarr_path, chunk_size, chunk_index): def write_projection_chunks(self):
zarr_array = zarr.open_array(store=zarr_path, mode='r') nchunks = (self.proj_shape[0] - 1) // self.proj_chunk_size + 1
proj_shape = zarr_array.shape compressor = zarr.codecs.Blosc(cname='zstd')
start = chunk_index*chunk_size
end = min((chunk_index+1)*chunk_size, proj_shape[0]) zarr.open_array(
chunk = zarr_array[start:end,:,:] store=self.proj_zarr_path,
print(f'Finished reading projection {chunk_index}.') mode='w',
return chunk dtype=np.uint16,
shape=self.proj_shape,
def recon_projection_chunk(ag, ig, zarr_path, chunk_size, chunk_index): chunks=(self.proj_chunk_size, *self.proj_shape[1:]),
""" compressor=compressor
ag: AcquisitionGeometry with chunked projection. )
ig: ImageGeometry for the reconstruction volume.
""" with Pool(self.processes) as pool:
start = chunk_index*chunk_size print(f'processes={pool._processes}')
end = min((chunk_index+1)*chunk_size, ag.shape[0]) pool.starmap(
ag_chunk = ag.copy().set_angles(ag.angles[start:end]) ChunkedCT.process_chunk,
chunk = read_projection_chunk(zarr_path, chunk_size, chunk_index) [(self.tiff_paths, self.proj_zarr_path, self.proj_chunk_size, self.proj_shape, i) for i in range(nchunks)]
print(f'Read {chunk_index}: chunk.min(),chunk.mean()={chunk.min()},{chunk.mean()}') )
def reconstruct_projection_chunk(self, ig, proj_chunk_index):
start = proj_chunk_index * self.proj_chunk_size
end = min((proj_chunk_index + 1) * self.proj_chunk_size, self.ag.shape[0])
ag_chunk = self.ag.copy().set_angles(self.ag.angles[start:end])
zarr_array = zarr.open_array(store=self.proj_zarr_path, mode='r')
chunk = zarr_array[start:end, :, :].astype(np.float32) / (2**16 - 1)
data_chunk = AcquisitionData(array=chunk, deep_copy=False, geometry=ag_chunk) data_chunk = AcquisitionData(array=chunk, deep_copy=False, geometry=ag_chunk)
data_chunk = TransmissionAbsorptionConverter()(data_chunk) data_chunk = TransmissionAbsorptionConverter()(data_chunk)
data_chunk.reorder(order='tigre') data_chunk.reorder(order='tigre')
recon_chunk = FDK(data_chunk, image_geometry=ig).run(verbose=1) recon_chunk = FDK(data_chunk, image_geometry=ig).run(verbose=0)
print(f'Finished reconstruction chunk {chunk_index}.')
# if self.verbose:
# print(f'Finished reconstructing chunk {proj_chunk_index}.')
return recon_chunk return recon_chunk
def chunk_volumes(): def reconstruct_volume_chunk(self, vol_chunk_index):
pass ig_chunk = self.ig.copy()
ig_chunk.voxel_num_z = self.vol_chunk_height
ig_chunk.center_z = (
-(1 - 1/self.vol_chunks)/2 * self.ig.voxel_num_z +
vol_chunk_index*self.vol_chunk_height
) * ig_chunk.voxel_size_z # scaling crucial as it is usually != 1
def write_numpy(arr, file_name): nchunks = (self.proj_shape[0] - 1) // self.proj_chunk_size + 1
np.save(f"{file_name}.npy", arr) vol_accumulated = np.zeros(ig_chunk.shape, dtype=np.float32)
os.sync()
def main(): for proj_chunk_index in range(nchunks):
path = '/dtu/3d-imaging-center/projects/2021_DANFIX_Casper/raw_data_3DIM/Casper_top_3_2 [2021-03-17 16.54.39]/Casper_top_3_2_recon.xtekct' recon_chunk = self.reconstruct_projection_chunk(ig_chunk, proj_chunk_index)
chunk_size = 300 vol_accumulated += recon_chunk.as_array()
zarr_path = 'out/ct_chunk.zarr' if self.verbose:
save_center_slice(recon_chunk.as_array(), f'out/center_recon_vol_{vol_chunk_index}_proj_{proj_chunk_index}.png')
ag = create_reader(path).get_geometry() if self.verbose:
ig = ag.get_ImageGeometry() save_center_slice(vol_accumulated, f'out/center_recon_{vol_chunk_index}_accumulated.png')
proj_shape = ag.shape
# get number of projections. then use CIL reader to read in subset of tiff files for each chunk and have a function to write to disk. # zarr.save_array('out/recon.zarr', float_to_uint16(vol_accumulated))
# potentially replace this with a custom reader
write_projection_chunks(config_path=path, zarr_path=zarr_path, chunk_size=chunk_size, proj_shape=proj_shape)
nchunks = (proj_shape[0]-1)//chunk_size + 1 zarr_array = zarr.open_array(store=self.recon_zarr_path, mode='a')
vol_accumulated = np.zeros(ig.shape, dtype=np.float32) start = self.vol_chunk_height * vol_chunk_index
for chunk_index in range(nchunks): end = self.vol_chunk_height * (vol_chunk_index+1)
recon_chunk = recon_projection_chunk(ag, ig, zarr_path, chunk_size=chunk_size, chunk_index=chunk_index) zarr_array[start:end, :, :] = vol_accumulated
vol_accumulated += recon_chunk.as_array() print(f'Finished writing volume chunk {vol_chunk_index}.')
plt.figure(figsize=(8,8))
plt.imshow(recon_chunk.as_array()[ig.shape[0]//2,:,:])
plt.colorbar()
plt.savefig(f'out/center_recon_{chunk_index}.png')
# zarr_store = zarr.open('ct_recon.zarr', mode='w', shape=ig.shape, dtype=np.float32) def reconstruct(self):
write_numpy(vol_accumulated, 'recon') self.write_projection_chunks()
# compressor = zarr.codecs.Blosc(cname='zstd')
zarr.open_array(
store=self.recon_zarr_path,
mode='w',
dtype=np.float32,
shape=self.ig.shape[::-1],
chunks=(self.vol_chunk_height, *self.ig.shape[1:][::-1]),
# compressor=compressor
)
for vol_chunk_index in range(self.vol_chunks):
self.reconstruct_volume_chunk(vol_chunk_index)
if __name__ == '__main__': if __name__ == '__main__':
main() parser = argparse.ArgumentParser()
\ No newline at end of file # parser.add_argument('-i', '--tiff_dir', type=str, required=True, help='Path to CT scan metadata file')
parser.add_argument('-p', '--processes', type=int, default=None, help='Number of processes')
parser.add_argument('-c', '--chunk-size', type=int, default=100, help='Projection chunk size')
parser.add_argument('-v', '--vol-chunks', type=int, default=1, help='Volume chunks')
args = parser.parse_args()
config_path = '/dtu/3d-imaging-center/projects/2021_DANFIX_Casper/raw_data_3DIM/Casper_top_3_2 [2021-03-17 16.54.39]/Casper_top_3_2_recon.xtekct'
tiff_dir = Path(config_path).parent
processor = ChunkedCT(
config_path=config_path, tiff_dir=tiff_dir,
proj_chunk_size=args.chunk_size, vol_chunks=args.vol_chunks, processes=args.processes
)
processor.reconstruct()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment