nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/__init__.py +0 -0
  15. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  16. nabu/cuda/kernel.py +11 -2
  17. nabu/cuda/processing.py +2 -2
  18. nabu/cuda/src/cone.cu +77 -0
  19. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  20. nabu/cuda/utils.py +0 -6
  21. nabu/estimation/alignment.py +5 -19
  22. nabu/estimation/cor.py +173 -599
  23. nabu/estimation/cor_sino.py +356 -26
  24. nabu/estimation/focus.py +63 -11
  25. nabu/estimation/tests/test_cor.py +124 -58
  26. nabu/estimation/tests/test_focus.py +6 -6
  27. nabu/estimation/tilt.py +2 -1
  28. nabu/estimation/utils.py +5 -33
  29. nabu/io/__init__.py +1 -1
  30. nabu/io/cast_volume.py +1 -1
  31. nabu/io/reader.py +416 -21
  32. nabu/io/tests/test_readers.py +422 -0
  33. nabu/io/tests/test_writers.py +1 -102
  34. nabu/io/writer.py +4 -433
  35. nabu/opencl/kernel.py +14 -3
  36. nabu/opencl/processing.py +8 -0
  37. nabu/pipeline/config_validators.py +5 -2
  38. nabu/pipeline/datadump.py +12 -5
  39. nabu/pipeline/estimators.py +162 -188
  40. nabu/pipeline/fullfield/chunked.py +168 -92
  41. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  42. nabu/pipeline/fullfield/computations.py +2 -7
  43. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  44. nabu/pipeline/fullfield/nabu_config.py +37 -13
  45. nabu/pipeline/fullfield/processconfig.py +22 -13
  46. nabu/pipeline/fullfield/reconstruction.py +13 -9
  47. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  48. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  49. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  50. nabu/pipeline/params.py +21 -1
  51. nabu/pipeline/processconfig.py +1 -12
  52. nabu/pipeline/reader.py +146 -0
  53. nabu/pipeline/tests/test_estimators.py +44 -72
  54. nabu/pipeline/utils.py +4 -2
  55. nabu/pipeline/writer.py +10 -2
  56. nabu/preproc/ccd_cuda.py +1 -1
  57. nabu/preproc/ctf.py +14 -7
  58. nabu/preproc/ctf_cuda.py +2 -3
  59. nabu/preproc/double_flatfield.py +5 -12
  60. nabu/preproc/double_flatfield_cuda.py +2 -2
  61. nabu/preproc/flatfield.py +5 -1
  62. nabu/preproc/flatfield_cuda.py +5 -1
  63. nabu/preproc/phase.py +24 -73
  64. nabu/preproc/phase_cuda.py +5 -8
  65. nabu/preproc/tests/test_ctf.py +11 -7
  66. nabu/preproc/tests/test_flatfield.py +67 -122
  67. nabu/preproc/tests/test_paganin.py +54 -30
  68. nabu/processing/azim.py +206 -0
  69. nabu/processing/convolution_cuda.py +1 -1
  70. nabu/processing/fft_cuda.py +15 -17
  71. nabu/processing/histogram.py +2 -0
  72. nabu/processing/histogram_cuda.py +2 -1
  73. nabu/processing/kernel_base.py +3 -0
  74. nabu/processing/muladd_cuda.py +1 -0
  75. nabu/processing/padding_opencl.py +1 -1
  76. nabu/processing/roll_opencl.py +1 -0
  77. nabu/processing/rotation_cuda.py +2 -2
  78. nabu/processing/tests/test_fft.py +17 -10
  79. nabu/processing/unsharp_cuda.py +1 -1
  80. nabu/reconstruction/cone.py +104 -40
  81. nabu/reconstruction/fbp.py +3 -0
  82. nabu/reconstruction/fbp_base.py +7 -2
  83. nabu/reconstruction/filtering.py +20 -7
  84. nabu/reconstruction/filtering_cuda.py +7 -1
  85. nabu/reconstruction/hbp.py +424 -0
  86. nabu/reconstruction/mlem.py +99 -0
  87. nabu/reconstruction/reconstructor.py +2 -0
  88. nabu/reconstruction/rings_cuda.py +19 -19
  89. nabu/reconstruction/sinogram_cuda.py +1 -0
  90. nabu/reconstruction/sinogram_opencl.py +3 -1
  91. nabu/reconstruction/tests/test_cone.py +10 -5
  92. nabu/reconstruction/tests/test_deringer.py +7 -6
  93. nabu/reconstruction/tests/test_fbp.py +124 -10
  94. nabu/reconstruction/tests/test_filtering.py +13 -11
  95. nabu/reconstruction/tests/test_halftomo.py +30 -4
  96. nabu/reconstruction/tests/test_mlem.py +91 -0
  97. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  98. nabu/resources/dataset_analyzer.py +142 -92
  99. nabu/resources/gpu.py +1 -0
  100. nabu/resources/nxflatfield.py +134 -125
  101. nabu/resources/templates/id16a_fluo.conf +42 -0
  102. nabu/resources/tests/test_extract.py +10 -0
  103. nabu/resources/tests/test_nxflatfield.py +2 -2
  104. nabu/stitching/alignment.py +80 -24
  105. nabu/stitching/config.py +105 -68
  106. nabu/stitching/definitions.py +1 -0
  107. nabu/stitching/frame_composition.py +68 -60
  108. nabu/stitching/overlap.py +91 -51
  109. nabu/stitching/single_axis_stitching.py +32 -0
  110. nabu/stitching/slurm_utils.py +6 -6
  111. nabu/stitching/stitcher/__init__.py +0 -0
  112. nabu/stitching/stitcher/base.py +124 -0
  113. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  114. nabu/stitching/stitcher/dumper/base.py +94 -0
  115. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  116. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  117. nabu/stitching/stitcher/post_processing.py +555 -0
  118. nabu/stitching/stitcher/pre_processing.py +1068 -0
  119. nabu/stitching/stitcher/single_axis.py +484 -0
  120. nabu/stitching/stitcher/stitcher.py +0 -0
  121. nabu/stitching/stitcher/y_stitcher.py +13 -0
  122. nabu/stitching/stitcher/z_stitcher.py +45 -0
  123. nabu/stitching/stitcher_2D.py +278 -0
  124. nabu/stitching/tests/test_config.py +12 -37
  125. nabu/stitching/tests/test_frame_composition.py +33 -59
  126. nabu/stitching/tests/test_overlap.py +149 -7
  127. nabu/stitching/tests/test_utils.py +1 -1
  128. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  129. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  130. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  131. nabu/stitching/utils/__init__.py +1 -0
  132. nabu/stitching/utils/post_processing.py +281 -0
  133. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  134. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  135. nabu/stitching/y_stitching.py +27 -0
  136. nabu/stitching/z_stitching.py +32 -2281
  137. nabu/testutils.py +1 -152
  138. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  139. nabu/utils.py +158 -61
  140. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
  141. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
  142. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
  143. nabu/io/tiffwriter_zmm.py +0 -99
  144. nabu/pipeline/fallback_utils.py +0 -149
  145. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  146. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  147. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  148. nabu/pipeline/helical/utils.py +0 -51
  149. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  150. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  151. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  152. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1068 @@
