nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  15. nabu/cuda/kernel.py +11 -2
  16. nabu/cuda/processing.py +2 -2
  17. nabu/cuda/src/cone.cu +77 -0
  18. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  19. nabu/cuda/utils.py +0 -6
  20. nabu/estimation/alignment.py +5 -19
  21. nabu/estimation/cor.py +173 -599
  22. nabu/estimation/cor_sino.py +356 -26
  23. nabu/estimation/focus.py +63 -11
  24. nabu/estimation/tests/test_cor.py +124 -58
  25. nabu/estimation/tests/test_focus.py +6 -6
  26. nabu/estimation/tilt.py +2 -1
  27. nabu/estimation/utils.py +5 -33
  28. nabu/io/__init__.py +1 -1
  29. nabu/io/cast_volume.py +1 -1
  30. nabu/io/reader.py +416 -21
  31. nabu/io/tests/test_readers.py +422 -0
  32. nabu/io/tests/test_writers.py +1 -102
  33. nabu/io/writer.py +4 -433
  34. nabu/opencl/kernel.py +14 -3
  35. nabu/opencl/processing.py +8 -0
  36. nabu/pipeline/config_validators.py +5 -2
  37. nabu/pipeline/datadump.py +12 -5
  38. nabu/pipeline/estimators.py +162 -188
  39. nabu/pipeline/fullfield/chunked.py +168 -92
  40. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  41. nabu/pipeline/fullfield/computations.py +2 -7
  42. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  43. nabu/pipeline/fullfield/nabu_config.py +37 -13
  44. nabu/pipeline/fullfield/processconfig.py +22 -13
  45. nabu/pipeline/fullfield/reconstruction.py +13 -9
  46. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  47. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  48. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  49. nabu/pipeline/params.py +21 -1
  50. nabu/pipeline/processconfig.py +1 -12
  51. nabu/pipeline/reader.py +146 -0
  52. nabu/pipeline/tests/test_estimators.py +44 -72
  53. nabu/pipeline/utils.py +4 -2
  54. nabu/pipeline/writer.py +10 -2
  55. nabu/preproc/ccd_cuda.py +1 -1
  56. nabu/preproc/ctf.py +14 -7
  57. nabu/preproc/ctf_cuda.py +2 -3
  58. nabu/preproc/double_flatfield.py +5 -12
  59. nabu/preproc/double_flatfield_cuda.py +2 -2
  60. nabu/preproc/flatfield.py +5 -1
  61. nabu/preproc/flatfield_cuda.py +5 -1
  62. nabu/preproc/phase.py +24 -73
  63. nabu/preproc/phase_cuda.py +5 -8
  64. nabu/preproc/tests/test_ctf.py +11 -7
  65. nabu/preproc/tests/test_flatfield.py +67 -122
  66. nabu/preproc/tests/test_paganin.py +54 -30
  67. nabu/processing/azim.py +206 -0
  68. nabu/processing/convolution_cuda.py +1 -1
  69. nabu/processing/fft_cuda.py +15 -17
  70. nabu/processing/histogram.py +2 -0
  71. nabu/processing/histogram_cuda.py +2 -1
  72. nabu/processing/kernel_base.py +3 -0
  73. nabu/processing/muladd_cuda.py +1 -0
  74. nabu/processing/padding_opencl.py +1 -1
  75. nabu/processing/roll_opencl.py +1 -0
  76. nabu/processing/rotation_cuda.py +2 -2
  77. nabu/processing/tests/test_fft.py +17 -10
  78. nabu/processing/unsharp_cuda.py +1 -1
  79. nabu/reconstruction/cone.py +104 -40
  80. nabu/reconstruction/fbp.py +3 -0
  81. nabu/reconstruction/fbp_base.py +7 -2
  82. nabu/reconstruction/filtering.py +20 -7
  83. nabu/reconstruction/filtering_cuda.py +7 -1
  84. nabu/reconstruction/hbp.py +424 -0
  85. nabu/reconstruction/mlem.py +99 -0
  86. nabu/reconstruction/reconstructor.py +2 -0
  87. nabu/reconstruction/rings_cuda.py +19 -19
  88. nabu/reconstruction/sinogram_cuda.py +1 -0
  89. nabu/reconstruction/sinogram_opencl.py +3 -1
  90. nabu/reconstruction/tests/test_cone.py +10 -5
  91. nabu/reconstruction/tests/test_deringer.py +7 -6
  92. nabu/reconstruction/tests/test_fbp.py +124 -10
  93. nabu/reconstruction/tests/test_filtering.py +13 -11
  94. nabu/reconstruction/tests/test_halftomo.py +30 -4
  95. nabu/reconstruction/tests/test_mlem.py +91 -0
  96. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  97. nabu/resources/dataset_analyzer.py +142 -92
  98. nabu/resources/gpu.py +1 -0
  99. nabu/resources/nxflatfield.py +134 -125
  100. nabu/resources/templates/id16a_fluo.conf +42 -0
  101. nabu/resources/tests/test_extract.py +10 -0
  102. nabu/resources/tests/test_nxflatfield.py +2 -2
  103. nabu/stitching/alignment.py +80 -24
  104. nabu/stitching/config.py +105 -68
  105. nabu/stitching/definitions.py +1 -0
  106. nabu/stitching/frame_composition.py +68 -60
  107. nabu/stitching/overlap.py +91 -51
  108. nabu/stitching/single_axis_stitching.py +32 -0
  109. nabu/stitching/slurm_utils.py +6 -6
  110. nabu/stitching/stitcher/__init__.py +0 -0
  111. nabu/stitching/stitcher/base.py +124 -0
  112. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  113. nabu/stitching/stitcher/dumper/base.py +94 -0
  114. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  115. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  116. nabu/stitching/stitcher/post_processing.py +555 -0
  117. nabu/stitching/stitcher/pre_processing.py +1068 -0
  118. nabu/stitching/stitcher/single_axis.py +484 -0
  119. nabu/stitching/stitcher/stitcher.py +0 -0
  120. nabu/stitching/stitcher/y_stitcher.py +13 -0
  121. nabu/stitching/stitcher/z_stitcher.py +45 -0
  122. nabu/stitching/stitcher_2D.py +278 -0
  123. nabu/stitching/tests/test_config.py +12 -37
  124. nabu/stitching/tests/test_frame_composition.py +33 -59
  125. nabu/stitching/tests/test_overlap.py +149 -7
  126. nabu/stitching/tests/test_utils.py +1 -1
  127. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  128. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  129. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  130. nabu/stitching/utils/__init__.py +1 -0
  131. nabu/stitching/utils/post_processing.py +281 -0
  132. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  133. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  134. nabu/stitching/y_stitching.py +27 -0
  135. nabu/stitching/z_stitching.py +32 -2263
  136. nabu/testutils.py +1 -152
  137. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  138. nabu/utils.py +158 -61
  139. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
  140. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
  141. nabu/io/tiffwriter_zmm.py +0 -99
  142. nabu/pipeline/fallback_utils.py +0 -149
  143. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  144. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  145. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  146. nabu/pipeline/helical/utils.py +0 -51
  147. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  148. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  149. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
  150. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  151. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -1,2279 +1,48 @@
1
- # coding: utf-8
2
- # /*##########################################################################
3
- #
4
- # Copyright (c) 2016-2017 European Synchrotron Radiation Facility
5
- #
6
- # Permission is hereby granted, free of charge, to any person obtaining a copy
7
- # of this software and associated documentation files (the "Software"), to deal
8
- # in the Software without restriction, including without limitation the rights
9
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
- # copies of the Software, and to permit persons to whom the Software is
11
- # furnished to do so, subject to the following conditions:
12
- #
13
- # The above copyright notice and this permission notice shall be included in
14
- # all copies or substantial portions of the Software.
15
- #
16
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22
- # THE SOFTWARE.
23
- #
24
- # ###########################################################################*/
1
+ from typing import Union
25
2
 
26
- __authors__ = ["H. Payno"]
27
- __license__ = "MIT"
28
- __date__ = "10/05/2022"
29
-
30
-
31
- import os
32
- from copy import copy
33
- from datetime import datetime
34
- from typing import Optional, Union, Iterable
35
- import numpy
36
- from math import ceil
37
- from contextlib import AbstractContextManager
38
- import h5py
39
- import logging
40
- from scipy.ndimage import shift as shift_scipy
41
- from functools import lru_cache as cache
42
-
43
- from silx.io.utils import get_data
44
- from silx.io.url import DataUrl
45
- from silx.io.dictdump import dicttonx
46
-
47
- from nxtomo.nxobject.nxdetector import ImageKey
48
- from nxtomo.nxobject.nxtransformations import NXtransformations
49
- from nxtomo.paths.nxtomo import get_paths as _get_nexus_paths
50
- from nxtomo.utils.transformation import build_matrix, LRDetTransformation, UDDetTransformation
51
-
52
- from tomoscan.io import HDF5File
53
- from tomoscan.esrf.scan.utils import cwd_context
54
3
  from tomoscan.identifier import BaseIdentifier
55
- from tomoscan.esrf import NXtomoScan, EDFTomoScan
56
- from tomoscan.volumebase import VolumeBase
57
- from tomoscan.esrf.volume import HDF5Volume
58
- from tomoscan.serie import Serie
59
- from tomoscan.factory import Factory as TomoscanFactory
60
- from tomoscan.utils.volume import concatenate as concatenate_volumes
61
- from tomoscan.esrf.scan.utils import (
62
- get_compacted_dataslices,
63
- ) # this version has a 'return_url_set' needed here. At one point they should be merged together
64
- from pyunitsystem.metricsystem import MetricSystem
65
-
66
- from nxtomo.application.nxtomo import NXtomo
67
- from silx.io.dictdump import dicttonx
68
-
69
- from nabu.io.utils import DatasetReader
70
- from nabu.stitching.frame_composition import ZFrameComposition
71
- from nabu.stitching.utils import find_projections_relative_shifts, find_volumes_relative_shifts, ShiftAlgorithm
4
+ from nabu.stitching.stitcher.z_stitcher import PreProcessingZStitcher as PreProcessZStitcher
5
+ from nabu.stitching.stitcher.z_stitcher import (
6
+ PostProcessingZStitcher as PostProcessZStitcher,
7
+ PostProcessingZStitcherNoDD as PostProcessZStitcherNoDD,
8
+ )
72
9
  from nabu.stitching.config import (
73
- CROSS_CORRELATION_SLICE_FIELD,
74
10
  PreProcessedZStitchingConfiguration,
75
11
  PostProcessedZStitchingConfiguration,
76
- ZStitchingConfiguration,
77
- KEY_IMG_REG_METHOD,
78
- KEY_RESCALE_MIN_PERCENTILES,
79
- KEY_RESCALE_MAX_PERCENTILES,
80
- KEY_THRESHOLD_FREQUENCY,
81
- )
82
- from nabu.stitching.alignment import align_horizontally, AlignmentAxis1
83
- from nabu.utils import Progress
84
- from nabu import version as nabu_version
85
- from nabu.io.writer import get_datetime
86
- from .overlap import (
87
- ZStichOverlapKernel,
88
- check_overlaps,
89
12
  )
90
- from .. import version as nabu_version
91
- from nabu.io.writer import get_datetime
92
- from nabu.misc.utils import rescale_data
93
- from nabu.stitching.alignment import PaddedRawData
94
- from nabu.stitching.sample_normalization import normalize_frame as normalize_frame_by_sample
95
13
 
96
- _logger = logging.getLogger(__name__)
97
14
 
