nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/__init__.py +0 -0
  15. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  16. nabu/cuda/kernel.py +11 -2
  17. nabu/cuda/processing.py +2 -2
  18. nabu/cuda/src/cone.cu +77 -0
  19. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  20. nabu/cuda/utils.py +0 -6
  21. nabu/estimation/alignment.py +5 -19
  22. nabu/estimation/cor.py +173 -599
  23. nabu/estimation/cor_sino.py +356 -26
  24. nabu/estimation/focus.py +63 -11
  25. nabu/estimation/tests/test_cor.py +124 -58
  26. nabu/estimation/tests/test_focus.py +6 -6
  27. nabu/estimation/tilt.py +2 -1
  28. nabu/estimation/utils.py +5 -33
  29. nabu/io/__init__.py +1 -1
  30. nabu/io/cast_volume.py +1 -1
  31. nabu/io/reader.py +416 -21
  32. nabu/io/tests/test_readers.py +422 -0
  33. nabu/io/tests/test_writers.py +1 -102
  34. nabu/io/writer.py +4 -433
  35. nabu/opencl/kernel.py +14 -3
  36. nabu/opencl/processing.py +8 -0
  37. nabu/pipeline/config_validators.py +5 -2
  38. nabu/pipeline/datadump.py +12 -5
  39. nabu/pipeline/estimators.py +162 -188
  40. nabu/pipeline/fullfield/chunked.py +168 -92
  41. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  42. nabu/pipeline/fullfield/computations.py +2 -7
  43. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  44. nabu/pipeline/fullfield/nabu_config.py +37 -13
  45. nabu/pipeline/fullfield/processconfig.py +22 -13
  46. nabu/pipeline/fullfield/reconstruction.py +13 -9
  47. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  48. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  49. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  50. nabu/pipeline/params.py +21 -1
  51. nabu/pipeline/processconfig.py +1 -12
  52. nabu/pipeline/reader.py +146 -0
  53. nabu/pipeline/tests/test_estimators.py +44 -72
  54. nabu/pipeline/utils.py +4 -2
  55. nabu/pipeline/writer.py +10 -2
  56. nabu/preproc/ccd_cuda.py +1 -1
  57. nabu/preproc/ctf.py +14 -7
  58. nabu/preproc/ctf_cuda.py +2 -3
  59. nabu/preproc/double_flatfield.py +5 -12
  60. nabu/preproc/double_flatfield_cuda.py +2 -2
  61. nabu/preproc/flatfield.py +5 -1
  62. nabu/preproc/flatfield_cuda.py +5 -1
  63. nabu/preproc/phase.py +24 -73
  64. nabu/preproc/phase_cuda.py +5 -8
  65. nabu/preproc/tests/test_ctf.py +11 -7
  66. nabu/preproc/tests/test_flatfield.py +67 -122
  67. nabu/preproc/tests/test_paganin.py +54 -30
  68. nabu/processing/azim.py +206 -0
  69. nabu/processing/convolution_cuda.py +1 -1
  70. nabu/processing/fft_cuda.py +15 -17
  71. nabu/processing/histogram.py +2 -0
  72. nabu/processing/histogram_cuda.py +2 -1
  73. nabu/processing/kernel_base.py +3 -0
  74. nabu/processing/muladd_cuda.py +1 -0
  75. nabu/processing/padding_opencl.py +1 -1
  76. nabu/processing/roll_opencl.py +1 -0
  77. nabu/processing/rotation_cuda.py +2 -2
  78. nabu/processing/tests/test_fft.py +17 -10
  79. nabu/processing/unsharp_cuda.py +1 -1
  80. nabu/reconstruction/cone.py +104 -40
  81. nabu/reconstruction/fbp.py +3 -0
  82. nabu/reconstruction/fbp_base.py +7 -2
  83. nabu/reconstruction/filtering.py +20 -7
  84. nabu/reconstruction/filtering_cuda.py +7 -1
  85. nabu/reconstruction/hbp.py +424 -0
  86. nabu/reconstruction/mlem.py +99 -0
  87. nabu/reconstruction/reconstructor.py +2 -0
  88. nabu/reconstruction/rings_cuda.py +19 -19
  89. nabu/reconstruction/sinogram_cuda.py +1 -0
  90. nabu/reconstruction/sinogram_opencl.py +3 -1
  91. nabu/reconstruction/tests/test_cone.py +10 -5
  92. nabu/reconstruction/tests/test_deringer.py +7 -6
  93. nabu/reconstruction/tests/test_fbp.py +124 -10
  94. nabu/reconstruction/tests/test_filtering.py +13 -11
  95. nabu/reconstruction/tests/test_halftomo.py +30 -4
  96. nabu/reconstruction/tests/test_mlem.py +91 -0
  97. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  98. nabu/resources/dataset_analyzer.py +142 -92
  99. nabu/resources/gpu.py +1 -0
  100. nabu/resources/nxflatfield.py +134 -125
  101. nabu/resources/templates/id16a_fluo.conf +42 -0
  102. nabu/resources/tests/test_extract.py +10 -0
  103. nabu/resources/tests/test_nxflatfield.py +2 -2
  104. nabu/stitching/alignment.py +80 -24
  105. nabu/stitching/config.py +105 -68
  106. nabu/stitching/definitions.py +1 -0
  107. nabu/stitching/frame_composition.py +68 -60
  108. nabu/stitching/overlap.py +91 -51
  109. nabu/stitching/single_axis_stitching.py +32 -0
  110. nabu/stitching/slurm_utils.py +6 -6
  111. nabu/stitching/stitcher/__init__.py +0 -0
  112. nabu/stitching/stitcher/base.py +124 -0
  113. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  114. nabu/stitching/stitcher/dumper/base.py +94 -0
  115. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  116. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  117. nabu/stitching/stitcher/post_processing.py +555 -0
  118. nabu/stitching/stitcher/pre_processing.py +1068 -0
  119. nabu/stitching/stitcher/single_axis.py +484 -0
  120. nabu/stitching/stitcher/stitcher.py +0 -0
  121. nabu/stitching/stitcher/y_stitcher.py +13 -0
  122. nabu/stitching/stitcher/z_stitcher.py +45 -0
  123. nabu/stitching/stitcher_2D.py +278 -0
  124. nabu/stitching/tests/test_config.py +12 -37
  125. nabu/stitching/tests/test_frame_composition.py +33 -59
  126. nabu/stitching/tests/test_overlap.py +149 -7
  127. nabu/stitching/tests/test_utils.py +1 -1
  128. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  129. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  130. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  131. nabu/stitching/utils/__init__.py +1 -0
  132. nabu/stitching/utils/post_processing.py +281 -0
  133. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  134. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  135. nabu/stitching/y_stitching.py +27 -0
  136. nabu/stitching/z_stitching.py +32 -2281
  137. nabu/testutils.py +1 -152
  138. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  139. nabu/utils.py +158 -61
  140. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
  141. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
  142. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
  143. nabu/io/tiffwriter_zmm.py +0 -99
  144. nabu/pipeline/fallback_utils.py +0 -149
  145. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  146. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  147. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  148. nabu/pipeline/helical/utils.py +0 -51
  149. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  150. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  151. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  152. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,484 @@
