Skip to content
Snippets Groups Projects
Commit 15c1c609 authored by David Johansen's avatar David Johansen
Browse files

Extended volume chunking to all dimensions from only the vertical dimension.

parent 6528eea7
No related branches found
No related tags found
No related merge requests found
out/
bjobs/out/
bjobs/logs/
out/*
bjobs/out/*
bjobs/logs/*
__pycache__/
\ No newline at end of file
......@@ -3,6 +3,7 @@ from pathlib import Path
import argparse
import re
from multiprocessing import Pool
import itertools
import zarr
import numpy as np
......@@ -42,15 +43,17 @@ def save_center_slice(arr, path):
plt.close()
class ChunkedCT:
def __init__(self, config_path, tiff_dir, proj_chunk_size=100, vol_chunks=1, batch_size=20, processes=None, verbose=1):
def __init__(self, config_path, tiff_dir, proj_chunk_size, vol_chunks, batch_size=20, processes=None, verbose=1):
"""
Perform CT reconstruction by chunking the projections and reconstruction volume.
Currently does not compress as chunk size limit for compression seems to be ~2 GB for the compressors I tried.
Parameters:
config_path: Path to CT config file.
tiff_dir: Directory where all tiffs are assumed to be projections.
proj_chunk_size: Number of projections per projection chunk.
vol_chunks: Number of chunks along the height (only chunked along this axis). Must be a divisor of the default ig height.
vol_chunks: Tuple of number of chunks along each direction in the order of Z, Y, X. Must be divisors of the corresponding default ig dimension length.
batch_size: Number of TIFFs per worker at a time during parallelized reading.
processes: Number of workers during parallelized reading. Default None chooses this to all available CPUs.
"""
......@@ -63,6 +66,7 @@ class ChunkedCT:
self.processes = processes
self.batch_size = batch_size
os.makedirs('out', exist_ok=True)
self.recon_zarr_path = 'out/recon.zarr'
self.proj_zarr_path = 'out/projections.zarr'
self.verbose = verbose
......@@ -73,7 +77,10 @@ class ChunkedCT:
self.ig = self.ag.get_ImageGeometry()
self.proj_shape = self.ag.shape
self.vol_chunk_height = self.ig.voxel_num_z // self.vol_chunks
if not all(self.ig.shape[i] % self.vol_chunks[i] == 0 for i in range(3)):
raise ValueError(f"ig shape not divisible by vol_chunks in all dimensions")
self.vol_chunk_shape = tuple(self.ig.shape[i] // self.vol_chunks[i] for i in range(3))
def set_tiff_paths(self):
"""
......@@ -117,7 +124,7 @@ class ChunkedCT:
def write_projection_chunks(self):
nchunks = (self.proj_shape[0] - 1) // self.proj_chunk_size + 1
compressor = zarr.codecs.Blosc(cname='zstd')
# compressor = zarr.codecs.Blosc(cname='zstd')
zarr.open_array(
store=self.proj_zarr_path,
......@@ -125,7 +132,7 @@ class ChunkedCT:
dtype=np.uint16,
shape=self.proj_shape,
chunks=(self.proj_chunk_size, *self.proj_shape[1:]),
compressor=compressor
compressor=None
)
with Pool(self.processes) as pool:
......@@ -139,7 +146,7 @@ class ChunkedCT:
ag_chunk = self.ag.copy().set_angles(self.ag.angles[start:end])
zarr_array = zarr.open_array(store=self.proj_zarr_path, mode='r')
chunk = zarr_array[start:end, :, :].astype(np.float32) / (2**16 - 1)
chunk = zarr_array[start:end, :, :].astype(np.float32) / (2**16 - 1) # here assuming TIFFs are uint16
data_chunk = AcquisitionData(array=chunk, deep_copy=False, geometry=ag_chunk)
data_chunk = TransmissionAbsorptionConverter()(data_chunk)
......@@ -147,32 +154,43 @@ class ChunkedCT:
recon_chunk = FDK(data_chunk, image_geometry=ig).run(verbose=0)
return recon_chunk
def reconstruct_volume_chunk(self, vol_chunk_index):
def make_ig_chunk(self, chunk_coords):
ig_chunk = self.ig.copy()
ig_chunk.voxel_num_z = self.vol_chunk_height
ig_chunk.center_z = (
-(1 - 1/self.vol_chunks)/2 * self.ig.voxel_num_z +
vol_chunk_index*self.vol_chunk_height
) * ig_chunk.voxel_size_z # scaling crucial as it is usually != 1
for axis, dim_label in enumerate(['z', 'y', 'x']):
setattr(ig_chunk, f'voxel_num_{dim_label}', self.vol_chunk_shape[axis])
chunk = chunk_coords[axis]
num_chunks = self.vol_chunks[axis]
ig_voxel_num = getattr(self.ig, f'voxel_num_{dim_label}')
ig_chunk_voxel_num = getattr(ig_chunk, f'voxel_num_{dim_label}')
ig_chunk_voxel_size = getattr(ig_chunk, f'voxel_size_{dim_label}')
nchunks = (self.proj_shape[0] - 1) // self.proj_chunk_size + 1
# center is in 'world' coordinates, hence the scaling with voxel_size
center = (-(1 - 1/num_chunks)/2 * ig_voxel_num + chunk * ig_chunk_voxel_num) * ig_chunk_voxel_size
setattr(ig_chunk, f'center_{dim_label}', center)
return ig_chunk
def reconstruct_volume_chunk(self, chunk_coords):
ig_chunk = self.make_ig_chunk(chunk_coords)
num_proj_chunks = (self.proj_shape[0] - 1) // self.proj_chunk_size + 1
vol_accumulated = np.zeros(ig_chunk.shape, dtype=np.float32)
for proj_chunk_index in range(nchunks):
for proj_chunk_index in range(num_proj_chunks):
recon_chunk = self.reconstruct_projection_chunk(ig_chunk, proj_chunk_index)
vol_accumulated += recon_chunk.as_array()
if self.verbose:
save_center_slice(recon_chunk.as_array(), f'out/center/recon_vol_{vol_chunk_index}_proj_{proj_chunk_index}.png')
str_chunk = '_'.join(str(i) for i in chunk_coords)
save_center_slice(recon_chunk.as_array(), f'out/center/recon_vol_{str_chunk}_proj_{proj_chunk_index}.png')
if self.verbose:
save_center_slice(vol_accumulated, f'out/center/recon_{vol_chunk_index}_accumulated.png')
save_center_slice(vol_accumulated, f'out/center/recon_{str_chunk}_accumulated.png')
zarr_array = zarr.open_array(store=self.recon_zarr_path, mode='a')
start = self.vol_chunk_height * vol_chunk_index
end = self.vol_chunk_height * (vol_chunk_index+1)
zarr_array[start:end, :, :] = vol_accumulated
print(f'Finished writing volume chunk {vol_chunk_index}.')
start = [s*c for s, c in zip(self.vol_chunk_shape, chunk_coords)]
end = [s*(c+1) for s, c in zip(self.vol_chunk_shape, chunk_coords)]
zarr_array[start[0]:end[0], start[1]:end[1], start[2]:end[2]] = vol_accumulated
print(f'Finished writing volume chunk {chunk_coords}.')
def reconstruct(self):
self.write_projection_chunks()
......@@ -183,26 +201,31 @@ class ChunkedCT:
mode='w',
dtype=np.float32,
shape=self.ig.shape[::-1],
chunks=(self.vol_chunk_height, *self.ig.shape[1:][::-1]),
chunks=self.vol_chunk_shape,
compressor=None
)
for vol_chunk_index in range(self.vol_chunks):
self.reconstruct_volume_chunk(vol_chunk_index)
for chunk_coords in itertools.product(*map(range, self.vol_chunks)):
self.reconstruct_volume_chunk(chunk_coords)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('-i', '--tiff_dir', type=str, required=True, help='Path to CT scan metadata file')
parser.add_argument('-p', '--processes', type=int, default=None, help='Number of processes')
parser.add_argument('-c', '--chunk-size', type=int, default=100, help='Projection chunk size')
parser.add_argument('-v', '--vol-chunks', type=int, default=1, help='Volume chunks')
# parser.add_argument('-m', '--meta_path', type=str, required=True, help='Path to CT scan metadata file.')
# parser.add_argument('-t', '--tiff_dir', type=str, default=True, help='Path to directory of TIFF projections.')
parser.add_argument('-p', '--processes', type=int, default=None, help="Number of processes for multiprocessing. None uses all available.")
parser.add_argument('-c', '--chunk-size', type=int, default=100, help="Projection chunk size.")
parser.add_argument('-v', '--vol-chunks', type=str, default='1,1,1', help="Volume chunks specified as number of chunks in each dimension e.g. 2,2,2")
args = parser.parse_args()
vol_chunks = tuple(int(d) for d in args.vol_chunks.split(','))
if len(vol_chunks) != 3:
raise ValueError
config_path = '/dtu/3d-imaging-center/projects/2021_DANFIX_Casper/raw_data_3DIM/Casper_top_3_2 [2021-03-17 16.54.39]/Casper_top_3_2_recon.xtekct'
tiff_dir = Path(config_path).parent
processor = ChunkedCT(
config_path=config_path, tiff_dir=tiff_dir,
proj_chunk_size=args.chunk_size, vol_chunks=args.vol_chunks, processes=args.processes
proj_chunk_size=args.chunk_size, vol_chunks=vol_chunks, processes=args.processes
)
processor.reconstruct()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment