nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. doc/conf.py +1 -1
  2. doc/doc_config.py +32 -0
  3. nabu/__init__.py +2 -1
  4. nabu/app/bootstrap_stitching.py +1 -1
  5. nabu/app/cli_configs.py +122 -2
  6. nabu/app/composite_cor.py +27 -2
  7. nabu/app/correct_rot.py +70 -0
  8. nabu/app/create_distortion_map_from_poly.py +42 -18
  9. nabu/app/diag_to_pix.py +358 -0
  10. nabu/app/diag_to_rot.py +449 -0
  11. nabu/app/generate_header.py +4 -3
  12. nabu/app/histogram.py +2 -2
  13. nabu/app/multicor.py +6 -1
  14. nabu/app/parse_reconstruction_log.py +151 -0
  15. nabu/app/prepare_weights_double.py +83 -22
  16. nabu/app/reconstruct.py +5 -1
  17. nabu/app/reconstruct_helical.py +7 -0
  18. nabu/app/reduce_dark_flat.py +6 -3
  19. nabu/app/rotate.py +4 -4
  20. nabu/app/stitching.py +16 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +18 -2
  22. nabu/app/validator.py +4 -4
  23. nabu/cuda/convolution.py +8 -376
  24. nabu/cuda/fft.py +4 -0
  25. nabu/cuda/kernel.py +4 -4
  26. nabu/cuda/medfilt.py +5 -158
  27. nabu/cuda/padding.py +5 -71
  28. nabu/cuda/processing.py +23 -2
  29. nabu/cuda/src/ElementOp.cu +78 -0
  30. nabu/cuda/src/backproj.cu +28 -2
  31. nabu/cuda/src/fourier_wavelets.cu +2 -2
  32. nabu/cuda/src/normalization.cu +23 -0
  33. nabu/cuda/src/padding.cu +2 -2
  34. nabu/cuda/src/transpose.cu +16 -0
  35. nabu/cuda/utils.py +39 -0
  36. nabu/estimation/alignment.py +10 -1
  37. nabu/estimation/cor.py +808 -38
  38. nabu/estimation/cor_sino.py +7 -9
  39. nabu/estimation/tests/test_cor.py +85 -3
  40. nabu/io/reader.py +26 -18
  41. nabu/io/tests/test_cast_volume.py +3 -3
  42. nabu/io/tests/test_detector_distortion.py +3 -3
  43. nabu/io/tiffwriter_zmm.py +2 -2
  44. nabu/io/utils.py +14 -4
  45. nabu/io/writer.py +5 -3
  46. nabu/misc/fftshift.py +6 -0
  47. nabu/misc/histogram.py +5 -285
  48. nabu/misc/histogram_cuda.py +8 -104
  49. nabu/misc/kernel_base.py +3 -121
  50. nabu/misc/padding_base.py +5 -69
  51. nabu/misc/processing_base.py +3 -107
  52. nabu/misc/rotation.py +5 -62
  53. nabu/misc/rotation_cuda.py +5 -65
  54. nabu/misc/transpose.py +6 -0
  55. nabu/misc/unsharp.py +3 -78
  56. nabu/misc/unsharp_cuda.py +5 -52
  57. nabu/misc/unsharp_opencl.py +8 -85
  58. nabu/opencl/fft.py +6 -0
  59. nabu/opencl/kernel.py +21 -6
  60. nabu/opencl/padding.py +5 -72
  61. nabu/opencl/processing.py +27 -5
  62. nabu/opencl/src/backproj.cl +3 -3
  63. nabu/opencl/src/fftshift.cl +65 -12
  64. nabu/opencl/src/padding.cl +2 -2
  65. nabu/opencl/src/roll.cl +96 -0
  66. nabu/opencl/src/transpose.cl +16 -0
  67. nabu/pipeline/config_validators.py +63 -3
  68. nabu/pipeline/dataset_validator.py +2 -2
  69. nabu/pipeline/estimators.py +193 -35
  70. nabu/pipeline/fullfield/chunked.py +34 -17
  71. nabu/pipeline/fullfield/chunked_cuda.py +7 -5
  72. nabu/pipeline/fullfield/computations.py +48 -13
  73. nabu/pipeline/fullfield/nabu_config.py +13 -13
  74. nabu/pipeline/fullfield/processconfig.py +10 -5
  75. nabu/pipeline/fullfield/reconstruction.py +1 -2
  76. nabu/pipeline/helical/fbp.py +5 -0
  77. nabu/pipeline/helical/filtering.py +12 -9
  78. nabu/pipeline/helical/gridded_accumulator.py +179 -33
  79. nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
  80. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
  81. nabu/pipeline/helical/helical_reconstruction.py +56 -18
  82. nabu/pipeline/helical/span_strategy.py +1 -1
  83. nabu/pipeline/helical/tests/test_accumulator.py +4 -0
  84. nabu/pipeline/params.py +23 -2
  85. nabu/pipeline/processconfig.py +3 -8
  86. nabu/pipeline/tests/test_chunk_reader.py +78 -0
  87. nabu/pipeline/tests/test_estimators.py +120 -2
  88. nabu/pipeline/utils.py +25 -0
  89. nabu/pipeline/writer.py +2 -0
  90. nabu/preproc/ccd_cuda.py +9 -7
  91. nabu/preproc/ctf.py +21 -26
  92. nabu/preproc/ctf_cuda.py +25 -25
  93. nabu/preproc/double_flatfield.py +14 -2
  94. nabu/preproc/double_flatfield_cuda.py +7 -11
  95. nabu/preproc/flatfield_cuda.py +23 -27
  96. nabu/preproc/phase.py +19 -24
  97. nabu/preproc/phase_cuda.py +21 -21
  98. nabu/preproc/shift_cuda.py +58 -28
  99. nabu/preproc/tests/test_ctf.py +5 -5
  100. nabu/preproc/tests/test_double_flatfield.py +2 -2
  101. nabu/preproc/tests/test_vshift.py +13 -2
  102. nabu/processing/__init__.py +0 -0
  103. nabu/processing/convolution_cuda.py +375 -0
  104. nabu/processing/fft_base.py +163 -0
  105. nabu/processing/fft_cuda.py +256 -0
  106. nabu/processing/fft_opencl.py +54 -0
  107. nabu/processing/fftshift.py +134 -0
  108. nabu/processing/histogram.py +286 -0
  109. nabu/processing/histogram_cuda.py +103 -0
  110. nabu/processing/kernel_base.py +126 -0
  111. nabu/processing/medfilt_cuda.py +159 -0
  112. nabu/processing/muladd.py +29 -0
  113. nabu/processing/muladd_cuda.py +68 -0
  114. nabu/processing/padding_base.py +71 -0
  115. nabu/processing/padding_cuda.py +75 -0
  116. nabu/processing/padding_opencl.py +77 -0
  117. nabu/processing/processing_base.py +123 -0
  118. nabu/processing/roll_opencl.py +64 -0
  119. nabu/processing/rotation.py +63 -0
  120. nabu/processing/rotation_cuda.py +66 -0
  121. nabu/processing/tests/__init__.py +0 -0
  122. nabu/processing/tests/test_fft.py +268 -0
  123. nabu/processing/tests/test_fftshift.py +71 -0
  124. nabu/{misc → processing}/tests/test_histogram.py +2 -4
  125. nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
  126. nabu/processing/tests/test_muladd.py +54 -0
  127. nabu/{cuda → processing}/tests/test_padding.py +119 -75
  128. nabu/processing/tests/test_roll.py +63 -0
  129. nabu/{misc → processing}/tests/test_rotation.py +3 -2
  130. nabu/processing/tests/test_transpose.py +72 -0
  131. nabu/{misc → processing}/tests/test_unsharp.py +41 -8
  132. nabu/processing/transpose.py +126 -0
  133. nabu/processing/unsharp.py +79 -0
  134. nabu/processing/unsharp_cuda.py +53 -0
  135. nabu/processing/unsharp_opencl.py +75 -0
  136. nabu/reconstruction/fbp.py +34 -10
  137. nabu/reconstruction/fbp_base.py +35 -16
  138. nabu/reconstruction/fbp_opencl.py +7 -12
  139. nabu/reconstruction/filtering.py +2 -2
  140. nabu/reconstruction/filtering_cuda.py +13 -14
  141. nabu/reconstruction/filtering_opencl.py +3 -4
  142. nabu/reconstruction/projection.py +2 -0
  143. nabu/reconstruction/rings.py +158 -1
  144. nabu/reconstruction/rings_cuda.py +218 -58
  145. nabu/reconstruction/sinogram_cuda.py +16 -12
  146. nabu/reconstruction/tests/test_deringer.py +116 -14
  147. nabu/reconstruction/tests/test_fbp.py +22 -31
  148. nabu/reconstruction/tests/test_filtering.py +11 -2
  149. nabu/resources/dataset_analyzer.py +89 -26
  150. nabu/resources/nxflatfield.py +2 -2
  151. nabu/resources/tests/test_nxflatfield.py +1 -1
  152. nabu/resources/utils.py +9 -2
  153. nabu/stitching/alignment.py +184 -0
  154. nabu/stitching/config.py +241 -39
  155. nabu/stitching/definitions.py +6 -0
  156. nabu/stitching/frame_composition.py +4 -2
  157. nabu/stitching/overlap.py +99 -3
  158. nabu/stitching/sample_normalization.py +60 -0
  159. nabu/stitching/slurm_utils.py +10 -10
  160. nabu/stitching/tests/test_alignment.py +99 -0
  161. nabu/stitching/tests/test_config.py +16 -1
  162. nabu/stitching/tests/test_overlap.py +68 -2
  163. nabu/stitching/tests/test_sample_normalization.py +49 -0
  164. nabu/stitching/tests/test_slurm_utils.py +5 -5
  165. nabu/stitching/tests/test_utils.py +3 -33
  166. nabu/stitching/tests/test_z_stitching.py +391 -22
  167. nabu/stitching/utils.py +144 -202
  168. nabu/stitching/z_stitching.py +309 -126
  169. nabu/testutils.py +18 -0
  170. nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
  171. nabu/utils.py +32 -6
  172. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
  173. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
  174. nabu-2024.1.0rc3.dist-info/RECORD +296 -0
  175. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
  176. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
  177. nabu/conftest.py +0 -14
  178. nabu/opencl/fftshift.py +0 -92
  179. nabu/opencl/tests/test_fftshift.py +0 -55
  180. nabu/opencl/tests/test_padding.py +0 -84
  181. nabu-2023.2.1.dist-info/RECORD +0 -252
  182. /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
  183. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/stitching/utils.py CHANGED
