nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  15. nabu/cuda/kernel.py +11 -2
  16. nabu/cuda/processing.py +2 -2
  17. nabu/cuda/src/cone.cu +77 -0
  18. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  19. nabu/cuda/utils.py +0 -6
  20. nabu/estimation/alignment.py +5 -19
  21. nabu/estimation/cor.py +173 -599
  22. nabu/estimation/cor_sino.py +356 -26
  23. nabu/estimation/focus.py +63 -11
  24. nabu/estimation/tests/test_cor.py +124 -58
  25. nabu/estimation/tests/test_focus.py +6 -6
  26. nabu/estimation/tilt.py +2 -1
  27. nabu/estimation/utils.py +5 -33
  28. nabu/io/__init__.py +1 -1
  29. nabu/io/cast_volume.py +1 -1
  30. nabu/io/reader.py +416 -21
  31. nabu/io/tests/test_readers.py +422 -0
  32. nabu/io/tests/test_writers.py +1 -102
  33. nabu/io/writer.py +4 -433
  34. nabu/opencl/kernel.py +14 -3
  35. nabu/opencl/processing.py +8 -0
  36. nabu/pipeline/config_validators.py +5 -2
  37. nabu/pipeline/datadump.py +12 -5
  38. nabu/pipeline/estimators.py +162 -188
  39. nabu/pipeline/fullfield/chunked.py +168 -92
  40. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  41. nabu/pipeline/fullfield/computations.py +2 -7
  42. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  43. nabu/pipeline/fullfield/nabu_config.py +37 -13
  44. nabu/pipeline/fullfield/processconfig.py +22 -13
  45. nabu/pipeline/fullfield/reconstruction.py +13 -9
  46. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  47. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  48. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  49. nabu/pipeline/params.py +21 -1
  50. nabu/pipeline/processconfig.py +1 -12
  51. nabu/pipeline/reader.py +146 -0
  52. nabu/pipeline/tests/test_estimators.py +44 -72
  53. nabu/pipeline/utils.py +4 -2
  54. nabu/pipeline/writer.py +10 -2
  55. nabu/preproc/ccd_cuda.py +1 -1
  56. nabu/preproc/ctf.py +14 -7
  57. nabu/preproc/ctf_cuda.py +2 -3
  58. nabu/preproc/double_flatfield.py +5 -12
  59. nabu/preproc/double_flatfield_cuda.py +2 -2
  60. nabu/preproc/flatfield.py +5 -1
  61. nabu/preproc/flatfield_cuda.py +5 -1
  62. nabu/preproc/phase.py +24 -73
  63. nabu/preproc/phase_cuda.py +5 -8
  64. nabu/preproc/tests/test_ctf.py +11 -7
  65. nabu/preproc/tests/test_flatfield.py +67 -122
  66. nabu/preproc/tests/test_paganin.py +54 -30
  67. nabu/processing/azim.py +206 -0
  68. nabu/processing/convolution_cuda.py +1 -1
  69. nabu/processing/fft_cuda.py +15 -17
  70. nabu/processing/histogram.py +2 -0
  71. nabu/processing/histogram_cuda.py +2 -1
  72. nabu/processing/kernel_base.py +3 -0
  73. nabu/processing/muladd_cuda.py +1 -0
  74. nabu/processing/padding_opencl.py +1 -1
  75. nabu/processing/roll_opencl.py +1 -0
  76. nabu/processing/rotation_cuda.py +2 -2
  77. nabu/processing/tests/test_fft.py +17 -10
  78. nabu/processing/unsharp_cuda.py +1 -1
  79. nabu/reconstruction/cone.py +104 -40
  80. nabu/reconstruction/fbp.py +3 -0
  81. nabu/reconstruction/fbp_base.py +7 -2
  82. nabu/reconstruction/filtering.py +20 -7
  83. nabu/reconstruction/filtering_cuda.py +7 -1
  84. nabu/reconstruction/hbp.py +424 -0
  85. nabu/reconstruction/mlem.py +99 -0
  86. nabu/reconstruction/reconstructor.py +2 -0
  87. nabu/reconstruction/rings_cuda.py +19 -19
  88. nabu/reconstruction/sinogram_cuda.py +1 -0
  89. nabu/reconstruction/sinogram_opencl.py +3 -1
  90. nabu/reconstruction/tests/test_cone.py +10 -5
  91. nabu/reconstruction/tests/test_deringer.py +7 -6
  92. nabu/reconstruction/tests/test_fbp.py +124 -10
  93. nabu/reconstruction/tests/test_filtering.py +13 -11
  94. nabu/reconstruction/tests/test_halftomo.py +30 -4
  95. nabu/reconstruction/tests/test_mlem.py +91 -0
  96. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  97. nabu/resources/dataset_analyzer.py +142 -92
  98. nabu/resources/gpu.py +1 -0
  99. nabu/resources/nxflatfield.py +134 -125
  100. nabu/resources/templates/id16a_fluo.conf +42 -0
  101. nabu/resources/tests/test_extract.py +10 -0
  102. nabu/resources/tests/test_nxflatfield.py +2 -2
  103. nabu/stitching/alignment.py +80 -24
  104. nabu/stitching/config.py +105 -68
  105. nabu/stitching/definitions.py +1 -0
  106. nabu/stitching/frame_composition.py +68 -60
  107. nabu/stitching/overlap.py +91 -51
  108. nabu/stitching/single_axis_stitching.py +32 -0
  109. nabu/stitching/slurm_utils.py +6 -6
  110. nabu/stitching/stitcher/__init__.py +0 -0
  111. nabu/stitching/stitcher/base.py +124 -0
  112. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  113. nabu/stitching/stitcher/dumper/base.py +94 -0
  114. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  115. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  116. nabu/stitching/stitcher/post_processing.py +555 -0
  117. nabu/stitching/stitcher/pre_processing.py +1068 -0
  118. nabu/stitching/stitcher/single_axis.py +484 -0
  119. nabu/stitching/stitcher/stitcher.py +0 -0
  120. nabu/stitching/stitcher/y_stitcher.py +13 -0
  121. nabu/stitching/stitcher/z_stitcher.py +45 -0
  122. nabu/stitching/stitcher_2D.py +278 -0
  123. nabu/stitching/tests/test_config.py +12 -37
  124. nabu/stitching/tests/test_frame_composition.py +33 -59
  125. nabu/stitching/tests/test_overlap.py +149 -7
  126. nabu/stitching/tests/test_utils.py +1 -1
  127. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  128. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  129. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  130. nabu/stitching/utils/__init__.py +1 -0
  131. nabu/stitching/utils/post_processing.py +281 -0
  132. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  133. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  134. nabu/stitching/y_stitching.py +27 -0
  135. nabu/stitching/z_stitching.py +32 -2263
  136. nabu/testutils.py +1 -152
  137. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  138. nabu/utils.py +158 -61
  139. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
  140. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
  141. nabu/io/tiffwriter_zmm.py +0 -99
  142. nabu/pipeline/fallback_utils.py +0 -149
  143. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  144. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  145. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  146. nabu/pipeline/helical/utils.py +0 -51
  147. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  148. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  149. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
  150. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  151. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,555 @@
