nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/bootstrap.py +2 -3
- nabu/app/cast_volume.py +4 -2
- nabu/app/cli_configs.py +5 -0
- nabu/app/composite_cor.py +1 -1
- nabu/app/create_distortion_map_from_poly.py +5 -6
- nabu/app/diag_to_pix.py +7 -19
- nabu/app/diag_to_rot.py +14 -29
- nabu/app/double_flatfield.py +32 -44
- nabu/app/parse_reconstruction_log.py +3 -0
- nabu/app/reconstruct.py +53 -15
- nabu/app/reconstruct_helical.py +2 -2
- nabu/app/stitching.py +27 -13
- nabu/app/tests/__init__.py +0 -0
- nabu/app/tests/test_reduce_dark_flat.py +4 -1
- nabu/cuda/kernel.py +11 -2
- nabu/cuda/processing.py +2 -2
- nabu/cuda/src/cone.cu +77 -0
- nabu/cuda/src/hierarchical_backproj.cu +271 -0
- nabu/cuda/utils.py +0 -6
- nabu/estimation/alignment.py +5 -19
- nabu/estimation/cor.py +173 -599
- nabu/estimation/cor_sino.py +356 -26
- nabu/estimation/focus.py +63 -11
- nabu/estimation/tests/test_cor.py +124 -58
- nabu/estimation/tests/test_focus.py +6 -6
- nabu/estimation/tilt.py +2 -1
- nabu/estimation/utils.py +5 -33
- nabu/io/__init__.py +1 -1
- nabu/io/cast_volume.py +1 -1
- nabu/io/reader.py +416 -21
- nabu/io/tests/test_readers.py +422 -0
- nabu/io/tests/test_writers.py +1 -102
- nabu/io/writer.py +4 -433
- nabu/opencl/kernel.py +14 -3
- nabu/opencl/processing.py +8 -0
- nabu/pipeline/config_validators.py +5 -2
- nabu/pipeline/datadump.py +12 -5
- nabu/pipeline/estimators.py +162 -188
- nabu/pipeline/fullfield/chunked.py +168 -92
- nabu/pipeline/fullfield/chunked_cuda.py +7 -3
- nabu/pipeline/fullfield/computations.py +2 -7
- nabu/pipeline/fullfield/dataset_validator.py +0 -4
- nabu/pipeline/fullfield/nabu_config.py +37 -13
- nabu/pipeline/fullfield/processconfig.py +22 -13
- nabu/pipeline/fullfield/reconstruction.py +13 -9
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +1 -1
- nabu/pipeline/params.py +21 -1
- nabu/pipeline/processconfig.py +1 -12
- nabu/pipeline/reader.py +146 -0
- nabu/pipeline/tests/test_estimators.py +44 -72
- nabu/pipeline/utils.py +4 -2
- nabu/pipeline/writer.py +10 -2
- nabu/preproc/ccd_cuda.py +1 -1
- nabu/preproc/ctf.py +14 -7
- nabu/preproc/ctf_cuda.py +2 -3
- nabu/preproc/double_flatfield.py +5 -12
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/flatfield.py +5 -1
- nabu/preproc/flatfield_cuda.py +5 -1
- nabu/preproc/phase.py +24 -73
- nabu/preproc/phase_cuda.py +5 -8
- nabu/preproc/tests/test_ctf.py +11 -7
- nabu/preproc/tests/test_flatfield.py +67 -122
- nabu/preproc/tests/test_paganin.py +54 -30
- nabu/processing/azim.py +206 -0
- nabu/processing/convolution_cuda.py +1 -1
- nabu/processing/fft_cuda.py +15 -17
- nabu/processing/histogram.py +2 -0
- nabu/processing/histogram_cuda.py +2 -1
- nabu/processing/kernel_base.py +3 -0
- nabu/processing/muladd_cuda.py +1 -0
- nabu/processing/padding_opencl.py +1 -1
- nabu/processing/roll_opencl.py +1 -0
- nabu/processing/rotation_cuda.py +2 -2
- nabu/processing/tests/test_fft.py +17 -10
- nabu/processing/unsharp_cuda.py +1 -1
- nabu/reconstruction/cone.py +104 -40
- nabu/reconstruction/fbp.py +3 -0
- nabu/reconstruction/fbp_base.py +7 -2
- nabu/reconstruction/filtering.py +20 -7
- nabu/reconstruction/filtering_cuda.py +7 -1
- nabu/reconstruction/hbp.py +424 -0
- nabu/reconstruction/mlem.py +99 -0
- nabu/reconstruction/reconstructor.py +2 -0
- nabu/reconstruction/rings_cuda.py +19 -19
- nabu/reconstruction/sinogram_cuda.py +1 -0
- nabu/reconstruction/sinogram_opencl.py +3 -1
- nabu/reconstruction/tests/test_cone.py +10 -5
- nabu/reconstruction/tests/test_deringer.py +7 -6
- nabu/reconstruction/tests/test_fbp.py +124 -10
- nabu/reconstruction/tests/test_filtering.py +13 -11
- nabu/reconstruction/tests/test_halftomo.py +30 -4
- nabu/reconstruction/tests/test_mlem.py +91 -0
- nabu/reconstruction/tests/test_reconstructor.py +8 -3
- nabu/resources/dataset_analyzer.py +142 -92
- nabu/resources/gpu.py +1 -0
- nabu/resources/nxflatfield.py +134 -125
- nabu/resources/templates/id16a_fluo.conf +42 -0
- nabu/resources/tests/test_extract.py +10 -0
- nabu/resources/tests/test_nxflatfield.py +2 -2
- nabu/stitching/alignment.py +80 -24
- nabu/stitching/config.py +105 -68
- nabu/stitching/definitions.py +1 -0
- nabu/stitching/frame_composition.py +68 -60
- nabu/stitching/overlap.py +91 -51
- nabu/stitching/single_axis_stitching.py +32 -0
- nabu/stitching/slurm_utils.py +6 -6
- nabu/stitching/stitcher/__init__.py +0 -0
- nabu/stitching/stitcher/base.py +124 -0
- nabu/stitching/stitcher/dumper/__init__.py +3 -0
- nabu/stitching/stitcher/dumper/base.py +94 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
- nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
- nabu/stitching/stitcher/post_processing.py +555 -0
- nabu/stitching/stitcher/pre_processing.py +1068 -0
- nabu/stitching/stitcher/single_axis.py +484 -0
- nabu/stitching/stitcher/stitcher.py +0 -0
- nabu/stitching/stitcher/y_stitcher.py +13 -0
- nabu/stitching/stitcher/z_stitcher.py +45 -0
- nabu/stitching/stitcher_2D.py +278 -0
- nabu/stitching/tests/test_config.py +12 -37
- nabu/stitching/tests/test_frame_composition.py +33 -59
- nabu/stitching/tests/test_overlap.py +149 -7
- nabu/stitching/tests/test_utils.py +1 -1
- nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
- nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
- nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
- nabu/stitching/utils/__init__.py +1 -0
- nabu/stitching/utils/post_processing.py +281 -0
- nabu/stitching/utils/tests/test_post-processing.py +21 -0
- nabu/stitching/{utils.py → utils/utils.py} +79 -52
- nabu/stitching/y_stitching.py +27 -0
- nabu/stitching/z_stitching.py +32 -2281
- nabu/testutils.py +1 -152
- nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
- nabu/utils.py +158 -61
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
- nabu/io/tiffwriter_zmm.py +0 -99
- nabu/pipeline/fallback_utils.py +0 -149
- nabu/pipeline/helical/tests/test_accumulator.py +0 -158
- nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
- nabu/pipeline/helical/tests/test_strategy.py +0 -61
- nabu/pipeline/helical/utils.py +0 -51
- nabu/pipeline/tests/test_chunk_reader.py +0 -74
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,555 @@
|
|
1
|
+
import logging
|
2
|
+
import numpy
|
3
|
+
import os
|
4
|
+
import h5py
|
5
|
+
from typing import Union
|
6
|
+
from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration
|
7
|
+
from nabu.stitching.alignment import AlignmentAxis1
|
8
|
+
from nabu.stitching.alignment import PaddedRawData
|
9
|
+
from math import ceil
|
10
|
+
from tomoscan.io import HDF5File
|
11
|
+
from tomoscan.esrf.scan.utils import cwd_context
|
12
|
+
from tomoscan.esrf import NXtomoScan
|
13
|
+
from tomoscan.series import Series
|
14
|
+
from tomoscan.volumebase import VolumeBase
|
15
|
+
from tomoscan.esrf.volume import HDF5Volume
|
16
|
+
from typing import Iterable
|
17
|
+
from contextlib import AbstractContextManager
|
18
|
+
from pyunitsystem.metricsystem import MetricSystem
|
19
|
+
from nabu.stitching.config import (
|
20
|
+
PostProcessedSingleAxisStitchingConfiguration,
|
21
|
+
KEY_IMG_REG_METHOD,
|
22
|
+
)
|
23
|
+
from nabu.stitching.utils.utils import find_volumes_relative_shifts
|
24
|
+
from nabu.io.utils import DatasetReader
|
25
|
+
from .single_axis import SingleAxisStitcher
|
26
|
+
|
27
|
+
_logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class FlippingValueError(ValueError):
|
31
|
+
pass
|
32
|
+
|
33
|
+
|
34
|
+
class PostProcessingStitching(SingleAxisStitcher):
|
35
|
+
"""
|
36
|
+
Loader to be used when load data during post-processing stitching (on recosntructed volume). Output is expected to be an NXtomo
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, configuration, progress=None) -> None:
|
40
|
+
if not isinstance(configuration, PostProcessedSingleAxisStitchingConfiguration):
|
41
|
+
raise TypeError(
|
42
|
+
f"configuration is expected to be an instance of {PostProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead"
|
43
|
+
)
|
44
|
+
self._input_volumes = configuration.input_volumes
|
45
|
+
self.__output_data_type = None
|
46
|
+
|
47
|
+
self._series = Series("series", iterable=self._input_volumes, use_identifiers=False)
|
48
|
+
|
49
|
+
super().__init__(configuration, progress=progress)
|
50
|
+
|
51
|
+
@property
|
52
|
+
def stitching_axis_in_frame_space(self):
|
53
|
+
if self.axis == 0:
|
54
|
+
return 0
|
55
|
+
elif self.axis in (1, 2):
|
56
|
+
raise NotImplementedError(f"post-processing stitching along axis {self.axis} is not handled.")
|
57
|
+
else:
|
58
|
+
raise NotImplementedError(f"stitching axis must be in (0, 1, 2). Get {self.axis}")
|
59
|
+
|
60
|
+
def settle_flips(self):
|
61
|
+
super().settle_flips()
|
62
|
+
if not self.configuration.duplicate_data:
|
63
|
+
if len(numpy.unique(self.configuration.flip_lr)) > 1:
|
64
|
+
raise FlippingValueError(
|
65
|
+
"Stitching without data duplication cannot handle volume with different flip. Please run the stitching with data duplication"
|
66
|
+
)
|
67
|
+
if True in self.configuration.flip_ud:
|
68
|
+
raise FlippingValueError(
|
69
|
+
"Stitching without data duplication cannot handle with up / down flips. Please run the stitching with data duplication"
|
70
|
+
)
|
71
|
+
|
72
|
+
def order_input_tomo_objects(self):
|
73
|
+
|
74
|
+
def get_min_bound(volume):
|
75
|
+
try:
|
76
|
+
bb = volume.get_bounding_box(axis=self.axis)
|
77
|
+
except ValueError: # if missing information
|
78
|
+
bb = None
|
79
|
+
if bb is not None:
|
80
|
+
return bb.min
|
81
|
+
else:
|
82
|
+
# if can't find bounding box (missing metadata to the volume
|
83
|
+
# try to get it from the scan
|
84
|
+
metadata = volume.metadata or volume.load_metadata()
|
85
|
+
scan_location = metadata.get("nabu_config", {}).get("dataset", {}).get("location", None)
|
86
|
+
scan_entry = metadata.get("nabu_config", {}).get("dataset", {}).get("hdf5_entry", None)
|
87
|
+
if scan_location is not None:
|
88
|
+
# this work around (until most volume have position metadata) works only for Hdf5volume
|
89
|
+
with cwd_context(os.path.dirname(volume.file_path)):
|
90
|
+
o_scan = NXtomoScan(scan_location, scan_entry)
|
91
|
+
bb_acqui = o_scan.get_bounding_box(axis=None)
|
92
|
+
# for next step volume position will be required.
|
93
|
+
# if you can find it set it directly
|
94
|
+
volume.position = (numpy.array(bb_acqui.max) - numpy.array(bb_acqui.min)) / 2.0 + numpy.array(
|
95
|
+
bb_acqui.min
|
96
|
+
)
|
97
|
+
# for now translation are stored in pixel size ref instead of real_pixel_size
|
98
|
+
volume.pixel_size = o_scan.x_real_pixel_size
|
99
|
+
if bb_acqui is not None:
|
100
|
+
return bb_acqui.min[0]
|
101
|
+
raise ValueError("Unable to find volume position. Unable to deduce z position")
|
102
|
+
|
103
|
+
try:
|
104
|
+
# order volumes from higher z to lower z
|
105
|
+
# if axis 0 position is provided then use directly it
|
106
|
+
if self.configuration.axis_0_pos_px is not None and len(self.configuration.axis_0_pos_px) > 0:
|
107
|
+
order = numpy.argsort(self.configuration.axis_0_pos_px)
|
108
|
+
sorted_series = Series(
|
109
|
+
self.series.name,
|
110
|
+
numpy.take_along_axis(numpy.array(self.series[:]), order, axis=0)[::-1],
|
111
|
+
use_identifiers=False,
|
112
|
+
)
|
113
|
+
else:
|
114
|
+
# else use bounding box
|
115
|
+
sorted_series = Series(
|
116
|
+
self.series.name,
|
117
|
+
sorted(self.series[:], key=get_min_bound, reverse=True),
|
118
|
+
use_identifiers=False,
|
119
|
+
)
|
120
|
+
except ValueError:
|
121
|
+
_logger.warning(
|
122
|
+
"Unable to find volume positions in metadata. Expect the volume to be ordered already (decreasing along axis 0.)"
|
123
|
+
)
|
124
|
+
else:
|
125
|
+
if sorted_series == self.series:
|
126
|
+
pass
|
127
|
+
elif sorted_series != self.series:
|
128
|
+
if sorted_series[:] != self.series[::-1]:
|
129
|
+
raise ValueError(
|
130
|
+
"Unable to get comprehensive input. ordering along axis 0 is not respected (decreasing)."
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
_logger.warning(
|
134
|
+
f"decreasing order haven't been respected. Need to reorder {self.serie_label} ({[str(scan) for scan in sorted_series[:]]}). Will also reorder positions"
|
135
|
+
)
|
136
|
+
if self.configuration.axis_0_pos_mm is not None:
|
137
|
+
self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1]
|
138
|
+
if self.configuration.axis_0_pos_px is not None:
|
139
|
+
self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1]
|
140
|
+
if self.configuration.axis_1_pos_mm is not None:
|
141
|
+
self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1]
|
142
|
+
if self.configuration.axis_1_pos_px is not None:
|
143
|
+
self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1]
|
144
|
+
if self.configuration.axis_2_pos_mm is not None:
|
145
|
+
self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1]
|
146
|
+
if self.configuration.axis_2_pos_px is not None:
|
147
|
+
self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1]
|
148
|
+
if not numpy.isscalar(self._configuration.flip_ud):
|
149
|
+
self._configuration.flip_ud = self._configuration.flip_ud[::-1]
|
150
|
+
if not numpy.isscalar(self._configuration.flip_lr):
|
151
|
+
self._configuration.flip_ud = self._configuration.flip_lr[::-1]
|
152
|
+
|
153
|
+
self._series = sorted_series
|
154
|
+
|
155
|
+
def check_inputs(self):
|
156
|
+
"""
|
157
|
+
insure input data is coherent
|
158
|
+
"""
|
159
|
+
# check input volume
|
160
|
+
if self.configuration.output_volume is None:
|
161
|
+
raise ValueError("input volume should be provided")
|
162
|
+
|
163
|
+
n_volumes = len(self.series)
|
164
|
+
if n_volumes == 0:
|
165
|
+
raise ValueError("no scan to stich together")
|
166
|
+
|
167
|
+
if not isinstance(self.configuration.output_volume, VolumeBase):
|
168
|
+
raise TypeError(f"make sure we return a volume identifier not {(type(self.configuration.output_volume))}")
|
169
|
+
|
170
|
+
# check axis 0 position
|
171
|
+
if isinstance(self.configuration.axis_0_pos_px, Iterable) and len(self.configuration.axis_0_pos_px) != (
|
172
|
+
n_volumes
|
173
|
+
):
|
174
|
+
raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_px)}")
|
175
|
+
if isinstance(self.configuration.axis_0_pos_mm, Iterable) and len(self.configuration.axis_0_pos_mm) != (
|
176
|
+
n_volumes
|
177
|
+
):
|
178
|
+
raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_mm)}")
|
179
|
+
|
180
|
+
# check axis 1 position
|
181
|
+
if isinstance(self.configuration.axis_1_pos_px, Iterable) and len(self.configuration.axis_1_pos_px) != (
|
182
|
+
n_volumes
|
183
|
+
):
|
184
|
+
raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_px)}")
|
185
|
+
if isinstance(self.configuration.axis_1_pos_mm, Iterable) and len(self.configuration.axis_1_pos_mm) != (
|
186
|
+
n_volumes
|
187
|
+
):
|
188
|
+
raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_mm)}")
|
189
|
+
|
190
|
+
# check axis 2 position
|
191
|
+
if isinstance(self.configuration.axis_1_pos_px, Iterable) and len(self.configuration.axis_1_pos_px) != (
|
192
|
+
n_volumes
|
193
|
+
):
|
194
|
+
raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_px)}")
|
195
|
+
if isinstance(self.configuration.axis_2_pos_mm, Iterable) and len(self.configuration.axis_2_pos_mm) != (
|
196
|
+
n_volumes
|
197
|
+
):
|
198
|
+
raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_2_pos_mm)}")
|
199
|
+
|
200
|
+
self._reading_orders = []
|
201
|
+
# the first scan will define the expected reading orderd, and expected flip.
|
202
|
+
# if all scan are flipped then we will keep it this way
|
203
|
+
self._reading_orders.append(1)
|
204
|
+
|
205
|
+
@staticmethod
|
206
|
+
def _get_bunch_of_data(
|
207
|
+
bunch_start: int,
|
208
|
+
bunch_end: int,
|
209
|
+
step: int,
|
210
|
+
volumes: tuple,
|
211
|
+
flip_lr_arr: bool,
|
212
|
+
flip_ud_arr: bool,
|
213
|
+
):
|
214
|
+
"""
|
215
|
+
goal is to load contiguous frames as much as possible...
|
216
|
+
return for each volume the bunch of slice along axis 1
|
217
|
+
warning: they can have different shapes
|
218
|
+
"""
|
219
|
+
|
220
|
+
def get_sub_volume(volume, flip_lr, flip_ud):
|
221
|
+
sub_volume = volume[:, bunch_start:bunch_end:step, :]
|
222
|
+
if flip_lr:
|
223
|
+
sub_volume = numpy.fliplr(sub_volume)
|
224
|
+
if flip_ud:
|
225
|
+
sub_volume = numpy.flipud(sub_volume)
|
226
|
+
return sub_volume
|
227
|
+
|
228
|
+
sub_volumes = [
|
229
|
+
get_sub_volume(volume, flip_lr, flip_ud)
|
230
|
+
for volume, flip_lr, flip_ud in zip(volumes, flip_lr_arr, flip_ud_arr)
|
231
|
+
]
|
232
|
+
# generator on it self: we want to iterate over the y axis
|
233
|
+
n_slices_in_bunch = ceil((bunch_end - bunch_start) / step)
|
234
|
+
assert isinstance(n_slices_in_bunch, int)
|
235
|
+
for i in range(n_slices_in_bunch):
|
236
|
+
yield [sub_volume[:, i, :] for sub_volume in sub_volumes]
|
237
|
+
|
238
|
+
def compute_estimated_shifts(self):
|
239
|
+
axis_0_pos_px = self.configuration.axis_0_pos_px
|
240
|
+
self._axis_0_rel_ini_shifts = []
|
241
|
+
# compute overlap along axis 0
|
242
|
+
for upper_volume, lower_volume, upper_volume_axis_0_pos, lower_volume_axis_0_pos in zip(
|
243
|
+
self.series[:-1], self.series[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:]
|
244
|
+
):
|
245
|
+
upper_volume_low_pos = upper_volume_axis_0_pos - upper_volume.get_volume_shape()[0] / 2
|
246
|
+
lower_volume_high_pos = lower_volume_axis_0_pos + lower_volume.get_volume_shape()[0] / 2
|
247
|
+
self._axis_0_rel_ini_shifts.append(
|
248
|
+
int(lower_volume_high_pos - upper_volume_low_pos) # overlap are expected to be int for now
|
249
|
+
)
|
250
|
+
self._axis_1_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_1_pos_px)
|
251
|
+
self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1)
|
252
|
+
|
253
|
+
def _compute_positions_as_px(self):
|
254
|
+
"""compute if necessary position other axis 0 from volume metadata"""
|
255
|
+
|
256
|
+
def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm):
|
257
|
+
if pos_as_px is not None:
|
258
|
+
if pos_as_mm is not None:
|
259
|
+
raise ValueError(
|
260
|
+
f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})"
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
return pos_as_px
|
264
|
+
|
265
|
+
elif pos_as_mm is not None:
|
266
|
+
# deduce from position given in configuration and pixel size
|
267
|
+
axis_N_pos_px = []
|
268
|
+
for volume, pos_in_mm in zip(self.series, pos_as_mm):
|
269
|
+
voxel_size_m = self.configuration.voxel_size or volume.voxel_size
|
270
|
+
axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0])
|
271
|
+
return axis_N_pos_px
|
272
|
+
else:
|
273
|
+
# deduce from motor position and pixel size
|
274
|
+
axis_N_pos_px = []
|
275
|
+
base_position_m = self.series[0].get_bounding_box(axis=axis).min
|
276
|
+
for volume in self.series:
|
277
|
+
voxel_size_m = self.configuration.voxel_size or volume.voxel_size
|
278
|
+
volume_axis_bb = volume.get_bounding_box(axis=axis)
|
279
|
+
axis_N_mean_pos_m = (volume_axis_bb.max - volume_axis_bb.min) / 2 + volume_axis_bb.min
|
280
|
+
axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
|
281
|
+
axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / voxel_size_m[0]))
|
282
|
+
return axis_N_pos_px
|
283
|
+
|
284
|
+
self.configuration.axis_0_pos_px = get_position_as_px_on_axis(
|
285
|
+
axis=0,
|
286
|
+
pos_as_px=self.configuration.axis_0_pos_px,
|
287
|
+
pos_as_mm=self.configuration.axis_0_pos_mm,
|
288
|
+
)
|
289
|
+
self.configuration.axis_0_pos_mm = None
|
290
|
+
|
291
|
+
self.configuration.axis_1_pos_px = get_position_as_px_on_axis(
|
292
|
+
axis=1,
|
293
|
+
pos_as_px=self.configuration.axis_1_pos_px,
|
294
|
+
pos_as_mm=self.configuration.axis_1_pos_mm,
|
295
|
+
)
|
296
|
+
|
297
|
+
self.configuration.axis_2_pos_px = get_position_as_px_on_axis(
|
298
|
+
axis=2,
|
299
|
+
pos_as_px=self.configuration.axis_2_pos_px,
|
300
|
+
pos_as_mm=self.configuration.axis_2_pos_mm,
|
301
|
+
)
|
302
|
+
self.configuration.axis_2_pos_mm = None
|
303
|
+
|
304
|
+
def _compute_shifts(self):
|
305
|
+
n_volumes = len(self.configuration.input_volumes)
|
306
|
+
if n_volumes == 0:
|
307
|
+
raise ValueError("no scan to stich provided")
|
308
|
+
|
309
|
+
slice_for_shift = self.configuration.slice_for_cross_correlation or "middle"
|
310
|
+
y_rel_shifts = self._axis_0_rel_ini_shifts
|
311
|
+
x_rel_shifts = self._axis_1_rel_ini_shifts
|
312
|
+
dim_axis_1 = max([volume.get_volume_shape()[1] for volume in self.series])
|
313
|
+
|
314
|
+
final_rel_shifts = []
|
315
|
+
for (
|
316
|
+
upper_volume,
|
317
|
+
lower_volume,
|
318
|
+
x_rel_shift,
|
319
|
+
y_rel_shift,
|
320
|
+
flip_ud_upper,
|
321
|
+
flip_ud_lower,
|
322
|
+
) in zip(
|
323
|
+
self.series[:-1],
|
324
|
+
self.series[1:],
|
325
|
+
x_rel_shifts,
|
326
|
+
y_rel_shifts,
|
327
|
+
self.configuration.flip_ud[:-1],
|
328
|
+
self.configuration.flip_ud[1:],
|
329
|
+
):
|
330
|
+
x_cross_algo = self.configuration.axis_2_params.get(KEY_IMG_REG_METHOD, None)
|
331
|
+
y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None)
|
332
|
+
|
333
|
+
# compute relative shift
|
334
|
+
found_shift_y, found_shift_x = find_volumes_relative_shifts(
|
335
|
+
upper_volume=upper_volume,
|
336
|
+
lower_volume=lower_volume,
|
337
|
+
dtype=self.get_output_data_type(),
|
338
|
+
dim_axis_1=dim_axis_1,
|
339
|
+
slice_for_shift=slice_for_shift,
|
340
|
+
x_cross_correlation_function=x_cross_algo,
|
341
|
+
y_cross_correlation_function=y_cross_algo,
|
342
|
+
x_shifts_params=self.configuration.axis_2_params,
|
343
|
+
y_shifts_params=self.configuration.axis_0_params,
|
344
|
+
estimated_shifts=(y_rel_shift, x_rel_shift),
|
345
|
+
flip_ud_lower_frame=flip_ud_lower,
|
346
|
+
flip_ud_upper_frame=flip_ud_upper,
|
347
|
+
alignment_axis_1=self.configuration.alignment_axis_1,
|
348
|
+
alignment_axis_2=self.configuration.alignment_axis_2,
|
349
|
+
overlap_axis=self.axis,
|
350
|
+
)
|
351
|
+
final_rel_shifts.append(
|
352
|
+
(found_shift_y, found_shift_x),
|
353
|
+
)
|
354
|
+
|
355
|
+
# set back values. Now position should start at 0
|
356
|
+
self._axis_0_rel_final_shifts = [final_shift[0] for final_shift in final_rel_shifts]
|
357
|
+
self._axis_1_rel_final_shifts = [final_shift[1] for final_shift in final_rel_shifts]
|
358
|
+
self._axis_2_rel_final_shifts = [0.0] * len(final_rel_shifts)
|
359
|
+
_logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_1_rel_final_shifts}")
|
360
|
+
print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_1_rel_final_shifts}")
|
361
|
+
_logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_final_shifts}")
|
362
|
+
print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_final_shifts}")
|
363
|
+
|
364
|
+
def get_output_data_type(self):
|
365
|
+
if self.__output_data_type is None:
|
366
|
+
|
367
|
+
def find_output_data_type():
|
368
|
+
first_vol = self._input_volumes[0]
|
369
|
+
if first_vol.data is not None:
|
370
|
+
return first_vol.data.dtype
|
371
|
+
elif isinstance(first_vol, HDF5Volume):
|
372
|
+
with DatasetReader(first_vol.data_url) as vol_dataset:
|
373
|
+
return vol_dataset.dtype
|
374
|
+
else:
|
375
|
+
return first_vol.load_data(store=False).dtype
|
376
|
+
|
377
|
+
self.__output_data_type = find_output_data_type()
|
378
|
+
return self.__output_data_type
|
379
|
+
|
380
|
+
def _create_stitched_volume(self, store_composition: bool):
|
381
|
+
overlap_kernels = self._overlap_kernels
|
382
|
+
self._slices_to_stitch, n_slices = self.configuration.settle_slices()
|
383
|
+
|
384
|
+
# sync overwrite_results with volume overwrite parameter
|
385
|
+
self.configuration.output_volume.overwrite = self.configuration.overwrite_results
|
386
|
+
|
387
|
+
# init final volume
|
388
|
+
final_volume = self.configuration.output_volume
|
389
|
+
final_volume_shape = (
|
390
|
+
int(
|
391
|
+
numpy.asarray([volume.get_volume_shape()[0] for volume in self._input_volumes]).sum()
|
392
|
+
- numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_final_shifts]).sum(),
|
393
|
+
),
|
394
|
+
n_slices,
|
395
|
+
self._stitching_constant_length,
|
396
|
+
)
|
397
|
+
|
398
|
+
data_type = self.get_output_data_type()
|
399
|
+
|
400
|
+
if self.progress:
|
401
|
+
self.progress.total = final_volume_shape[1]
|
402
|
+
|
403
|
+
y_index = 0
|
404
|
+
if isinstance(self._slices_to_stitch, slice):
|
405
|
+
step = self._slices_to_stitch.step or 1
|
406
|
+
else:
|
407
|
+
step = 1
|
408
|
+
|
409
|
+
output_dataset_args = {
|
410
|
+
"volume": final_volume,
|
411
|
+
"volume_shape": final_volume_shape,
|
412
|
+
"dtype": data_type,
|
413
|
+
"dumper": self.dumper,
|
414
|
+
}
|
415
|
+
from .dumper.postprocessing import PostProcessingStitchingDumperNoDD
|
416
|
+
|
417
|
+
# TODO: FIXME: for now not very elegant but in the case of avoiding data duplication
|
418
|
+
# we need to provide the the information about the stitched part shape.
|
419
|
+
# this should be move to the dumper in the future
|
420
|
+
if isinstance(self.dumper, PostProcessingStitchingDumperNoDD):
|
421
|
+
output_dataset_args["stitching_sources_arr_shapes"] = tuple(
|
422
|
+
[(abs(overlap), n_slices, self._stitching_constant_length) for overlap in self._axis_0_rel_final_shifts]
|
423
|
+
)
|
424
|
+
|
425
|
+
with self.dumper.OutputDatasetContext(**output_dataset_args):
|
426
|
+
# note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
|
427
|
+
with _RawDatasetsContext(
|
428
|
+
self._input_volumes,
|
429
|
+
alignment_axis_1=self.configuration.alignment_axis_1,
|
430
|
+
) as raw_datasets:
|
431
|
+
# note: raw_datasets can be numpy arrays or HDF5 dataset (in the case of HDF5Volume)
|
432
|
+
# to speed up we read by bunch of dataset. For numpy array this doesn't change anything
|
433
|
+
# but for HDF5 dataset this can speed up a lot the processing (depending on HDF5 dataset chuncks)
|
434
|
+
# note: we read trhough axis 1
|
435
|
+
if isinstance(self.dumper, PostProcessingStitchingDumperNoDD):
|
436
|
+
self.dumper.raw_regions_hdf5_dataset = raw_datasets
|
437
|
+
for bunch_start, bunch_end in PostProcessingStitching._data_bunch_iterator(
|
438
|
+
slices=self._slices_to_stitch, bunch_size=50
|
439
|
+
):
|
440
|
+
for data_frames in PostProcessingStitching._get_bunch_of_data(
|
441
|
+
bunch_start,
|
442
|
+
bunch_end,
|
443
|
+
step=step,
|
444
|
+
volumes=raw_datasets,
|
445
|
+
flip_lr_arr=self.configuration.flip_lr,
|
446
|
+
flip_ud_arr=self.configuration.flip_ud,
|
447
|
+
):
|
448
|
+
if self.configuration.rescale_frames:
|
449
|
+
data_frames = self.rescale_frames(data_frames)
|
450
|
+
if self.configuration.normalization_by_sample.is_active():
|
451
|
+
data_frames = self.normalize_frame_by_sample(data_frames)
|
452
|
+
|
453
|
+
sf = PostProcessingStitching.stitch_frames(
|
454
|
+
frames=data_frames,
|
455
|
+
axis=self.axis,
|
456
|
+
output_dtype=data_type,
|
457
|
+
x_relative_shifts=self._axis_1_rel_final_shifts,
|
458
|
+
y_relative_shifts=self._axis_0_rel_final_shifts,
|
459
|
+
overlap_kernels=overlap_kernels,
|
460
|
+
dumper=self.dumper,
|
461
|
+
i_frame=y_index,
|
462
|
+
return_composition_cls=store_composition if y_index == 0 else False,
|
463
|
+
stitching_axis=self.axis,
|
464
|
+
check_inputs=y_index == 0, # on process check on the first iteration
|
465
|
+
)
|
466
|
+
if y_index == 0 and store_composition:
|
467
|
+
_, self._frame_composition = sf
|
468
|
+
|
469
|
+
if self.progress is not None:
|
470
|
+
self.progress.update()
|
471
|
+
y_index += 1
|
472
|
+
|
473
|
+
# alias to general API
|
474
|
+
def _create_stitching(self, store_composition):
|
475
|
+
self._create_stitched_volume(store_composition=store_composition)
|
476
|
+
|
477
|
+
|
478
|
+
class _RawDatasetsContext(AbstractContextManager):
|
479
|
+
"""
|
480
|
+
return volume data for all input volume (target: used for volume stitching).
|
481
|
+
If the volume is an HDF5Volume then the HDF5 dataset will be used (on disk)
|
482
|
+
If the volume is of another type then it will be loaded in memory then used (more memory consuming)
|
483
|
+
"""
|
484
|
+
|
485
|
+
def __init__(self, volumes: tuple, alignment_axis_1) -> None:
|
486
|
+
super().__init__()
|
487
|
+
for volume in volumes:
|
488
|
+
if not isinstance(volume, VolumeBase):
|
489
|
+
raise TypeError(
|
490
|
+
f"Volumes are expected to be an instance of {VolumeBase}. {type(volume)} provided instead"
|
491
|
+
)
|
492
|
+
|
493
|
+
self._volumes = volumes
|
494
|
+
self.__file_handlers = []
|
495
|
+
self._alignment_axis_1 = alignment_axis_1
|
496
|
+
|
497
|
+
@property
|
498
|
+
def alignment_axis_1(self):
|
499
|
+
return self._alignment_axis_1
|
500
|
+
|
501
|
+
def __enter__(self):
|
502
|
+
# handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory
|
503
|
+
datasets = []
|
504
|
+
shapes = {volume.get_volume_shape()[1] for volume in self._volumes}
|
505
|
+
axis_1_dim = max(shapes)
|
506
|
+
axis_1_need_padding = len(shapes) > 1
|
507
|
+
|
508
|
+
try:
|
509
|
+
for volume in self._volumes:
|
510
|
+
if volume.data is not None:
|
511
|
+
data = volume.data
|
512
|
+
elif isinstance(volume, HDF5Volume):
|
513
|
+
file_handler = HDF5File(volume.data_url.file_path(), mode="r")
|
514
|
+
dataset = file_handler[volume.data_url.data_path()]
|
515
|
+
data = dataset
|
516
|
+
self.__file_handlers.append(file_handler)
|
517
|
+
# for other file format: load the full dataset in memory
|
518
|
+
else:
|
519
|
+
data = volume.load_data(store=False)
|
520
|
+
if data is None:
|
521
|
+
raise ValueError(f"No data found for volume {volume.get_identifier()}")
|
522
|
+
if axis_1_need_padding:
|
523
|
+
data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
|
524
|
+
datasets.append(data)
|
525
|
+
except Exception as e:
|
526
|
+
# if some errors happen during loading HDF5
|
527
|
+
for file_handled in self.__file_handlers:
|
528
|
+
file_handled.close()
|
529
|
+
raise e
|
530
|
+
|
531
|
+
return datasets
|
532
|
+
|
533
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
534
|
+
success = True
|
535
|
+
for file_handler in self.__file_handlers:
|
536
|
+
success = success and file_handler.close()
|
537
|
+
if exc_type is None:
|
538
|
+
return success
|
539
|
+
|
540
|
+
def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
|
541
|
+
alignment = AlignmentAxis1.from_value(alignment)
|
542
|
+
if alignment is AlignmentAxis1.BACK:
|
543
|
+
axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
|
544
|
+
elif alignment is AlignmentAxis1.CENTER:
|
545
|
+
half_width = int((axis_1_dim - data.shape[1]) / 2)
|
546
|
+
axis_1_pad_width = (half_width, axis_1_dim - data.shape[1] - half_width)
|
547
|
+
elif alignment is AlignmentAxis1.FRONT:
|
548
|
+
axis_1_pad_width = (0, axis_1_dim - data.shape[1])
|
549
|
+
else:
|
550
|
+
raise ValueError(f"alignment {alignment} is not handled")
|
551
|
+
|
552
|
+
return PaddedRawData(
|
553
|
+
data=data,
|
554
|
+
axis_1_pad_width=axis_1_pad_width,
|
555
|
+
)
|