@@ -1,18 +1,24 @@
1
1
  from distutils.version import StrictVersion
2
-
3
2
  from typing import Optional, Union
4
3
  import logging
4
+ import functools
5
5
  import numpy
6
+ from scipy.ndimage import affine_transform
6
7
  from tomoscan.scanbase import TomoScanBase
7
- from nabu.misc import fourier_filters
8
- from nabu.stitching.overlap import OverlapStitchingStrategy, ZStichOverlapKernel
9
- from nabu.estimation.alignment import AlignmentBase
10
- from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer
8
+ from tomoscan.volumebase import VolumeBase
9
+ from nxtomo.utils.transformation import build_matrix, UDDetTransformation
11
10
  from silx.utils.enum import Enum as _Enum
12
- from scipy.ndimage import shift as scipy_shift
13
11
  from scipy.fft import rfftn as local_fftn
14
12
  from scipy.fft import irfftn as local_ifftn
15
- from nabu.resources.nxflatfield import update_dataset_info_flats_darks
13
+ from silx.utils.enum import Enum as _Enum
14
+ from nxtomo.utils.transformation import build_matrix, UDDetTransformation
15
+ from tomoscan.scanbase import TomoScanBase
16
+ from .overlap import OverlapStitchingStrategy, ZStichOverlapKernel
17
+ from .alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
18
+ from ..misc import fourier_filters
19
+ from ..estimation.alignment import AlignmentBase
20
+ from ..resources.dataset_analyzer import HDF5DatasetAnalyzer
21
+ from ..resources.nxflatfield import update_dataset_info_flats_darks
16
22
 