1
+ import logging
2
+ import numpy
3
+ import os
4
+ import h5py
5
+ from typing import Union
6
+ from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration
7
+ from nabu.stitching.alignment import AlignmentAxis1
8
+ from nabu.stitching.alignment import PaddedRawData
9
+ from math import ceil
10
+ from tomoscan.io import HDF5File
11
+ from tomoscan.esrf.scan.utils import cwd_context
12
+ from tomoscan.esrf import NXtomoScan
13
+ from tomoscan.series import Series
14
+ from tomoscan.volumebase import VolumeBase
15
+ from tomoscan.esrf.volume import HDF5Volume
16
+ from typing import Iterable
17
+ from contextlib import AbstractContextManager
18
+ from pyunitsystem.metricsystem import MetricSystem
19
+ from nabu.stitching.config import (
20
+ PostProcessedSingleAxisStitchingConfiguration,
21
+ KEY_IMG_REG_METHOD,
22
+ )
23
+ from nabu.stitching.utils.utils import find_volumes_relative_shifts
24
+ from nabu.io.utils import DatasetReader
25
+ from .single_axis import SingleAxisStitcher
26
+
27
+ _logger = logging.getLogger(__name__)
28
+
29
+
30
+ class FlippingValueError(ValueError):
31
+ pass
32
+
33
+
34
+ class PostProcessingStitching(SingleAxisStitcher):
35
+ """
36
+ Loader to be used when load data during post-processing stitching (on recosntructed volume). Output is expected to be an NXtomo
37
+ """
38
+
39
+ def __init__(self, configuration, progress=None) -> None:
40
+ if not isinstance(configuration, PostProcessedSingleAxisStitchingConfiguration):
41
+ raise TypeError(
42
+ f"configuration is expected to be an instance of {PostProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead"
43
+ )
44
+ self._input_volumes = configuration.input_volumes
45
+ self.__output_data_type = None
46
+
47
+ self._series = Series("series", iterable=self._input_volumes, use_identifiers=False)
48
+
49
+ super().__init__(configuration, progress=progress)
50
+
51
+ @property
52
+ def stitching_axis_in_frame_space(self):
53
+ if self.axis == 0:
54
+ return 0
55
+ elif self.axis in (1, 2):
56
+ raise NotImplementedError(f"post-processing stitching along axis {self.axis} is not handled.")
57
+ else:
58
+ raise NotImplementedError(f"stitching axis must be in (0, 1, 2). Get {self.axis}")
59
+
60
+ def settle_flips(self):
61
+ super().settle_flips()
62
+ if not self.configuration.duplicate_data:
63
+ if len(numpy.unique(self.configuration.flip_lr)) > 1:
64
+ raise FlippingValueError(
65
+ "Stitching without data duplication cannot handle volume with different flip. Please run the stitching with data duplication"
66
+ )
67
+ if True in self.configuration.flip_ud:
68
+ raise FlippingValueError(
69
+ "Stitching without data duplication cannot handle with up / down flips. Please run the stitching with data duplication"
70
+ )
71
+
72
+ def order_input_tomo_objects(self):
73
+
74
+ def get_min_bound(volume):
75
+ try:
76
+ bb = volume.get_bounding_box(axis=self.axis)
77
+ except ValueError: # if missing information
78
+ bb = None
79
+ if bb is not None:
80
+ return bb.min
81
+ else:
82
+ # if can't find bounding box (missing metadata to the volume
83
+ # try to get it from the scan
84
+ metadata = volume.metadata or volume.load_metadata()
85
+ scan_location = metadata.get("nabu_config", {}).get("dataset", {}).get("location", None)
86
+ scan_entry = metadata.get("nabu_config", {}).get("dataset", {}).get("hdf5_entry", None)
87
+ if scan_location is not None:
88
+ # this work around (until most volume have position metadata) works only for Hdf5volume
89
+ with cwd_context(os.path.dirname(volume.file_path)):
90
+ o_scan = NXtomoScan(scan_location, scan_entry)
91
+ bb_acqui = o_scan.get_bounding_box(axis=None)
92
+ # for next step volume position will be required.
93
+ # if you can find it set it directly
94
+ volume.position = (numpy.array(bb_acqui.max) - numpy.array(bb_acqui.min)) / 2.0 + numpy.array(
95
+ bb_acqui.min
96
+ )
97
+ # for now translation are stored in pixel size ref instead of real_pixel_size
98
+ volume.pixel_size = o_scan.x_real_pixel_size
99
+ if bb_acqui is not None:
100
+ return bb_acqui.min[0]
101
+ raise ValueError("Unable to find volume position. Unable to deduce z position")
102
+
103
+ try:
104
+ # order volumes from higher z to lower z
105
+ # if axis 0 position is provided then use directly it
106
+ if self.configuration.axis_0_pos_px is not None and len(self.configuration.axis_0_pos_px) > 0:
107
+ order = numpy.argsort(self.configuration.axis_0_pos_px)
108
+ sorted_series = Series(
109
+ self.series.name,
110
+ numpy.take_along_axis(numpy.array(self.series[:]), order, axis=0)[::-1],
111
+ use_identifiers=False,
112
+ )
113
+ else:
114
+ # else use bounding box
115
+ sorted_series = Series(
116
+ self.series.name,
117
+ sorted(self.series[:], key=get_min_bound, reverse=True),
118
+ use_identifiers=False,
119
+ )
120
+ except ValueError:
121
+ _logger.warning(
122
+ "Unable to find volume positions in metadata. Expect the volume to be ordered already (decreasing along axis 0.)"
123
+ )
124
+ else:
125
+ if sorted_series == self.series:
126
+ pass
127
+ elif sorted_series != self.series:
128
+ if sorted_series[:] != self.series[::-1]:
129
+ raise ValueError(
130
+ "Unable to get comprehensive input. ordering along axis 0 is not respected (decreasing)."
131
+ )
132
+ else:
133
+ _logger.warning(
134
+ f"decreasing order haven't been respected. Need to reorder {self.serie_label} ({[str(scan) for scan in sorted_series[:]]}). Will also reorder positions"
135
+ )
136
+ if self.configuration.axis_0_pos_mm is not None:
137
+ self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1]
138
+ if self.configuration.axis_0_pos_px is not None:
139
+ self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1]
140
+ if self.configuration.axis_1_pos_mm is not None:
141
+ self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1]
142
+ if self.configuration.axis_1_pos_px is not None:
143
+ self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1]
144
+ if self.configuration.axis_2_pos_mm is not None:
145
+ self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1]
146
+ if self.configuration.axis_2_pos_px is not None:
147
+ self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1]
148
+ if not numpy.isscalar(self._configuration.flip_ud):
149
+ self._configuration.flip_ud = self._configuration.flip_ud[::-1]
150
+ if not numpy.isscalar(self._configuration.flip_lr):
151
+ self._configuration.flip_ud = self._configuration.flip_lr[::-1]
152
+
153
+ self._series = sorted_series
154
+
155
+ def check_inputs(self):
156
+ """
157
+ insure input data is coherent
158
+ """
159
+ # check input volume
160
+ if self.configuration.output_volume is None:
161
+ raise ValueError("input volume should be provided")
162
+
163
+ n_volumes = len(self.series)
164
+ if n_volumes == 0:
165
+ raise ValueError("no scan to stich together")
166
+
167
+ if not isinstance(self.configuration.output_volume, VolumeBase):
168
+ raise TypeError(f"make sure we return a volume identifier not {(type(self.configuration.output_volume))}")
169
+
170
+ # check axis 0 position
171
+ if isinstance(self.configuration.axis_0_pos_px, Iterable) and len(self.configuration.axis_0_pos_px) != (
172
+ n_volumes
173
+ ):
174
+ raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_px)}")
175
+ if isinstance(self.configuration.axis_0_pos_mm, Iterable) and len(self.configuration.axis_0_pos_mm) != (
176
+ n_volumes
177
+ ):
178
+ raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_mm)}")
179
+
180
+ # check axis 1 position
181
+ if isinstance(self.configuration.axis_1_pos_px, Iterable) and len(self.configuration.axis_1_pos_px) != (
182
+ n_volumes
183
+ ):
184
+ raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_px)}")
185
+ if isinstance(self.configuration.axis_1_pos_mm, Iterable) and len(self.configuration.axis_1_pos_mm) != (
186
+ n_volumes
187
+ ):
188
+ raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_mm)}")
189
+
190
+ # check axis 2 position
191
+ if isinstance(self.configuration.axis_1_pos_px, Iterable) and len(self.configuration.axis_1_pos_px) != (
192
+ n_volumes
193
+ ):
194
+ raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_px)}")
195
+ if isinstance(self.configuration.axis_2_pos_mm, Iterable) and len(self.configuration.axis_2_pos_mm) != (
196
+ n_volumes
197
+ ):
198
+ raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_2_pos_mm)}")
199
+
200
+ self._reading_orders = []
201
+ # the first scan will define the expected reading orderd, and expected flip.
202
+ # if all scan are flipped then we will keep it this way
203
+ self._reading_orders.append(1)
204
+
205
+ @staticmethod
206
+ def _get_bunch_of_data(
207
+ bunch_start: int,
208
+ bunch_end: int,
209
+ step: int,
210
+ volumes: tuple,
211
+ flip_lr_arr: bool,
212
+ flip_ud_arr: bool,
213
+ ):
214
+ """
215
+ goal is to load contiguous frames as much as possible...
216
+ return for each volume the bunch of slice along axis 1
217
+ warning: they can have different shapes
218
+ """
219
+
220
+ def get_sub_volume(volume, flip_lr, flip_ud):
221
+ sub_volume = volume[:, bunch_start:bunch_end:step, :]
222
+ if flip_lr:
223
+ sub_volume = numpy.fliplr(sub_volume)
224
+ if flip_ud:
225
+ sub_volume = numpy.flipud(sub_volume)
226
+ return sub_volume
227
+
228
+ sub_volumes = [
229
+ get_sub_volume(volume, flip_lr, flip_ud)
230
+ for volume, flip_lr, flip_ud in zip(volumes, flip_lr_arr, flip_ud_arr)
231
+ ]
232
+ # generator on it self: we want to iterate over the y axis
233
+ n_slices_in_bunch = ceil((bunch_end - bunch_start) / step)
234
+ assert isinstance(n_slices_in_bunch, int)
235
+ for i in range(n_slices_in_bunch):
236
+ yield [sub_volume[:, i, :] for sub_volume in sub_volumes]
237
+
238
+ def compute_estimated_shifts(self):
239
+ axis_0_pos_px = self.configuration.axis_0_pos_px
240
+ self._axis_0_rel_ini_shifts = []
241
+ # compute overlap along axis 0
242
+ for upper_volume, lower_volume, upper_volume_axis_0_pos, lower_volume_axis_0_pos in zip(
243
+ self.series[:-1], self.series[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:]
244
+ ):
245
+ upper_volume_low_pos = upper_volume_axis_0_pos - upper_volume.get_volume_shape()[0] / 2
246
+ lower_volume_high_pos = lower_volume_axis_0_pos + lower_volume.get_volume_shape()[0] / 2
247
+ self._axis_0_rel_ini_shifts.append(
248
+ int(lower_volume_high_pos - upper_volume_low_pos) # overlap are expected to be int for now
249
+ )
250
+ self._axis_1_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_1_pos_px)
251
+ self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1)
252
+
253
+ def _compute_positions_as_px(self):
254
+ """compute if necessary position other axis 0 from volume metadata"""
255
+
256
+ def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm):
257
+ if pos_as_px is not None:
258
+ if pos_as_mm is not None:
259
+ raise ValueError(
260
+ f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})"
261
+ )
262
+ else:
263
+ return pos_as_px
264
+
265
+ elif pos_as_mm is not None:
266
+ # deduce from position given in configuration and pixel size
267
+ axis_N_pos_px = []
268
+ for volume, pos_in_mm in zip(self.series, pos_as_mm):
269
+ voxel_size_m = self.configuration.voxel_size or volume.voxel_size
270
+ axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0])
271
+ return axis_N_pos_px
272
+ else:
273
+ # deduce from motor position and pixel size
274
+ axis_N_pos_px = []
275
+ base_position_m = self.series[0].get_bounding_box(axis=axis).min
276
+ for volume in self.series:
277
+ voxel_size_m = self.configuration.voxel_size or volume.voxel_size
278
+ volume_axis_bb = volume.get_bounding_box(axis=axis)
279
+ axis_N_mean_pos_m = (volume_axis_bb.max - volume_axis_bb.min) / 2 + volume_axis_bb.min
280
+ axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
281
+ axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / voxel_size_m[0]))
282
+ return axis_N_pos_px
283
+
284
+ self.configuration.axis_0_pos_px = get_position_as_px_on_axis(
285
+ axis=0,
286
+ pos_as_px=self.configuration.axis_0_pos_px,
287
+ pos_as_mm=self.configuration.axis_0_pos_mm,
288
+ )
289
+ self.configuration.axis_0_pos_mm = None
290
+
291
+ self.configuration.axis_1_pos_px = get_position_as_px_on_axis(
292
+ axis=1,
293
+ pos_as_px=self.configuration.axis_1_pos_px,
294
+ pos_as_mm=self.configuration.axis_1_pos_mm,
295
+ )
296
+
297
+ self.configuration.axis_2_pos_px = get_position_as_px_on_axis(
298
+ axis=2,
299
+ pos_as_px=self.configuration.axis_2_pos_px,
300
+ pos_as_mm=self.configuration.axis_2_pos_mm,
301
+ )
302
+ self.configuration.axis_2_pos_mm = None
303
+
304
+ def _compute_shifts(self):
305
+ n_volumes = len(self.configuration.input_volumes)
306
+ if n_volumes == 0:
307
+ raise ValueError("no scan to stich provided")
308
+
309
+ slice_for_shift = self.configuration.slice_for_cross_correlation or "middle"
310
+ y_rel_shifts = self._axis_0_rel_ini_shifts
311
+ x_rel_shifts = self._axis_1_rel_ini_shifts
312
+ dim_axis_1 = max([volume.get_volume_shape()[1] for volume in self.series])
313
+
314
+ final_rel_shifts = []
315
+ for (
316
+ upper_volume,
317
+ lower_volume,
318
+ x_rel_shift,
319
+ y_rel_shift,
320
+ flip_ud_upper,
321
+ flip_ud_lower,
322
+ ) in zip(
323
+ self.series[:-1],
324
+ self.series[1:],
325
+ x_rel_shifts,
326
+ y_rel_shifts,
327
+ self.configuration.flip_ud[:-1],
328
+ self.configuration.flip_ud[1:],
329
+ ):
330
+ x_cross_algo = self.configuration.axis_2_params.get(KEY_IMG_REG_METHOD, None)
331
+ y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None)
332
+
333
+ # compute relative shift
334
+ found_shift_y, found_shift_x = find_volumes_relative_shifts(
335
+ upper_volume=upper_volume,
336
+ lower_volume=lower_volume,
337
+ dtype=self.get_output_data_type(),
338
+ dim_axis_1=dim_axis_1,
339
+ slice_for_shift=slice_for_shift,
340
+ x_cross_correlation_function=x_cross_algo,
341
+ y_cross_correlation_function=y_cross_algo,
342
+ x_shifts_params=self.configuration.axis_2_params,
343
+ y_shifts_params=self.configuration.axis_0_params,
344
+ estimated_shifts=(y_rel_shift, x_rel_shift),
345
+ flip_ud_lower_frame=flip_ud_lower,
346
+ flip_ud_upper_frame=flip_ud_upper,
347
+ alignment_axis_1=self.configuration.alignment_axis_1,
348
+ alignment_axis_2=self.configuration.alignment_axis_2,
349
+ overlap_axis=self.axis,
350
+ )
351
+ final_rel_shifts.append(
352
+ (found_shift_y, found_shift_x),
353
+ )
354
+
355
+ # set back values. Now position should start at 0
356
+ self._axis_0_rel_final_shifts = [final_shift[0] for final_shift in final_rel_shifts]
357
+ self._axis_1_rel_final_shifts = [final_shift[1] for final_shift in final_rel_shifts]
358
+ self._axis_2_rel_final_shifts = [0.0] * len(final_rel_shifts)
359
+ _logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_1_rel_final_shifts}")
360
+ print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_1_rel_final_shifts}")
361
+ _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_final_shifts}")
362
+ print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_final_shifts}")
363
+
364
+ def get_output_data_type(self):
365
+ if self.__output_data_type is None:
366
+
367
+ def find_output_data_type():
368
+ first_vol = self._input_volumes[0]
369
+ if first_vol.data is not None:
370
+ return first_vol.data.dtype
371
+ elif isinstance(first_vol, HDF5Volume):
372
+ with DatasetReader(first_vol.data_url) as vol_dataset:
373
+ return vol_dataset.dtype
374
+ else:
375
+ return first_vol.load_data(store=False).dtype
376
+
377
+ self.__output_data_type = find_output_data_type()
378
+ return self.__output_data_type
379
+
380
+ def _create_stitched_volume(self, store_composition: bool):
381
+ overlap_kernels = self._overlap_kernels
382
+ self._slices_to_stitch, n_slices = self.configuration.settle_slices()
383
+
384
+ # sync overwrite_results with volume overwrite parameter
385
+ self.configuration.output_volume.overwrite = self.configuration.overwrite_results
386
+
387
+ # init final volume
388
+ final_volume = self.configuration.output_volume
389
+ final_volume_shape = (
390
+ int(
391
+ numpy.asarray([volume.get_volume_shape()[0] for volume in self._input_volumes]).sum()
392
+ - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_final_shifts]).sum(),
393
+ ),
394
+ n_slices,
395
+ self._stitching_constant_length,
396
+ )
397
+
398
+ data_type = self.get_output_data_type()
399
+
400
+ if self.progress:
401
+ self.progress.total = final_volume_shape[1]
402
+
403
+ y_index = 0
404
+ if isinstance(self._slices_to_stitch, slice):
405
+ step = self._slices_to_stitch.step or 1
406
+ else:
407
+ step = 1
408
+
409
+ output_dataset_args = {
410
+ "volume": final_volume,
411
+ "volume_shape": final_volume_shape,
412
+ "dtype": data_type,
413
+ "dumper": self.dumper,
414
+ }
415
+ from .dumper.postprocessing import PostProcessingStitchingDumperNoDD
416
+
417
+ # TODO: FIXME: for now not very elegant but in the case of avoiding data duplication
418
+ # we need to provide the the information about the stitched part shape.
419
+ # this should be move to the dumper in the future
420
+ if isinstance(self.dumper, PostProcessingStitchingDumperNoDD):
421
+ output_dataset_args["stitching_sources_arr_shapes"] = tuple(
422
+ [(abs(overlap), n_slices, self._stitching_constant_length) for overlap in self._axis_0_rel_final_shifts]
423
+ )
424
+
425
+ with self.dumper.OutputDatasetContext(**output_dataset_args):
426
+ # note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
427
+ with _RawDatasetsContext(
428
+ self._input_volumes,
429
+ alignment_axis_1=self.configuration.alignment_axis_1,
430
+ ) as raw_datasets:
431
+ # note: raw_datasets can be numpy arrays or HDF5 dataset (in the case of HDF5Volume)
432
+ # to speed up we read by bunch of dataset. For numpy array this doesn't change anything
433
+ # but for HDF5 dataset this can speed up a lot the processing (depending on HDF5 dataset chuncks)
434
+ # note: we read trhough axis 1
435
+ if isinstance(self.dumper, PostProcessingStitchingDumperNoDD):
436
+ self.dumper.raw_regions_hdf5_dataset = raw_datasets
437
+ for bunch_start, bunch_end in PostProcessingStitching._data_bunch_iterator(
438
+ slices=self._slices_to_stitch, bunch_size=50
439
+ ):
440
+ for data_frames in PostProcessingStitching._get_bunch_of_data(
441
+ bunch_start,
442
+ bunch_end,
443
+ step=step,
444
+ volumes=raw_datasets,
445
+ flip_lr_arr=self.configuration.flip_lr,
446
+ flip_ud_arr=self.configuration.flip_ud,
447
+ ):
448
+ if self.configuration.rescale_frames:
449
+ data_frames = self.rescale_frames(data_frames)
450
+ if self.configuration.normalization_by_sample.is_active():
451
+ data_frames = self.normalize_frame_by_sample(data_frames)
452
+
453
+ sf = PostProcessingStitching.stitch_frames(
454
+ frames=data_frames,
455
+ axis=self.axis,
456
+ output_dtype=data_type,
457
+ x_relative_shifts=self._axis_1_rel_final_shifts,
458
+ y_relative_shifts=self._axis_0_rel_final_shifts,
459
+ overlap_kernels=overlap_kernels,
460
+ dumper=self.dumper,
461
+ i_frame=y_index,
462
+ return_composition_cls=store_composition if y_index == 0 else False,
463
+ stitching_axis=self.axis,
464
+ check_inputs=y_index == 0, # on process check on the first iteration
465
+ )
466
+ if y_index == 0 and store_composition:
467
+ _, self._frame_composition = sf
468
+
469
+ if self.progress is not None:
470
+ self.progress.update()
471
+ y_index += 1
472
+
473
+ # alias to general API
474
+ def _create_stitching(self, store_composition):
475
+ self._create_stitched_volume(store_composition=store_composition)
476
+
477
+
478
+ class _RawDatasetsContext(AbstractContextManager):
479
+ """
480
+ return volume data for all input volume (target: used for volume stitching).
481
+ If the volume is an HDF5Volume then the HDF5 dataset will be used (on disk)
482
+ If the volume is of another type then it will be loaded in memory then used (more memory consuming)
483
+ """
484
+
485
+ def __init__(self, volumes: tuple, alignment_axis_1) -> None:
486
+ super().__init__()
487
+ for volume in volumes:
488
+ if not isinstance(volume, VolumeBase):
489
+ raise TypeError(
490
+ f"Volumes are expected to be an instance of {VolumeBase}. {type(volume)} provided instead"
491
+ )
492
+
493
+ self._volumes = volumes
494
+ self.__file_handlers = []
495
+ self._alignment_axis_1 = alignment_axis_1
496
+
497
+ @property
498
+ def alignment_axis_1(self):
499
+ return self._alignment_axis_1
500
+
501
+ def __enter__(self):
502
+ # handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory
503
+ datasets = []
504
+ shapes = {volume.get_volume_shape()[1] for volume in self._volumes}
505
+ axis_1_dim = max(shapes)
506
+ axis_1_need_padding = len(shapes) > 1
507
+
508
+ try:
509
+ for volume in self._volumes:
510
+ if volume.data is not None:
511
+ data = volume.data
512
+ elif isinstance(volume, HDF5Volume):
513
+ file_handler = HDF5File(volume.data_url.file_path(), mode="r")
514
+ dataset = file_handler[volume.data_url.data_path()]
515
+ data = dataset
516
+ self.__file_handlers.append(file_handler)
517
+ # for other file format: load the full dataset in memory
518
+ else:
519
+ data = volume.load_data(store=False)
520
+ if data is None:
521
+ raise ValueError(f"No data found for volume {volume.get_identifier()}")
522
+ if axis_1_need_padding:
523
+ data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
524
+ datasets.append(data)
525
+ except Exception as e:
526
+ # if some errors happen during loading HDF5
527
+ for file_handled in self.__file_handlers:
528
+ file_handled.close()
529
+ raise e
530
+
531
+ return datasets
532
+
533
+ def __exit__(self, exc_type, exc_value, traceback):
534
+ success = True
535
+ for file_handler in self.__file_handlers:
536
+ success = success and file_handler.close()
537
+ if exc_type is None:
538
+ return success
539
+
540
+ def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
541
+ alignment = AlignmentAxis1.from_value(alignment)
542
+ if alignment is AlignmentAxis1.BACK:
543
+ axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
544
+ elif alignment is AlignmentAxis1.CENTER:
545
+ half_width = int((axis_1_dim - data.shape[1]) / 2)
546
+ axis_1_pad_width = (half_width, axis_1_dim - data.shape[1] - half_width)
547
+ elif alignment is AlignmentAxis1.FRONT:
548
+ axis_1_pad_width = (0, axis_1_dim - data.shape[1])
549
+ else:
550
+ raise ValueError(f"alignment {alignment} is not handled")
551
+
552
+ return PaddedRawData(
553
+ data=data,
554
+ axis_1_pad_width=axis_1_pad_width,
555
+ )