98
-
99
- def z_stitching(configuration: ZStitchingConfiguration, progress=None) -> BaseIdentifier:
15
+ def z_stitching(
16
+ configuration: Union[PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration], progress=None
17
+ ) -> BaseIdentifier:
100
18
  """
101
- Apply stitching from provided configuration.
19
+ Apply stitching from provided configuration. Along axis 0 (aka z)
102
20
  Return a DataUrl with the created NXtomo or Volume
103
- """
21
+
22
+ like:
23
+ axis 0
24
+ ^
25
+ |
26
+ x-ray |
27
+ --------> ------> axis 2
28
+ /
29
+ /
30
+ axis 1
31
+ """
32
+ stitcher = None
33
+ assert configuration.axis is not None
104
34
  if isinstance(configuration, PreProcessedZStitchingConfiguration):
105
- stitcher = PreProcessZStitcher(configuration=configuration, progress=progress)
35
+ if configuration.axis == 0:
36
+ stitcher = PreProcessZStitcher(configuration=configuration, progress=progress)
106
37
  elif isinstance(configuration, PostProcessedZStitchingConfiguration):
107
- stitcher = PostProcessZStitcher(configuration=configuration, progress=progress)
108
- else:
38
+ assert configuration.axis == 0
39
+ if configuration.duplicate_data:
40
+ stitcher = PostProcessZStitcher(configuration=configuration, progress=progress)
41
+ else:
42
+ stitcher = PostProcessZStitcherNoDD(configuration=configuration, progress=progress)
43
+
44
+ if stitcher is None:
109
45
  raise TypeError(
110
46
  f"configuration is expected to be in {(PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration)}. {type(configuration)} provided"
111
47
  )
112
48
  return stitcher.stitch()
113
-
114
-
115
- class ZStitcher:
116
- @staticmethod
117
- def param_is_auto(param):
118
- return param in ("auto", ("auto",))
119
-
120
- def __init__(self, configuration, progress: Progress = None) -> None:
121
- if not isinstance(configuration, ZStitchingConfiguration):
122
- raise TypeError
123
-
124
- # flag to check if the serie has been ordered yet or not
125
- self._configuration = copy(configuration)
126
- # copy configuration because we will edit it
127
- self._frame_composition = None
128
- self._progress = progress
129
- self._overlap_kernels = []
130
- # kernels to create the stitching on overlaps.
131
-
132
- self._axis_0_rel_shifts = []
133
- self._axis_2_rel_shifts = []
134
- # shift between upper and lower frames
135
-
136
- self._stitching_width = None
137
- # stitching width: larger volume width. Other volume will be pad
138
-
139
- # z serie must be defined from daughter class
140
- assert hasattr(self, "_z_serie")
141
-
142
- def shifts_is_scalar(shifts):
143
- return isinstance(shifts, ShiftAlgorithm) or numpy.isscalar(shifts)
144
-
145
- # 'expend' shift algorithm
146
- if shifts_is_scalar(self.configuration.axis_0_pos_px):
147
- self.configuration.axis_0_pos_px = [
148
- self.configuration.axis_0_pos_px,
149
- ] * (len(self.z_serie) - 1)
150
- if shifts_is_scalar(self.configuration.axis_1_pos_px):
151
- self.configuration.axis_1_pos_px = [
152
- self.configuration.axis_1_pos_px,
153
- ] * (len(self.z_serie) - 1)
154
- if shifts_is_scalar(self.configuration.axis_2_pos_px):
155
- self.configuration.axis_2_pos_px = [
156
- self.configuration.axis_2_pos_px,
157
- ] * (len(self.z_serie) - 1)
158
- if numpy.isscalar(self.configuration.axis_0_params):
159
- self.configuration.axis_0_params = [
160
- self.configuration.axis_0_params,
161
- ] * (len(self.z_serie) - 1)
162
- if numpy.isscalar(self.configuration.axis_1_params):
163
- self.configuration.axis_1_params = [
164
- self.configuration.axis_1_params,
165
- ] * (len(self.z_serie) - 1)
166
- if numpy.isscalar(self.configuration.axis_2_params):
167
- self.configuration.axis_2_params = [
168
- self.configuration.axis_2_params,
169
- ] * (len(self.z_serie) - 1)
170
-
171
- @property
172
- def frame_composition(self):
173
- return self._frame_composition
174
-
175
- def get_final_axis_positions_in_px(self) -> dict:
176
- """
177
- :return: dict with tomo object identifier (str) as key and a tuple of position in pixel (axis_0_pos, axis_1_pos, axis_2_pos)
178
- :rtype: dict
179
- """
180
- final_shifts_axis_0 = [
181
- 0,
182
- ]
183
- final_shifts_axis_0.extend(self._axis_0_rel_shifts)
184
- final_shifts_axis_0 = numpy.array(final_shifts_axis_0)
185
-
186
- final_shifts_axis_2 = [
187
- 0,
188
- ]
189
- final_shifts_axis_2.extend(self._axis_2_rel_shifts)
190
- final_shifts_axis_2 = numpy.array(final_shifts_axis_2)
191
-
192
- estimated_shifts_axis_0 = self._axis_0_estimated_shifts.copy()
193
- estimated_shifts_axis_0.insert(0, 0)
194
-
195
- final_pos = {}
196
- previous_shift = 0
197
- for tomo_obj, pos_axis_0, pos_axis_2, final_shift_axis_0, estimated_shift_axis_0, final_shift_axis_2 in zip(
198
- self.z_serie,
199
- self.configuration.axis_0_pos_px,
200
- self.configuration.axis_2_pos_px,
201
- estimated_shifts_axis_0,
202
- final_shifts_axis_0,
203
- final_shifts_axis_2,
204
- ):
205
- # warning estimated_shift is the estimatation from the overlap. So playes no role here
206
- final_pos[tomo_obj.get_identifier().to_str()] = (
207
- pos_axis_0 - (final_shift_axis_0 - estimated_shift_axis_0) + previous_shift,
208
- None, # axis 1 is not handled for now
209
- pos_axis_2 + final_shift_axis_2,
210
- )
211
- previous_shift += final_shift_axis_0 - estimated_shift_axis_0
212
- return final_pos
213
-
214
- def from_abs_pos_to_rel_pos(self, abs_position: tuple):
215
- """
216
- return relative position from on object to the other but in relative this time
217
- :param tuple abs_position: tuple containing the absolute positions
218
- :return: len(abs_position) - 1 relative position
219
- :rtype: tuple
220
- """
221
- return tuple([pos_obj_b - pos_obj_a for (pos_obj_a, pos_obj_b) in zip(abs_position[:-1], abs_position[1:])])
222
-
223
- def from_rel_pos_to_abs_pos(self, rel_positions: tuple, init_pos: int):
224
- """
225
- return absolute positions from a tuple of relative position and an initial position
226
- :param tuple rel_positions: tuple containing the absolute positions
227
- :return: len(rel_positions) + 1 relative position
228
- :rtype: tuple
229
- """
230
- abs_pos = [
231
- init_pos,
232
- ]
233
- for rel_pos in rel_positions:
234
- abs_pos.append(abs_pos[-1] + rel_pos)
235
- return abs_pos
236
-
237
- def stitch(self, store_composition: bool = True) -> BaseIdentifier:
238
- """
239
- Apply expected stitch from configuration and return the DataUrl of the object created
240
-
241
- :param bool store_composition: if True then store the composition used for stitching in frame_composition.
242
- So it can be reused by third part (like tomwer) to display composition made
243
- """
244
- raise NotImplementedError("base class")
245
-
246
- def settle_flips(self):
247
- """
248
- User can provide some information on existing flips at frame level.
249
- The goal of this step is to get one flip_lr and on flip_ud value per scan or volume
250
- """
251
- if numpy.isscalar(self.configuration.flip_lr):
252
- self.configuration.flip_lr = tuple([self.configuration.flip_lr] * len(self.z_serie))
253
- else:
254
- if not len(self.configuration.flip_lr) == len(self.z_serie):
255
- raise ValueError("flip_lr expects a scalar value or one value per element to stitch")
256
- self.configuration.flip_lr = tuple(self.configuration.flip_lr)
257
- for elmt in self.configuration.flip_lr:
258
- if not isinstance(elmt, bool):
259
- raise TypeError
260
-
261
- if numpy.isscalar(self.configuration.flip_ud):
262
- self.configuration.flip_ud = tuple([self.configuration.flip_ud] * len(self.z_serie))
263
- else:
264
- if not len(self.configuration.flip_ud) == len(self.z_serie):
265
- raise ValueError("flip_ud expects a scalar value or one value per element to stitch")
266
- self.configuration.flip_ud = tuple(self.configuration.flip_ud)
267
- for elmt in self.configuration.flip_ud:
268
- if not isinstance(elmt, bool):
269
- raise TypeError
270
-
271
- def _compute_shifts(self):
272
- """
273
- after this stage the final shifts must be determine
274
- """
275
- raise NotImplementedError("base class")
276
-
277
- def _createOverlapKernels(self):
278
- """
279
- after this stage the overlap kernels must be created and with the final overlap size
280
- """
281
- if self._axis_0_rel_shifts is None or len(self._axis_0_rel_shifts) == 0:
282
- raise RuntimeError(
283
- "axis 0 shifts have not been defined yet. Please define them before calling this function"
284
- )
285
-
286
- overlap_size = self.configuration.axis_0_params.get("overlap_size", None)
287
- if overlap_size in (None, "None", ""):
288
- overlap_size = -1
289
- else:
290
- overlap_size = int(overlap_size)
291
-
292
- self._stitching_width = max([get_obj_width(obj) for obj in self.z_serie])
293
-
294
- for axis_0_shift in self._axis_0_rel_shifts:
295
- if overlap_size == -1:
296
- height = abs(axis_0_shift)
297
- else:
298
- height = overlap_size
299
-
300
- self._overlap_kernels.append(
301
- ZStichOverlapKernel(
302
- frame_width=self._stitching_width,
303
- stitching_strategy=self.configuration.stitching_strategy,
304
- overlap_size=height,
305
- extra_params=self.configuration.stitching_kernels_extra_params,
306
- )
307
- )
308
-
309
- @property
310
- def z_serie(self) -> Serie:
311
- return self._z_serie
312
-
313
- @property
314
- def configuration(self) -> ZStitchingConfiguration:
315
- return self._configuration
316
-
317
- @property
318
- def progress(self) -> Optional[Progress]:
319
- return self._progress
320
-
321
- @staticmethod
322
- def get_overlap_areas(
323
- upper_frame: numpy.ndarray,
324
- lower_frame: numpy.ndarray,
325
- upper_frame_key_line: int,
326
- lower_frame_key_line: int,
327
- overlap_size: int,
328
- stitching_axis: int,
329
- ):
330
- """
331
- return the requested area from lower_frame and upper_frame.
332
-
333
- Lower_frame contains at the end of it the 'real overlap' with the upper_frame.
334
- Upper_frame contains the 'real overlap' at the end of it.
335
-
336
- For some reason the user can ask the stitching height to be smaller than the `real overlap`.
337
-
338
- Here are some drawing to have a better of view of those regions:
339
-
340
- .. image:: images/stitching/z_stitch_real_overlap.png
341
- :width: 600
342
-
343
- .. image:: z_stitch_stitch_height.png
344
- :width: 600
345
- """
346
- assert stitching_axis in (0, 1, 2)
347
- for pf, pn in zip((lower_frame_key_line, upper_frame_key_line), ("lower_frame", "upper_frame")):
348
- if not isinstance(pf, (int, numpy.number)):
349
- raise TypeError(f"{pn} is expected to be a number. {type(pf)} provided")
350
- assert overlap_size >= 0
351
-
352
- lf_start = ceil(lower_frame_key_line - overlap_size / 2)
353
- lf_end = ceil(lower_frame_key_line + overlap_size / 2)
354
- uf_start = ceil(upper_frame_key_line - overlap_size / 2)
355
- uf_end = ceil(upper_frame_key_line + overlap_size / 2)
356
-
357
- lf_start, lf_end = min(lf_start, lf_end), max(lf_start, lf_end)
358
- uf_start, uf_end = min(uf_start, uf_end), max(uf_start, uf_end)
359
- if lf_start < 0 or uf_start < 0:
360
- raise ValueError(
361
- f"requested overlap ({overlap_size}) is incoherent with key line positions ({lower_frame_key_line}, {upper_frame_key_line}) - expected to be smaller."
362
- )
363
- overlap_upper = upper_frame[uf_start:uf_end]
364
- overlap_lower = lower_frame[lf_start:lf_end]
365
- if not overlap_upper.shape == overlap_lower.shape:
366
- # maybe in the future: try to reduce one according to the other ????
367
- raise RuntimeError(
368
- f"lower and upper frame have different overlap size ({overlap_upper.shape} vs {overlap_lower.shape})"
369
- )
370
- return overlap_upper, overlap_lower
371
-
372
- @staticmethod
373
- def _data_bunch_iterator(slices, bunch_size):
374
- """util to get indices by bunch until we reach n_frames"""
375
- if isinstance(slices, slice):
376
- # note: slice step is handled at a different level
377
- start = end = slices.start
378
-
379
- while True:
380
- start, end = end, min((end + bunch_size), slices.stop)
381
- yield (start, end)
382
- if end >= slices.stop:
383
- break
384
- # in the case of non-contiguous frames
385
- elif isinstance(slices, Iterable):
386
- for s in slices:
387
- yield (s, s + 1)
388
- else:
389
- raise TypeError(f"slices is provided as {type(slices)}. When Iterable or slice is expected")
390
-
391
- def rescale_frames(self, frames: tuple):
392
- """
393
- rescale_frames if requested by the configuration
394
- """
395
- _logger.info("apply rescale frames")
396
-
397
- def cast_percentile(percentile) -> int:
398
- if isinstance(percentile, str):
399
- percentile.replace(" ", "").rstrip("%")
400
- return int(percentile)
401
-
402
- rescale_min_percentile = cast_percentile(self.configuration.rescale_params.get(KEY_RESCALE_MIN_PERCENTILES, 0))
403
- rescale_max_percentile = cast_percentile(
404
- self.configuration.rescale_params.get(KEY_RESCALE_MAX_PERCENTILES, 100)
405
- )
406
-
407
- new_min = numpy.percentile(frames[0], rescale_min_percentile)
408
- new_max = numpy.percentile(frames[0], rescale_max_percentile)
409
-
410
- def rescale(data):
411
- # FIXME: takes time because browse several time the dataset, twice for percentiles and twices to get min and max when calling rescale_data...
412
- data_min = numpy.percentile(data, rescale_min_percentile)
413
- data_max = numpy.percentile(data, rescale_max_percentile)
414
- return rescale_data(data, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max)
415
-
416
- return tuple([rescale(data) for data in frames])
417
-
418
- def normalize_frame_by_sample(self, frames: tuple):
419
- """
420
- normalize frame from a sample picked on the left or the right
421
- """
422
- _logger.info("apply normalization by a sample")
423
- return tuple(
424
- [
425
- normalize_frame_by_sample(
426
- frame=frame,
427
- side=self.configuration.normalization_by_sample.side,
428
- method=self.configuration.normalization_by_sample.method,
429
- margin_before_sample=self.configuration.normalization_by_sample.margin,
430
- sample_width=self.configuration.normalization_by_sample.width,
431
- )
432
- for frame in frames
433
- ]
434
- )
435
-
436
- @staticmethod
437
- def stitch_frames(
438
- frames: Union[tuple, numpy.ndarray],
439
- x_relative_shifts: tuple,
440
- y_relative_shifts: tuple,
441
- output_dtype: numpy.ndarray,
442
- stitching_axis: int,
443
- overlap_kernels: tuple,
444
- output_dataset: Optional[Union[h5py.Dataset, numpy.ndarray]] = None,
445
- dump_frame_fct=None,
446
- check_inputs=True,
447
- shift_mode="nearest",
448
- i_frame=None,
449
- return_composition_cls=False,
450
- alignment="center",
451
- pad_mode="constant",
452
- new_width: Optional[int] = None,
453
- ) -> numpy.ndarray:
454
- """
455
- shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and
456
- save them to output_dataset.
457
-
458
- :param tuple frames: element must be a DataUrl or a 2D numpy array
459
- """
460
- if check_inputs:
461
- if len(frames) < 2:
462
- raise ValueError(f"Not enought frames provided for stitching ({len(frames)} provided)")
463
- if len(frames) != len(x_relative_shifts) + 1:
464
- raise ValueError(
465
- f"Incoherent number of shift provided ({len(x_relative_shifts)}) compare to number of frame ({len(frames)}). len(frames) - 1 expected"
466
- )
467
- if len(x_relative_shifts) != len(overlap_kernels):
468
- raise ValueError(
469
- f"expect to have the same number of x_relative_shifts ({len(x_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
470
- )
471
- if len(y_relative_shifts) != len(overlap_kernels):
472
- raise ValueError(
473
- f"expect to have the same number of y_relative_shifts ({len(y_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
474
- )
475
-
476
- relative_positions = [(0, 0)]
477
- for y_rel_pos, x_rel_pos in zip(y_relative_shifts, x_relative_shifts):
478
- relative_positions.append(
479
- (
480
- y_rel_pos + relative_positions[-1][0],
481
- x_rel_pos + relative_positions[-1][1],
482
- )
483
- )
484
- check_overlaps(
485
- frames=tuple(frames),
486
- positions=tuple(relative_positions),
487
- axis=0,
488
- raise_error=False,
489
- )
490
-
491
- def check_frame_is_2d(frame):
492
- if frame.ndim != 2:
493
- raise ValueError(f"2D frame expected when {frame.ndim}D provided")
494
-
495
- # step_0 load data if from url
496
- data = []
497
- for frame in frames:
498
- if isinstance(frame, DataUrl):
499
- data_frame = get_data(frame)
500
- if check_inputs:
501
- check_frame_is_2d(data_frame)
502
- data.append(data_frame)
503
- elif isinstance(frame, numpy.ndarray):
504
- if check_inputs:
505
- check_frame_is_2d(frame)
506
- data.append(frame)
507
- else:
508
- raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}")
509
-
510
- # step 1: shift each frames (except the first one)
511
- x_shifted_data = [data[0]]
512
- for frame, x_relative_shift in zip(data[1:], x_relative_shifts):
513
- # note: for now we only shift data in x. the y shift is handled in the FrameComposition
514
- x_relative_shift = numpy.asarray(x_relative_shift).astype(numpy.int8)
515
- if x_relative_shift == 0:
516
- shifted_frame = frame
517
- else:
518
- # TO speed up: should use the Fourier transform
519
- shifted_frame = shift_scipy(
520
- frame,
521
- mode=shift_mode,
522
- shift=[0, -x_relative_shift],
523
- order=1,
524
- )
525
- x_shifted_data.append(shifted_frame)
526
-
527
- # step 2: create stitched frame
528
- res = stitch_vertically_raw_frames(
529
- frames=x_shifted_data,
530
- key_lines=(
531
- [
532
- (int(frame.shape[stitching_axis] - abs(y_relative_shift / 2)), int(abs(y_relative_shift / 2)))
533
- for y_relative_shift, frame in zip(y_relative_shifts, frames)
534
- ]
535
- ),
536
- overlap_kernels=overlap_kernels,
537
- check_inputs=check_inputs,
538
- output_dtype=output_dtype,
539
- return_composition_cls=return_composition_cls,
540
- alignment=alignment,
541
- pad_mode=pad_mode,
542
- new_width=new_width,
543
- )
544
- if return_composition_cls:
545
- stitched_frame, _ = res
546
- else:
547
- stitched_frame = res
548
-
549
- # step 3: dump stitched frame
550
- if output_dataset is not None and i_frame is not None:
551
- dump_frame_fct(
552
- output_dataset=output_dataset,
553
- index=i_frame,
554
- stitched_frame=stitched_frame,
555
- )
556
- return res
557
-
558
- @staticmethod
559
- @cache(maxsize=None)
560
- def _get_UD_flip_matrix():
561
- return UDDetTransformation().as_matrix()
562
-
563
- @staticmethod
564
- @cache(maxsize=None)
565
- def _get_LR_flip_matrix():
566
- return LRDetTransformation().as_matrix()
567
-
568
- @staticmethod
569
- @cache(maxsize=None)
570
- def _get_UD_AND_LR_flip_matrix():
571
- return numpy.matmul(
572
- ZStitcher._get_UD_flip_matrix(),
573
- ZStitcher._get_LR_flip_matrix(),
574
- )
575
-
576
-
577
- class PreProcessZStitcher(ZStitcher):
578
- def __init__(self, configuration, progress=None) -> None:
579
- # z serie must be defined first
580
- self._z_serie = Serie("z-serie", iterable=configuration.input_scans, use_identifiers=False)
581
- self._reading_orders = []
582
- self._x_flips = []
583
- self._y_flips = []
584
- # some scan can have been taken in the opposite order (so must be read on the opposite order one from the other)
585
- self._axis_0_estimated_shifts = None
586
- super().__init__(configuration, progress)
587
-
588
- # 'expend' auto shift request if only set once for all
589
- if numpy.isscalar(self.configuration.axis_0_pos_px):
590
- self.configuration.axis_0_pos_px = [
591
- self.configuration.axis_0_pos_px,
592
- ] * (len(self.z_serie) - 1)
593
- if numpy.isscalar(self.configuration.axis_1_pos_px):
594
- self.configuration.axis_1_pos_px = [
595
- self.configuration.axis_1_pos_px,
596
- ] * (len(self.z_serie) - 1)
597
- if numpy.isscalar(self.configuration.axis_2_pos_px):
598
- self.configuration.axis_2_pos_px = [
599
- self.configuration.axis_2_pos_px,
600
- ] * (len(self.z_serie) - 1)
601
-
602
- if self.configuration.axis_0_params is None:
603
- self.configuration.axis_0_params = {}
604
- if self.configuration.axis_1_params is None:
605
- self.configuration.axis_1_params = {}
606
- if self.configuration.axis_2_params is None:
607
- self.configuration.axis_2_params = {}
608
-
609
- @staticmethod
610
- def _dump_frame(output_dataset: h5py.Dataset, index: int, stitched_frame: numpy.ndarray):
611
- output_dataset[index] = stitched_frame
612
-
613
- @property
614
- def reading_orders(self):
615
- """
616
- as scan can be take on one direction or the order (rotation goes from X to Y then from Y to X)
617
- we might need to read data from one direction or another
618
- """
619
- return self._reading_orders
620
-
621
- @property
622
- def x_flips(self) -> list:
623
- return self._x_flips
624
-
625
- @property
626
- def y_flips(self) -> list:
627
- return self._y_flips
628
-
629
- def stitch(self, store_composition=True) -> BaseIdentifier:
630
- """
631
- :param bool return_composition: if True then return the frame composition (used by the GUI for example to display a background with the same class)
632
- """
633
- if self.progress is not None:
634
- self.progress.set_name("order scans")
635
- self._order_scans()
636
- if self.progress is not None:
637
- self.progress.set_name("check inputs")
638
- self._check_inputs()
639
- self.settle_flips()
640
- self._compute_positions_as_px()
641
- self._compute_axis_0_estimated_shifts()
642
- if self.progress is not None:
643
- self.progress.set_name("compute flat field")
644
- self._compute_reduced_flats_and_darks()
645
- if self.progress is not None:
646
- self.progress.set_name("compute shifts")
647
- self._compute_shifts()
648
- self._createOverlapKernels()
649
- if self.progress is not None:
650
- self.progress.set_name("stitch projections, save them and create NXtomo")
651
- self._create_nx_tomo(store_composition=store_composition)
652
- if self.progress is not None:
653
- self.progress.set_name("dump configuration")
654
- self._dump_stitching_configuration()
655
- stitched_scan = self.configuration.get_output_object()
656
- return stitched_scan.get_identifier()
657
-
658
- def _order_scans(self):
659
- """
660
- ensure scans are in z decreasing order
661
- """
662
-
663
- def get_min_z(scan):
664
- return scan.get_bounding_box(axis=0).min
665
-
666
- # order scans from higher z to lower z
667
- # if axis 0 position is provided then use directly it
668
- if self.configuration.axis_0_pos_px is not None and len(self.configuration.axis_0_pos_px) > 0:
669
- order = numpy.argsort(self.configuration.axis_0_pos_px)[::-1]
670
- sorted_z_serie = Serie(
671
- self.z_serie.name,
672
- numpy.take_along_axis(numpy.array(self.z_serie[:]), order, axis=0),
673
- use_identifiers=False,
674
- )
675
- else:
676
- # else use bounding box
677
- sorted_z_serie = Serie(
678
- self.z_serie.name,
679
- sorted(self.z_serie[:], key=get_min_z, reverse=True),
680
- use_identifiers=False,
681
- )
682
- if sorted_z_serie != self.z_serie:
683
- if sorted_z_serie[:] != self.z_serie[::-1]:
684
- raise ValueError("Unable to get comprehensive input. Z (decreasing) ordering is not respected.")
685
- else:
686
- _logger.warning(
687
- f"z decreasing order haven't been respected. Need to reorder z serie ({[str(scan) for scan in sorted_z_serie[:]]}). Will also reorder overlap height, stitching height and invert shifts"
688
- )
689
- if self.configuration.axis_0_pos_mm is not None:
690
- self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1]
691
- if self.configuration.axis_0_pos_px is not None:
692
- self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1]
693
- if self.configuration.axis_1_pos_mm is not None:
694
- self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1]
695
- if self.configuration.axis_1_pos_px is not None:
696
- self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1]
697
- if self.configuration.axis_2_pos_mm is not None:
698
- self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1]
699
- if self.configuration.axis_2_pos_px is not None:
700
- self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1]
701
- if not numpy.isscalar(self._configuration.flip_ud):
702
- self._configuration.flip_ud = self._configuration.flip_ud[::-1]
703
- if not numpy.isscalar(self._configuration.flip_lr):
704
- self._configuration.flip_ud = self._configuration.flip_lr[::-1]
705
-
706
- self._z_serie = sorted_z_serie
707
-
708
- def _check_inputs(self):
709
- """
710
- insure input data is coherent
711
- """
712
- n_scans = len(self.z_serie)
713
- if n_scans == 0:
714
- raise ValueError("no scan to stich together")
715
-
716
- for scan in self.z_serie:
717
- from tomoscan.scanbase import TomoScanBase
718
-
719
- if not isinstance(scan, TomoScanBase):
720
- raise TypeError(f"z-preproc stitching expects instances of {TomoScanBase}. {type(scan)} provided.")
721
-
722
- # check output file path and data path are provided
723
- if self.configuration.output_file_path in (None, ""):
724
- raise ValueError("outptu_file_path should be provided to the configuration")
725
- if self.configuration.output_data_path in (None, ""):
726
- raise ValueError("output_data_path should be provided to the configuration")
727
-
728
- # check number of shift provided
729
- for axis_pos_px, axis_name in zip(
730
- (
731
- self.configuration.axis_0_pos_px,
732
- self.configuration.axis_1_pos_px,
733
- self.configuration.axis_2_pos_px,
734
- self.configuration.axis_0_pos_mm,
735
- self.configuration.axis_1_pos_mm,
736
- self.configuration.axis_2_pos_mm,
737
- ),
738
- (
739
- "axis_0_pos_px",
740
- "axis_1_pos_px",
741
- "axis_2_pos_px",
742
- "axis_0_pos_mm",
743
- "axis_1_pos_mm",
744
- "axis_2_pos_mm",
745
- ),
746
- ):
747
- if isinstance(axis_pos_px, Iterable) and len(axis_pos_px) != (n_scans):
748
- raise ValueError(f"{axis_name} expect {n_scans} shift defined. Get {len(axis_pos_px)}")
749
-
750
- self._reading_orders = []
751
- # the first scan will define the expected reading orderd, and expected flip.
752
- # if all scan are flipped then we will keep it this way
753
- self._reading_orders.append(1)
754
-
755
- # check scans are coherent (nb projections, rotation angle, energy...)
756
- for scan_0, scan_1 in zip(self.z_serie[0:-1], self.z_serie[1:]):
757
- if len(scan_0.projections) != len(scan_1.projections):
758
- raise ValueError(f"{scan_0} and {scan_1} have a different number of projections")
759
- if isinstance(scan_0, NXtomoScan) and isinstance(scan_1, NXtomoScan):
760
- # check rotation (only of is an NXtomoScan)
761
- scan_0_angles = numpy.asarray(scan_0.rotation_angle)
762
- scan_0_projections_angles = scan_0_angles[
763
- numpy.asarray(scan_0.image_key_control) == ImageKey.PROJECTION.value
764
- ]
765
- scan_1_angles = numpy.asarray(scan_1.rotation_angle)
766
- scan_1_projections_angles = scan_1_angles[
767
- numpy.asarray(scan_1.image_key_control) == ImageKey.PROJECTION.value
768
- ]
769
- if not numpy.allclose(scan_0_projections_angles, scan_1_projections_angles, atol=10e-1):
770
- if numpy.allclose(
771
- scan_0_projections_angles,
772
- scan_1_projections_angles[::-1],
773
- atol=10e-1,
774
- ):
775
- reading_order = -1 * self._reading_orders[-1]
776
- else:
777
- raise ValueError(f"Angles from {scan_0} and {scan_1} are different")
778
- else:
779
- reading_order = 1 * self._reading_orders[-1]
780
- self._reading_orders.append(reading_order)
781
- # check energy
782
- if scan_0.energy is None:
783
- _logger.warning(f"no energy found for {scan_0}")
784
- elif not numpy.isclose(scan_0.energy, scan_1.energy, rtol=1e-03):
785
- _logger.warning(
786
- f"different energy found between {scan_0} ({scan_0.energy}) and {scan_1} ({scan_1.energy})"
787
- )
788
- # check FOV
789
- if not scan_0.field_of_view == scan_1.field_of_view:
790
- raise ValueError(f"{scan_0} and {scan_1} have different field of view")
791
- # check distance
792
- if scan_0.distance is None:
793
- _logger.warning(f"no distance found for {scan_0}")
794
- elif not numpy.isclose(scan_0.distance, scan_1.distance, rtol=10e-3):
795
- raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance")
796
- # check pixel size
797
- if not numpy.isclose(scan_0.x_pixel_size, scan_1.x_pixel_size):
798
- raise ValueError(
799
- f"{scan_0} and {scan_1} have different x pixel size. {scan_0.x_pixel_size} vs {scan_1.x_pixel_size}"
800
- )
801
- if not numpy.isclose(scan_0.y_pixel_size, scan_1.y_pixel_size):
802
- raise ValueError(
803
- f"{scan_0} and {scan_1} have different y pixel size. {scan_0.y_pixel_size} vs {scan_1.y_pixel_size}"
804
- )
805
- if scan_0.dim_1 != scan_1.dim_1:
806
- raise ValueError(
807
- f"projections width are expected to be the same. Not the canse for {scan_0} ({scan_0.dim_1} and {scan_1} ({scan_1.dim_1}))"
808
- )
809
-
810
- for scan in self.z_serie:
811
- # check x, y and z translation are constant (only if is an NXtomoScan)
812
- if isinstance(scan, NXtomoScan):
813
- if scan.x_translation is not None and not numpy.isclose(
814
- min(scan.x_translation), max(scan.x_translation)
815
- ):
816
- _logger.warning(
817
- "x translations appears to be evolving over time. Might end up with wrong stitching"
818
- )
819
- if scan.y_translation is not None and not numpy.isclose(
820
- min(scan.y_translation), max(scan.y_translation)
821
- ):
822
- _logger.warning(
823
- "y translations appears to be evolving over time. Might end up with wrong stitching"
824
- )
825
- if scan.z_translation is not None and not numpy.isclose(
826
- min(scan.z_translation), max(scan.z_translation)
827
- ):
828
- _logger.warning(
829
- "z translations appears to be evolving over time. Might end up with wrong stitching"
830
- )
831
-
832
- def _compute_positions_as_px(self):
833
- """insure we have or we can deduce an estimated position as pixel"""
834
-
835
- def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm):
836
- if pos_as_px is not None:
837
- if pos_as_mm is not None:
838
- raise ValueError(
839
- f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})"
840
- )
841
- else:
842
- return pos_as_px
843
-
844
- elif pos_as_mm is not None:
845
- # deduce from position given in configuration and pixel size
846
- axis_N_pos_px = []
847
- for scan, pos_in_mm in zip(self.z_serie, pos_as_mm):
848
- pixel_size_m = self.configuration.pixel_size or scan.pixel_size
849
- axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / pixel_size_m)
850
- return axis_N_pos_px
851
- else:
852
- # deduce from motor position and pixel size
853
- axis_N_pos_px = []
854
- base_position_m = self.z_serie[0].get_bounding_box(axis=axis).min
855
- for scan in self.z_serie:
856
- pixel_size_m = self.configuration.pixel_size or scan.pixel_size
857
- scan_axis_bb = scan.get_bounding_box(axis=axis)
858
- axis_N_mean_pos_m = (scan_axis_bb.max - scan_axis_bb.min) / 2 + scan_axis_bb.min
859
- axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
860
- axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / pixel_size_m))
861
- return axis_N_pos_px
862
-
863
- self.configuration.axis_0_pos_px = get_position_as_px_on_axis(
864
- axis=0,
865
- pos_as_px=self.configuration.axis_0_pos_px,
866
- pos_as_mm=self.configuration.axis_0_pos_mm,
867
- )
868
- self.configuration.axis_0_pos_mm = None
869
-
870
- self.configuration.axis_2_pos_px = get_position_as_px_on_axis(
871
- axis=2,
872
- pos_as_px=self.configuration.axis_2_pos_px,
873
- pos_as_mm=self.configuration.axis_2_pos_mm,
874
- )
875
- self.configuration.axis_2_pos_mm = None
876
-
877
- # add some log
878
- if self.configuration.axis_1_pos_mm is not None or self.configuration.axis_1_pos_px is not None:
879
- _logger.warning("axis 1 position is not handled by the z-stitcher. Will be ignored")
880
- axis_0_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_0_pos_px])
881
- axis_2_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_2_pos_px])
882
- _logger.info(f"axis 0 position to be used: " + axis_0_pos)
883
- _logger.info(f"axis 2 position to be used: " + axis_2_pos)
884
-
885
- def _compute_reduced_flats_and_darks(self):
886
- """
887
- make sure reduced dark and flats are existing otherwise compute them
888
- """
889
- for scan in self.z_serie:
890
- try:
891
- reduced_darks, darks_infos = scan.load_reduced_darks(return_info=True)
892
- except:
893
- _logger.info("no reduced dark found. Try to compute them.")
894
- if reduced_darks in (None, {}):
895
- reduced_darks, darks_infos = scan.compute_reduced_darks(return_info=True)
896
- try:
897
- # if we don't have write in the folder containing the .nx for example
898
- scan.save_reduced_darks(reduced_darks, darks_infos=darks_infos)
899
- except Exception as e:
900
- pass
901
- scan.set_reduced_darks(reduced_darks, darks_infos=darks_infos)
902
-
903
- try:
904
- reduced_flats, flats_infos = scan.load_reduced_flats(return_info=True)
905
- except:
906
- _logger.info("no reduced flats found. Try to compute them.")
907
- if reduced_flats in (None, {}):
908
- reduced_flats, flats_infos = scan.compute_reduced_flats(return_info=True)
909
- try:
910
- # if we don't have write in the folder containing the .nx for example
911
- scan.save_reduced_flats(reduced_flats, flats_infos=flats_infos)
912
- except Exception as e:
913
- pass
914
- scan.set_reduced_flats(reduced_flats, flats_infos=flats_infos)
915
-
916
- def _compute_shifts(self):
917
- """
918
- compute all shift requested (set to 'auto' in the configuration)
919
- """
920
- n_scans = len(self.configuration.input_scans)
921
- if n_scans == 0:
922
- raise ValueError("no scan to stich provided")
923
-
924
- projection_for_shift = self.configuration.slice_for_cross_correlation or "middle"
925
- y_rel_shifts = self._axis_0_estimated_shifts
926
- x_rel_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_2_pos_px)
927
-
928
- final_rel_shifts = []
929
- for (
930
- scan_0,
931
- scan_1,
932
- order_s0,
933
- order_s1,
934
- x_rel_shift,
935
- y_rel_shift,
936
- ) in zip(
937
- self.z_serie[:-1],
938
- self.z_serie[1:],
939
- self.reading_orders[:-1],
940
- self.reading_orders[1:],
941
- x_rel_shifts,
942
- y_rel_shifts,
943
- ):
944
- x_cross_algo = self.configuration.axis_2_params.get(KEY_IMG_REG_METHOD, None)
945
- y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None)
946
-
947
- # compute relative shift
948
- found_shift_y, found_shift_x = find_projections_relative_shifts(
949
- upper_scan=scan_0,
950
- lower_scan=scan_1,
951
- projection_for_shift=projection_for_shift,
952
- x_cross_correlation_function=x_cross_algo,
953
- y_cross_correlation_function=y_cross_algo,
954
- x_shifts_params=self.configuration.axis_2_params,
955
- y_shifts_params=self.configuration.axis_0_params,
956
- invert_order=order_s1 != order_s0,
957
- estimated_shifts=(y_rel_shift, x_rel_shift),
958
- )
959
- final_rel_shifts.append(
960
- (found_shift_y, found_shift_x),
961
- )
962
-
963
- # set back values. Now position should start at 0
964
- self._axis_0_rel_shifts = [final_shift[0] for final_shift in final_rel_shifts]
965
- self._axis_2_rel_shifts = [final_shift[1] for final_shift in final_rel_shifts]
966
- _logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_shifts}")
967
- print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_shifts}")
968
- _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_2_rel_shifts}")
969
- print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_2_rel_shifts}")
970
-
971
- @staticmethod
972
- def _get_bunch_of_data(
973
- bunch_start: int,
974
- bunch_end: int,
975
- step: int,
976
- scans: tuple,
977
- scans_projections_indexes: tuple,
978
- reading_orders: tuple,
979
- flip_lr_arr: tuple,
980
- flip_ud_arr: tuple,
981
- ):
982
- """
983
- goal is to load contiguous projections as much as possible...
984
-
985
- :param int bunch_start: begining of the bunch
986
- :param int bunch_end: end of the bunch
987
- :param int scans: ordered scan for which we want to get data
988
- :param scans_projections_indexes: tuple with scans and scan projection indexes to be loaded
989
- :param tuple flip_lr_arr: extra information from the user to left-right flip frames
990
- :param tuple flip_ud_arr: extra information from the user to up-down flip frames
991
- :return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together
992
- """
993
- assert len(scans) == len(scans_projections_indexes)
994
- assert isinstance(flip_lr_arr, tuple)
995
- assert isinstance(flip_ud_arr, tuple)
996
- assert isinstance(step, int)
997
- scans_proj_urls = []
998
- # for each scan store the real indices and the data url
999
-
1000
- for scan, scan_projection_indexes in zip(scans, scans_projections_indexes):
1001
- scan_proj_urls = {}
1002
- # for each scan get the list of url to be loaded
1003
- for i_proj in range(bunch_start, bunch_end):
1004
- if i_proj % step != 0:
1005
- continue
1006
- proj_index_in_full_scan = scan_projection_indexes[i_proj]
1007
- scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan]
1008
- scans_proj_urls.append(scan_proj_urls)
1009
-
1010
- # then load data
1011
- all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object)
1012
- from nabu.preproc.flatfield import FlatFieldArrays
1013
-
1014
- for i_scan, (scan_urls, scan_flip_lr, scan_flip_ud, reading_order) in enumerate(
1015
- zip(scans_proj_urls, flip_lr_arr, flip_ud_arr, reading_orders)
1016
- ):
1017
- i_frame = 0
1018
- _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
1019
- for _, url in set_of_compacted_slices.items():
1020
- scan = scans[i_scan]
1021
- url = DataUrl(
1022
- file_path=url.file_path(),
1023
- data_path=url.data_path(),
1024
- scheme="silx",
1025
- data_slice=url.data_slice(),
1026
- )
1027
- raw_radios = get_data(url)[::reading_order]
1028
- radio_indices = url.data_slice()
1029
- if isinstance(radio_indices, slice):
1030
- step = radio_indices.step if radio_indices is not None else 1
1031
- radio_indices = numpy.arange(
1032
- start=radio_indices.start,
1033
- stop=radio_indices.stop,
1034
- step=step,
1035
- dtype=numpy.int16,
1036
- )
1037
-
1038
- missing = []
1039
- if len(scan.reduced_flats) == 0:
1040
- missing = "flats"
1041
- if len(scan.reduced_darks) == 0:
1042
- missing = "darks"
1043
-
1044
- if len(missing) > 0:
1045
- _logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
1046
- ff_arrays = None
1047
- data = raw_radios
1048
- else:
1049
- has_reduced_metadata = (
1050
- scan.reduced_flats_infos is not None
1051
- and len(scan.reduced_flats_infos.machine_electric_current) > 0
1052
- and scan.reduced_darks_infos is not None
1053
- and len(scan.reduced_darks_infos.machine_electric_current) > 0
1054
- )
1055
- if not has_reduced_metadata:
1056
- _logger.warning("no metadata about current found. Won't normalize according to machine current")
1057
-
1058
- ff_arrays = FlatFieldArrays(
1059
- radios_shape=(len(radio_indices), scan.dim_2, scan.dim_1),
1060
- flats=scan.reduced_flats,
1061
- darks=scan.reduced_darks,
1062
- radios_indices=radio_indices,
1063
- radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None,
1064
- flats_srcurrent=(
1065
- scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None
1066
- ),
1067
- )
1068
- # note: we need to cast radios to float 32. Darks and flats are cast to anyway
1069
- data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32))
1070
-
1071
- transformations = list(scans[i_scan].get_detector_transformations(tuple()))
1072
- if scan_flip_lr:
1073
- transformations.append(LRDetTransformation())
1074
- if scan_flip_ud:
1075
- transformations.append(UDDetTransformation())
1076
-
1077
- transformation_matrix_det_space = build_matrix(transformations)
1078
- if transformation_matrix_det_space is None or numpy.allclose(
1079
- transformation_matrix_det_space, numpy.identity(3)
1080
- ):
1081
- flip_ud = False
1082
- flip_lr = False
1083
- elif numpy.array_equal(transformation_matrix_det_space, ZStitcher._get_UD_flip_matrix()):
1084
- flip_ud = True
1085
- flip_lr = False
1086
- elif numpy.allclose(transformation_matrix_det_space, ZStitcher._get_LR_flip_matrix()):
1087
- flip_ud = False
1088
- flip_lr = True
1089
- elif numpy.allclose(transformation_matrix_det_space, ZStitcher._get_UD_AND_LR_flip_matrix()):
1090
- flip_ud = True
1091
- flip_lr = True
1092
- else:
1093
- raise ValueError("case not handled... For now only handle up-down flip as left-right flip")
1094
-
1095
- for frame in data:
1096
- if flip_ud:
1097
- frame = numpy.flipud(frame)
1098
- if flip_lr:
1099
- frame = numpy.fliplr(frame)
1100
- all_scan_final_data[i_frame, i_scan] = frame
1101
- i_frame += 1
1102
-
1103
- return all_scan_final_data
1104
-
1105
- def _compute_axis_0_estimated_shifts(self):
1106
- axis_0_pos_px = self.configuration.axis_0_pos_px
1107
- self._axis_0_estimated_shifts = []
1108
- # compute overlap along axis 0
1109
- for upper_scan, lower_scan, upper_scan_axis_0_pos, lower_scan_axis_0_pos in zip(
1110
- self.z_serie[:-1], self.z_serie[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:]
1111
- ):
1112
- upper_scan_pos = upper_scan_axis_0_pos - upper_scan.dim_2 / 2
1113
- lower_scan_high_pos = lower_scan_axis_0_pos + lower_scan.dim_2 / 2
1114
- # simple test of overlap. More complete test are runned by check_overlaps later
1115
- if lower_scan_high_pos <= upper_scan_pos:
1116
- raise ValueError(f"no overlap found between {upper_scan} and {lower_scan}")
1117
- self._axis_0_estimated_shifts.append(
1118
- int(lower_scan_high_pos - upper_scan_pos) # overlap are expected to be int for now
1119
- )
1120
-
1121
- def _create_nx_tomo(self, store_composition: bool = False):
1122
- """
1123
- create final NXtomo with stitched frames.
1124
- Policy: save all projections flat fielded. So this NXtomo will only contain projections (no dark and no flat).
1125
- But nabu will be able to reconstruct it with field `flatfield` set to False
1126
- """
1127
- nx_tomo = NXtomo()
1128
-
1129
- nx_tomo.energy = self.z_serie[0].energy
1130
- start_times = list(filter(None, [scan.start_time for scan in self.z_serie]))
1131
- end_times = list(filter(None, [scan.end_time for scan in self.z_serie]))
1132
-
1133
- if len(start_times) > 0:
1134
- nx_tomo.start_time = (
1135
- numpy.asarray([numpy.datetime64(start_time) for start_time in start_times]).min().astype(datetime)
1136
- )
1137
- else:
1138
- _logger.warning("Unable to find any start_time from input")
1139
- if len(end_times) > 0:
1140
- nx_tomo.end_time = (
1141
- numpy.asarray([numpy.datetime64(end_time) for end_time in end_times]).max().astype(datetime)
1142
- )
1143
- else:
1144
- _logger.warning("Unable to find any end_time from input")
1145
-
1146
- title = ";".join([scan.sequence_name or "" for scan in self.z_serie])
1147
- nx_tomo.title = f"stitch done from {title}"
1148
-
1149
- self._slices_to_stitch, n_proj = self.configuration.settle_slices()
1150
-
1151
- # handle detector (without frames)
1152
- nx_tomo.instrument.detector.field_of_view = self.z_serie[0].field_of_view
1153
- nx_tomo.instrument.detector.distance = self.z_serie[0].distance
1154
- nx_tomo.instrument.detector.x_pixel_size = self.z_serie[0].x_pixel_size
1155
- nx_tomo.instrument.detector.y_pixel_size = self.z_serie[0].y_pixel_size
1156
- nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
1157
- nx_tomo.instrument.detector.tomo_n = n_proj
1158
- # note: stitching process insure unflipping of frames. So make sure transformations is defined as an empty set
1159
- nx_tomo.instrument.detector.transformations = NXtransformations()
1160
-
1161
- if isinstance(self.z_serie[0], NXtomoScan):
1162
- # note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
1163
- rotation_angle = numpy.asarray(self.z_serie[0].rotation_angle)
1164
- nx_tomo.sample.rotation_angle = rotation_angle[
1165
- numpy.asarray(self.z_serie[0].image_key_control) == ImageKey.PROJECTION.value
1166
- ]
1167
- elif isinstance(self.z_serie[0], EDFTomoScan):
1168
- nx_tomo.sample.rotation_angle = numpy.linspace(
1169
- start=0, stop=self.z_serie[0].scan_range, num=self.z_serie[0].tomo_n
1170
- )
1171
- else:
1172
- raise NotImplementedError(
1173
- f"scan type ({type(self.z_serie[0])} is not handled)",
1174
- NXtomoScan,
1175
- isinstance(self.z_serie[0], NXtomoScan),
1176
- )
1177
-
1178
- # do a sub selection of the rotation angle if a we are only computing a part of the slices
1179
- def apply_slices_selection(array, slices):
1180
- if isinstance(slices, slice):
1181
- return array[slices.start : slices.stop : 1]
1182
- elif isinstance(slices, Iterable):
1183
- return list([array[index] for index in slices])
1184
- else:
1185
- raise RuntimeError("slices must be instance of a slice or of an iterable")
1186
-
1187
- nx_tomo.sample.rotation_angle = apply_slices_selection(
1188
- array=nx_tomo.sample.rotation_angle, slices=self._slices_to_stitch
1189
- )
1190
-
1191
- # handle sample
1192
- n_frames = n_proj
1193
- if False not in [isinstance(scan, NXtomoScan) for scan in self.z_serie]:
1194
- # we consider the new x, y and z position to be at the center of the one created
1195
- x_translation = [scan.x_translation for scan in self.z_serie if scan.x_translation is not None]
1196
- nx_tomo.sample.x_translation = [numpy.asarray(x_translation).mean()] * n_frames
1197
- y_translation = [scan.y_translation for scan in self.z_serie if scan.y_translation is not None]
1198
- nx_tomo.sample.y_translation = [numpy.asarray(y_translation).mean()] * n_frames
1199
- z_translation = [scan.z_translation for scan in self.z_serie if scan.z_translation is not None]
1200
- nx_tomo.sample.z_translation = [numpy.asarray(z_translation).mean()] * n_frames
1201
-
1202
- nx_tomo.sample.name = self.z_serie[0].sample_name
1203
-
1204
- # compute stiched frame shape
1205
- stitched_frame_shape = (
1206
- n_proj,
1207
- (
1208
- numpy.asarray([scan.dim_2 for scan in self.z_serie]).sum()
1209
- - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_shifts]).sum()
1210
- ),
1211
- self._stitching_width,
1212
- )
1213
-
1214
- # get expected output dataset first (just in case output and input files are the same)
1215
- first_proj_idx = sorted(self.z_serie[0].projections.keys())[0]
1216
- first_proj_url = self.z_serie[0].projections[first_proj_idx]
1217
- if h5py.is_hdf5(first_proj_url.file_path()):
1218
- first_proj_url = DataUrl(
1219
- file_path=first_proj_url.file_path(),
1220
- data_path=first_proj_url.data_path(),
1221
- scheme="h5py",
1222
- )
1223
-
1224
- # first save the NXtomo entry without the frame
1225
- # dicttonx will fail if the folder does not exists
1226
- dir_name = os.path.dirname(self.configuration.output_file_path)
1227
- if dir_name not in (None, ""):
1228
- os.makedirs(dir_name, exist_ok=True)
1229
- nx_tomo.save(
1230
- file_path=self.configuration.output_file_path,
1231
- data_path=self.configuration.output_data_path,
1232
- nexus_path_version=self.configuration.output_nexus_version,
1233
- overwrite=self.configuration.overwrite_results,
1234
- )
1235
-
1236
- transformation_matrices = {
1237
- scan.get_identifier()
1238
- .to_str()
1239
- .center(80, "-"): numpy.array2string(build_matrix(scan.get_detector_transformations(tuple())))
1240
- for scan in self.z_serie
1241
- }
1242
- _logger.info(
1243
- "scan detector transformation matrices are:\n"
1244
- "\n".join(["/n".join(item) for item in transformation_matrices.items()])
1245
- )
1246
-
1247
- _logger.info(
1248
- f"reading order is {self.reading_orders}",
1249
- )
1250
-
1251
- def get_output_data_type():
1252
- return numpy.float32 # because we will apply flat field correction on it and they are not raw data
1253
- # scan = self.z_serie[0]
1254
- # radio_url = tuple(scan.projections.values())[0]
1255
- # assert isinstance(radio_url, DataUrl)
1256
- # data = get_data(radio_url)
1257
- # return data.dtype
1258
-
1259
- output_dtype = get_output_data_type()
1260
- # append frames ("instrument/detactor/data" dataset)
1261
- with HDF5File(self.configuration.output_file_path, mode="a") as h5f:
1262
- # note: nx_tomo.save already handles the possible overwrite conflict by removing
1263
- # self.configuration.output_file_path or raising an error
1264
-
1265
- stitched_frame_path = "/".join(
1266
- [
1267
- self.configuration.output_data_path,
1268
- _get_nexus_paths(self.configuration.output_nexus_version).PROJ_PATH,
1269
- ]
1270
- )
1271
- projection_dataset = h5f.create_dataset(
1272
- name=stitched_frame_path,
1273
- shape=stitched_frame_shape,
1274
- dtype=output_dtype,
1275
- )
1276
- # TODO: we could also create in several time and create a virtual dataset from it.
1277
- scans_projections_indexes = []
1278
- for scan, reverse in zip(self.z_serie, self.reading_orders):
1279
- scans_projections_indexes.append(sorted(scan.projections.keys(), reverse=(reverse == -1)))
1280
- if self.progress:
1281
- self.progress.set_max_advancement(len(scan.projections.keys()))
1282
-
1283
- if isinstance(self._slices_to_stitch, slice):
1284
- step = self._slices_to_stitch.step or 1
1285
- else:
1286
- step = 1
1287
- i_proj = 0
1288
- for bunch_start, bunch_end in PreProcessZStitcher._data_bunch_iterator(
1289
- slices=self._slices_to_stitch, bunch_size=50
1290
- ):
1291
- for data_frames in PreProcessZStitcher._get_bunch_of_data(
1292
- bunch_start,
1293
- bunch_end,
1294
- step=step,
1295
- scans=self.z_serie,
1296
- scans_projections_indexes=scans_projections_indexes,
1297
- flip_ud_arr=self.configuration.flip_ud,
1298
- flip_lr_arr=self.configuration.flip_lr,
1299
- reading_orders=self.reading_orders,
1300
- ):
1301
- if self.configuration.rescale_frames:
1302
- data_frames = self.rescale_frames(data_frames)
1303
- if self.configuration.normalization_by_sample.is_active():
1304
- data_frames = self.normalize_frame_by_sample(data_frames)
1305
-
1306
- sf = ZStitcher.stitch_frames(
1307
- frames=data_frames,
1308
- x_relative_shifts=self._axis_2_rel_shifts,
1309
- y_relative_shifts=self._axis_0_rel_shifts,
1310
- output_dataset=projection_dataset,
1311
- overlap_kernels=self._overlap_kernels,
1312
- i_frame=i_proj,
1313
- output_dtype=output_dtype,
1314
- dump_frame_fct=self._dump_frame,
1315
- return_composition_cls=store_composition if i_proj == 0 else False,
1316
- stitching_axis=0,
1317
- pad_mode=self.configuration.pad_mode,
1318
- alignment=self.configuration.alignment_axis_2,
1319
- new_width=self._stitching_width,
1320
- check_inputs=i_proj == 0, # on process check on the first iteration
1321
- )
1322
- if i_proj == 0 and store_composition:
1323
- _, self._frame_composition = sf
1324
- if self.progress is not None:
1325
- self.progress.increase_advancement()
1326
-
1327
- i_proj += 1
1328
-
1329
- # create link to this dataset that can be missing
1330
- # "data/data" link
1331
- if "data" in h5f[self.configuration.output_data_path]:
1332
- data_group = h5f[self.configuration.output_data_path]["data"]
1333
- if not stitched_frame_path.startswith("/"):
1334
- stitched_frame_path = "/" + stitched_frame_path
1335
- data_group["data"] = h5py.SoftLink(stitched_frame_path)
1336
- if "default" not in h5f[self.configuration.output_data_path].attrs:
1337
- h5f[self.configuration.output_data_path].attrs["default"] = "data"
1338
- for attr_name, attr_value in zip(
1339
- ("NX_class", "SILX_style/axis_scale_types", "signal"),
1340
- ("NXdata", ["linear", "linear"], "data"),
1341
- ):
1342
- if attr_name not in data_group.attrs:
1343
- data_group.attrs[attr_name] = attr_value
1344
-
1345
- return nx_tomo
1346
-
1347
- def _dump_stitching_configuration(self):
1348
- """dump configuration used for stitching at the NXtomo entry"""
1349
- process_name = "stitching_configuration"
1350
- config_dict = self.configuration.to_dict()
1351
- # adding nabu specific information
1352
- nabu_process_info = {
1353
- "@NX_class": "NXentry",
1354
- f"{process_name}@NX_class": "NXprocess",
1355
- f"{process_name}/program": "nabu-stitching",
1356
- f"{process_name}/version": nabu_version,
1357
- f"{process_name}/date": get_datetime(),
1358
- f"{process_name}/configuration": config_dict,
1359
- }
1360
-
1361
- dicttonx(
1362
- nabu_process_info,
1363
- h5file=self.configuration.output_file_path,
1364
- h5path=self.configuration.output_data_path,
1365
- update_mode="replace",
1366
- mode="a",
1367
- )
1368
-
1369
-
1370
- class PostProcessZStitcher(ZStitcher):
1371
- def __init__(self, configuration, progress: Progress = None) -> None:
1372
- self._input_volumes = configuration.input_volumes
1373
- self.__output_data_type = None
1374
-
1375
- self._z_serie = Serie("z-serie", iterable=self._input_volumes, use_identifiers=False)
1376
- super().__init__(configuration, progress)
1377
-
1378
- @staticmethod
1379
- def _dump_frame(output_dataset: h5py.Dataset, index: int, stitched_frame: numpy.ndarray):
1380
- # fix numpy array direction to be coherent with input data
1381
- # stitched_frame is received a `z-down` frame and we want to return in the same orientation (`z-up`)
1382
- output_dataset[:, index, :] = stitched_frame
1383
-
1384
- def stitch(self, store_composition=True) -> BaseIdentifier:
1385
- """
1386
- Apply expected stitch from configuration and return the DataUrl of the object created
1387
- """
1388
- if self.progress is not None:
1389
- self.progress.set_name("order volumes")
1390
- self._order_volumes()
1391
- if self.progress is not None:
1392
- self.progress.set_name("check inputs")
1393
- self._check_inputs()
1394
- self.settle_flips()
1395
- if self.progress is not None:
1396
- self.progress.set_name("compute shifts")
1397
- self._compute_positions_as_px()
1398
- self._compute_axis_0_estimated_shifts()
1399
- self._compute_shifts()
1400
- self._createOverlapKernels()
1401
- if self.progress is not None:
1402
- self.progress.set_name("stitch volumes")
1403
- self._create_stitched_volume(store_composition=store_composition)
1404
- if self.progress is not None:
1405
- self.progress.set_name("dump configuration")
1406
- self._dump_stitching_configuration()
1407
- return self.configuration.output_volume.get_identifier()
1408
-
1409
- def _order_volumes(self):
1410
- """
1411
- ensure scans are in z increasing order
1412
- """
1413
-
1414
- def get_min_z(volume):
1415
- try:
1416
- bb = volume.get_bounding_box(axis="z")
1417
- except ValueError: # if missing information
1418
- bb = None
1419
- if bb is not None:
1420
- return bb.min
1421
- else:
1422
- # if can't find bounding box (missing metadata to the volume
1423
- # try to get it from the scan
1424
- metadata = volume.metadata or volume.load_metadata()
1425
- scan_location = metadata.get("nabu_config", {}).get("dataset", {}).get("location", None)
1426
- scan_entry = metadata.get("nabu_config", {}).get("dataset", {}).get("hdf5_entry", None)
1427
- if scan_location is not None:
1428
- # this work around (until most volume have position metadata) works only for Hdf5volume
1429
- with cwd_context(os.path.dirname(volume.file_path)):
1430
- o_scan = NXtomoScan(scan_location, scan_entry)
1431
- bb_acqui = o_scan.get_bounding_box(axis=None)
1432
- # for next step volume position will be required.
1433
- # if you can find it set it directly
1434
- volume.position = (numpy.array(bb_acqui.max) - numpy.array(bb_acqui.min)) / 2.0 + numpy.array(
1435
- bb_acqui.min
1436
- )
1437
- # for now translation are stored in pixel size ref instead of real_pixel_size
1438
- volume.pixel_size = o_scan.x_real_pixel_size
1439
- if bb_acqui is not None:
1440
- return bb_acqui.min[0]
1441
- raise ValueError("Unable to find volume position. Unable to deduce z position")
1442
-
1443
- try:
1444
- # order volumes from higher z to lower z
1445
- # if axis 0 position is provided then use directly it
1446
- if self.configuration.axis_0_pos_px is not None and len(self.configuration.axis_0_pos_px) > 0:
1447
- order = numpy.argsort(self.configuration.axis_0_pos_px)
1448
- sorted_z_serie = Serie(
1449
- self.z_serie.name,
1450
- numpy.take_along_axis(numpy.array(self.z_serie[:]), order, axis=0)[::-1],
1451
- use_identifiers=False,
1452
- )
1453
- else:
1454
- # else use bounding box
1455
- sorted_z_serie = Serie(
1456
- self.z_serie.name,
1457
- sorted(self.z_serie[:], key=get_min_z, reverse=True),
1458
- use_identifiers=False,
1459
- )
1460
- except ValueError:
1461
- _logger.warning(
1462
- "Unable to find volume positions in metadata. Expect the volume to be ordered already (decreasing along axis 0.)"
1463
- )
1464
- else:
1465
- if sorted_z_serie == self.z_serie:
1466
- pass
1467
- elif sorted_z_serie != self.z_serie:
1468
- if sorted_z_serie[:] != self.z_serie[::-1]:
1469
- raise ValueError(
1470
- "Unable to get comprehensive input. ordering along axis 0 is not respected (decreasing)."
1471
- )
1472
- else:
1473
- _logger.warning(
1474
- f"z decreasing order haven't been respected. Need to reorder z serie ({[str(scan) for scan in sorted_z_serie[:]]}). Will also reorder positions"
1475
- )
1476
- if self.configuration.axis_0_pos_mm is not None:
1477
- self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1]
1478
- if self.configuration.axis_0_pos_px is not None:
1479
- self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1]
1480
- if self.configuration.axis_1_pos_mm is not None:
1481
- self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1]
1482
- if self.configuration.axis_1_pos_px is not None:
1483
- self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1]
1484
- if self.configuration.axis_2_pos_mm is not None:
1485
- self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1]
1486
- if self.configuration.axis_2_pos_px is not None:
1487
- self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1]
1488
- if not numpy.isscalar(self._configuration.flip_ud):
1489
- self._configuration.flip_ud = self._configuration.flip_ud[::-1]
1490
- if not numpy.isscalar(self._configuration.flip_lr):
1491
- self._configuration.flip_ud = self._configuration.flip_lr[::-1]
1492
-
1493
- self._z_serie = sorted_z_serie
1494
-
1495
- def _compute_positions_as_px(self):
1496
- """compute if necessary position other axis 0 from volume metadata"""
1497
-
1498
- def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm):
1499
- if pos_as_px is not None:
1500
- if pos_as_mm is not None:
1501
- raise ValueError(
1502
- f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})"
1503
- )
1504
- else:
1505
- return pos_as_px
1506
-
1507
- elif pos_as_mm is not None:
1508
- # deduce from position given in configuration and pixel size
1509
- axis_N_pos_px = []
1510
- for volume, pos_in_mm in zip(self.z_serie, pos_as_mm):
1511
- voxel_size_m = self.configuration.voxel_size or volume.voxel_size
1512
- axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0])
1513
- return axis_N_pos_px
1514
- else:
1515
- # deduce from motor position and pixel size
1516
- axis_N_pos_px = []
1517
- base_position_m = self.z_serie[0].get_bounding_box(axis=axis).min
1518
- for volume in self.z_serie:
1519
- voxel_size_m = self.configuration.voxel_size or volume.voxel_size
1520
- volume_axis_bb = volume.get_bounding_box(axis=axis)
1521
- axis_N_mean_pos_m = (volume_axis_bb.max - volume_axis_bb.min) / 2 + volume_axis_bb.min
1522
- axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
1523
- axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / voxel_size_m[0]))
1524
- return axis_N_pos_px
1525
-
1526
- self.configuration.axis_0_pos_px = get_position_as_px_on_axis(
1527
- axis=0,
1528
- pos_as_px=self.configuration.axis_0_pos_px,
1529
- pos_as_mm=self.configuration.axis_0_pos_mm,
1530
- )
1531
- self.configuration.axis_0_pos_mm = None
1532
-
1533
- self.configuration.axis_2_pos_px = get_position_as_px_on_axis(
1534
- axis=2,
1535
- pos_as_px=self.configuration.axis_2_pos_px,
1536
- pos_as_mm=self.configuration.axis_2_pos_mm,
1537
- )
1538
- self.configuration.axis_2_pos_mm = None
1539
-
1540
- def _compute_axis_0_estimated_shifts(self):
1541
- axis_0_pos_px = self.configuration.axis_0_pos_px
1542
- self._axis_0_estimated_shifts = []
1543
- # compute overlap along axis 0
1544
- for upper_volume, lower_volume, upper_volume_axis_0_pos, lower_volume_axis_0_pos in zip(
1545
- self.z_serie[:-1], self.z_serie[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:]
1546
- ):
1547
- upper_volume_low_pos = upper_volume_axis_0_pos - upper_volume.get_volume_shape()[0] / 2
1548
- lower_volume_high_pos = lower_volume_axis_0_pos + lower_volume.get_volume_shape()[0] / 2
1549
- self._axis_0_estimated_shifts.append(
1550
- int(lower_volume_high_pos - upper_volume_low_pos) # overlap are expected to be int for now
1551
- )
1552
-
1553
- def _compute_shifts(self):
1554
- n_volumes = len(self.configuration.input_volumes)
1555
- if n_volumes == 0:
1556
- raise ValueError("no scan to stich provided")
1557
-
1558
- slice_for_shift = self.configuration.slice_for_cross_correlation or "middle"
1559
- y_rel_shifts = self._axis_0_estimated_shifts
1560
- x_rel_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_2_pos_px)
1561
- dim_axis_1 = max([volume.get_volume_shape()[1] for volume in self.z_serie])
1562
-
1563
- final_rel_shifts = []
1564
- for (
1565
- upper_volume,
1566
- lower_volume,
1567
- x_rel_shift,
1568
- y_rel_shift,
1569
- flip_ud_upper,
1570
- flip_ud_lower,
1571
- ) in zip(
1572
- self.z_serie[:-1],
1573
- self.z_serie[1:],
1574
- x_rel_shifts,
1575
- y_rel_shifts,
1576
- self.configuration.flip_ud[:-1],
1577
- self.configuration.flip_ud[1:],
1578
- ):
1579
- x_cross_algo = self.configuration.axis_2_params.get(KEY_IMG_REG_METHOD, None)
1580
- y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None)
1581
-
1582
- # compute relative shift
1583
- found_shift_y, found_shift_x = find_volumes_relative_shifts(
1584
- upper_volume=upper_volume,
1585
- lower_volume=lower_volume,
1586
- dtype=self.get_output_data_type(),
1587
- dim_axis_1=dim_axis_1,
1588
- slice_for_shift=slice_for_shift,
1589
- x_cross_correlation_function=x_cross_algo,
1590
- y_cross_correlation_function=y_cross_algo,
1591
- x_shifts_params=self.configuration.axis_2_params,
1592
- y_shifts_params=self.configuration.axis_0_params,
1593
- estimated_shifts=(y_rel_shift, x_rel_shift),
1594
- flip_ud_lower_frame=flip_ud_lower,
1595
- flip_ud_upper_frame=flip_ud_upper,
1596
- alignment_axis_1=self.configuration.alignment_axis_1,
1597
- alignment_axis_2=self.configuration.alignment_axis_2,
1598
- )
1599
- final_rel_shifts.append(
1600
- (found_shift_y, found_shift_x),
1601
- )
1602
-
1603
- # set back values. Now position should start at 0
1604
- self._axis_0_rel_shifts = [final_shift[0] for final_shift in final_rel_shifts]
1605
- self._axis_2_rel_shifts = [final_shift[1] for final_shift in final_rel_shifts]
1606
- _logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_2_rel_shifts}")
1607
- print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_2_rel_shifts}")
1608
- _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_shifts}")
1609
- print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_shifts}")
1610
-
1611
- def _dump_stitching_configuration(self):
1612
- voxel_size = self._input_volumes[0].voxel_size
1613
-
1614
- def get_position():
1615
- # the z-serie is z-ordered from higher to lower. We can reuse this with pixel size and shape to
1616
- # compute the position of the stitched volume
1617
- if voxel_size is None:
1618
- return None
1619
- return numpy.array(self._input_volumes[0].position) + voxel_size * (
1620
- numpy.array(self._input_volumes[0].get_volume_shape()) / 2.0
1621
- - numpy.array(self.configuration.output_volume.get_volume_shape()) / 2.0
1622
- )
1623
-
1624
- self.configuration.output_volume.voxel_size = voxel_size or ""
1625
- try:
1626
- self.configuration.output_volume.position = get_position()
1627
- except Exception:
1628
- self.configuration.output_volume.position = numpy.array([0, 0, 0])
1629
-
1630
- self.configuration.output_volume.metadata.update(
1631
- {
1632
- "program": "nabu-stitching",
1633
- "version": nabu_version,
1634
- "date": get_datetime(),
1635
- "configuration": self.configuration.to_dict(),
1636
- }
1637
- )
1638
- self.configuration.output_volume.save_metadata()
1639
-
1640
- def _check_inputs(self):
1641
- """
1642
- insure input data is coherent
1643
- """
1644
- # check input volume
1645
- if self.configuration.output_volume is None:
1646
- raise ValueError("input volume should be provided")
1647
-
1648
- n_volumes = len(self.z_serie)
1649
- if n_volumes == 0:
1650
- raise ValueError("no scan to stich together")
1651
-
1652
- if not isinstance(self.configuration.output_volume, VolumeBase):
1653
- raise TypeError(f"make sure we return a volume identifier not {(type(self.configuration.output_volume))}")
1654
-
1655
- # check axis 0 position
1656
- if isinstance(self.configuration.axis_0_pos_px, Iterable) and len(self.configuration.axis_0_pos_px) != (
1657
- n_volumes
1658
- ):
1659
- raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_px)}")
1660
- if isinstance(self.configuration.axis_0_pos_mm, Iterable) and len(self.configuration.axis_0_pos_mm) != (
1661
- n_volumes
1662
- ):
1663
- raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_0_pos_mm)}")
1664
-
1665
- # check axis 1 position
1666
- if isinstance(self.configuration.axis_1_pos_px, Iterable) and len(self.configuration.axis_1_pos_px) != (
1667
- n_volumes
1668
- ):
1669
- raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_px)}")
1670
- if isinstance(self.configuration.axis_1_pos_mm, Iterable) and len(self.configuration.axis_1_pos_mm) != (
1671
- n_volumes
1672
- ):
1673
- raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_1_pos_mm)}")
1674
-
1675
- # check axis 2 position
1676
- if isinstance(self.configuration.axis_2_pos_px, Iterable) and len(self.configuration.axis_2_pos_px) != (
1677
- n_volumes
1678
- ):
1679
- raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_2_pos_px)}")
1680
- if isinstance(self.configuration.axis_2_pos_mm, Iterable) and len(self.configuration.axis_2_pos_mm) != (
1681
- n_volumes
1682
- ):
1683
- raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_2_pos_mm)}")
1684
-
1685
- self._reading_orders = []
1686
- # the first scan will define the expected reading orderd, and expected flip.
1687
- # if all scan are flipped then we will keep it this way
1688
- self._reading_orders.append(1)
1689
-
1690
- def get_output_data_type(self):
1691
- if self.__output_data_type is None:
1692
-
1693
- def find_output_data_type():
1694
- first_vol = self._input_volumes[0]
1695
- if first_vol.data is not None:
1696
- return first_vol.data.dtype
1697
- elif isinstance(first_vol, HDF5Volume):
1698
- with DatasetReader(first_vol.data_url) as vol_dataset:
1699
- return vol_dataset.dtype
1700
- else:
1701
- return first_vol.load_data(store=False).dtype
1702
-
1703
- self.__output_data_type = find_output_data_type()
1704
- return self.__output_data_type
1705
-
1706
- def _create_stitched_volume(self, store_composition: bool):
1707
- overlap_kernels = self._overlap_kernels
1708
- self._slices_to_stitch, n_slices = self.configuration.settle_slices()
1709
-
1710
- # sync overwrite_results with volume overwrite parameter
1711
- self.configuration.output_volume.overwrite = self.configuration.overwrite_results
1712
-
1713
- # init final volume
1714
- final_volume = self.configuration.output_volume
1715
- final_volume_shape = (
1716
- int(
1717
- numpy.asarray([volume.get_volume_shape()[0] for volume in self._input_volumes]).sum()
1718
- - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_shifts]).sum(),
1719
- ),
1720
- n_slices,
1721
- self._stitching_width,
1722
- )
1723
-
1724
- data_type = self.get_output_data_type()
1725
-
1726
- if self.progress:
1727
- self.progress.set_max_advancement(final_volume_shape[1])
1728
-
1729
- y_index = 0
1730
- if isinstance(self._slices_to_stitch, slice):
1731
- step = self._slices_to_stitch.step or 1
1732
- else:
1733
- step = 1
1734
- with PostProcessZStitcher._FinalDatasetContext(
1735
- volume=final_volume, volume_shape=final_volume_shape, dtype=data_type
1736
- ) as output_dataset:
1737
- # note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
1738
- with PostProcessZStitcher._RawDatasetsContext(
1739
- self._input_volumes,
1740
- alignment_axis_1=self.configuration.alignment_axis_1,
1741
- ) as raw_datasets:
1742
- # note: raw_datasets can be numpy arrays or HDF5 dataset (in the case of HDF5Volume)
1743
- # to speed up we read by bunch of dataset. For numpy array this doesn't change anything
1744
- # but for HDF5 dataset this can speed up a lot the processing (depending on HDF5 dataset chuncks)
1745
- # note: we read trhough axis 1
1746
- for bunch_start, bunch_end in PostProcessZStitcher._data_bunch_iterator(
1747
- slices=self._slices_to_stitch, bunch_size=50
1748
- ):
1749
- for data_frames in PostProcessZStitcher._get_bunch_of_data(
1750
- bunch_start,
1751
- bunch_end,
1752
- step=step,
1753
- volumes=raw_datasets,
1754
- flip_lr_arr=self.configuration.flip_lr,
1755
- flip_ud_arr=self.configuration.flip_ud,
1756
- ):
1757
- if self.configuration.rescale_frames:
1758
- data_frames = self.rescale_frames(data_frames)
1759
- if self.configuration.normalization_by_sample.is_active():
1760
- data_frames = self.normalize_frame_by_sample(data_frames)
1761
-
1762
- sf = ZStitcher.stitch_frames(
1763
- frames=data_frames,
1764
- x_relative_shifts=self._axis_2_rel_shifts,
1765
- y_relative_shifts=self._axis_0_rel_shifts,
1766
- overlap_kernels=overlap_kernels,
1767
- output_dataset=output_dataset,
1768
- dump_frame_fct=self._dump_frame,
1769
- i_frame=y_index,
1770
- output_dtype=data_type,
1771
- return_composition_cls=store_composition if y_index == 0 else False,
1772
- stitching_axis=0,
1773
- check_inputs=y_index == 0, # on process check on the first iteration
1774
- )
1775
- if y_index == 0 and store_composition:
1776
- _, self._frame_composition = sf
1777
-
1778
- if self.progress is not None:
1779
- self.progress.increase_advancement()
1780
- y_index += 1
1781
-
1782
- @staticmethod
1783
- def _get_bunch_of_data(
1784
- bunch_start: int,
1785
- bunch_end: int,
1786
- step: int,
1787
- volumes: tuple,
1788
- flip_lr_arr: bool,
1789
- flip_ud_arr: bool,
1790
- ):
1791
- """
1792
- goal is to load contiguous frames as much as possible...
1793
- return for each volume the bunch of slice along axis 1
1794
- warning: they can have different shapes
1795
- """
1796
-
1797
- def get_sub_volume(volume, flip_lr, flip_ud):
1798
- sub_volume = volume[:, bunch_start:bunch_end:step, :]
1799
- if flip_lr:
1800
- sub_volume = numpy.fliplr(sub_volume)
1801
- if flip_ud:
1802
- sub_volume = numpy.flipud(sub_volume)
1803
- return sub_volume
1804
-
1805
- sub_volumes = [
1806
- get_sub_volume(volume, flip_lr, flip_ud)
1807
- for volume, flip_lr, flip_ud in zip(volumes, flip_lr_arr, flip_ud_arr)
1808
- ]
1809
- # generator on it self: we want to iterate over the y axis
1810
- n_slices_in_bunch = ceil((bunch_end - bunch_start) / step)
1811
- assert isinstance(n_slices_in_bunch, int)
1812
- for i in range(n_slices_in_bunch):
1813
- yield [sub_volume[:, i, :] for sub_volume in sub_volumes]
1814
-
1815
- class _FinalDatasetContext(AbstractContextManager):
1816
- """Manager to create the data volume and save it (data only !). target: used for volume stitching
1817
- In the case of HDF5 we want to save this directly in the file to avoid
1818
- keeping the full volume in memory.
1819
- Insure also contain processing will be common between the different processing
1820
- """
1821
-
1822
- def __init__(self, volume: VolumeBase, volume_shape: tuple, dtype: numpy.dtype) -> None:
1823
- super().__init__()
1824
- if not isinstance(volume, VolumeBase):
1825
- raise TypeError(
1826
- f"Volume is expected to be an instance of {VolumeBase}. {type(volume)} provided instead"
1827
- )
1828
-
1829
- self._volume = volume
1830
- self._volume_shape = volume_shape
1831
- self.__file_handler = None
1832
- self._dtype = dtype
1833
-
1834
- def __enter__(self):
1835
- # handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory
1836
- if isinstance(self._volume, HDF5Volume):
1837
- self.__file_handler = HDF5File(self._volume.data_url.file_path(), mode="a")
1838
- # if need to delete an existing dataset
1839
- if self._volume.overwrite and self._volume.data_path in self.__file_handler:
1840
- try:
1841
- del self.__file_handler[self._volume.data_path]
1842
- except Exception as e:
1843
- _logger.error(f"Fail to overwrite data. Reason is {e}")
1844
- data = None
1845
- self.__file_handler.close()
1846
- return data
1847
-
1848
- # create dataset
1849
- try:
1850
- data = self.__file_handler.create_dataset(
1851
- self._volume.data_url.data_path(),
1852
- shape=self._volume_shape,
1853
- dtype=self._dtype,
1854
- )
1855
- except Exception as e2:
1856
- _logger.error(f"Fail to create final dataset. Reason is {e2}")
1857
- data = None
1858
- self.__file_handler.close()
1859
- # for other file format: create the full dataset in memory
1860
- else:
1861
- data = numpy.empty(self._volume_shape, dtype=self._dtype)
1862
- return data
1863
-
1864
- def __exit__(self, *exc):
1865
- if self.__file_handler is not None:
1866
- return self.__file_handler.close()
1867
- else:
1868
- self._volume.save_data()
1869
-
1870
- class _RawDatasetsContext(AbstractContextManager):
1871
- """
1872
- return volume data for all input volume (target: used for volume stitching).
1873
- If the volume is an HDF5Volume then the HDF5 dataset will be used (on disk)
1874
- If the volume is of another type then it will be loaded in memory then used (more memory consuming)
1875
- """
1876
-
1877
- def __init__(self, volumes: tuple, alignment_axis_1) -> None:
1878
- super().__init__()
1879
- for volume in volumes:
1880
- if not isinstance(volume, VolumeBase):
1881
- raise TypeError(
1882
- f"Volumes are expected to be an instance of {VolumeBase}. {type(volume)} provided instead"
1883
- )
1884
-
1885
- self._volumes = volumes
1886
- self.__file_handlers = []
1887
- self._alignment_axis_1 = alignment_axis_1
1888
-
1889
- @property
1890
- def alignment_axis_1(self):
1891
- return self._alignment_axis_1
1892
-
1893
- def __enter__(self):
1894
- # handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory
1895
- datasets = []
1896
- shapes = {volume.get_volume_shape()[1] for volume in self._volumes}
1897
- axis_1_dim = max(shapes)
1898
- axis_1_need_padding = len(shapes) > 1
1899
-
1900
- try:
1901
- for volume in self._volumes:
1902
- if volume.data is not None:
1903
- data = volume.data
1904
- elif isinstance(volume, HDF5Volume):
1905
- file_handler = HDF5File(volume.data_url.file_path(), mode="r")
1906
- dataset = file_handler[volume.data_url.data_path()]
1907
- data = dataset
1908
- self.__file_handlers.append(file_handler)
1909
- # for other file format: load the full dataset in memory
1910
- else:
1911
- data = volume.load_data(store=False)
1912
- if data is None:
1913
- raise ValueError(f"No data found for volume {volume.get_identifier()}")
1914
- if axis_1_need_padding:
1915
- data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
1916
- datasets.append(data)
1917
- except Exception as e:
1918
- # if some errors happen during loading HDF5
1919
- for file_handled in self.__file_handlers:
1920
- file_handled.close()
1921
- raise e
1922
-
1923
- return datasets
1924
-
1925
- def __exit__(self, *exc):
1926
- success = True
1927
- for file_handler in self.__file_handlers:
1928
- success = success and file_handler.close()
1929
- return success
1930
-
1931
- def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
1932
- alignment = AlignmentAxis1.from_value(alignment)
1933
- if alignment is AlignmentAxis1.BACK:
1934
- axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
1935
- elif alignment is AlignmentAxis1.CENTER:
1936
- half_width = int((axis_1_dim - data.shape[1]) / 2)
1937
- axis_1_pad_width = (half_width, axis_1_dim - data.shape[1] - half_width)
1938
- elif alignment is AlignmentAxis1.FRONT:
1939
- axis_1_pad_width = (0, axis_1_dim - data.shape[1])
1940
- else:
1941
- raise ValueError(f"alignment {alignment} is not handled")
1942
-
1943
- return PaddedRawData(
1944
- data=data,
1945
- axis_1_pad_width=axis_1_pad_width,
1946
- )
1947
-
1948
-
1949
- def stitch_vertically_raw_frames(
1950
- frames: tuple,
1951
- key_lines: tuple,
1952
- overlap_kernels: Union[ZStichOverlapKernel, tuple],
1953
- output_dtype: numpy.dtype = numpy.float32,
1954
- check_inputs=True,
1955
- raw_frames_compositions: Optional[ZFrameComposition] = None,
1956
- overlap_frames_compositions: Optional[ZFrameComposition] = None,
1957
- return_composition_cls=False,
1958
- alignment="center",
1959
- pad_mode="constant",
1960
- new_width: Optional[int] = None,
1961
- ) -> numpy.ndarray:
1962
- """
1963
- stitches raw frames (already shifted and flat fielded !!!) together using
1964
- raw stitching (no pixel interpolation, y_overlap_in_px is expected to be a int).
1965
- Sttiching is done vertically (along the y axis of the frame ref)
1966
-
1967
- | --------------
1968
- | | |
1969
- | | Frame 1 | --------------
1970
- | | | | Frame 1 |
1971
- | -------------- | |
1972
- Y | --> stitching |~ stitching ~|
1973
- | -------------- | |
1974
- | | | | Frame 2 |
1975
- | | Frame 2 | --------------
1976
- | | |
1977
- | --------------
1978
- |
1979
-
1980
- returns stitched_projection, raw_img_1, raw_img_2, computed_overlap
1981
- proj_0 and pro_1 are already expected to be in a row. Having stitching_height_in_px in common. At top of proj_0
1982
- and at bottom of proj_1
1983
-
1984
- :param tuple frames: tuple of 2D numpy array. Expected to be Z up oriented at this stage
1985
- :param tuple key_lines: for each jonction define the two lines to overlaid (from the upper and the lower frames). In the reference where 0 is the bottom line of the image.
1986
- :param overlap_kernels: ZStichOverlapKernel overlap kernel to be used or a list of kernel (one per overlap). Define startegy and overlap heights
1987
- :param numpy.dtype output_dtype: dataset dtype. For now must be provided because flat field corrcetion change data type (numpy.float32 for now)
1988
- :param bool check_inputs: if True will do more test on inputs parameters like checking frame shapes, coherence of the request.. As it can be time consuming it is optional
1989
- :param raw_frames_compositions: pre computed raw frame composition. If not provided will compute them. allow providing it to speed up calculation
1990
- :param overlap_frames_compositions: pre computed stitched frame composition. If not provided will compute them. allow providing it to speed up calculation
1991
- :param bool return_frame_compositions: if False return simply the stitched frames. Else return a tuple with stitching frame and the dictionnary with the composition frames...
1992
- """
1993
- assert overlap_kernels is not None, "overlap kernels must be provided"
1994
-
1995
- if check_inputs:
1996
-
1997
- def check_frame(proj):
1998
- if not isinstance(proj, numpy.ndarray) and proj.ndim == 2:
1999
- raise ValueError(f"frames are expected to be 2D numpy array")
2000
-
2001
- [check_frame(frame) for frame in frames]
2002
- for frame_0, frame_1 in zip(frames[:-1], frames[1:]):
2003
- if not (frame_0.ndim == frame_1.ndim == 2):
2004
- raise ValueError("Frames are expected to be 2D")
2005
-
2006
- for frame_0, frame_1, kernel in zip(frames[:-1], frames[1:], overlap_kernels):
2007
- if frame_0.shape[0] < kernel.overlap_size:
2008
- raise ValueError(
2009
- f"frame_0 height ({frame_0.shape[0]}) is less than kernel overlap ({kernel.overlap_size})"
2010
- )
2011
- if frame_1.shape[0] < kernel.overlap_size:
2012
- raise ValueError(
2013
- f"frame_1 height ({frame_1.shape[0]}) is less than kernel overlap ({kernel.overlap_size})"
2014
- )
2015
- if not len(key_lines) == len(overlap_kernels):
2016
- raise ValueError("we expect to have the same number of key_lines then the number of kernel")
2017
- else:
2018
- for key_line in key_lines:
2019
- for value in key_line:
2020
- if not isinstance(value, (int, numpy.integer)):
2021
- raise TypeError(f"key_line is expected to be an integer. {type(key_line)} provided")
2022
- elif value < 0:
2023
- raise ValueError(f"key lines are expected to be positive values. Get {value} as key line value")
2024
-
2025
- if new_width is None:
2026
- new_width = max([frame.shape[-1] for frame in frames])
2027
- frames = tuple(
2028
- [
2029
- align_horizontally(
2030
- data=frame,
2031
- alignment=alignment,
2032
- new_width=new_width,
2033
- pad_mode=pad_mode,
2034
- )
2035
- for frame in frames
2036
- ]
2037
- )
2038
-
2039
- # step 1: create numpy array that will contain stitching
2040
- # if raw composition doesn't exists create it
2041
- if raw_frames_compositions is None:
2042
- raw_frames_compositions = ZFrameComposition.compute_raw_frame_compositions(
2043
- frames=frames,
2044
- overlap_kernels=overlap_kernels,
2045
- key_lines=key_lines,
2046
- stitching_axis=0,
2047
- )
2048
- new_frame_height = raw_frames_compositions.global_end_y[-1] - raw_frames_compositions.global_start_y[0]
2049
- stitched_projection_shape = (
2050
- # here we only handle frames because shift are already done
2051
- int(new_frame_height),
2052
- new_width,
2053
- )
2054
- stitch_array = numpy.empty(stitched_projection_shape, dtype=output_dtype)
2055
-
2056
- # step 2: set raw data
2057
- # fill stitch array with raw data raw data
2058
- raw_frames_compositions.compose(
2059
- output_frame=stitch_array,
2060
- input_frames=frames,
2061
- )
2062
-
2063
- # step 3 set stitched data
2064
-
2065
- # 3.1 create stitched overlaps
2066
- stitched_overlap = []
2067
- for frame_0, frame_1, kernel, key_line in zip(frames[:-1], frames[1:], overlap_kernels, key_lines):
2068
- assert kernel.overlap_size >= 0
2069
- frame_0_overlap, frame_1_overlap = ZStitcher.get_overlap_areas(
2070
- upper_frame=frame_0,
2071
- lower_frame=frame_1,
2072
- upper_frame_key_line=key_line[0],
2073
- lower_frame_key_line=key_line[1],
2074
- overlap_size=kernel.overlap_size,
2075
- stitching_axis=0,
2076
- )
2077
-
2078
- assert (
2079
- frame_0_overlap.shape[0] == frame_1_overlap.shape[0] == kernel.overlap_size
2080
- ), f"{frame_0_overlap.shape[0]} == {frame_1_overlap.shape[0]} == {kernel.overlap_size}"
2081
-
2082
- stitched_overlap.append(
2083
- kernel.stitch(
2084
- frame_0_overlap,
2085
- frame_1_overlap,
2086
- )[0]
2087
- )
2088
- # 3.2 fill stitched overlap on output array
2089
- if overlap_frames_compositions is None:
2090
- overlap_frames_compositions = ZFrameComposition.compute_stitch_frame_composition(
2091
- frames=frames,
2092
- overlap_kernels=overlap_kernels,
2093
- key_lines=key_lines,
2094
- stitching_axis=0,
2095
- )
2096
- overlap_frames_compositions.compose(
2097
- output_frame=stitch_array,
2098
- input_frames=stitched_overlap,
2099
- )
2100
- if return_composition_cls:
2101
- return (
2102
- stitch_array,
2103
- {
2104
- "raw_compositon": raw_frames_compositions,
2105
- "overlap_compositon": overlap_frames_compositions,
2106
- },
2107
- )
2108
-
2109
- return stitch_array
2110
-
2111
-
2112
- class StitchingPostProcAggregation:
2113
- """
2114
- for remote stitching each process will stitch a part of the volume or projections.
2115
- Then once all are finished we want to aggregate them all to a final volume or NXtomo.
2116
-
2117
- This is the goal of this class.
2118
- Please be careful with API. This is already inheriting from a tomwer class
2119
-
2120
- :param ZStitchingConfiguration stitching_config: configuration of the stitching configuration
2121
- :param Optional[tuple] futures: futures that just runned
2122
- :param Optional[tuple] existing_objs: futures that just runned
2123
- :param
2124
- """
2125
-
2126
- def __init__(
2127
- self,
2128
- stitching_config: ZStitchingConfiguration,
2129
- futures: Optional[tuple] = None,
2130
- existing_objs_ids: Optional[tuple] = None,
2131
- ) -> None:
2132
- if not isinstance(stitching_config, (ZStitchingConfiguration)):
2133
- raise TypeError(f"stitching_config should be an instance of {ZStitchingConfiguration}")
2134
- if not ((existing_objs_ids is None) ^ (futures is None)):
2135
- raise ValueError("Either existing_objs or futures should be provided (can't provide both)")
2136
- self._futures = futures
2137
- self._stitching_config = stitching_config
2138
- self._existing_objs_ids = existing_objs_ids
2139
-
2140
- @property
2141
- def futures(self):
2142
- # TODO: deprecate it ?
2143
- return self._futures
2144
-
2145
- def retrieve_tomo_objects(self) -> tuple():
2146
- """
2147
- Return tomo objects to be stitched together. Either from future or from existing_objs
2148
- """
2149
- if self._existing_objs_ids is not None:
2150
- scan_ids = self._existing_objs_ids
2151
- else:
2152
- results = {}
2153
- _logger.info(f"wait for slurm job to be completed")
2154
- for obj_id, future in self.futures.items():
2155
- results[obj_id] = future.result()
2156
-
2157
- failed = tuple(
2158
- filter(
2159
- lambda x: x.exception() is not None,
2160
- self.futures.values(),
2161
- )
2162
- )
2163
- if len(failed) > 0:
2164
- # if some job failed: unseless to do the concatenation
2165
- exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed])
2166
- raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}")
2167
-
2168
- canceled = tuple(
2169
- filter(
2170
- lambda x: x.cancelled(),
2171
- self.futures.values(),
2172
- )
2173
- )
2174
- if len(canceled) > 0:
2175
- # if some job canceled: unseless to do the concatenation
2176
- raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}")
2177
- scan_ids = results.keys()
2178
- return [TomoscanFactory.create_tomo_object_from_identifier(scan_id) for scan_id in scan_ids]
2179
-
2180
- def dump_stiching_config_as_nx_process(self, file_path: str, data_path: str, overwrite: bool, process_name: str):
2181
- dict_to_dump = {
2182
- process_name: {
2183
- "config": self._stitching_config.to_dict(),
2184
- "program": "nabu-stitching",
2185
- "version": nabu_version,
2186
- "date": get_datetime(),
2187
- },
2188
- f"{process_name}@NX_class": "NXprocess",
2189
- }
2190
-
2191
- dicttonx(
2192
- dict_to_dump,
2193
- h5file=file_path,
2194
- h5path=data_path,
2195
- update_mode="replace" if overwrite else "add",
2196
- mode="a",
2197
- )
2198
-
2199
- @property
2200
- def stitching_config(self) -> ZStitchingConfiguration:
2201
- return self._stitching_config
2202
-
2203
- def process(self) -> None:
2204
- """
2205
- main function
2206
- """
2207
-
2208
- # concatenate result
2209
- _logger.info("all job succeeded. Concatenate results")
2210
- if isinstance(self._stitching_config, PreProcessedZStitchingConfiguration):
2211
- # 1: case of a pre-processing stitching
2212
- scans = self.retrieve_tomo_objects()
2213
- nx_tomos = []
2214
- for scan in scans:
2215
- nx_tomos.append(
2216
- NXtomo().load(
2217
- file_path=scan.master_file,
2218
- data_path=scan.entry,
2219
- )
2220
- )
2221
- final_nx_tomo = NXtomo.concatenate(nx_tomos)
2222
- final_nx_tomo.save(
2223
- file_path=self.stitching_config.output_file_path,
2224
- data_path=self.stitching_config.output_data_path,
2225
- overwrite=self.stitching_config.overwrite_results,
2226
- )
2227
-
2228
- # dump NXprocess if possible
2229
- parts = self.stitching_config.output_data_path.split("/")
2230
- process_name = parts[-1] + "_stitching"
2231
- if len(parts) < 2:
2232
- data_path = "/"
2233
- else:
2234
- data_path = "/".join(parts[:-1])
2235
-
2236
- self.dump_stiching_config_as_nx_process(
2237
- file_path=self.stitching_config.output_file_path,
2238
- data_path=data_path,
2239
- process_name=process_name,
2240
- overwrite=self.stitching_config.overwrite_results,
2241
- )
2242
-
2243
- elif isinstance(self.stitching_config, PostProcessedZStitchingConfiguration):
2244
- # 2: case of a post-processing stitching
2245
- outputs_sub_volumes = self.retrieve_tomo_objects()
2246
- concatenate_volumes(
2247
- output_volume=self.stitching_config.output_volume,
2248
- volumes=tuple(outputs_sub_volumes),
2249
- axis=1,
2250
- )
2251
-
2252
- if isinstance(self.stitching_config.output_volume, HDF5Volume):
2253
- parts = self.stitching_config.output_volume.metadata_url.data_path().split("/")
2254
- process_name = parts[-1] + "_stitching"
2255
- if len(parts) < 2:
2256
- data_path = "/"
2257
- else:
2258
- data_path = "/".join(parts[:-1])
2259
-
2260
- self.dump_stiching_config_as_nx_process(
2261
- file_path=self.stitching_config.output_volume.metadata_url.file_path(),
2262
- data_path=data_path,
2263
- process_name=process_name,
2264
- overwrite=self.stitching_config.overwrite_results,
2265
- )
2266
- else:
2267
- raise TypeError(f"stitching_config type ({type(self.stitching_config)}) not handled")
2268
-
2269
-
2270
- def get_obj_width(obj: Union[NXtomoScan, VolumeBase]) -> int:
2271
- """
2272
- return tomo object width
2273
- """
2274
- if isinstance(obj, NXtomoScan):
2275
- return obj.dim_1
2276
- elif isinstance(obj, VolumeBase):
2277
- return obj.get_volume_shape()[-1]
2278
- else:
2279
- raise TypeError(f"obj type ({type(obj)}) is not handled")