17
23
  try:
18
24
  import itk
@@ -40,7 +46,6 @@ class ShiftAlgorithm(_Enum):
40
46
 
41
47
  NABU_FFT = "nabu-fft"
42
48
  SKIMAGE = "skimage"
43
- SHIFT_GRID = "shift-grid"
44
49
  ITK_IMG_REG_V4 = "itk-img-reg-v4"
45
50
  NONE = "None"
46
51
 
@@ -94,7 +99,6 @@ def find_frame_relative_shifts(
94
99
  ):
95
100
  from nabu.stitching.config import (
96
101
  KEY_WINDOW_SIZE,
97
- KEY_SCORE_METHOD,
98
102
  KEY_LOW_PASS_FILTER,
99
103
  KEY_HIGH_PASS_FILTER,
100
104
  ) # avoid cyclic import
@@ -144,52 +148,49 @@ def find_frame_relative_shifts(
144
148
  initial_shifts = numpy.array(estimated_shifts).copy()
145
149
  extra_shifts = numpy.array([0.0, 0.0])
146
150
 
147
- # 2.0 call cross correlation function from the estimated cor from motors
148
- for axis, method, params in zip(
149
- (0, 1),
150
- (y_cross_correlation_function, x_cross_correlation_function),
151
- (y_shifts_params, x_shifts_params),
152
- ):
153
- if method is ShiftAlgorithm.NABU_FFT:
154
- extra_shifts[axis] = find_shift_correlate(img1=overlap_upper_frame, img2=overlap_lower_frame)[axis]
155
- elif method is ShiftAlgorithm.SKIMAGE:
156
- if not __has_sk_phase_correlation__:
157
- raise ValueError("scikit-image not installed. Cannot do phase correlation from it")
158
- else:
159
- found_shift, _, _ = phase_cross_correlation(
160
- reference_image=overlap_upper_frame, moving_image=overlap_lower_frame, space="real"
161
- )
162
- extra_shifts[axis] = found_shift[axis]
163
- elif method is ShiftAlgorithm.NONE: # None as a string in case some uers give this value
164
- # in the case we don't want to apply algorithm keep the initial 'guessed' shifts
165
- continue
166
- elif method is ShiftAlgorithm.SHIFT_GRID:
167
- if axis == 0:
168
- window_size = (int(y_shifts_params.get(KEY_WINDOW_SIZE, 200)), 0)
169
- elif axis == 1:
170
- window_size = (0, int(x_shifts_params.get(KEY_WINDOW_SIZE, 200)))
171
- score_method = params.get(KEY_SCORE_METHOD, ScoreMethod.STD)
172
- extra_shifts[axis] = -shift_grid_search(
173
- img_1=overlap_upper_frame,
174
- img_2=overlap_lower_frame,
175
- window_sizes=window_size,
176
- step_size=1,
177
- axis=(axis,),
178
- score_method=score_method,
179
- )[axis]
180
- elif method is ShiftAlgorithm.ITK_IMG_REG_V4:
181
- extra_shifts[axis] = find_shift_with_itk(img1=overlap_upper_frame, img2=overlap_lower_frame)[axis]
151
+ def skimage_proxy(img1, img2):
152
+ if not __has_sk_phase_correlation__:
153
+ raise ValueError("scikit-image not installed. Cannot do phase correlation from it")
182
154
  else:
183
- raise ValueError(f"requested cross correlation function not handled ({method})")
184
- final_rel_shifts = numpy.array(extra_shifts) + initial_shifts
155
+ found_shift, _, _ = phase_cross_correlation(reference_image=img1, moving_image=img2, space="real")
156
+ return -found_shift
157
+
158
+ shift_methods = {
159
+ ShiftAlgorithm.NABU_FFT: functools.partial(
160
+ find_shift_correlate, img1=overlap_upper_frame, img2=overlap_lower_frame
161
+ ),
162
+ ShiftAlgorithm.SKIMAGE: functools.partial(skimage_proxy, img1=overlap_upper_frame, img2=overlap_lower_frame),
163
+ ShiftAlgorithm.ITK_IMG_REG_V4: functools.partial(
164
+ find_shift_with_itk, img1=overlap_upper_frame, img2=overlap_lower_frame
165
+ ),
166
+ ShiftAlgorithm.NONE: functools.partial(lambda: (0.0, 0.0)),
167
+ }
168
+
169
+ res_algo = {}
170
+ for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)):
171
+ if shift_alg not in shift_methods:
172
+ raise ValueError(f"requested image alignment function not handled ({shift_alg})")
173
+ try:
174
+ res_algo[shift_alg] = shift_methods[shift_alg]()
175
+ except Exception as e:
176
+ _logger.error(f"Failed to find shift from {shift_alg.value}. Error is {e}")
177
+ res_algo[shift_alg] = (0, 0)
178
+
179
+ extra_shifts = (
180
+ res_algo[y_cross_correlation_function][0],
181
+ res_algo[x_cross_correlation_function][1],
182
+ )
185
183
 