1
+ import h5py
2
+ import numpy
3
+ import logging
4
+ from math import ceil
5
+ from typing import Optional, Iterable, Union
6
+ from tomoscan.series import Series
7
+ from tomoscan.identifier import BaseIdentifier
8
+ from nabu.stitching.stitcher.base import _StitcherBase, get_obj_constant_side_length
9
+ from nabu.stitching.stitcher_2D import stitch_raw_frames
10
+ from nabu.stitching.utils.utils import ShiftAlgorithm, from_slice_to_n_elements
11
+ from nabu.stitching.overlap import (
12
+ check_overlaps,
13
+ ImageStichOverlapKernel,
14
+ )
15
+ from nabu.stitching.config import (
16
+ SingleAxisStitchingConfiguration,
17
+ KEY_RESCALE_MIN_PERCENTILES,
18
+ KEY_RESCALE_MAX_PERCENTILES,
19
+ )
20
+ from nabu.misc.utils import rescale_data
21
+ from nabu.stitching.sample_normalization import normalize_frame as normalize_frame_by_sample
22
+ from nabu.stitching.stitcher.dumper.base import DumperBase
23
+ from silx.io.utils import get_data
24
+ from silx.io.url import DataUrl
25
+ from scipy.ndimage import shift as shift_scipy
26
+
27
+
28
+ _logger = logging.getLogger(__name__)
29
+
30
+
31
+ PROGRESS_BAR_STITCH_VOL_DESC = "stitch volumes"
32
+ # description of the progress bar used when stitching volume.
33
+ # Needed to retrieve advancement from file when stitching remotely
34
+
35
+
36
+ class _SingleAxisMetaClass(type):
37
+ """
38
+ Metaclass for single axis stitcher in order to aggregate dumper class and axis
39
+ """
40
+
41
+ def __new__(mcls, name, bases, attrs, axis=None, dumper_cls=None):
42
+ mcls = super().__new__(mcls, name, bases, attrs)
43
+ mcls._axis = axis
44
+ mcls._dumperCls = dumper_cls
45
+ return mcls
46
+
47
+
48
+ class SingleAxisStitcher(_StitcherBase, metaclass=_SingleAxisMetaClass):
49
+ """
50
+ Any single-axis base class
51
+ """
52
+
53
+ def __init__(self, configuration, *args, **kwargs) -> None:
54
+ super().__init__(configuration, *args, **kwargs)
55
+ if self._dumperCls is not None:
56
+ self._dumper = self._dumperCls(configuration=configuration)
57
+ else:
58
+ self._dumper = None
59
+
60
+ # initial shifts
61
+ self._axis_0_rel_ini_shifts = []
62
+ """Shift between two juxtapose objects along axis 0 found from position metadata or given by the user"""
63
+ self._axis_1_rel_ini_shifts = []
64
+ """Shift between two juxtapose objects along axis 1 found from position metadata or given by the user"""
65
+ self._axis_2_rel_ini_shifts = []
66
+ """Shift between two juxtapose objects along axis 2 found from position metadata or given by the user"""
67
+
68
+ # shifts to add once refine
69
+ self._axis_0_rel_final_shifts = []
70
+ """Shift over axis 0 found once refined by the cross correlation algorithm"""
71
+ self._axis_1_rel_final_shifts = []
72
+ """Shift over axis 1 found once refined by the cross correlation algorithm"""
73
+ self._axis_2_rel_final_shifts = []
74
+ """Shift over axis 2 found once refined by the cross correlation algorithm"""
75
+
76
+ self._slices_to_stitch = None
77
+ # slices to be stitched. Obtained from calling Configuration.settle_slices
78
+
79
+ self._stitching_constant_length = None
80
+ # stitching width: larger volume width. Other volume will be pad
81
+
82
+ def shifts_is_scalar(shifts):
83
+ return isinstance(shifts, ShiftAlgorithm) or numpy.isscalar(shifts)
84
+
85
+ # 'expend' shift algorithm
86
+ if shifts_is_scalar(self.configuration.axis_0_pos_px):
87
+ self.configuration.axis_0_pos_px = [
88
+ self.configuration.axis_0_pos_px,
89
+ ] * (len(self.series) - 1)
90
+ if shifts_is_scalar(self.configuration.axis_1_pos_px):
91
+ self.configuration.axis_1_pos_px = [
92
+ self.configuration.axis_1_pos_px,
93
+ ] * (len(self.series) - 1)
94
+ if shifts_is_scalar(self.configuration.axis_1_pos_px):
95
+ self.configuration.axis_1_pos_px = [
96
+ self.configuration.axis_1_pos_px,
97
+ ] * (len(self.series) - 1)
98
+ if numpy.isscalar(self.configuration.axis_0_params):
99
+ self.configuration.axis_0_params = [
100
+ self.configuration.axis_0_params,
101
+ ] * (len(self.series) - 1)
102
+ if numpy.isscalar(self.configuration.axis_1_params):
103
+ self.configuration.axis_1_params = [
104
+ self.configuration.axis_1_params,
105
+ ] * (len(self.series) - 1)
106
+ if numpy.isscalar(self.configuration.axis_2_params):
107
+ self.configuration.axis_2_params = [
108
+ self.configuration.axis_2_params,
109
+ ] * (len(self.series) - 1)
110
+
111
+ @property
112
+ def axis(self) -> int:
113
+ return self._axis
114
+
115
+ @property
116
+ def dumper(self):
117
+ return self._dumper
118
+
119
+ @property
120
+ def stitching_axis_in_frame_space(self):
121
+ """
122
+ stitching is operated in 2D (frame) space. So the axis in frame space is different than the one in 3D ebs-tomo space (https://tomo.gitlab-pages.esrf.fr/bliss-tomo/master/modelization.html)
123
+ """
124
+ raise NotImplementedError("Base class")
125
+
126
+ def stitch(self, store_composition: bool = True) -> BaseIdentifier:
127
+ if self.progress is not None:
128
+ self.progress.set_description("order scans")
129
+ self.order_input_tomo_objects()
130
+ if self.progress is not None:
131
+ self.progress.set_description("check inputs")
132
+ self.check_inputs()
133
+ self.settle_flips()
134
+
135
+ if self.progress is not None:
136
+ self.progress.set_description("compute shifts")
137
+ self._compute_positions_as_px()
138
+ self.pre_processing_computation()
139
+
140
+ self.compute_estimated_shifts()
141
+ self._compute_shifts()
142
+ self._createOverlapKernels()
143
+ if self.progress is not None:
144
+ self.progress.set_description(PROGRESS_BAR_STITCH_VOL_DESC)
145
+
146
+ self._create_stitching(store_composition=store_composition)
147
+ if self.progress is not None:
148
+ self.progress.set_description("dump configuration")
149
+ self.dumper.save_configuration()
150
+ return self.dumper.output_identifier
151
+
152
+ @property
153
+ def serie_label(self) -> str:
154
+ """return serie name for logs"""
155
+ return "single axis serie"
156
+
157
+ def get_n_slices_to_stitch(self):
158
+ """Return the number of slice to be stitched"""
159
+ if self._slices_to_stitch is None:
160
+ raise RuntimeError("Slices needs to be settled first")
161
+ return from_slice_to_n_elements(self._slices_to_stitch)
162
+
163
+ def get_final_axis_positions_in_px(self) -> dict:
164
+ """
165
+ compute the final position (**in pixel**) from the initial position of the first object and the final relative shift computed (1)
166
+ (1): the final relative shift is obtained from the initial shift (from motor position of provided by the user) + the refinement shift from cross correlation algorithm
167
+ :return: dict with tomo object identifier (str) as key and a tuple of position in pixel (axis_0_pos, axis_1_pos, axis_2_pos)
168
+ """
169
+ pos_0_shift = numpy.concatenate(
170
+ (
171
+ numpy.atleast_1d(0.0),
172
+ numpy.array(self._axis_0_rel_final_shifts) - numpy.array(self._axis_0_rel_ini_shifts),
173
+ )
174
+ )
175
+ pos_0_cum_shift = numpy.cumsum(pos_0_shift)
176
+ final_pos_axis_0 = self.configuration.axis_0_pos_px + pos_0_cum_shift
177
+
178
+ pos_1_shift = numpy.concatenate(
179
+ (
180
+ numpy.atleast_1d(0.0),
181
+ numpy.array(self._axis_1_rel_final_shifts) - numpy.array(self._axis_1_rel_ini_shifts),
182
+ )
183
+ )
184
+ pos_1_cum_shift = numpy.cumsum(pos_1_shift)
185
+ final_pos_axis_1 = self.configuration.axis_1_pos_px + pos_1_cum_shift
186
+
187
+ pos_2_shift = numpy.concatenate(
188
+ (
189
+ numpy.atleast_1d(0.0),
190
+ numpy.array(self._axis_2_rel_final_shifts) - numpy.array(self._axis_2_rel_ini_shifts),
191
+ )
192
+ )
193
+ pos_2_cum_shift = numpy.cumsum(pos_2_shift)
194
+ final_pos_axis_2 = self.configuration.axis_2_pos_px + pos_2_cum_shift
195
+
196
+ assert len(final_pos_axis_0) == len(final_pos_axis_1)
197
+ assert len(final_pos_axis_0) == len(final_pos_axis_2)
198
+ assert len(final_pos_axis_0) == len(self.series)
199
+
200
+ return {
201
+ tomo_obj.get_identifier().to_str(): (pos_0, pos_1, pos_2)
202
+ for tomo_obj, (pos_0, pos_1, pos_2) in zip(
203
+ self.series, zip(final_pos_axis_0, final_pos_axis_1, final_pos_axis_2)
204
+ )
205
+ }
206
+
207
+ def settle_flips(self):
208
+ """
209
+ User can provide some information on existing flips at frame level.
210
+ The goal of this step is to get one flip_lr and on flip_ud value per scan or volume
211
+ """
212
+ if numpy.isscalar(self.configuration.flip_lr):
213
+ self.configuration.flip_lr = tuple([self.configuration.flip_lr] * len(self.series))
214
+ else:
215
+ if not len(self.configuration.flip_lr) == len(self.series):
216
+ raise ValueError("flip_lr expects a scalar value or one value per element to stitch")
217
+ self.configuration.flip_lr = tuple(self.configuration.flip_lr)
218
+ for elmt in self.configuration.flip_lr:
219
+ if not isinstance(elmt, bool):
220
+ raise TypeError
221
+
222
+ if numpy.isscalar(self.configuration.flip_ud):
223
+ self.configuration.flip_ud = tuple([self.configuration.flip_ud] * len(self.series))
224
+ else:
225
+ if not len(self.configuration.flip_ud) == len(self.series):
226
+ raise ValueError("flip_ud expects a scalar value or one value per element to stitch")
227
+ self.configuration.flip_ud = tuple(self.configuration.flip_ud)
228
+ for elmt in self.configuration.flip_ud:
229
+ if not isinstance(elmt, bool):
230
+ raise TypeError
231
+
232
+ def _createOverlapKernels(self):
233
+ """
234
+ after this stage the overlap kernels must be created and with the final overlap size
235
+ """
236
+ if self.axis == 0:
237
+ stitched_axis_rel_shifts = self._axis_0_rel_final_shifts
238
+ stitched_axis_params = self.configuration.axis_0_params
239
+ elif self.axis == 1:
240
+ stitched_axis_rel_shifts = self._axis_1_rel_final_shifts
241
+ stitched_axis_params = self.configuration.axis_1_params
242
+ elif self.axis == 2:
243
+ stitched_axis_rel_shifts = self._axis_2_rel_final_shifts
244
+ stitched_axis_params = self.configuration.axis_2_params
245
+ else:
246
+ raise NotImplementedError
247
+
248
+ if stitched_axis_rel_shifts is None or len(stitched_axis_rel_shifts) == 0:
249
+ raise RuntimeError(
250
+ f"axis {self.axis} shifts have not been defined yet. Please define them before calling this function"
251
+ )
252
+
253
+ overlap_size = stitched_axis_params.get("overlap_size", None)
254
+ if overlap_size in (None, "None", ""):
255
+ overlap_size = -1
256
+ else:
257
+ overlap_size = int(overlap_size)
258
+
259
+ self._stitching_constant_length = max(
260
+ [get_obj_constant_side_length(obj, axis=self.axis) for obj in self.series]
261
+ )
262
+
263
+ for stitched_axis_shift in stitched_axis_rel_shifts:
264
+ if overlap_size == -1:
265
+ height = abs(stitched_axis_shift)
266
+ else:
267
+ height = overlap_size
268
+
269
+ self._overlap_kernels.append(
270
+ ImageStichOverlapKernel(
271
+ stitching_axis=self.stitching_axis_in_frame_space,
272
+ frame_unstitched_axis_size=self._stitching_constant_length,
273
+ stitching_strategy=self.configuration.stitching_strategy,
274
+ overlap_size=height,
275
+ extra_params=self.configuration.stitching_kernels_extra_params,
276
+ )
277
+ )
278
+
279
+ @property
280
+ def series(self) -> Series:
281
+ return self._series
282
+
283
+ @property
284
+ def configuration(self) -> SingleAxisStitchingConfiguration:
285
+ return self._configuration
286
+
287
+ @property
288
+ def progress(self):
289
+ return self._progress
290
+
291
+ @staticmethod
292
+ def _data_bunch_iterator(slices, bunch_size):
293
+ """util to get indices by bunch until we reach n_frames"""
294
+ if isinstance(slices, slice):
295
+ # note: slice step is handled at a different level
296
+ start = end = slices.start
297
+
298
+ while True:
299
+ start, end = end, min((end + bunch_size), slices.stop)
300
+ yield (start, end)
301
+ if end >= slices.stop:
302
+ break
303
+ # in the case of non-contiguous frames
304
+ elif isinstance(slices, Iterable):
305
+ for s in slices:
306
+ yield (s, s + 1)
307
+ else:
308
+ raise TypeError(f"slices is provided as {type(slices)}. When Iterable or slice is expected")
309
+
310
+ def rescale_frames(self, frames: tuple):
311
+ """
312
+ rescale_frames if requested by the configuration
313
+ """
314
+ _logger.info("apply rescale frames")
315
+
316
+ def cast_percentile(percentile) -> int:
317
+ if isinstance(percentile, str):
318
+ percentile.replace(" ", "").rstrip("%")
319
+ return int(percentile)
320
+
321
+ rescale_min_percentile = cast_percentile(self.configuration.rescale_params.get(KEY_RESCALE_MIN_PERCENTILES, 0))
322
+ rescale_max_percentile = cast_percentile(
323
+ self.configuration.rescale_params.get(KEY_RESCALE_MAX_PERCENTILES, 100)
324
+ )
325
+
326
+ new_min = numpy.percentile(frames[0], rescale_min_percentile)
327
+ new_max = numpy.percentile(frames[0], rescale_max_percentile)
328
+
329
+ def rescale(data):
330
+ # FIXME: takes time because browse several time the dataset, twice for percentiles and twices to get min and max when calling rescale_data...
331
+ data_min = numpy.percentile(data, rescale_min_percentile)
332
+ data_max = numpy.percentile(data, rescale_max_percentile)
333
+ return rescale_data(data, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max)
334
+
335
+ return tuple([rescale(data) for data in frames])
336
+
337
+ def normalize_frame_by_sample(self, frames: tuple):
338
+ """
339
+ normalize frame from a sample picked on the left or the right
340
+ """
341
+ _logger.info("apply normalization by a sample")
342
+ return tuple(
343
+ [
344
+ normalize_frame_by_sample(
345
+ frame=frame,
346
+ side=self.configuration.normalization_by_sample.side,
347
+ method=self.configuration.normalization_by_sample.method,
348
+ margin_before_sample=self.configuration.normalization_by_sample.margin,
349
+ sample_width=self.configuration.normalization_by_sample.width,
350
+ )
351
+ for frame in frames
352
+ ]
353
+ )
354
+
355
+ @staticmethod
356
+ def stitch_frames(
357
+ frames: Union[tuple, numpy.ndarray],
358
+ axis,
359
+ x_relative_shifts: tuple,
360
+ y_relative_shifts: tuple,
361
+ output_dtype: numpy.ndarray,
362
+ stitching_axis: int,
363
+ overlap_kernels: tuple,
364
+ dumper: DumperBase = None,
365
+ check_inputs=True,
366
+ shift_mode="nearest",
367
+ i_frame=None,
368
+ return_composition_cls=False,
369
+ alignment="center",
370
+ pad_mode="constant",
371
+ new_width: Optional[int] = None,
372
+ ) -> numpy.ndarray:
373
+ """
374
+ shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and
375
+ save them to output_dataset.
376
+
377
+ :param tuple frames: element must be a DataUrl or a 2D numpy array
378
+ :param stitching_regions_hdf5_dataset:
379
+ """
380
+ if check_inputs:
381
+ if len(frames) < 2:
382
+ raise ValueError(f"Not enought frames provided for stitching ({len(frames)} provided)")
383
+ if len(frames) != len(x_relative_shifts) + 1:
384
+ raise ValueError(
385
+ f"Incoherent number of shift provided ({len(x_relative_shifts)}) compare to number of frame ({len(frames)}). len(frames) - 1 expected"
386
+ )
387
+ if len(x_relative_shifts) != len(overlap_kernels):
388
+ raise ValueError(
389
+ f"expect to have the same number of x_relative_shifts ({len(x_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
390
+ )
391
+ if len(y_relative_shifts) != len(overlap_kernels):
392
+ raise ValueError(
393
+ f"expect to have the same number of y_relative_shifts ({len(y_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
394
+ )
395
+
396
+ relative_positions = [(0, 0, 0)]
397
+ for y_rel_pos, x_rel_pos in zip(y_relative_shifts, x_relative_shifts):
398
+ relative_positions.append(
399
+ (
400
+ y_rel_pos + relative_positions[-1][0],
401
+ 0, # position over axis 1 (aka y) is not handled yet
402
+ x_rel_pos + relative_positions[-1][2],
403
+ )
404
+ )
405
+ check_overlaps(
406
+ frames=tuple(frames),
407
+ positions=tuple(relative_positions),
408
+ axis=axis,
409
+ raise_error=False,
410
+ )
411
+
412
+ def check_frame_is_2d(frame):
413
+ if frame.ndim != 2:
414
+ raise ValueError(f"2D frame expected when {frame.ndim}D provided")
415
+
416
+ # step_0 load data if from url
417
+ data = []
418
+ for frame in frames:
419
+ if isinstance(frame, DataUrl):
420
+ data_frame = get_data(frame)
421
+ if check_inputs:
422
+ check_frame_is_2d(data_frame)
423
+ data.append(data_frame)
424
+ elif isinstance(frame, numpy.ndarray):
425
+ if check_inputs:
426
+ check_frame_is_2d(frame)
427
+ data.append(frame)
428
+ else:
429
+ raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}")
430
+
431
+ # step 1: shift each frames (except the first one)
432
+ if stitching_axis == 0:
433
+ relative_shift_along_stitched_axis = y_relative_shifts
434
+ relative_shift_along_unstitched_axis = x_relative_shifts
435
+ elif stitching_axis == 1:
436
+ relative_shift_along_stitched_axis = x_relative_shifts
437
+ relative_shift_along_unstitched_axis = y_relative_shifts
438
+ else:
439
+ raise NotImplementedError("")
440
+
441
+ shifted_data = [data[0]]
442
+ for frame, relative_shift in zip(data[1:], relative_shift_along_unstitched_axis):
443
+ # note: for now we only shift data in x. the y shift is handled in the FrameComposition
444
+ relative_shift = numpy.asarray(relative_shift).astype(numpy.int8)
445
+ if relative_shift == 0:
446
+ shifted_frame = frame
447
+ else:
448
+ # TO speed up: should use the Fourier transform
449
+ shifted_frame = shift_scipy(
450
+ frame,
451
+ mode=shift_mode,
452
+ shift=[0, -relative_shift] if stitching_axis == 0 else [-relative_shift, 0],
453
+ order=1,
454
+ )
455
+ shifted_data.append(shifted_frame)
456
+
457
+ # step 2: create stitched frame
458
+ stitched_frame, composition_cls = stitch_raw_frames(
459
+ frames=shifted_data,
460
+ key_lines=(
461
+ [
462
+ (int(frame.shape[stitching_axis] - abs(relative_shift / 2)), int(abs(relative_shift / 2)))
463
+ for relative_shift, frame in zip(relative_shift_along_stitched_axis, frames)
464
+ ]
465
+ ),
466
+ overlap_kernels=overlap_kernels,
467
+ check_inputs=check_inputs,
468
+ output_dtype=output_dtype,
469
+ return_composition_cls=True,
470
+ alignment=alignment,
471
+ pad_mode=pad_mode,
472
+ new_unstitched_axis_size=new_width,
473
+ )
474
+ dumper.save_stitched_frame(
475
+ stitched_frame=stitched_frame,
476
+ composition_cls=composition_cls,
477
+ i_frame=i_frame,
478
+ axis=1,
479
+ )
480
+
481
+ if return_composition_cls:
482
+ return stitched_frame, composition_cls
483
+ else:
484
+ return stitched_frame
File without changes
@@ -0,0 +1,13 @@
1
+ from nabu.stitching.stitcher.pre_processing import PreProcessingStitching
2
+ from .dumper import PreProcessingStitchingDumper
3
+
4
+
5
+ class PreProcessingYStitcher(
6
+ PreProcessingStitching,
7
+ dumper_cls=PreProcessingStitchingDumper,
8
+ axis=1,
9
+ ):
10
+
11
+ @property
12
+ def serie_label(self) -> str:
13
+ return "y-serie"
@@ -0,0 +1,45 @@
1
+ from nabu.stitching.stitcher.pre_processing import PreProcessingStitching
2
+ from nabu.stitching.stitcher.post_processing import PostProcessingStitching
3
+ from .dumper import PreProcessingStitchingDumper, PostProcessingStitchingDumperNoDD, PostProcessingStitchingDumper
4
+ from nabu.stitching.stitcher.single_axis import _SingleAxisMetaClass
5
+
6
+
7
+ class PreProcessingZStitcher(
8
+ PreProcessingStitching,
9
+ dumper_cls=PreProcessingStitchingDumper,
10
+ axis=0,
11
+ ):
12
+
13
+ def check_inputs(self):
14
+ """
15
+ insure input data is coherent
16
+ """
17
+ super().check_inputs()
18
+
19
+ for scan_0, scan_1 in zip(self.series[0:-1], self.series[1:]):
20
+ if scan_0.dim_1 != scan_1.dim_1:
21
+ raise ValueError(
22
+ f"projections width are expected to be the same. Not the case for {scan_0} ({scan_0.dim_1} and {scan_1} ({scan_1.dim_1}))"
23
+ )
24
+
25
+
26
+ class PostProcessingZStitcher(
27
+ PostProcessingStitching,
28
+ metaclass=_SingleAxisMetaClass,
29
+ dumper_cls=PostProcessingStitchingDumper,
30
+ axis=0,
31
+ ):
32
+ @property
33
+ def serie_label(self) -> str:
34
+ return "z-serie"
35
+
36
+
37
+ class PostProcessingZStitcherNoDD(
38
+ PostProcessingStitching,
39
+ metaclass=_SingleAxisMetaClass,
40
+ dumper_cls=PostProcessingStitchingDumperNoDD,
41
+ axis=0,
42
+ ):
43
+ @property
44
+ def serie_label(self) -> str:
45
+ return "z-serie"