1
+ import numpy
2
+ import logging
3
+ import h5py
4
+ import os
5
+ from typing import Iterable
6
+ from silx.io.url import DataUrl
7
+ from silx.io.utils import get_data
8
+ from datetime import datetime
9
+
10
+ from nxtomo.nxobject.nxdetector import ImageKey
11
+ from nxtomo.application.nxtomo import NXtomo
12
+ from nxtomo.nxobject.nxtransformations import NXtransformations
13
+ from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation, DetZFlipTransformation
14
+ from nxtomo.paths.nxtomo import get_paths as _get_nexus_paths
15
+
16
+ from tomoscan.io import HDF5File
17
+ from tomoscan.series import Series
18
+ from tomoscan.esrf import NXtomoScan, EDFTomoScan
19
+ from tomoscan.esrf.scan.utils import (
20
+ get_compacted_dataslices,
21
+ ) # this version has a 'return_url_set' needed here. At one point they should be merged together
22
+ from nabu.stitching.config import (
23
+ PreProcessedSingleAxisStitchingConfiguration,
24
+ KEY_IMG_REG_METHOD,
25
+ )
26
+ from nabu.stitching.utils import find_projections_relative_shifts
27
+ from functools import lru_cache as cache
28
+ from .single_axis import SingleAxisStitcher
29
+ from pyunitsystem.metricsystem import MetricSystem
30
+
31
+
32
+ _logger = logging.getLogger(__name__)
33
+
34
+
35
+ class PreProcessingStitching(SingleAxisStitcher):
36
+ """
37
+ loader to be used when save data during pre-processing stitching (on projections). Output is expected to be an NXtomo
38
+
39
+ warning: axis are provided according to the `acquisition space <https://tomo.gitlab-pages.esrf.fr/bliss-tomo/master/modelization.html>`_
40
+ """
41
+
42
+ def __init__(self, configuration, progress=None) -> None:
43
+ """ """
44
+ if not isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration):
45
+ raise TypeError(
46
+ f"configuration is expected to be an instance of {PreProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead"
47
+ )
48
+ super().__init__(configuration, progress=progress)
49
+ self._series = Series("series", iterable=configuration.input_scans, use_identifiers=False)
50
+ self._reading_orders = []
51
+ # TODO: rename flips to axis_0_flips, axis_1_flips, axis_2_flips...
52
+ self._x_flips = []
53
+ self._y_flips = []
54
+ self._z_flips = []
55
+
56
+ # 'expend' auto shift request if only set once for all
57
+ if numpy.isscalar(self.configuration.axis_0_pos_px):
58
+ self.configuration.axis_0_pos_px = [
59
+ self.configuration.axis_0_pos_px,
60
+ ] * (len(self.series) - 1)
61
+ if numpy.isscalar(self.configuration.axis_1_pos_px):
62
+ self.configuration.axis_1_pos_px = [
63
+ self.configuration.axis_1_pos_px,
64
+ ] * (len(self.series) - 1)
65
+ if numpy.isscalar(self.configuration.axis_1_pos_px):
66
+ self.configuration.axis_1_pos_px = [
67
+ self.configuration.axis_1_pos_px,
68
+ ] * (len(self.series) - 1)
69
+
70
+ if self.configuration.axis_0_params is None:
71
+ self.configuration.axis_0_params = {}
72
+ if self.configuration.axis_1_params is None:
73
+ self.configuration.axis_1_params = {}
74
+ if self.configuration.axis_2_params is None:
75
+ self.configuration.axis_2_params = {}
76
+
77
+ def pre_processing_computation(self):
78
+ self.compute_reduced_flats_and_darks()
79
+
80
+ @property
81
+ def stitching_axis_in_frame_space(self):
82
+ if self.axis == 0:
83
+ return 0
84
+ elif self.axis == 1:
85
+ return 1
86
+ elif self.axis == 2:
87
+ raise NotImplementedError(
88
+ "pre-processing stitching along axis 2 is not handled. This would require to do interpolation between frame along the rotation angle. Just not possible"
89
+ )
90
+ else:
91
+ raise NotImplementedError(f"stitching axis must be in (0, 1, 2). Get {self.axis}")
92
+
93
+ @property
94
+ def x_flips(self) -> list:
95
+ return self._x_flips
96
+
97
+ @property
98
+ def y_flips(self) -> list:
99
+ return self._y_flips
100
+
101
+ def order_input_tomo_objects(self):
102
+
103
+ def get_min_bound(scan):
104
+ return scan.get_bounding_box(axis=self.axis).min
105
+
106
+ # order scans along the stitched axis
107
+ if self.axis == 0:
108
+ position_along_stitched_axis = self.configuration.axis_0_pos_px
109
+ elif self.axis == 1:
110
+ position_along_stitched_axis = self.configuration.axis_1_pos_px
111
+ else:
112
+ raise ValueError(
113
+ "stitching cannot be done along axis 2 for pre-processing. This would require to interpolate frame between different rotation angle"
114
+ )
115
+ # if axis 0 position is provided then use directly it
116
+ if position_along_stitched_axis is not None and len(position_along_stitched_axis) > 0:
117
+ order = numpy.argsort(position_along_stitched_axis)[::-1]
118
+ sorted_series = Series(
119
+ self.series.name,
120
+ numpy.take_along_axis(numpy.array(self.series[:]), order, axis=0),
121
+ use_identifiers=False,
122
+ )
123
+ else:
124
+ # else use bounding box
125
+ sorted_series = Series(
126
+ self.series.name,
127
+ sorted(self.series[:], key=get_min_bound, reverse=True),
128
+ use_identifiers=False,
129
+ )
130
+ if sorted_series != self.series:
131
+ if sorted_series[:] != self.series[::-1]:
132
+ raise ValueError(
133
+ f"Unable to get comprehensive input. Axis {self.axis} (decreasing) ordering is not respected."
134
+ )
135
+ else:
136
+ _logger.warning(
137
+ f"decreasing order haven't been respected. Need to reorder {self.serie_label} ({[str(scan) for scan in sorted_series[:]]}). Will also reorder overlap height, stitching height and invert shifts"
138
+ )
139
+ if self.configuration.axis_0_pos_mm is not None:
140
+ self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1]
141
+ if self.configuration.axis_0_pos_px is not None:
142
+ self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1]
143
+ if self.configuration.axis_1_pos_mm is not None:
144
+ self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1]
145
+ if self.configuration.axis_1_pos_px is not None:
146
+ self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1]
147
+ if self.configuration.axis_2_pos_mm is not None:
148
+ self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1]
149
+ if self.configuration.axis_2_pos_px is not None:
150
+ self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1]
151
+ if not numpy.isscalar(self._configuration.flip_ud):
152
+ self._configuration.flip_ud = self._configuration.flip_ud[::-1]
153
+ if not numpy.isscalar(self._configuration.flip_lr):
154
+ self._configuration.flip_ud = self._configuration.flip_lr[::-1]
155
+
156
+ self._series = sorted_series
157
+
158
+ def check_inputs(self):
159
+ """
160
+ insure input data is coherent
161
+ """
162
+ n_scans = len(self.series)
163
+ if n_scans == 0:
164
+ raise ValueError("no scan to stich together")
165
+
166
+ for scan in self.series:
167
+ from tomoscan.scanbase import TomoScanBase
168
+
169
+ if not isinstance(scan, TomoScanBase):
170
+ raise TypeError(f"z-preproc stitching expects instances of {TomoScanBase}. {type(scan)} provided.")
171
+
172
+ # check output file path and data path are provided
173
+ if self.configuration.output_file_path in (None, ""):
174
+ raise ValueError("output_file_path should be provided to the configuration")
175
+ if self.configuration.output_data_path in (None, ""):
176
+ raise ValueError("output_data_path should be provided to the configuration")
177
+
178
+ # check number of shift provided
179
+ for axis_pos_px, axis_name in zip(
180
+ (
181
+ self.configuration.axis_0_pos_px,
182
+ self.configuration.axis_1_pos_px,
183
+ self.configuration.axis_1_pos_px,
184
+ self.configuration.axis_0_pos_mm,
185
+ self.configuration.axis_1_pos_mm,
186
+ self.configuration.axis_2_pos_mm,
187
+ ),
188
+ (
189
+ "axis_0_pos_px",
190
+ "axis_1_pos_px",
191
+ "axis_2_pos_px",
192
+ "axis_0_pos_mm",
193
+ "axis_1_pos_mm",
194
+ "axis_2_pos_mm",
195
+ ),
196
+ ):
197
+ if isinstance(axis_pos_px, Iterable) and len(axis_pos_px) != (n_scans):
198
+ raise ValueError(f"{axis_name} expect {n_scans} shift defined. Get {len(axis_pos_px)}")
199
+
200
+ self._reading_orders = []
201
+ # the first scan will define the expected reading orderd, and expected flip.
202
+ # if all scan are flipped then we will keep it this way
203
+ self._reading_orders.append(1)
204
+
205
+ # check scans are coherent (nb projections, rotation angle, energy...)
206
+ for scan_0, scan_1 in zip(self.series[0:-1], self.series[1:]):
207
+ if len(scan_0.projections) != len(scan_1.projections):
208
+ raise ValueError(f"{scan_0} and {scan_1} have a different number of projections")
209
+ if isinstance(scan_0, NXtomoScan) and isinstance(scan_1, NXtomoScan):
210
+ # check rotation (only of is an NXtomoScan)
211
+ scan_0_angles = numpy.asarray(scan_0.rotation_angle)
212
+ scan_0_projections_angles = scan_0_angles[
213
+ numpy.asarray(scan_0.image_key_control) == ImageKey.PROJECTION.value
214
+ ]
215
+ scan_1_angles = numpy.asarray(scan_1.rotation_angle)
216
+ scan_1_projections_angles = scan_1_angles[
217
+ numpy.asarray(scan_1.image_key_control) == ImageKey.PROJECTION.value
218
+ ]
219
+ if not numpy.allclose(scan_0_projections_angles, scan_1_projections_angles, atol=10e-1):
220
+ if numpy.allclose(
221
+ scan_0_projections_angles,
222
+ scan_1_projections_angles[::-1],
223
+ atol=10e-1,
224
+ ):
225
+ reading_order = -1 * self._reading_orders[-1]
226
+ else:
227
+ raise ValueError(f"Angles from {scan_0} and {scan_1} are different")
228
+ else:
229
+ reading_order = 1 * self._reading_orders[-1]
230
+ self._reading_orders.append(reading_order)
231
+ # check energy
232
+ if scan_0.energy is None:
233
+ _logger.warning(f"no energy found for {scan_0}")
234
+ elif not numpy.isclose(scan_0.energy, scan_1.energy, rtol=1e-03):
235
+ _logger.warning(
236
+ f"different energy found between {scan_0} ({scan_0.energy}) and {scan_1} ({scan_1.energy})"
237
+ )
238
+ # check FOV
239
+ if not scan_0.field_of_view == scan_1.field_of_view:
240
+ raise ValueError(f"{scan_0} and {scan_1} have different field of view")
241
+ # check distance
242
+ if scan_0.distance is None:
243
+ _logger.warning(f"no distance found for {scan_0}")
244
+ elif not numpy.isclose(scan_0.distance, scan_1.distance, rtol=10e-3):
245
+ raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance")
246
+ # check pixel size
247
+ if not numpy.isclose(scan_0.x_pixel_size, scan_1.x_pixel_size):
248
+ raise ValueError(
249
+ f"{scan_0} and {scan_1} have different x pixel size. {scan_0.x_pixel_size} vs {scan_1.x_pixel_size}"
250
+ )
251
+ if not numpy.isclose(scan_0.y_pixel_size, scan_1.y_pixel_size):
252
+ raise ValueError(
253
+ f"{scan_0} and {scan_1} have different y pixel size. {scan_0.y_pixel_size} vs {scan_1.y_pixel_size}"
254
+ )
255
+
256
+ for scan in self.series:
257
+ # check x, y and z translation are constant (only if is an NXtomoScan)
258
+ if isinstance(scan, NXtomoScan):
259
+ if scan.x_translation is not None and not numpy.isclose(
260
+ min(scan.x_translation), max(scan.x_translation)
261
+ ):
262
+ _logger.warning(
263
+ "x translations appears to be evolving over time. Might end up with wrong stitching"
264
+ )
265
+ if scan.y_translation is not None and not numpy.isclose(
266
+ min(scan.y_translation), max(scan.y_translation)
267
+ ):
268
+ _logger.warning(
269
+ "y translations appears to be evolving over time. Might end up with wrong stitching"
270
+ )
271
+ if scan.z_translation is not None and not numpy.isclose(
272
+ min(scan.z_translation), max(scan.z_translation)
273
+ ):
274
+ _logger.warning(
275
+ "z translations appears to be evolving over time. Might end up with wrong stitching"
276
+ )
277
+
278
+ def _compute_positions_as_px(self):
279
+ """insure we have or we can deduce an estimated position as pixel"""
280
+
281
+ def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm):
282
+ if pos_as_px is not None:
283
+ if pos_as_mm is not None:
284
+ raise ValueError(
285
+ f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})"
286
+ )
287
+ else:
288
+ return pos_as_px
289
+
290
+ elif pos_as_mm is not None:
291
+ # deduce from position given in configuration and pixel size
292
+ axis_N_pos_px = []
293
+ for scan, pos_in_mm in zip(self.series, pos_as_mm):
294
+ pixel_size_m = self.configuration.pixel_size or scan.pixel_size
295
+ axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / pixel_size_m)
296
+ return axis_N_pos_px
297
+ else:
298
+ # deduce from motor position and pixel size
299
+ axis_N_pos_px = []
300
+ base_position_m = self.series[0].get_bounding_box(axis=axis).min
301
+ for scan in self.series:
302
+ pixel_size_m = self.configuration.pixel_size or scan.pixel_size
303
+ scan_axis_bb = scan.get_bounding_box(axis=axis)
304
+ axis_N_mean_pos_m = (scan_axis_bb.max - scan_axis_bb.min) / 2 + scan_axis_bb.min
305
+ axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
306
+ axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / pixel_size_m))
307
+ return axis_N_pos_px
308
+
309
+ for axis, property_px_name, property_mm_name in zip(
310
+ (0, 1, 2),
311
+ (
312
+ "axis_0_pos_px",
313
+ "axis_1_pos_px",
314
+ "axis_2_pos_px",
315
+ ),
316
+ (
317
+ "axis_0_pos_mm",
318
+ "axis_1_pos_mm",
319
+ "axis_2_pos_mm",
320
+ ),
321
+ ):
322
+ assert hasattr(
323
+ self.configuration, property_px_name
324
+ ), f"configuration API changed. should have {property_px_name}"
325
+ assert hasattr(
326
+ self.configuration, property_mm_name
327
+ ), f"configuration API changed. should have {property_px_name}"
328
+ try:
329
+ new_px_position = get_position_as_px_on_axis(
330
+ axis=axis,
331
+ pos_as_px=getattr(self.configuration, property_px_name),
332
+ pos_as_mm=getattr(self.configuration, property_mm_name),
333
+ )
334
+ except ValueError:
335
+ # when unable to find the position
336
+ if axis == self.axis:
337
+ # if we cannot find position over the stitching axis then raise an error: unable to process without
338
+ raise
339
+ else:
340
+ _logger.warning(f"Unable to find position over axis {axis}. Set them to zero")
341
+ setattr(
342
+ self.configuration,
343
+ property_px_name,
344
+ numpy.array([0] * len(self.series)),
345
+ )
346
+ else:
347
+ setattr(
348
+ self.configuration,
349
+ property_px_name,
350
+ new_px_position,
351
+ )
352
+
353
+ # clear position in mm as the one we will used are the px one
354
+ self.configuration.axis_0_pos_mm = None
355
+ self.configuration.axis_1_pos_mm = None
356
+ self.configuration.axis_2_pos_mm = None
357
+
358
+ # add some log
359
+ if self.configuration.axis_2_pos_mm is not None or self.configuration.axis_2_pos_px is not None:
360
+ _logger.warning("axis 2 position is not handled by the stitcher. Will be ignored")
361
+ axis_0_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_0_pos_px])
362
+ axis_1_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_1_pos_px])
363
+ axis_2_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_2_pos_px])
364
+ _logger.info(f"axis 0 position to be used: " + axis_0_pos)
365
+ _logger.info(f"axis 1 position to be used: " + axis_1_pos)
366
+ _logger.info(f"axis 2 position to be used: " + axis_2_pos)
367
+ _logger.info(f"stitching will be applied along axis: {self.axis}")
368
+
369
+ def compute_estimated_shifts(self):
370
+ if self.axis == 0:
371
+ # if we want to stitch over axis 0 (aka z)
372
+ axis_0_pos_px = self.configuration.axis_0_pos_px
373
+ self._axis_0_rel_ini_shifts = []
374
+ # compute overlap along axis 0
375
+ for upper_scan, lower_scan, upper_scan_axis_0_pos, lower_scan_axis_0_pos in zip(
376
+ self.series[:-1], self.series[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:]
377
+ ):
378
+ upper_scan_pos = upper_scan_axis_0_pos - upper_scan.dim_2 / 2
379
+ lower_scan_high_pos = lower_scan_axis_0_pos + lower_scan.dim_2 / 2
380
+ # simple test of overlap. More complete test are run by check_overlaps later
381
+ if lower_scan_high_pos <= upper_scan_pos:
382
+ raise ValueError(f"no overlap found between {upper_scan} and {lower_scan}")
383
+ self._axis_0_rel_ini_shifts.append(
384
+ int(lower_scan_high_pos - upper_scan_pos) # overlap are expected to be int for now
385
+ )
386
+ self._axis_1_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_1_pos_px)
387
+ self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1)
388
+ elif self.axis == 1:
389
+ # if we want to stitch over axis 1 (aka Y in acquisition reference - which is x in frame reference)
390
+ axis_1_pos_px = self.configuration.axis_1_pos_px
391
+ self._axis_1_rel_ini_shifts = []
392
+ # compute overlap along axis 0
393
+ for left_scan, right_scan, left_scan_axis_1_pos, right_scan_axis_1_pos in zip(
394
+ self.series[:-1], self.series[1:], axis_1_pos_px[:-1], axis_1_pos_px[1:]
395
+ ):
396
+ left_scan_pos = left_scan_axis_1_pos - left_scan.dim_1 / 2
397
+ right_scan_high_pos = right_scan_axis_1_pos + right_scan.dim_1 / 2
398
+ # simple test of overlap. More complete test are run by check_overlaps later
399
+ if right_scan_high_pos <= left_scan_pos:
400
+ raise ValueError(f"no overlap found between {left_scan} and {right_scan}")
401
+ self._axis_1_rel_ini_shifts.append(
402
+ int(right_scan_high_pos - left_scan_pos) # overlap are expected to be int for now
403
+ )
404
+ self._axis_0_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_0_pos_px)
405
+ self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1)
406
+ else:
407
+ raise NotImplementedError("stitching only forseen for axis 0 and 1 for now")
408
+
409
+ def _compute_shifts(self):
410
+ """
411
+ compute all shift requested (set to 'auto' in the configuration)
412
+
413
+ """
414
+ n_scans = len(self.configuration.input_scans)
415
+ if n_scans == 0:
416
+ raise ValueError("no scan to stich provided")
417
+
418
+ projection_for_shift = self.configuration.slice_for_cross_correlation or "middle"
419
+ if self.axis not in (0, 1):
420
+ raise NotImplementedError("only stitching over axis 0 and 2 are handled for pre-processing stitching")
421
+
422
+ final_rel_shifts = []
423
+ for (
424
+ scan_0,
425
+ scan_1,
426
+ order_s0,
427
+ order_s1,
428
+ x_rel_shift,
429
+ y_rel_shift,
430
+ ) in zip(
431
+ self.series[:-1],
432
+ self.series[1:],
433
+ self.reading_orders[:-1],
434
+ self.reading_orders[1:],
435
+ self._axis_1_rel_ini_shifts,
436
+ self._axis_0_rel_ini_shifts,
437
+ ):
438
+ x_cross_algo = self.configuration.axis_1_params.get(KEY_IMG_REG_METHOD, None)
439
+ y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None)
440
+
441
+ # compute relative shift
442
+ found_shift_y, found_shift_x = find_projections_relative_shifts(
443
+ upper_scan=scan_0,
444
+ lower_scan=scan_1,
445
+ projection_for_shift=projection_for_shift,
446
+ x_cross_correlation_function=x_cross_algo,
447
+ y_cross_correlation_function=y_cross_algo,
448
+ x_shifts_params=self.configuration.axis_1_params, # image x map acquisition axis 1 (Y)
449
+ y_shifts_params=self.configuration.axis_0_params, # image y map acquisition axis 0 (Z)
450
+ invert_order=order_s1 != order_s0,
451
+ estimated_shifts=(y_rel_shift, x_rel_shift),
452
+ axis=self.axis,
453
+ )
454
+ final_rel_shifts.append(
455
+ (found_shift_y, found_shift_x),
456
+ )
457
+
458
+ # set back values. Now position should start at 0
459
+ self._axis_0_rel_final_shifts = [final_shift[0] for final_shift in final_rel_shifts]
460
+ self._axis_1_rel_final_shifts = [final_shift[1] for final_shift in final_rel_shifts]
461
+ self._axis_2_rel_final_shifts = [0.0] * len(final_rel_shifts)
462
+ _logger.info(f"axis 1 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_final_shifts}")
463
+ print(f"axis 1 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_final_shifts}")
464
+ _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_1_rel_final_shifts}")
465
+ print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_1_rel_final_shifts}")
466
+
467
+ def _create_nx_tomo(self, store_composition: bool = False):
468
+ """
469
+ create final NXtomo with stitched frames.
470
+ Policy: save all projections flat fielded. So this NXtomo will only contain projections (no dark and no flat).
471
+ But nabu will be able to reconstruct it with field `flatfield` set to False
472
+ """
473
+ nx_tomo = NXtomo()
474
+
475
+ nx_tomo.energy = self.series[0].energy
476
+ start_times = list(filter(None, [scan.start_time for scan in self.series]))
477
+ end_times = list(filter(None, [scan.end_time for scan in self.series]))
478
+
479
+ if len(start_times) > 0:
480
+ nx_tomo.start_time = (
481
+ numpy.asarray([numpy.datetime64(start_time) for start_time in start_times]).min().astype(datetime)
482
+ )
483
+ else:
484
+ _logger.warning("Unable to find any start_time from input")
485
+ if len(end_times) > 0:
486
+ nx_tomo.end_time = (
487
+ numpy.asarray([numpy.datetime64(end_time) for end_time in end_times]).max().astype(datetime)
488
+ )
489
+ else:
490
+ _logger.warning("Unable to find any end_time from input")
491
+
492
+ title = ";".join([scan.sequence_name or "" for scan in self.series])
493
+ nx_tomo.title = f"stitch done from {title}"
494
+
495
+ self._slices_to_stitch, n_proj = self.configuration.settle_slices()
496
+
497
+ # handle detector (without frames)
498
+ nx_tomo.instrument.detector.field_of_view = self.series[0].field_of_view
499
+ nx_tomo.instrument.detector.distance = self.series[0].distance
500
+ nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size
501
+ nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size
502
+ nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
503
+ nx_tomo.instrument.detector.tomo_n = n_proj
504
+ # note: stitching process insure un-flipping of frames. So make sure transformations is defined as an empty set
505
+ nx_tomo.instrument.detector.transformations = NXtransformations()
506
+
507
+ if isinstance(self.series[0], NXtomoScan):
508
+ # note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
509
+ rotation_angle = numpy.asarray(self.series[0].rotation_angle)
510
+ nx_tomo.sample.rotation_angle = rotation_angle[
511
+ numpy.asarray(self.series[0].image_key_control) == ImageKey.PROJECTION.value
512
+ ]
513
+ elif isinstance(self.series[0], EDFTomoScan):
514
+ nx_tomo.sample.rotation_angle = numpy.linspace(
515
+ start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n
516
+ )
517
+ else:
518
+ raise NotImplementedError(
519
+ f"scan type ({type(self.series[0])} is not handled)",
520
+ NXtomoScan,
521
+ isinstance(self.series[0], NXtomoScan),
522
+ )
523
+
524
+ # do a sub selection of the rotation angle if a we are only computing a part of the slices
525
+ def apply_slices_selection(array, slices, allow_empty: bool = False):
526
+ if isinstance(slices, slice):
527
+ return array[slices.start : slices.stop : 1]
528
+ elif isinstance(slices, Iterable):
529
+ return list([array[index] for index in slices])
530
+ else:
531
+ raise RuntimeError("slices must be instance of a slice or of an iterable")
532
+
533
+ nx_tomo.sample.rotation_angle = apply_slices_selection(
534
+ array=nx_tomo.sample.rotation_angle, slices=self._slices_to_stitch
535
+ )
536
+
537
+ # handle sample
538
+ if False not in [isinstance(scan, NXtomoScan) for scan in self.series]:
539
+
540
+ def get_sample_translation_for_projs(scan: NXtomoScan, attr):
541
+ values = numpy.array(getattr(scan, attr))
542
+ mask = scan.image_key_control == ImageKey.PROJECTION.value
543
+ return values[mask]
544
+
545
+ # we consider the new x, y and z position to be at the center of the one created
546
+ x_translation = [
547
+ get_sample_translation_for_projs(scan, "x_translation")
548
+ for scan in self.series
549
+ if scan.x_translation is not None
550
+ ]
551
+ if len(x_translation) > 0:
552
+ # if there is some metadata about {x|y|z} translations
553
+ # we want to take the mean of each frame for each projections
554
+ x_translation = apply_slices_selection(
555
+ numpy.array(x_translation).mean(axis=0),
556
+ slices=self._slices_to_stitch,
557
+ )
558
+ else:
559
+ # if no NXtomo has information about x_translation.
560
+ # note: if at least one has missing values the numpy.Array(x_translation) with create an error as well
561
+ x_translation = [0.0] * n_proj
562
+ _logger.warning("Unable to fin input nxtomo x_translation values. Set it to 0.0")
563
+ nx_tomo.sample.x_translation = x_translation
564
+
565
+ y_translation = [
566
+ get_sample_translation_for_projs(scan, "y_translation")
567
+ for scan in self.series
568
+ if scan.y_translation is not None
569
+ ]
570
+ if len(y_translation) > 0:
571
+ y_translation = apply_slices_selection(
572
+ numpy.array(y_translation).mean(axis=0),
573
+ slices=self._slices_to_stitch,
574
+ )
575
+ else:
576
+ y_translation = [0.0] * n_proj
577
+ _logger.warning("Unable to fin input nxtomo y_translation values. Set it to 0.0")
578
+ nx_tomo.sample.y_translation = y_translation
579
+ z_translation = [
580
+ get_sample_translation_for_projs(scan, "z_translation")
581
+ for scan in self.series
582
+ if scan.z_translation is not None
583
+ ]
584
+ if len(z_translation) > 0:
585
+ z_translation = apply_slices_selection(
586
+ numpy.array(z_translation).mean(axis=0),
587
+ slices=self._slices_to_stitch,
588
+ )
589
+ else:
590
+ z_translation = [0.0] * n_proj
591
+ _logger.warning("Unable to fin input nxtomo z_translation values. Set it to 0.0")
592
+ nx_tomo.sample.z_translation = z_translation
593
+
594
+ nx_tomo.sample.name = self.series[0].sample_name
595
+
596
+ # compute stitched frame shape
597
+ if self.axis == 0:
598
+ stitched_frame_shape = (
599
+ n_proj,
600
+ (
601
+ numpy.asarray([scan.dim_2 for scan in self.series]).sum()
602
+ - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_final_shifts]).sum()
603
+ ),
604
+ self._stitching_constant_length,
605
+ )
606
+ elif self.axis == 1:
607
+ stitched_frame_shape = (
608
+ n_proj,
609
+ self._stitching_constant_length,
610
+ (
611
+ numpy.asarray([scan.dim_1 for scan in self.series]).sum()
612
+ - numpy.asarray([abs(overlap) for overlap in self._axis_1_rel_final_shifts]).sum()
613
+ ),
614
+ )
615
+ else:
616
+ raise NotImplementedError("stitching on pre-processing along axis 2 (x-ray direction) is not handled")
617
+
618
+ if stitched_frame_shape[0] < 1 or stitched_frame_shape[1] < 1 or stitched_frame_shape[2] < 1:
619
+ raise RuntimeError(f"Error in stitched frame shape calculation. {stitched_frame_shape} found.")
620
+ # get expected output dataset first (just in case output and input files are the same)
621
+ first_proj_idx = sorted(self.series[0].projections.keys())[0]
622
+ first_proj_url = self.series[0].projections[first_proj_idx]
623
+ if h5py.is_hdf5(first_proj_url.file_path()):
624
+ first_proj_url = DataUrl(
625
+ file_path=first_proj_url.file_path(),
626
+ data_path=first_proj_url.data_path(),
627
+ scheme="h5py",
628
+ )
629
+
630
+ # first save the NXtomo entry without the frame
631
+ # dicttonx will fail if the folder does not exists
632
+ dir_name = os.path.dirname(self.configuration.output_file_path)
633
+ if dir_name not in (None, ""):
634
+ os.makedirs(dir_name, exist_ok=True)
635
+ nx_tomo.save(
636
+ file_path=self.configuration.output_file_path,
637
+ data_path=self.configuration.output_data_path,
638
+ nexus_path_version=self.configuration.output_nexus_version,
639
+ overwrite=self.configuration.overwrite_results,
640
+ )
641
+
642
+ transformation_matrices = {
643
+ scan.get_identifier()
644
+ .to_str()
645
+ .center(80, "-"): numpy.array2string(build_matrix(scan.get_detector_transformations(tuple())))
646
+ for scan in self.series
647
+ }
648
+ _logger.info(
649
+ "scan detector transformation matrices are:\n"
650
+ "\n".join(["/n".join(item) for item in transformation_matrices.items()])
651
+ )
652
+
653
+ _logger.info(
654
+ f"reading order is {self.reading_orders}",
655
+ )
656
+
657
+ def get_output_data_type():
658
+ return numpy.float32 # because we will apply flat field correction on it and they are not raw data
659
+
660
+ output_dtype = get_output_data_type()
661
+ # append frames ("instrument/detector/data" dataset)
662
+ with HDF5File(self.configuration.output_file_path, mode="a") as h5f:
663
+ # note: nx_tomo.save already handles the possible overwrite conflict by removing
664
+ # self.configuration.output_file_path or raising an error
665
+
666
+ stitched_frame_path = "/".join(
667
+ [
668
+ self.configuration.output_data_path,
669
+ _get_nexus_paths(self.configuration.output_nexus_version).PROJ_PATH,
670
+ ]
671
+ )
672
+ self.dumper.output_dataset = h5f.create_dataset(
673
+ name=stitched_frame_path,
674
+ shape=stitched_frame_shape,
675
+ dtype=output_dtype,
676
+ )
677
+ # TODO: we could also create in several time and create a virtual dataset from it.
678
+ scans_projections_indexes = []
679
+ for scan, reverse in zip(self.series, self.reading_orders):
680
+ scans_projections_indexes.append(sorted(scan.projections.keys(), reverse=(reverse == -1)))
681
+ if self.progress:
682
+ self.progress.total = self.get_n_slices_to_stitch()
683
+
684
+ if isinstance(self._slices_to_stitch, slice):
685
+ step = self._slices_to_stitch.step or 1
686
+ else:
687
+ step = 1
688
+ i_proj = 0
689
+ for bunch_start, bunch_end in self._data_bunch_iterator(slices=self._slices_to_stitch, bunch_size=50):
690
+ for data_frames in self._get_bunch_of_data(
691
+ bunch_start,
692
+ bunch_end,
693
+ step=step,
694
+ scans=self.series,
695
+ scans_projections_indexes=scans_projections_indexes,
696
+ flip_ud_arr=self.configuration.flip_ud,
697
+ flip_lr_arr=self.configuration.flip_lr,
698
+ reading_orders=self.reading_orders,
699
+ ):
700
+ if self.configuration.rescale_frames:
701
+ data_frames = self.rescale_frames(data_frames)
702
+ if self.configuration.normalization_by_sample.is_active():
703
+ data_frames = self.normalize_frame_by_sample(data_frames)
704
+
705
+ sf = SingleAxisStitcher.stitch_frames(
706
+ frames=data_frames,
707
+ axis=self.axis,
708
+ x_relative_shifts=self._axis_1_rel_final_shifts,
709
+ y_relative_shifts=self._axis_0_rel_final_shifts,
710
+ overlap_kernels=self._overlap_kernels,
711
+ i_frame=i_proj,
712
+ output_dtype=output_dtype,
713
+ dumper=self.dumper,
714
+ return_composition_cls=store_composition if i_proj == 0 else False,
715
+ stitching_axis=self.axis,
716
+ pad_mode=self.configuration.pad_mode,
717
+ alignment=self.configuration.alignment_axis_2,
718
+ new_width=self._stitching_constant_length,
719
+ check_inputs=i_proj == 0, # on process check on the first iteration
720
+ )
721
+ if i_proj == 0 and store_composition:
722
+ _, self._frame_composition = sf
723
+ if self.progress is not None:
724
+ self.progress.update()
725
+
726
+ i_proj += 1
727
+
728
+ # create link to this dataset that can be missing
729
+ # "data/data" link
730
+ if "data" in h5f[self.configuration.output_data_path]:
731
+ data_group = h5f[self.configuration.output_data_path]["data"]
732
+ if not stitched_frame_path.startswith("/"):
733
+ stitched_frame_path = "/" + stitched_frame_path
734
+ data_group["data"] = h5py.SoftLink(stitched_frame_path)
735
+ if "default" not in h5f[self.configuration.output_data_path].attrs:
736
+ h5f[self.configuration.output_data_path].attrs["default"] = "data"
737
+ for attr_name, attr_value in zip(
738
+ ("NX_class", "SILX_style/axis_scale_types", "signal"),
739
+ ("NXdata", ["linear", "linear"], "data"),
740
+ ):
741
+ if attr_name not in data_group.attrs:
742
+ data_group.attrs[attr_name] = attr_value
743
+
744
+ return nx_tomo
745
+
746
+ def _create_stitching(self, store_composition):
747
+ self._create_nx_tomo(store_composition=store_composition)
748
+
749
+ @staticmethod
750
+ def get_bunch_of_data(
751
+ bunch_start: int,
752
+ bunch_end: int,
753
+ step: int,
754
+ scans: tuple,
755
+ scans_projections_indexes: tuple,
756
+ reading_orders: tuple,
757
+ flip_lr_arr: tuple,
758
+ flip_ud_arr: tuple,
759
+ ):
760
+ """
761
+ goal is to load contiguous projections as much as possible...
762
+
763
+ :param int bunch_start: begining of the bunch
764
+ :param int bunch_end: end of the bunch
765
+ :param int scans: ordered scan for which we want to get data
766
+ :param scans_projections_indexes: tuple with scans and scan projection indexes to be loaded
767
+ :param tuple flip_lr_arr: extra information from the user to left-right flip frames
768
+ :param tuple flip_ud_arr: extra information from the user to up-down flip frames
769
+ :return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together
770
+ """
771
+ assert len(scans) == len(scans_projections_indexes)
772
+ assert isinstance(flip_lr_arr, tuple)
773
+ assert isinstance(flip_ud_arr, tuple)
774
+ assert isinstance(step, int)
775
+ scans_proj_urls = []
776
+ # for each scan store the real indices and the data url
777
+
778
+ for scan, scan_projection_indexes in zip(scans, scans_projections_indexes):
779
+ scan_proj_urls = {}
780
+ # for each scan get the list of url to be loaded
781
+ for i_proj in range(bunch_start, bunch_end):
782
+ if i_proj % step != 0:
783
+ continue
784
+ proj_index_in_full_scan = scan_projection_indexes[i_proj]
785
+ scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan]
786
+ scans_proj_urls.append(scan_proj_urls)
787
+
788
+ # then load data
789
+ all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object)
790
+ from nabu.preproc.flatfield import FlatFieldArrays
791
+
792
+ for i_scan, (scan_urls, scan_flip_lr, scan_flip_ud, reading_order) in enumerate(
793
+ zip(scans_proj_urls, flip_lr_arr, flip_ud_arr, reading_orders)
794
+ ):
795
+ i_frame = 0
796
+ _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
797
+ for _, url in set_of_compacted_slices.items():
798
+ scan = scans[i_scan]
799
+ url = DataUrl(
800
+ file_path=url.file_path(),
801
+ data_path=url.data_path(),
802
+ scheme="silx",
803
+ data_slice=url.data_slice(),
804
+ )
805
+ raw_radios = get_data(url)[::reading_order]
806
+ radio_indices = url.data_slice()
807
+ if isinstance(radio_indices, slice):
808
+ step = radio_indices.step if radio_indices is not None else 1
809
+ radio_indices = numpy.arange(
810
+ start=radio_indices.start,
811
+ stop=radio_indices.stop,
812
+ step=step,
813
+ dtype=numpy.int16,
814
+ )
815
+
816
+ missing = []
817
+ if len(scan.reduced_flats) == 0:
818
+ missing = "flats"
819
+ if len(scan.reduced_darks) == 0:
820
+ missing = "darks"
821
+
822
+ if len(missing) > 0:
823
+ _logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
824
+ ff_arrays = None
825
+ data = raw_radios
826
+ else:
827
+ has_reduced_metadata = (
828
+ scan.reduced_flats_infos is not None
829
+ and len(scan.reduced_flats_infos.machine_electric_current) > 0
830
+ and scan.reduced_darks_infos is not None
831
+ and len(scan.reduced_darks_infos.machine_electric_current) > 0
832
+ )
833
+ if not has_reduced_metadata:
834
+ _logger.warning("no metadata about current found. Won't normalize according to machine current")
835
+
836
+ ff_arrays = FlatFieldArrays(
837
+ radios_shape=(len(radio_indices), scan.dim_2, scan.dim_1),
838
+ flats=scan.reduced_flats,
839
+ darks=scan.reduced_darks,
840
+ radios_indices=radio_indices,
841
+ radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None,
842
+ flats_srcurrent=(
843
+ scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None
844
+ ),
845
+ )
846
+ # note: we need to cast radios to float 32. Darks and flats are cast to anyway
847
+ data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32))
848
+
849
+ transformations = list(scans[i_scan].get_detector_transformations(tuple()))
850
+ if scan_flip_lr:
851
+ transformations.append(DetZFlipTransformation(flip=True))
852
+ if scan_flip_ud:
853
+ transformations.append(DetYFlipTransformation(flip=True))
854
+
855
+ transformation_matrix_det_space = build_matrix(transformations)
856
+ if transformation_matrix_det_space is None or numpy.allclose(
857
+ transformation_matrix_det_space, numpy.identity(3)
858
+ ):
859
+ flip_ud = False
860
+ flip_lr = False
861
+ elif numpy.array_equal(transformation_matrix_det_space, PreProcessingStitching._get_UD_flip_matrix()):
862
+ flip_ud = True
863
+ flip_lr = False
864
+ elif numpy.allclose(transformation_matrix_det_space, PreProcessingStitching._get_LR_flip_matrix()):
865
+ flip_ud = False
866
+ flip_lr = True
867
+ elif numpy.allclose(
868
+ transformation_matrix_det_space, PreProcessingStitching._get_UD_AND_LR_flip_matrix()
869
+ ):
870
+ flip_ud = True
871
+ flip_lr = True
872
+ else:
873
+ raise ValueError("case not handled... For now only handle up-down flip as left-right flip")
874
+
875
+ for frame in data:
876
+ if flip_ud:
877
+ frame = numpy.flipud(frame)
878
+ if flip_lr:
879
+ frame = numpy.fliplr(frame)
880
+ all_scan_final_data[i_frame, i_scan] = frame
881
+ i_frame += 1
882
+
883
+ return all_scan_final_data
884
+
885
+ def compute_reduced_flats_and_darks(self):
886
+ """
887
+ make sure reduced dark and flats are existing otherwise compute them
888
+ """
889
+ for scan in self.series:
890
+ try:
891
+ reduced_darks, darks_infos = scan.load_reduced_darks(return_info=True)
892
+ except:
893
+ _logger.info("no reduced dark found. Try to compute them.")
894
+ if reduced_darks in (None, {}):
895
+ reduced_darks, darks_infos = scan.compute_reduced_darks(return_info=True)
896
+ try:
897
+ # if we don't have write in the folder containing the .nx for example
898
+ scan.save_reduced_darks(reduced_darks, darks_infos=darks_infos)
899
+ except Exception as e:
900
+ pass
901
+ scan.set_reduced_darks(reduced_darks, darks_infos=darks_infos)
902
+
903
+ try:
904
+ reduced_flats, flats_infos = scan.load_reduced_flats(return_info=True)
905
+ except:
906
+ _logger.info("no reduced flats found. Try to compute them.")
907
+ if reduced_flats in (None, {}):
908
+ reduced_flats, flats_infos = scan.compute_reduced_flats(return_info=True)
909
+ try:
910
+ # if we don't have write in the folder containing the .nx for example
911
+ scan.save_reduced_flats(reduced_flats, flats_infos=flats_infos)
912
+ except Exception as e:
913
+ pass
914
+ scan.set_reduced_flats(reduced_flats, flats_infos=flats_infos)
915
+
916
+ @staticmethod
917
+ @cache(maxsize=None)
918
+ def _get_UD_flip_matrix():
919
+ return DetYFlipTransformation(flip=True).as_matrix()
920
+
921
+ @staticmethod
922
+ @cache(maxsize=None)
923
+ def _get_LR_flip_matrix():
924
+ return DetZFlipTransformation(flip=True).as_matrix()
925
+
926
+ @staticmethod
927
+ @cache(maxsize=None)
928
+ def _get_UD_AND_LR_flip_matrix():
929
+ return numpy.matmul(
930
+ PreProcessingStitching._get_UD_flip_matrix(),
931
+ PreProcessingStitching._get_LR_flip_matrix(),
932
+ )
933
+
934
+ @staticmethod
935
+ def _get_bunch_of_data(
936
+ bunch_start: int,
937
+ bunch_end: int,
938
+ step: int,
939
+ scans: tuple,
940
+ scans_projections_indexes: tuple,
941
+ reading_orders: tuple,
942
+ flip_lr_arr: tuple,
943
+ flip_ud_arr: tuple,
944
+ ):
945
+ """
946
+ goal is to load contiguous projections as much as possible...
947
+
948
+ :param int bunch_start: begining of the bunch
949
+ :param int bunch_end: end of the bunch
950
+ :param int scans: ordered scan for which we want to get data
951
+ :param scans_projections_indexes: tuple with scans and scan projection indexes to be loaded
952
+ :param tuple flip_lr_arr: extra information from the user to left-right flip frames
953
+ :param tuple flip_ud_arr: extra information from the user to up-down flip frames
954
+ :return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together
955
+ """
956
+ assert len(scans) == len(scans_projections_indexes)
957
+ assert isinstance(flip_lr_arr, tuple)
958
+ assert isinstance(flip_ud_arr, tuple)
959
+ assert isinstance(step, int)
960
+ scans_proj_urls = []
961
+ # for each scan store the real indices and the data url
962
+
963
+ for scan, scan_projection_indexes in zip(scans, scans_projections_indexes):
964
+ scan_proj_urls = {}
965
+ # for each scan get the list of url to be loaded
966
+ for i_proj in range(bunch_start, bunch_end):
967
+ if i_proj % step != 0:
968
+ continue
969
+ proj_index_in_full_scan = scan_projection_indexes[i_proj]
970
+ scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan]
971
+ scans_proj_urls.append(scan_proj_urls)
972
+
973
+ # then load data
974
+ all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object)
975
+ from nabu.preproc.flatfield import FlatFieldArrays
976
+
977
+ for i_scan, (scan_urls, scan_flip_lr, scan_flip_ud, reading_order) in enumerate(
978
+ zip(scans_proj_urls, flip_lr_arr, flip_ud_arr, reading_orders)
979
+ ):
980
+ i_frame = 0
981
+ _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
982
+ for _, url in set_of_compacted_slices.items():
983
+ scan = scans[i_scan]
984
+ url = DataUrl(
985
+ file_path=url.file_path(),
986
+ data_path=url.data_path(),
987
+ scheme="silx",
988
+ data_slice=url.data_slice(),
989
+ )
990
+ raw_radios = get_data(url)[::reading_order]
991
+ radio_indices = url.data_slice()
992
+ if isinstance(radio_indices, slice):
993
+ step = radio_indices.step if radio_indices is not None else 1
994
+ radio_indices = numpy.arange(
995
+ start=radio_indices.start,
996
+ stop=radio_indices.stop,
997
+ step=step,
998
+ dtype=numpy.int16,
999
+ )
1000
+
1001
+ missing = []
1002
+ if len(scan.reduced_flats) == 0:
1003
+ missing = "flats"
1004
+ if len(scan.reduced_darks) == 0:
1005
+ missing = "darks"
1006
+
1007
+ if len(missing) > 0:
1008
+ _logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
1009
+ ff_arrays = None
1010
+ data = raw_radios
1011
+ else:
1012
+ has_reduced_metadata = (
1013
+ scan.reduced_flats_infos is not None
1014
+ and len(scan.reduced_flats_infos.machine_electric_current) > 0
1015
+ and scan.reduced_darks_infos is not None
1016
+ and len(scan.reduced_darks_infos.machine_electric_current) > 0
1017
+ )
1018
+ if not has_reduced_metadata:
1019
+ _logger.warning("no metadata about current found. Won't normalize according to machine current")
1020
+
1021
+ ff_arrays = FlatFieldArrays(
1022
+ radios_shape=(len(radio_indices), scan.dim_2, scan.dim_1),
1023
+ flats=scan.reduced_flats,
1024
+ darks=scan.reduced_darks,
1025
+ radios_indices=radio_indices,
1026
+ radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None,
1027
+ flats_srcurrent=(
1028
+ scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None
1029
+ ),
1030
+ )
1031
+ # note: we need to cast radios to float 32. Darks and flats are cast to anyway
1032
+ data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32))
1033
+
1034
+ transformations = list(scans[i_scan].get_detector_transformations(tuple()))
1035
+ if scan_flip_lr:
1036
+ transformations.append(DetZFlipTransformation(flip=True))
1037
+ if scan_flip_ud:
1038
+ transformations.append(DetYFlipTransformation(flip=True))
1039
+
1040
+ transformation_matrix_det_space = build_matrix(transformations)
1041
+ if transformation_matrix_det_space is None or numpy.allclose(
1042
+ transformation_matrix_det_space, numpy.identity(3)
1043
+ ):
1044
+ flip_ud = False
1045
+ flip_lr = False
1046
+ elif numpy.array_equal(transformation_matrix_det_space, PreProcessingStitching._get_UD_flip_matrix()):
1047
+ flip_ud = True
1048
+ flip_lr = False
1049
+ elif numpy.allclose(transformation_matrix_det_space, PreProcessingStitching._get_LR_flip_matrix()):
1050
+ flip_ud = False
1051
+ flip_lr = True
1052
+ elif numpy.allclose(
1053
+ transformation_matrix_det_space, PreProcessingStitching._get_UD_AND_LR_flip_matrix()
1054
+ ):
1055
+ flip_ud = True
1056
+ flip_lr = True
1057
+ else:
1058
+ raise ValueError("case not handled... For now only handle up-down flip as left-right flip")
1059
+
1060
+ for frame in data:
1061
+ if flip_ud:
1062
+ frame = numpy.flipud(frame)
1063
+ if flip_lr:
1064
+ frame = numpy.fliplr(frame)
1065
+ all_scan_final_data[i_frame, i_scan] = frame
1066
+ i_frame += 1
1067
+
1068
+ return all_scan_final_data