184
+ final_rel_shifts = numpy.array(extra_shifts) + initial_shifts
186
185
  return tuple([int(shift) for shift in final_rel_shifts])
187
186
 
188
187
 
189
188
  def find_volumes_relative_shifts(
190
- upper_volume: numpy.ndarray,
191
- lower_volume: numpy.ndarray,
189
+ upper_volume: VolumeBase,
190
+ lower_volume: VolumeBase,
192
191
  estimated_shifts,
192
+ dim_axis_1: int,
193
+ dtype,
193
194
  flip_ud_upper_frame: bool = False,
194
195
  flip_ud_lower_frame: bool = False,
195
196
  slice_for_shift: Union[int, str] = "middle",
@@ -197,15 +198,59 @@ def find_volumes_relative_shifts(
197
198
  y_cross_correlation_function=None,
198
199
  x_shifts_params: Optional[dict] = None,
199
200
  y_shifts_params: Optional[dict] = None,
201
+ alignment_axis_2="center",
202
+ alignment_axis_1="center",
200
203
  ):
204
+ """
205
+
206
+ :param int dim_axis_1: axis 1 dimension (to handle axis 1 alignment)
207
+ """
201
208
  if y_shifts_params is None:
202
209
  y_shifts_params = {}
203
210
 
204
211
  if x_shifts_params is None:
205
212
  x_shifts_params = {}
206
213
 
207
- upper_frame = upper_volume.get_slice(slice_for_shift, axis=1)
208
- lower_frame = lower_volume.get_slice(slice_for_shift, axis=1)
214
+ alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2)
215
+ alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1)
216
+ assert dim_axis_1 > 0, "dim_axis_1 <= 0"
217
+
218
+ if isinstance(slice_for_shift, str):
219
+ if slice_for_shift == "first":
220
+ slice_for_shift = 0
221
+ elif slice_for_shift == "last":
222
+ slice_for_shift = dim_axis_1
223
+ elif slice_for_shift == "middle":
224
+ slice_for_shift = dim_axis_1 // 2
225
+ else:
226
+ raise ValueError("invalid slice provided to search shift", slice_for_shift)
227
+
228
+ def get_slice_along_axis_1(volume: VolumeBase, index: int):
229
+ assert isinstance(index, int), f"index should be an int, {type(index)} provided"
230
+ volume_shape = volume.get_volume_shape()
231
+ if alignment_axis_1 is AlignmentAxis1.BACK:
232
+ front_empty_width = dim_axis_1 - volume_shape[1]
233
+ if index < front_empty_width:
234
+ return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
235
+ else:
236
+ return volume.get_slice(index=index - front_empty_width, axis=1)
237
+ elif alignment_axis_1 is AlignmentAxis1.FRONT:
238
+ if index >= volume_shape[1]:
239
+ return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
240
+ else:
241
+ return volume.get_slice(index=index, axis=1)
242
+ elif alignment_axis_1 is AlignmentAxis1.CENTER:
243
+ front_empty_width = (dim_axis_1 - volume_shape[1]) // 2
244
+ back_empty_width = dim_axis_1 - front_empty_width
245
+ if index < front_empty_width or index > back_empty_width:
246
+ return PaddedRawData.get_empty_frame(shape=(volume_shape[0], volume_shape[2]), dtype=dtype)
247
+ else:
248
+ return volume.get_slice(index=index - front_empty_width, axis=1)
249
+ else:
250
+ raise TypeError(f"unmanaged alignment mode {alignment_axis_1.value}")
251
+
252
+ upper_frame = get_slice_along_axis_1(upper_volume, index=slice_for_shift)
253
+ lower_frame = get_slice_along_axis_1(lower_volume, index=slice_for_shift)
209
254
  if flip_ud_upper_frame:
210
255
  upper_frame = numpy.flipud(upper_frame.copy())
211
256
  if flip_ud_lower_frame:
@@ -214,14 +259,35 @@ def find_volumes_relative_shifts(
214
259
  from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
215
260
 
216
261
  w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
217
- start_overlap = max(estimated_shifts[0] - w_window_size // 2, 0)
218
- end_overlap = min(estimated_shifts[0] + w_window_size // 2, min(upper_frame.shape[0], lower_frame.shape[0]))
262
+ start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
263
+ end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_frame.shape[0], lower_frame.shape[0]))
219
264
 
220
265
  if start_overlap == 0:
221
266
  overlap_upper_frame = upper_frame[-end_overlap:]
222
267
  else:
223
268
  overlap_upper_frame = upper_frame[-end_overlap:-start_overlap]
224
269
  overlap_lower_frame = lower_frame[start_overlap:end_overlap]
270
+
271
+ # align if necessary
272
+ if overlap_upper_frame.shape[1] != overlap_lower_frame.shape[1]:
273
+ overlap_frame_width = min(overlap_upper_frame.shape[1], overlap_lower_frame.shape[1])
274
+ if alignment_axis_2 is AlignmentAxis2.CENTER:
275
+ upper_frame_left_pos = overlap_upper_frame.shape[1] // 2 - overlap_frame_width // 2
276
+ upper_frame_right_pos = upper_frame_left_pos + overlap_frame_width
277
+ overlap_upper_frame = overlap_upper_frame[:, upper_frame_left_pos:upper_frame_right_pos]
278
+
279
+ lower_frame_left_pos = overlap_lower_frame.shape[1] // 2 - overlap_frame_width // 2
280
+ lower_frame_right_pos = lower_frame_left_pos + overlap_frame_width
281
+ overlap_lower_frame = overlap_lower_frame[:, lower_frame_left_pos:lower_frame_right_pos]
282
+ elif alignment_axis_2 is AlignmentAxis2.LEFT:
283
+ overlap_upper_frame = overlap_upper_frame[:, :overlap_frame_width]
284
+ overlap_lower_frame = overlap_lower_frame[:, :overlap_frame_width]
285
+ elif alignment_axis_2 is AlignmentAxis2.RIGTH:
286
+ overlap_upper_frame = overlap_upper_frame[:, -overlap_frame_width:]
287
+ overlap_lower_frame = overlap_lower_frame[:, -overlap_frame_width:]
288
+ else:
289
+ raise ValueError(f"Alignement {alignment_axis_2.value} is not handled")
290
+
225
291
  if not overlap_upper_frame.shape == overlap_lower_frame.shape:
226
292
  raise ValueError(f"Fail to get consistant overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})")
227
293
 
@@ -293,7 +359,6 @@ def find_projections_relative_shifts(
293
359
  ):
294
360
  cor_options = x_shifts_params.copy()
295
361
  cor_options.pop("img_reg_method", None)
296
- cor_options.pop("score_method", None)
297
362
  # remove all none numeric options because estimate_cor will call 'literal_eval' on them
298
363
 
299
364
  upper_scan_dataset_info = HDF5DatasetAnalyzer(
@@ -326,16 +391,24 @@ def find_projections_relative_shifts(
326
391
 
327
392
  # } else we will compute shift from the flat projections
328
393
 
329
- def get_flat_fielded_proj(scan: TomoScanBase, proj_index: int, reverse: bool, revert_x: bool, revert_y):
394
+ def get_flat_fielded_proj(
395
+ scan: TomoScanBase, proj_index: int, reverse: bool, transformation_matrix: Optional[numpy.ndarray]
396
+ ):
330
397
  first_proj_idx = sorted(lower_scan.projections.keys(), reverse=reverse)[proj_index]
331
398
  ff = scan.flat_field_correction(
332
399
  (scan.projections[first_proj_idx],),
333
400
  (first_proj_idx,),
334
401
  )[0]
335
- if revert_x:
336
- ff = numpy.fliplr(ff)
337
- if revert_y:
338
- ff = numpy.flipud(ff)
402
+ assert ff.ndim == 2, f"expects a single 2D frame. Get something with {ff.ndim} dimensions"
403
+ if transformation_matrix is not None:
404
+ assert (
405
+ transformation_matrix.ndim == 2
406
+ ), f"expects a 2D transformation matrix. Get a {transformation_matrix.ndim} D"
407
+ if numpy.isclose(transformation_matrix[2, 2], -1):
408
+ transformation_matrix[2, :] = 0
409
+ transformation_matrix[0, 2] = 0
410
+ transformation_matrix[2, 2] = 1
411
+ ff = numpy.flipud(ff)
339
412
  return ff
340
413
 
341
414
  if isinstance(projection_for_shift, str):
@@ -357,26 +430,32 @@ def find_projections_relative_shifts(
357
430
  f"projection_for_shift is expected to be an int. Not {type(projection_for_shift)} - {projection_for_shift}"
358
431
  )
359
432
 
433
+ upper_scan_transformations = list(upper_scan.get_detector_transformations(tuple()))
434
+ if flip_ud_upper_frame:
435
+ upper_scan_transformations.append(UDDetTransformation())
436
+ upper_scan_trans_matrix = build_matrix(upper_scan_transformations)
437
+ lower_scan_transformations = list(lower_scan.get_detector_transformations(tuple()))
438
+ if flip_ud_lower_frame:
439
+ lower_scan_transformations.append(UDDetTransformation())
440
+ lower_scan_trans_matrix = build_matrix(lower_scan_transformations)
360
441
  upper_proj = get_flat_fielded_proj(
361
442
  upper_scan,
362
443
  projection_for_shift,
363
444
  reverse=False,
364
- revert_x=upper_scan.get_x_flipped(default=False),
365
- revert_y=upper_scan.get_y_flipped(default=False) ^ flip_ud_upper_frame,
445
+ transformation_matrix=upper_scan_trans_matrix,
366
446
  )
367
447
  lower_proj = get_flat_fielded_proj(
368
448
  lower_scan,
369
449
  projection_for_shift,
370
450
  reverse=invert_order,
371
- revert_x=lower_scan.get_x_flipped(default=False),
372
- revert_y=lower_scan.get_y_flipped(default=False) ^ flip_ud_lower_frame,
451
+ transformation_matrix=lower_scan_trans_matrix,
373
452
  )
374
453
 
375
454
  from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
376
455
 
377
456
  w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
378
- start_overlap = max(estimated_shifts[0] - w_window_size // 2, 0)
379
- end_overlap = min(estimated_shifts[0] + w_window_size // 2, min(upper_proj.shape[0], lower_proj.shape[0]))
457
+ start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
458
+ end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_proj.shape[0], lower_proj.shape[0]))
380
459
  if start_overlap == 0:
381
460
  overlap_upper_frame = upper_proj[-end_overlap:]
382
461
  else:
@@ -410,144 +489,7 @@ def find_shift_correlate(img1, img2, padding_mode="reflect"):
410
489
 
411
490
  (f_vals, fv, fh) = alignment.extract_peak_region_2d(cc, cc_vs=cc_vs, cc_hs=cc_hs)
412
491
  shifts_vh = alignment.refine_max_position_2d(f_vals, fv, fh)
413
- return shifts_vh
414
-
415
-
416
- class ScoreMethod(_Enum):
417
- STD = "standard deviation"
418
- TV = "total variation"
419
- TV_INVERSE = "1 / (total variation)"
420
- STD_INVERSE = "1 / std"
421
-
422
- @classmethod
423
- def from_value(cls, value):
424
- if isinstance(value, str):
425
- # for string handle the case where value as been provided as 'value'. As there is spaces this can happen
426
- value = value.lstrip("'").rstrip("'")
427
- if value in ("tv", "TV"):
428
- return ScoreMethod.TV
429
- elif value in ("std", "STD"):
430
- return ScoreMethod.STD
431
- else:
432
- return super().from_value(value=value)
433
-
434
-
435
- def compute_score_contrast_std(data: numpy.ndarray):
436
- """
437
- Compute a contrast score by simply computing the standard deviation of
438
- the frame
439
- :param numpy.ndarray data: frame for which we should compute the score
440
- :return: score of the frame
441
- :rtype: float
442
- """
443
- if data is None:
444
- return None
445
- else:
446
- return data.std()
447
-
448
-
449
- def compute_tv_score(data: numpy.ndarray):
450
- """
451
- Compute the data score as image total variation
452
-
453
- :param numpy.ndarray data: frame for which we should compute the score
454
- :return: score of the frame
455
- :rtype: float
456
- """
457
- tv = numpy.sum(numpy.sqrt(numpy.gradient(data, axis=0) ** 2 + numpy.gradient(data, axis=1) ** 2))
458
- return tv
459
-
460
-
461
- def compute_score(img_1, img_2, shift, score_method, score_region, return_img=False):
462
- score_method = ScoreMethod.from_value(score_method)
463
-
464
- img_2 = scipy_shift(img_2, shift=shift)
465
-
466
- img_2_reduce = img_2[
467
- score_region[0].start : score_region[0].stop,
468
- score_region[1].start : score_region[1].stop,
469
- ]
470
- img_1_reduce = img_1[
471
- score_region[0].start : score_region[0].stop,
472
- score_region[1].start : score_region[1].stop,
473
- ]
474
-
475
- img_sum = img_1_reduce * 0.5 + img_2_reduce * 0.5
476
-
477
- if score_method is ScoreMethod.TV:
478
- result = compute_tv_score(img_sum)
479
- elif score_method is ScoreMethod.STD:
480
- result = compute_score_contrast_std(img_sum)
481
- elif score_method is ScoreMethod.TV_INVERSE:
482
- result = 1 / compute_tv_score(img_sum)
483
- elif score_method is ScoreMethod.STD_INVERSE:
484
- result = 1 / compute_score_contrast_std(img_sum)
485
- else:
486
- raise ValueError(f"{score_method} is not handled")
487
- if return_img:
488
- return result, img_sum
489
- else:
490
- return result
491
-
492
-
493
- def shift_grid_search(img_1, img_2, window_sizes: tuple, axis, step_size, score_method=ScoreMethod.STD):
494
- """
495
- we could consider adding weights to do the exact same operation that will be done for the stitching... ? overkilled ?
496
- :param tuple window_sizes: as y_size, x_size
497
- """
498
- if not isinstance(window_sizes, tuple) and len(window_sizes) != 2:
499
- raise TypeError(f"window_sizes is expected to be a tuple of two ints. {window_sizes} provided")
500
- if not img_1.ndim == img_2.ndim == 2:
501
- raise ValueError("image dimension should be 2D")
502
-
503
- for value in axis:
504
- if not value in (0, 1):
505
- raise ValueError(f"axis {value} is not handled")
506
- half_window_x = min((img_1.shape[1], img_2.shape[1], abs(window_sizes[1]))) // 2
507
- half_window_y = min((img_1.shape[0], img_2.shape[0], abs(window_sizes[0]))) // 2
508
- if 1 in axis:
509
- x_research = numpy.arange(-half_window_x, half_window_x, step_size)
510
- x_score_region = slice(
511
- int(half_window_x * 2),
512
- int(min(img_1.shape[1], img_2.shape[1]) - (half_window_x * 2)),
513
- )
514
- else:
515
- x_research = tuple((0,))
516
- x_score_region = slice(0, int(min(img_1.shape[1], img_2.shape[1])))
517
-
518
- if 0 in axis:
519
- y_research = numpy.arange(-half_window_y, half_window_y, step_size)
520
- y_score_region = slice(
521
- int((half_window_y * 2)),
522
- int(min(img_1.shape[0], img_2.shape[0]) - (half_window_y * 2)),
523
- )
524
- else:
525
- y_research = tuple((0,))
526
- y_score_region = slice(0, int(min(img_1.shape[0], img_2.shape[0])))
527
- score_region = (
528
- y_score_region,
529
- x_score_region,
530
- )
531
-
532
- score_width = score_region[1].stop - score_region[1].start
533
- if score_width < 10:
534
- _logger.warning("score_width seems very low. Try reducing window_sizes")
535
-
536
- best_score, best_shift = None, (0, 0)
537
- for x_shift in x_research:
538
- for y_shift in y_research:
539
- res = compute_score(
540
- img_1,
541
- img_2.copy(),
542
- shift=(y_shift, x_shift),
543
- score_method=score_method,
544
- score_region=score_region,
545
- )
546
- local_score = res
547
- if best_score is None or (local_score is not None and local_score > best_score):
548
- best_score = local_score
549
- best_shift = (y_shift, x_shift)
550
- return numpy.array(best_shift)
492
+ return -shifts_vh
551
493
 
552
494
 
553
495
  def find_shift_with_itk(img1: numpy.ndarray, img2: numpy.ndarray) -> tuple: