nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. doc/conf.py +1 -1
  2. doc/doc_config.py +32 -0
  3. nabu/__init__.py +2 -1
  4. nabu/app/bootstrap_stitching.py +1 -1
  5. nabu/app/cli_configs.py +122 -2
  6. nabu/app/composite_cor.py +27 -2
  7. nabu/app/correct_rot.py +70 -0
  8. nabu/app/create_distortion_map_from_poly.py +42 -18
  9. nabu/app/diag_to_pix.py +358 -0
  10. nabu/app/diag_to_rot.py +449 -0
  11. nabu/app/generate_header.py +4 -3
  12. nabu/app/histogram.py +2 -2
  13. nabu/app/multicor.py +6 -1
  14. nabu/app/parse_reconstruction_log.py +151 -0
  15. nabu/app/prepare_weights_double.py +83 -22
  16. nabu/app/reconstruct.py +5 -1
  17. nabu/app/reconstruct_helical.py +7 -0
  18. nabu/app/reduce_dark_flat.py +6 -3
  19. nabu/app/rotate.py +4 -4
  20. nabu/app/stitching.py +16 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +18 -2
  22. nabu/app/validator.py +4 -4
  23. nabu/cuda/convolution.py +8 -376
  24. nabu/cuda/fft.py +4 -0
  25. nabu/cuda/kernel.py +4 -4
  26. nabu/cuda/medfilt.py +5 -158
  27. nabu/cuda/padding.py +5 -71
  28. nabu/cuda/processing.py +23 -2
  29. nabu/cuda/src/ElementOp.cu +78 -0
  30. nabu/cuda/src/backproj.cu +28 -2
  31. nabu/cuda/src/fourier_wavelets.cu +2 -2
  32. nabu/cuda/src/normalization.cu +23 -0
  33. nabu/cuda/src/padding.cu +2 -2
  34. nabu/cuda/src/transpose.cu +16 -0
  35. nabu/cuda/utils.py +39 -0
  36. nabu/estimation/alignment.py +10 -1
  37. nabu/estimation/cor.py +808 -38
  38. nabu/estimation/cor_sino.py +7 -9
  39. nabu/estimation/tests/test_cor.py +85 -3
  40. nabu/io/reader.py +26 -18
  41. nabu/io/tests/test_cast_volume.py +3 -3
  42. nabu/io/tests/test_detector_distortion.py +3 -3
  43. nabu/io/tiffwriter_zmm.py +2 -2
  44. nabu/io/utils.py +14 -4
  45. nabu/io/writer.py +5 -3
  46. nabu/misc/fftshift.py +6 -0
  47. nabu/misc/histogram.py +5 -285
  48. nabu/misc/histogram_cuda.py +8 -104
  49. nabu/misc/kernel_base.py +3 -121
  50. nabu/misc/padding_base.py +5 -69
  51. nabu/misc/processing_base.py +3 -107
  52. nabu/misc/rotation.py +5 -62
  53. nabu/misc/rotation_cuda.py +5 -65
  54. nabu/misc/transpose.py +6 -0
  55. nabu/misc/unsharp.py +3 -78
  56. nabu/misc/unsharp_cuda.py +5 -52
  57. nabu/misc/unsharp_opencl.py +8 -85
  58. nabu/opencl/fft.py +6 -0
  59. nabu/opencl/kernel.py +21 -6
  60. nabu/opencl/padding.py +5 -72
  61. nabu/opencl/processing.py +27 -5
  62. nabu/opencl/src/backproj.cl +3 -3
  63. nabu/opencl/src/fftshift.cl +65 -12
  64. nabu/opencl/src/padding.cl +2 -2
  65. nabu/opencl/src/roll.cl +96 -0
  66. nabu/opencl/src/transpose.cl +16 -0
  67. nabu/pipeline/config_validators.py +63 -3
  68. nabu/pipeline/dataset_validator.py +2 -2
  69. nabu/pipeline/estimators.py +193 -35
  70. nabu/pipeline/fullfield/chunked.py +34 -17
  71. nabu/pipeline/fullfield/chunked_cuda.py +7 -5
  72. nabu/pipeline/fullfield/computations.py +48 -13
  73. nabu/pipeline/fullfield/nabu_config.py +13 -13
  74. nabu/pipeline/fullfield/processconfig.py +10 -5
  75. nabu/pipeline/fullfield/reconstruction.py +1 -2
  76. nabu/pipeline/helical/fbp.py +5 -0
  77. nabu/pipeline/helical/filtering.py +12 -9
  78. nabu/pipeline/helical/gridded_accumulator.py +179 -33
  79. nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
  80. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
  81. nabu/pipeline/helical/helical_reconstruction.py +56 -18
  82. nabu/pipeline/helical/span_strategy.py +1 -1
  83. nabu/pipeline/helical/tests/test_accumulator.py +4 -0
  84. nabu/pipeline/params.py +23 -2
  85. nabu/pipeline/processconfig.py +3 -8
  86. nabu/pipeline/tests/test_chunk_reader.py +78 -0
  87. nabu/pipeline/tests/test_estimators.py +120 -2
  88. nabu/pipeline/utils.py +25 -0
  89. nabu/pipeline/writer.py +2 -0
  90. nabu/preproc/ccd_cuda.py +9 -7
  91. nabu/preproc/ctf.py +21 -26
  92. nabu/preproc/ctf_cuda.py +25 -25
  93. nabu/preproc/double_flatfield.py +14 -2
  94. nabu/preproc/double_flatfield_cuda.py +7 -11
  95. nabu/preproc/flatfield_cuda.py +23 -27
  96. nabu/preproc/phase.py +19 -24
  97. nabu/preproc/phase_cuda.py +21 -21
  98. nabu/preproc/shift_cuda.py +58 -28
  99. nabu/preproc/tests/test_ctf.py +5 -5
  100. nabu/preproc/tests/test_double_flatfield.py +2 -2
  101. nabu/preproc/tests/test_vshift.py +13 -2
  102. nabu/processing/__init__.py +0 -0
  103. nabu/processing/convolution_cuda.py +375 -0
  104. nabu/processing/fft_base.py +163 -0
  105. nabu/processing/fft_cuda.py +256 -0
  106. nabu/processing/fft_opencl.py +54 -0
  107. nabu/processing/fftshift.py +134 -0
  108. nabu/processing/histogram.py +286 -0
  109. nabu/processing/histogram_cuda.py +103 -0
  110. nabu/processing/kernel_base.py +126 -0
  111. nabu/processing/medfilt_cuda.py +159 -0
  112. nabu/processing/muladd.py +29 -0
  113. nabu/processing/muladd_cuda.py +68 -0
  114. nabu/processing/padding_base.py +71 -0
  115. nabu/processing/padding_cuda.py +75 -0
  116. nabu/processing/padding_opencl.py +77 -0
  117. nabu/processing/processing_base.py +123 -0
  118. nabu/processing/roll_opencl.py +64 -0
  119. nabu/processing/rotation.py +63 -0
  120. nabu/processing/rotation_cuda.py +66 -0
  121. nabu/processing/tests/__init__.py +0 -0
  122. nabu/processing/tests/test_fft.py +268 -0
  123. nabu/processing/tests/test_fftshift.py +71 -0
  124. nabu/{misc → processing}/tests/test_histogram.py +2 -4
  125. nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
  126. nabu/processing/tests/test_muladd.py +54 -0
  127. nabu/{cuda → processing}/tests/test_padding.py +119 -75
  128. nabu/processing/tests/test_roll.py +63 -0
  129. nabu/{misc → processing}/tests/test_rotation.py +3 -2
  130. nabu/processing/tests/test_transpose.py +72 -0
  131. nabu/{misc → processing}/tests/test_unsharp.py +41 -8
  132. nabu/processing/transpose.py +126 -0
  133. nabu/processing/unsharp.py +79 -0
  134. nabu/processing/unsharp_cuda.py +53 -0
  135. nabu/processing/unsharp_opencl.py +75 -0
  136. nabu/reconstruction/fbp.py +34 -10
  137. nabu/reconstruction/fbp_base.py +35 -16
  138. nabu/reconstruction/fbp_opencl.py +7 -12
  139. nabu/reconstruction/filtering.py +2 -2
  140. nabu/reconstruction/filtering_cuda.py +13 -14
  141. nabu/reconstruction/filtering_opencl.py +3 -4
  142. nabu/reconstruction/projection.py +2 -0
  143. nabu/reconstruction/rings.py +158 -1
  144. nabu/reconstruction/rings_cuda.py +218 -58
  145. nabu/reconstruction/sinogram_cuda.py +16 -12
  146. nabu/reconstruction/tests/test_deringer.py +116 -14
  147. nabu/reconstruction/tests/test_fbp.py +22 -31
  148. nabu/reconstruction/tests/test_filtering.py +11 -2
  149. nabu/resources/dataset_analyzer.py +89 -26
  150. nabu/resources/nxflatfield.py +2 -2
  151. nabu/resources/tests/test_nxflatfield.py +1 -1
  152. nabu/resources/utils.py +9 -2
  153. nabu/stitching/alignment.py +184 -0
  154. nabu/stitching/config.py +241 -39
  155. nabu/stitching/definitions.py +6 -0
  156. nabu/stitching/frame_composition.py +4 -2
  157. nabu/stitching/overlap.py +99 -3
  158. nabu/stitching/sample_normalization.py +60 -0
  159. nabu/stitching/slurm_utils.py +10 -10
  160. nabu/stitching/tests/test_alignment.py +99 -0
  161. nabu/stitching/tests/test_config.py +16 -1
  162. nabu/stitching/tests/test_overlap.py +68 -2
  163. nabu/stitching/tests/test_sample_normalization.py +49 -0
  164. nabu/stitching/tests/test_slurm_utils.py +5 -5
  165. nabu/stitching/tests/test_utils.py +3 -33
  166. nabu/stitching/tests/test_z_stitching.py +391 -22
  167. nabu/stitching/utils.py +144 -202
  168. nabu/stitching/z_stitching.py +309 -126
  169. nabu/testutils.py +18 -0
  170. nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
  171. nabu/utils.py +32 -6
  172. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
  173. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
  174. nabu-2024.1.0rc3.dist-info/RECORD +296 -0
  175. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
  176. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
  177. nabu/conftest.py +0 -14
  178. nabu/opencl/fftshift.py +0 -92
  179. nabu/opencl/tests/test_fftshift.py +0 -55
  180. nabu/opencl/tests/test_padding.py +0 -84
  181. nabu-2023.2.1.dist-info/RECORD +0 -252
  182. /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
  183. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
@@ -37,30 +37,33 @@ from math import ceil
37
37
  from contextlib import AbstractContextManager
38
38
  import h5py
39
39
  import logging
40
-
41
40
  from scipy.ndimage import shift as shift_scipy
41
+ from functools import lru_cache as cache
42
42
 
43
43
  from silx.io.utils import get_data
44
44
  from silx.io.url import DataUrl
45
45
  from silx.io.dictdump import dicttonx
46
46
 
47
+ from nxtomo.nxobject.nxdetector import ImageKey
48
+ from nxtomo.nxobject.nxtransformations import NXtransformations
49
+ from nxtomo.paths.nxtomo import get_paths as _get_nexus_paths
50
+ from nxtomo.utils.transformation import build_matrix, LRDetTransformation, UDDetTransformation
51
+
47
52
  from tomoscan.io import HDF5File
48
53
  from tomoscan.esrf.scan.utils import cwd_context
49
54
  from tomoscan.identifier import BaseIdentifier
50
- from tomoscan.esrf import HDF5TomoScan, EDFTomoScan
55
+ from tomoscan.esrf import NXtomoScan, EDFTomoScan
51
56
  from tomoscan.volumebase import VolumeBase
52
57
  from tomoscan.esrf.volume import HDF5Volume
53
58
  from tomoscan.serie import Serie
54
- from tomoscan.esrf.scan.hdf5scan import ImageKey
55
- from tomoscan.nexus.paths.nxtomo import get_paths as _get_nexus_paths
56
59
  from tomoscan.factory import Factory as TomoscanFactory
57
60
  from tomoscan.utils.volume import concatenate as concatenate_volumes
58
61
  from tomoscan.esrf.scan.utils import (
59
62
  get_compacted_dataslices,
60
63
  ) # this version has a 'return_url_set' needed here. At one point they should be merged together
61
- from tomoscan.unitsystem.metricsystem import MetricSystem
64
+ from pyunitsystem.metricsystem import MetricSystem
62
65
 
63
- from nxtomomill.nexus import NXtomo
66
+ from nxtomo.application.nxtomo import NXtomo
64
67
  from silx.io.dictdump import dicttonx
65
68
 
66
69
  from nabu.io.utils import DatasetReader
@@ -76,16 +79,19 @@ from nabu.stitching.config import (
76
79
  KEY_RESCALE_MAX_PERCENTILES,
77
80
  KEY_THRESHOLD_FREQUENCY,
78
81
  )
82
+ from nabu.stitching.alignment import align_horizontally, AlignmentAxis1
79
83
  from nabu.utils import Progress
80
84
  from nabu import version as nabu_version
81
85
  from nabu.io.writer import get_datetime
82
86
  from .overlap import (
83
87
  ZStichOverlapKernel,
88
+ check_overlaps,
84
89
  )
85
90
  from .. import version as nabu_version
86
91
  from nabu.io.writer import get_datetime
87
92
  from nabu.misc.utils import rescale_data
88
-
93
+ from nabu.stitching.alignment import PaddedRawData
94
+ from nabu.stitching.sample_normalization import normalize_frame as normalize_frame_by_sample
89
95
 
90
96
  _logger = logging.getLogger(__name__)
91
97
 
@@ -127,6 +133,9 @@ class ZStitcher:
127
133
  self._axis_2_rel_shifts = []
128
134
  # shift between upper and lower frames
129
135
 
136
+ self._stitching_width = None
137
+ # stitching width: larger volume width. Other volume will be pad
138
+
130
139
  # z serie must be defined from daughter class
131
140
  assert hasattr(self, "_z_serie")
132
141
 
@@ -184,6 +193,7 @@ class ZStitcher:
184
193
  estimated_shifts_axis_0.insert(0, 0)
185
194
 
186
195
  final_pos = {}
196
+ previous_shift = 0
187
197
  for tomo_obj, pos_axis_0, pos_axis_2, final_shift_axis_0, estimated_shift_axis_0, final_shift_axis_2 in zip(
188
198
  self.z_serie,
189
199
  self.configuration.axis_0_pos_px,
@@ -194,10 +204,11 @@ class ZStitcher:
194
204
  ):
195
205
  # warning estimated_shift is the estimatation from the overlap. So playes no role here
196
206
  final_pos[tomo_obj.get_identifier().to_str()] = (
197
- pos_axis_0 + (final_shift_axis_0 - estimated_shift_axis_0),
207
+ pos_axis_0 - (final_shift_axis_0 - estimated_shift_axis_0) + previous_shift,
198
208
  None, # axis 1 is not handled for now
199
209
  pos_axis_2 + final_shift_axis_2,
200
210
  )
211
+ previous_shift += final_shift_axis_0 - estimated_shift_axis_0
201
212
  return final_pos
202
213
 
203
214
  def from_abs_pos_to_rel_pos(self, abs_position: tuple):
@@ -278,21 +289,17 @@ class ZStitcher:
278
289
  else:
279
290
  overlap_size = int(overlap_size)
280
291
 
281
- for axis_0_shift, obj in zip(self._axis_0_rel_shifts, self.z_serie):
292
+ self._stitching_width = max([get_obj_width(obj) for obj in self.z_serie])
293
+
294
+ for axis_0_shift in self._axis_0_rel_shifts:
282
295
  if overlap_size == -1:
283
296
  height = abs(axis_0_shift)
284
297
  else:
285
298
  height = overlap_size
286
- if isinstance(obj, HDF5TomoScan):
287
- frame_width = obj.dim_1
288
- elif isinstance(obj, VolumeBase):
289
- frame_width = obj.get_volume_shape()[2]
290
- else:
291
- raise TypeError(f"obj type ({type(obj)}) is not handled")
292
299
 
293
300
  self._overlap_kernels.append(
294
301
  ZStichOverlapKernel(
295
- frame_width=frame_width,
302
+ frame_width=self._stitching_width,
296
303
  stitching_strategy=self.configuration.stitching_strategy,
297
304
  overlap_size=height,
298
305
  extra_params=self.configuration.stitching_kernels_extra_params,
@@ -385,6 +392,7 @@ class ZStitcher:
385
392
  """
386
393
  rescale_frames if requested by the configuration
387
394
  """
395
+ _logger.info("apply rescale frames")
388
396
 
389
397
  def cast_percentile(percentile) -> int:
390
398
  if isinstance(percentile, str):
@@ -407,9 +415,27 @@ class ZStitcher:
407
415
 
408
416
  return tuple([rescale(data) for data in frames])
409
417
 
418
+ def normalize_frame_by_sample(self, frames: tuple):
419
+ """
420
+ normalize frame from a sample picked on the left or the right
421
+ """
422
+ _logger.info("apply normalization by a sample")
423
+ return tuple(
424
+ [
425
+ normalize_frame_by_sample(
426
+ frame=frame,
427
+ side=self.configuration.normalization_by_sample.side,
428
+ method=self.configuration.normalization_by_sample.method,
429
+ margin_before_sample=self.configuration.normalization_by_sample.margin,
430
+ sample_width=self.configuration.normalization_by_sample.width,
431
+ )
432
+ for frame in frames
433
+ ]
434
+ )
435
+
410
436
  @staticmethod
411
437
  def stitch_frames(
412
- frames: tuple,
438
+ frames: Union[tuple, numpy.ndarray],
413
439
  x_relative_shifts: tuple,
414
440
  y_relative_shifts: tuple,
415
441
  output_dtype: numpy.ndarray,
@@ -421,6 +447,9 @@ class ZStitcher:
421
447
  shift_mode="nearest",
422
448
  i_frame=None,
423
449
  return_composition_cls=False,
450
+ alignment="center",
451
+ pad_mode="constant",
452
+ new_width: Optional[int] = None,
424
453
  ) -> numpy.ndarray:
425
454
  """
426
455
  shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and
@@ -444,6 +473,21 @@ class ZStitcher:
444
473
  f"expect to have the same number of y_relative_shifts ({len(y_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
445
474
  )
446
475
 
476
+ relative_positions = [(0, 0)]
477
+ for y_rel_pos, x_rel_pos in zip(y_relative_shifts, x_relative_shifts):
478
+ relative_positions.append(
479
+ (
480
+ y_rel_pos + relative_positions[-1][0],
481
+ x_rel_pos + relative_positions[-1][1],
482
+ )
483
+ )
484
+ check_overlaps(
485
+ frames=tuple(frames),
486
+ positions=tuple(relative_positions),
487
+ axis=0,
488
+ raise_error=False,
489
+ )
490
+
447
491
  def check_frame_is_2d(frame):
448
492
  if frame.ndim != 2:
449
493
  raise ValueError(f"2D frame expected when {frame.ndim}D provided")
@@ -462,6 +506,7 @@ class ZStitcher:
462
506
  data.append(frame)
463
507
  else:
464
508
  raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}")
509
+
465
510
  # step 1: shift each frames (except the first one)
466
511
  x_shifted_data = [data[0]]
467
512
  for frame, x_relative_shift in zip(data[1:], x_relative_shifts):
@@ -492,6 +537,9 @@ class ZStitcher:
492
537
  check_inputs=check_inputs,
493
538
  output_dtype=output_dtype,
494
539
  return_composition_cls=return_composition_cls,
540
+ alignment=alignment,
541
+ pad_mode=pad_mode,
542
+ new_width=new_width,
495
543
  )
496
544
  if return_composition_cls:
497
545
  stitched_frame, _ = res
@@ -507,6 +555,24 @@ class ZStitcher:
507
555
  )
508
556
  return res
509
557
 
558
+ @staticmethod
559
+ @cache(maxsize=None)
560
+ def _get_UD_flip_matrix():
561
+ return UDDetTransformation().as_matrix()
562
+
563
+ @staticmethod
564
+ @cache(maxsize=None)
565
+ def _get_LR_flip_matrix():
566
+ return LRDetTransformation().as_matrix()
567
+
568
+ @staticmethod
569
+ @cache(maxsize=None)
570
+ def _get_UD_AND_LR_flip_matrix():
571
+ return numpy.matmul(
572
+ ZStitcher._get_UD_flip_matrix(),
573
+ ZStitcher._get_LR_flip_matrix(),
574
+ )
575
+
510
576
 
511
577
  class PreProcessZStitcher(ZStitcher):
512
578
  def __init__(self, configuration, progress=None) -> None:
@@ -681,13 +747,6 @@ class PreProcessZStitcher(ZStitcher):
681
747
  if isinstance(axis_pos_px, Iterable) and len(axis_pos_px) != (n_scans):
682
748
  raise ValueError(f"{axis_name} expect {n_scans} shift defined. Get {len(axis_pos_px)}")
683
749
 
684
- for scan in self.z_serie:
685
- if scan.x_flipped is None or scan.y_flipped is None:
686
- _logger.warning(
687
- f"Found at least one scan with no frame flips information ({scan}). Will consider those are unflipped. Might end up with some inverted frame errors."
688
- )
689
- break
690
-
691
750
  self._reading_orders = []
692
751
  # the first scan will define the expected reading orderd, and expected flip.
693
752
  # if all scan are flipped then we will keep it this way
@@ -697,8 +756,8 @@ class PreProcessZStitcher(ZStitcher):
697
756
  for scan_0, scan_1 in zip(self.z_serie[0:-1], self.z_serie[1:]):
698
757
  if len(scan_0.projections) != len(scan_1.projections):
699
758
  raise ValueError(f"{scan_0} and {scan_1} have a different number of projections")
700
- if isinstance(scan_0, HDF5TomoScan) and isinstance(scan_1, HDF5TomoScan):
701
- # check rotation (only of is an HDF5TomoScan)
759
+ if isinstance(scan_0, NXtomoScan) and isinstance(scan_1, NXtomoScan):
760
+ # check rotation (only of is an NXtomoScan)
702
761
  scan_0_angles = numpy.asarray(scan_0.rotation_angle)
703
762
  scan_0_projections_angles = scan_0_angles[
704
763
  numpy.asarray(scan_0.image_key_control) == ImageKey.PROJECTION.value
@@ -749,8 +808,8 @@ class PreProcessZStitcher(ZStitcher):
749
808
  )
750
809
 
751
810
  for scan in self.z_serie:
752
- # check x, y and z translation are constant (only if is an HDF5TomoScan)
753
- if isinstance(scan, HDF5TomoScan):
811
+ # check x, y and z translation are constant (only if is an NXtomoScan)
812
+ if isinstance(scan, NXtomoScan):
754
813
  if scan.x_translation is not None and not numpy.isclose(
755
814
  min(scan.x_translation), max(scan.x_translation)
756
815
  ):
@@ -1002,23 +1061,43 @@ class PreProcessZStitcher(ZStitcher):
1002
1061
  darks=scan.reduced_darks,
1003
1062
  radios_indices=radio_indices,
1004
1063
  radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None,
1005
- flats_srcurrent=scan.reduced_flats_infos.machine_electric_current
1006
- if has_reduced_metadata
1007
- else None,
1064
+ flats_srcurrent=(
1065
+ scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None
1066
+ ),
1008
1067
  )
1009
1068
  # note: we need to cast radios to float 32. Darks and flats are cast to anyway
1010
1069
  data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32))
1011
1070
 
1012
- flip_lr = scans[i_scan].get_x_flipped(default=False) ^ scan_flip_lr
1013
- flip_ud = scans[i_scan].get_y_flipped(default=False) ^ scan_flip_ud
1071
+ transformations = list(scans[i_scan].get_detector_transformations(tuple()))
1072
+ if scan_flip_lr:
1073
+ transformations.append(LRDetTransformation())
1074
+ if scan_flip_ud:
1075
+ transformations.append(UDDetTransformation())
1076
+
1077
+ transformation_matrix_det_space = build_matrix(transformations)
1078
+ if transformation_matrix_det_space is None or numpy.allclose(
1079
+ transformation_matrix_det_space, numpy.identity(3)
1080
+ ):
1081
+ flip_ud = False
1082
+ flip_lr = False
1083
+ elif numpy.array_equal(transformation_matrix_det_space, ZStitcher._get_UD_flip_matrix()):
1084
+ flip_ud = True
1085
+ flip_lr = False
1086
+ elif numpy.allclose(transformation_matrix_det_space, ZStitcher._get_LR_flip_matrix()):
1087
+ flip_ud = False
1088
+ flip_lr = True
1089
+ elif numpy.allclose(transformation_matrix_det_space, ZStitcher._get_UD_AND_LR_flip_matrix()):
1090
+ flip_ud = True
1091
+ flip_lr = True
1092
+ else:
1093
+ raise ValueError("case not handled... For now only handle up-down flip as left-right flip")
1094
+
1014
1095
  for frame in data:
1015
- f_frame = frame
1016
- if flip_lr:
1017
- f_frame = numpy.fliplr(f_frame)
1018
1096
  if flip_ud:
1019
- f_frame = numpy.flipud(f_frame)
1020
-
1021
- all_scan_final_data[i_frame, i_scan] = f_frame
1097
+ frame = numpy.flipud(frame)
1098
+ if flip_lr:
1099
+ frame = numpy.fliplr(frame)
1100
+ all_scan_final_data[i_frame, i_scan] = frame
1022
1101
  i_frame += 1
1023
1102
 
1024
1103
  return all_scan_final_data
@@ -1032,7 +1111,9 @@ class PreProcessZStitcher(ZStitcher):
1032
1111
  ):
1033
1112
  upper_scan_pos = upper_scan_axis_0_pos - upper_scan.dim_2 / 2
1034
1113
  lower_scan_high_pos = lower_scan_axis_0_pos + lower_scan.dim_2 / 2
1035
- assert lower_scan_high_pos > upper_scan_pos, f"no overlap found between {upper_scan} and {lower_scan}"
1114
+ # simple test of overlap. More complete test are runned by check_overlaps later
1115
+ if lower_scan_high_pos <= upper_scan_pos:
1116
+ raise ValueError(f"no overlap found between {upper_scan} and {lower_scan}")
1036
1117
  self._axis_0_estimated_shifts.append(
1037
1118
  int(lower_scan_high_pos - upper_scan_pos) # overlap are expected to be int for now
1038
1119
  )
@@ -1074,11 +1155,10 @@ class PreProcessZStitcher(ZStitcher):
1074
1155
  nx_tomo.instrument.detector.y_pixel_size = self.z_serie[0].y_pixel_size
1075
1156
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
1076
1157
  nx_tomo.instrument.detector.tomo_n = n_proj
1077
- # note: stitching process insure unflipping of frames
1078
- nx_tomo.instrument.detector.x_flipped = False
1079
- nx_tomo.instrument.detector.y_flipped = False
1158
+ # note: stitching process insure unflipping of frames. So make sure transformations is defined as an empty set
1159
+ nx_tomo.instrument.detector.transformations = NXtransformations()
1080
1160
 
1081
- if isinstance(self.z_serie[0], HDF5TomoScan):
1161
+ if isinstance(self.z_serie[0], NXtomoScan):
1082
1162
  # note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
1083
1163
  rotation_angle = numpy.asarray(self.z_serie[0].rotation_angle)
1084
1164
  nx_tomo.sample.rotation_angle = rotation_angle[
@@ -1091,8 +1171,8 @@ class PreProcessZStitcher(ZStitcher):
1091
1171
  else:
1092
1172
  raise NotImplementedError(
1093
1173
  f"scan type ({type(self.z_serie[0])} is not handled)",
1094
- HDF5TomoScan,
1095
- isinstance(self.z_serie[0], HDF5TomoScan),
1174
+ NXtomoScan,
1175
+ isinstance(self.z_serie[0], NXtomoScan),
1096
1176
  )
1097
1177
 
1098
1178
  # do a sub selection of the rotation angle if a we are only computing a part of the slices
@@ -1110,7 +1190,7 @@ class PreProcessZStitcher(ZStitcher):
1110
1190
 
1111
1191
  # handle sample
1112
1192
  n_frames = n_proj
1113
- if False not in [isinstance(scan, HDF5TomoScan) for scan in self.z_serie]:
1193
+ if False not in [isinstance(scan, NXtomoScan) for scan in self.z_serie]:
1114
1194
  # we consider the new x, y and z position to be at the center of the one created
1115
1195
  x_translation = [scan.x_translation for scan in self.z_serie if scan.x_translation is not None]
1116
1196
  nx_tomo.sample.x_translation = [numpy.asarray(x_translation).mean()] * n_frames
@@ -1128,7 +1208,7 @@ class PreProcessZStitcher(ZStitcher):
1128
1208
  numpy.asarray([scan.dim_2 for scan in self.z_serie]).sum()
1129
1209
  - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_shifts]).sum()
1130
1210
  ),
1131
- self.z_serie[0].dim_1,
1211
+ self._stitching_width,
1132
1212
  )
1133
1213
 
1134
1214
  # get expected output dataset first (just in case output and input files are the same)
@@ -1151,11 +1231,15 @@ class PreProcessZStitcher(ZStitcher):
1151
1231
  overwrite=self.configuration.overwrite_results,
1152
1232
  )
1153
1233
 
1234
+ transformation_matrices = {
1235
+ scan.get_identifier()
1236
+ .to_str()
1237
+ .center(80, "-"): numpy.array2string(build_matrix(scan.get_detector_transformations(tuple())))
1238
+ for scan in self.z_serie
1239
+ }
1154
1240
  _logger.info(
1155
- f"scan x flipped are {','.join([str(scan.get_x_flipped(default=False)) for scan in self.z_serie])}"
1156
- )
1157
- _logger.info(
1158
- f"scan y flipped are {','.join([str(scan.get_y_flipped(default=False)) for scan in self.z_serie])}"
1241
+ "scan detector transformation matrices are:\n"
1242
+ "\n".join(["/n".join(item) for item in transformation_matrices.items()])
1159
1243
  )
1160
1244
 
1161
1245
  _logger.info(
@@ -1214,6 +1298,9 @@ class PreProcessZStitcher(ZStitcher):
1214
1298
  ):
1215
1299
  if self.configuration.rescale_frames:
1216
1300
  data_frames = self.rescale_frames(data_frames)
1301
+ if self.configuration.normalization_by_sample.is_active():
1302
+ data_frames = self.normalize_frame_by_sample(data_frames)
1303
+
1217
1304
  sf = ZStitcher.stitch_frames(
1218
1305
  frames=data_frames,
1219
1306
  x_relative_shifts=self._axis_2_rel_shifts,
@@ -1225,6 +1312,10 @@ class PreProcessZStitcher(ZStitcher):
1225
1312
  dump_frame_fct=self._dump_frame,
1226
1313
  return_composition_cls=store_composition if i_proj == 0 else False,
1227
1314
  stitching_axis=0,
1315
+ pad_mode=self.configuration.pad_mode,
1316
+ alignment=self.configuration.alignment_axis_2,
1317
+ new_width=self._stitching_width,
1318
+ check_inputs=i_proj == 0, # on process check on the first iteration
1228
1319
  )
1229
1320
  if i_proj == 0 and store_composition:
1230
1321
  _, self._frame_composition = sf
@@ -1277,6 +1368,7 @@ class PreProcessZStitcher(ZStitcher):
1277
1368
  class PostProcessZStitcher(ZStitcher):
1278
1369
  def __init__(self, configuration, progress: Progress = None) -> None:
1279
1370
  self._input_volumes = configuration.input_volumes
1371
+ self.__output_data_type = None
1280
1372
 
1281
1373
  self._z_serie = Serie("z-serie", iterable=self._input_volumes, use_identifiers=False)
1282
1374
  super().__init__(configuration, progress)
@@ -1333,7 +1425,7 @@ class PostProcessZStitcher(ZStitcher):
1333
1425
  if scan_location is not None:
1334
1426
  # this work around (until most volume have position metadata) works only for Hdf5volume
1335
1427
  with cwd_context(os.path.dirname(volume.file_path)):
1336
- o_scan = HDF5TomoScan(scan_location, scan_entry)
1428
+ o_scan = NXtomoScan(scan_location, scan_entry)
1337
1429
  bb_acqui = o_scan.get_bounding_box(axis=None)
1338
1430
  # for next step volume position will be required.
1339
1431
  # if you can find it set it directly
@@ -1414,7 +1506,7 @@ class PostProcessZStitcher(ZStitcher):
1414
1506
  # deduce from position given in configuration and pixel size
1415
1507
  axis_N_pos_px = []
1416
1508
  for volume, pos_in_mm in zip(self.z_serie, pos_as_mm):
1417
- voxel_size_m = self.configuration.voxel_size or volume.pixel_size
1509
+ voxel_size_m = self.configuration.voxel_size or volume.voxel_size
1418
1510
  axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0])
1419
1511
  return axis_N_pos_px
1420
1512
  else:
@@ -1422,7 +1514,7 @@ class PostProcessZStitcher(ZStitcher):
1422
1514
  axis_N_pos_px = []
1423
1515
  base_position_m = self.z_serie[0].get_bounding_box(axis=axis).min
1424
1516
  for volume in self.z_serie:
1425
- voxel_size_m = self.configuration.voxel_size or volume.pixel_size
1517
+ voxel_size_m = self.configuration.voxel_size or volume.voxel_size
1426
1518
  volume_axis_bb = volume.get_bounding_box(axis=axis)
1427
1519
  axis_N_mean_pos_m = (volume_axis_bb.max - volume_axis_bb.min) / 2 + volume_axis_bb.min
1428
1520
  axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
@@ -1464,6 +1556,7 @@ class PostProcessZStitcher(ZStitcher):
1464
1556
  slice_for_shift = self.configuration.slice_for_cross_correlation or "middle"
1465
1557
  y_rel_shifts = self._axis_0_estimated_shifts
1466
1558
  x_rel_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_2_pos_px)
1559
+ dim_axis_1 = max([volume.get_volume_shape()[1] for volume in self.z_serie])
1467
1560
 
1468
1561
  final_rel_shifts = []
1469
1562
  for (
@@ -1488,6 +1581,8 @@ class PostProcessZStitcher(ZStitcher):
1488
1581
  found_shift_y, found_shift_x = find_volumes_relative_shifts(
1489
1582
  upper_volume=upper_volume,
1490
1583
  lower_volume=lower_volume,
1584
+ dtype=self.get_output_data_type(),
1585
+ dim_axis_1=dim_axis_1,
1491
1586
  slice_for_shift=slice_for_shift,
1492
1587
  x_cross_correlation_function=x_cross_algo,
1493
1588
  y_cross_correlation_function=y_cross_algo,
@@ -1496,6 +1591,8 @@ class PostProcessZStitcher(ZStitcher):
1496
1591
  estimated_shifts=(y_rel_shift, x_rel_shift),
1497
1592
  flip_ud_lower_frame=flip_ud_lower,
1498
1593
  flip_ud_upper_frame=flip_ud_upper,
1594
+ alignment_axis_1=self.configuration.alignment_axis_1,
1595
+ alignment_axis_2=self.configuration.alignment_axis_2,
1499
1596
  )
1500
1597
  final_rel_shifts.append(
1501
1598
  (found_shift_y, found_shift_x),
@@ -1504,10 +1601,10 @@ class PostProcessZStitcher(ZStitcher):
1504
1601
  # set back values. Now position should start at 0
1505
1602
  self._axis_0_rel_shifts = [final_shift[0] for final_shift in final_rel_shifts]
1506
1603
  self._axis_2_rel_shifts = [final_shift[1] for final_shift in final_rel_shifts]
1507
- _logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_shifts}")
1508
- print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_shifts}")
1509
- _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_2_rel_shifts}")
1510
- print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_2_rel_shifts}")
1604
+ _logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_2_rel_shifts}")
1605
+ print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_2_rel_shifts}")
1606
+ _logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_shifts}")
1607
+ print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_shifts}")
1511
1608
 
1512
1609
  def _dump_stitching_configuration(self):
1513
1610
  voxel_size = self._input_volumes[0].voxel_size
@@ -1583,22 +1680,27 @@ class PostProcessZStitcher(ZStitcher):
1583
1680
  ):
1584
1681
  raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_2_pos_mm)}")
1585
1682
 
1586
- yz_shape = None
1587
- for volume in self.configuration.input_volumes:
1588
- assert isinstance(volume, VolumeBase)
1589
- volume_shape = volume.get_volume_shape()
1590
- if volume_shape is None:
1591
- raise ValueError("Unable to load volume shape (probably no data found from {volume.get_identifier()}")
1592
- if yz_shape is None:
1593
- yz_shape = volume_shape[1:]
1594
- elif yz_shape != volume_shape[1:]:
1595
- raise ValueError("Input volumes have incoherent (yz) shapes. Unable to stitch it together")
1596
-
1597
1683
  self._reading_orders = []
1598
1684
  # the first scan will define the expected reading orderd, and expected flip.
1599
1685
  # if all scan are flipped then we will keep it this way
1600
1686
  self._reading_orders.append(1)
1601
1687
 
1688
+ def get_output_data_type(self):
1689
+ if self.__output_data_type is None:
1690
+
1691
+ def find_output_data_type():
1692
+ first_vol = self._input_volumes[0]
1693
+ if first_vol.data is not None:
1694
+ return first_vol.data.dtype
1695
+ elif isinstance(first_vol, HDF5Volume):
1696
+ with DatasetReader(first_vol.data_url) as vol_dataset:
1697
+ return vol_dataset.dtype
1698
+ else:
1699
+ return first_vol.load_data(store=False).dtype
1700
+
1701
+ self.__output_data_type = find_output_data_type()
1702
+ return self.__output_data_type
1703
+
1602
1704
  def _create_stitched_volume(self, store_composition: bool):
1603
1705
  overlap_kernels = self._overlap_kernels
1604
1706
  self._slices_to_stitch, n_slices = self.configuration.settle_slices()
@@ -1614,20 +1716,10 @@ class PostProcessZStitcher(ZStitcher):
1614
1716
  - numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_shifts]).sum(),
1615
1717
  ),
1616
1718
  n_slices,
1617
- self._input_volumes[0].get_volume_shape()[2],
1719
+ self._stitching_width,
1618
1720
  )
1619
1721
 
1620
- def get_output_data_type():
1621
- first_vol = self._input_volumes[0]
1622
- if first_vol.data is not None:
1623
- return first_vol.data.dtype
1624
- elif isinstance(first_vol, HDF5Volume):
1625
- with DatasetReader(first_vol.data_url) as vol_dataset:
1626
- return vol_dataset.dtype
1627
- else:
1628
- return first_vol.load_data(store=False).dtype
1629
-
1630
- data_type = get_output_data_type()
1722
+ data_type = self.get_output_data_type()
1631
1723
 
1632
1724
  if self.progress:
1633
1725
  self.progress.set_max_advancement(final_volume_shape[1])
@@ -1641,7 +1733,10 @@ class PostProcessZStitcher(ZStitcher):
1641
1733
  volume=final_volume, volume_shape=final_volume_shape, dtype=data_type
1642
1734
  ) as output_dataset:
1643
1735
  # note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
1644
- with PostProcessZStitcher._RawDatasetsContext(self._input_volumes) as raw_datasets:
1736
+ with PostProcessZStitcher._RawDatasetsContext(
1737
+ self._input_volumes,
1738
+ alignment_axis_1=self.configuration.alignment_axis_1,
1739
+ ) as raw_datasets:
1645
1740
  # note: raw_datasets can be numpy arrays or HDF5 dataset (in the case of HDF5Volume)
1646
1741
  # to speed up we read by bunch of dataset. For numpy array this doesn't change anything
1647
1742
  # but for HDF5 dataset this can speed up a lot the processing (depending on HDF5 dataset chuncks)
@@ -1659,6 +1754,9 @@ class PostProcessZStitcher(ZStitcher):
1659
1754
  ):
1660
1755
  if self.configuration.rescale_frames:
1661
1756
  data_frames = self.rescale_frames(data_frames)
1757
+ if self.configuration.normalization_by_sample.is_active():
1758
+ data_frames = self.normalize_frame_by_sample(data_frames)
1759
+
1662
1760
  sf = ZStitcher.stitch_frames(
1663
1761
  frames=data_frames,
1664
1762
  x_relative_shifts=self._axis_2_rel_shifts,
@@ -1670,6 +1768,7 @@ class PostProcessZStitcher(ZStitcher):
1670
1768
  output_dtype=data_type,
1671
1769
  return_composition_cls=store_composition if y_index == 0 else False,
1672
1770
  stitching_axis=0,
1771
+ check_inputs=y_index == 0, # on process check on the first iteration
1673
1772
  )
1674
1773
  if y_index == 0 and store_composition:
1675
1774
  _, self._frame_composition = sf
@@ -1680,7 +1779,12 @@ class PostProcessZStitcher(ZStitcher):
1680
1779
 
1681
1780
  @staticmethod
1682
1781
  def _get_bunch_of_data(
1683
- bunch_start: int, bunch_end: int, step: int, volumes: tuple, flip_lr_arr: bool, flip_ud_arr: bool
1782
+ bunch_start: int,
1783
+ bunch_end: int,
1784
+ step: int,
1785
+ volumes: tuple,
1786
+ flip_lr_arr: bool,
1787
+ flip_ud_arr: bool,
1684
1788
  ):
1685
1789
  """
1686
1790
  goal is to load contiguous frames as much as possible...
@@ -1768,7 +1872,7 @@ class PostProcessZStitcher(ZStitcher):
1768
1872
  If the volume is of another type then it will be loaded in memory then used (more memory consuming)
1769
1873
  """
1770
1874
 
1771
- def __init__(self, volumes: tuple) -> None:
1875
+ def __init__(self, volumes: tuple, alignment_axis_1) -> None:
1772
1876
  super().__init__()
1773
1877
  for volume in volumes:
1774
1878
  if not isinstance(volume, VolumeBase):
@@ -1778,25 +1882,36 @@ class PostProcessZStitcher(ZStitcher):
1778
1882
 
1779
1883
  self._volumes = volumes
1780
1884
  self.__file_handlers = []
1885
+ self._alignment_axis_1 = alignment_axis_1
1886
+
1887
+ @property
1888
+ def alignment_axis_1(self):
1889
+ return self._alignment_axis_1
1781
1890
 
1782
1891
  def __enter__(self):
1783
1892
  # handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory
1784
1893
  datasets = []
1894
+ shapes = {volume.get_volume_shape()[1] for volume in self._volumes}
1895
+ axis_1_dim = max(shapes)
1896
+ axis_1_need_padding = len(shapes) > 1
1897
+
1785
1898
  try:
1786
1899
  for volume in self._volumes:
1787
1900
  if volume.data is not None:
1788
- datasets.append(volume.data)
1901
+ data = volume.data
1789
1902
  elif isinstance(volume, HDF5Volume):
1790
1903
  file_handler = HDF5File(filename=volume.data_url.file_path(), mode="r")
1791
1904
  dataset = file_handler[volume.data_url.data_path()]
1792
- datasets.append(dataset)
1905
+ data = dataset
1793
1906
  self.__file_handlers.append(file_handler)
1794
1907
  # for other file format: load the full dataset in memory
1795
1908
  else:
1796
1909
  data = volume.load_data(store=False)
1797
1910
  if data is None:
1798
1911
  raise ValueError(f"No data found for volume {volume.get_identifier()}")
1799
- datasets.append(data)
1912
+ if axis_1_need_padding:
1913
+ data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
1914
+ datasets.append(data)
1800
1915
  except Exception as e:
1801
1916
  # if some errors happen during loading HDF5
1802
1917
  for file_handled in self.__file_handlers:
@@ -1811,16 +1926,36 @@ class PostProcessZStitcher(ZStitcher):
1811
1926
  success = success and file_handler.close()
1812
1927
  return success
1813
1928
 
1929
+ def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
1930
+ alignment = AlignmentAxis1.from_value(alignment)
1931
+ if alignment is AlignmentAxis1.BACK:
1932
+ axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
1933
+ elif alignment is AlignmentAxis1.CENTER:
1934
+ half_width = int((axis_1_dim - data.shape[1]) / 2)
1935
+ axis_1_pad_width = (half_width, axis_1_dim - data.shape[1] - half_width)
1936
+ elif alignment is AlignmentAxis1.FRONT:
1937
+ axis_1_pad_width = (0, axis_1_dim - data.shape[1])
1938
+ else:
1939
+ raise ValueError(f"alignment {alignment} is not handled")
1940
+
1941
+ return PaddedRawData(
1942
+ data=data,
1943
+ axis_1_pad_width=axis_1_pad_width,
1944
+ )
1945
+
1814
1946
 
1815
1947
  def stitch_vertically_raw_frames(
1816
1948
  frames: tuple,
1817
1949
  key_lines: tuple,
1818
- overlap_kernels: Optional[Union[ZStichOverlapKernel, tuple]],
1950
+ overlap_kernels: Union[ZStichOverlapKernel, tuple],
1819
1951
  output_dtype: numpy.dtype = numpy.float32,
1820
1952
  check_inputs=True,
1821
1953
  raw_frames_compositions: Optional[ZFrameComposition] = None,
1822
1954
  overlap_frames_compositions: Optional[ZFrameComposition] = None,
1823
1955
  return_composition_cls=False,
1956
+ alignment="center",
1957
+ pad_mode="constant",
1958
+ new_width: Optional[int] = None,
1824
1959
  ) -> numpy.ndarray:
1825
1960
  """
1826
1961
  stitches raw frames (already shifted and flat fielded !!!) together using
@@ -1865,8 +2000,7 @@ def stitch_vertically_raw_frames(
1865
2000
  for frame_0, frame_1 in zip(frames[:-1], frames[1:]):
1866
2001
  if not (frame_0.ndim == frame_1.ndim == 2):
1867
2002
  raise ValueError("Frames are expected to be 2D")
1868
- if frame_0.shape[1] != frame_1.shape[1]:
1869
- raise ValueError("Both projections are expected to have the same width")
2003
+
1870
2004
  for frame_0, frame_1, kernel in zip(frames[:-1], frames[1:], overlap_kernels):
1871
2005
  if frame_0.shape[0] < kernel.overlap_size:
1872
2006
  raise ValueError(
@@ -1886,6 +2020,20 @@ def stitch_vertically_raw_frames(
1886
2020
  elif value < 0:
1887
2021
  raise ValueError(f"key lines are expected to be positive values. Get {value} as key line value")
1888
2022
 
2023
+ if new_width is None:
2024
+ new_width = max([frame.shape[-1] for frame in frames])
2025
+ frames = tuple(
2026
+ [
2027
+ align_horizontally(
2028
+ data=frame,
2029
+ alignment=alignment,
2030
+ new_width=new_width,
2031
+ pad_mode=pad_mode,
2032
+ )
2033
+ for frame in frames
2034
+ ]
2035
+ )
2036
+
1889
2037
  # step 1: create numpy array that will contain stitching
1890
2038
  # if raw composition doesn't exists create it
1891
2039
  if raw_frames_compositions is None:
@@ -1899,7 +2047,7 @@ def stitch_vertically_raw_frames(
1899
2047
  stitched_projection_shape = (
1900
2048
  # here we only handle frames because shift are already done
1901
2049
  int(new_frame_height),
1902
- frames[0].shape[1],
2050
+ new_width,
1903
2051
  )
1904
2052
  stitch_array = numpy.empty(stitched_projection_shape, dtype=output_dtype)
1905
2053
 
@@ -1966,18 +2114,67 @@ class StitchingPostProcAggregation:
1966
2114
 
1967
2115
  This is the goal of this class.
1968
2116
  Please be careful with API. This is already inheriting from a tomwer class
2117
+
2118
+ :param ZStitchingConfiguration stitching_config: configuration of the stitching configuration
2119
+ :param Optional[tuple] futures: futures that just runned
2120
+ :param Optional[tuple] existing_objs: futures that just runned
2121
+ :param
1969
2122
  """
1970
2123
 
1971
- def __init__(self, futures, stitching_config) -> None:
1972
- if isinstance(stitching_config, (PreProcessedZStitchingConfiguration, PostProcessedZStitchingConfiguration)):
1973
- raise TypeError
2124
+ def __init__(
2125
+ self,
2126
+ stitching_config: ZStitchingConfiguration,
2127
+ futures: Optional[tuple] = None,
2128
+ existing_objs_ids: Optional[tuple] = None,
2129
+ ) -> None:
2130
+ if not isinstance(stitching_config, (ZStitchingConfiguration)):
2131
+ raise TypeError(f"stitching_config should be an instance of {ZStitchingConfiguration}")
2132
+ if not ((existing_objs_ids is None) ^ (futures is None)):
2133
+ raise ValueError("Either existing_objs or futures should be provided (can't provide both)")
1974
2134
  self._futures = futures
1975
2135
  self._stitching_config = stitching_config
2136
+ self._existing_objs_ids = existing_objs_ids
1976
2137
 
1977
2138
  @property
1978
2139
  def futures(self):
2140
+ # TODO: deprecate it ?
1979
2141
  return self._futures
1980
2142
 
2143
+ def retrieve_tomo_objects(self) -> tuple():
2144
+ """
2145
+ Return tomo objects to be stitched together. Either from future or from existing_objs
2146
+ """
2147
+ if self._existing_objs_ids is not None:
2148
+ scan_ids = self._existing_objs_ids
2149
+ else:
2150
+ results = {}
2151
+ _logger.info(f"wait for slurm job to be completed")
2152
+ for obj_id, future in self.futures.items():
2153
+ results[obj_id] = future.result()
2154
+
2155
+ failed = tuple(
2156
+ filter(
2157
+ lambda x: x.exception() is not None,
2158
+ self.futures.values(),
2159
+ )
2160
+ )
2161
+ if len(failed) > 0:
2162
+ # if some job failed: unseless to do the concatenation
2163
+ exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed])
2164
+ raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}")
2165
+
2166
+ canceled = tuple(
2167
+ filter(
2168
+ lambda x: x.cancelled(),
2169
+ self.futures.values(),
2170
+ )
2171
+ )
2172
+ if len(canceled) > 0:
2173
+ # if some job canceled: unseless to do the concatenation
2174
+ raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}")
2175
+ scan_ids = results.keys()
2176
+ return [TomoscanFactory.create_tomo_object_from_identifier(scan_id) for scan_id in scan_ids]
2177
+
1981
2178
  def dump_stiching_config_as_nx_process(self, file_path: str, data_path: str, overwrite: bool, process_name: str):
1982
2179
  dict_to_dump = {
1983
2180
  process_name: {
@@ -2005,40 +2202,14 @@ class StitchingPostProcAggregation:
2005
2202
  """
2006
2203
  main function
2007
2204
  """
2008
- # retrive results (and wait if some processing are not finished)
2009
- results = {}
2010
- _logger.info(f"wait for slurm job to be completed")
2011
- for obj_id, future in self.futures.items():
2012
- results[obj_id] = future.result()
2013
-
2014
- failed = tuple(
2015
- filter(
2016
- lambda x: x.exception() is not None,
2017
- self.futures.values(),
2018
- )
2019
- )
2020
- if len(failed) > 0:
2021
- # if some job failed: unseless to do the concatenation
2022
- exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed])
2023
- raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}")
2024
-
2025
- canceled = tuple(
2026
- filter(
2027
- lambda x: x.cancelled(),
2028
- self.futures.values(),
2029
- )
2030
- )
2031
- if len(canceled) > 0:
2032
- # if some job canceled: unseless to do the concatenation
2033
- raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}")
2034
2205
 
2035
2206
  # concatenate result
2036
2207
  _logger.info("all job succeeded. Concatenate results")
2037
2208
  if isinstance(self._stitching_config, PreProcessedZStitchingConfiguration):
2038
2209
  # 1: case of a pre-processing stitching
2210
+ scans = self.retrieve_tomo_objects()
2039
2211
  nx_tomos = []
2040
- for result in results.keys():
2041
- scan = TomoscanFactory.create_tomo_object_from_identifier(result)
2212
+ for scan in scans:
2042
2213
  nx_tomos.append(
2043
2214
  NXtomo().load(
2044
2215
  file_path=scan.master_file,
@@ -2069,9 +2240,7 @@ class StitchingPostProcAggregation:
2069
2240
 
2070
2241
  elif isinstance(self.stitching_config, PostProcessedZStitchingConfiguration):
2071
2242
  # 2: case of a post-processing stitching
2072
- outputs_sub_volumes = [
2073
- TomoscanFactory.create_tomo_object_from_identifier(result) for result in results.keys()
2074
- ]
2243
+ outputs_sub_volumes = self.retrieve_tomo_objects()
2075
2244
  concatenate_volumes(
2076
2245
  output_volume=self.stitching_config.output_volume,
2077
2246
  volumes=tuple(outputs_sub_volumes),
@@ -2092,3 +2261,17 @@ class StitchingPostProcAggregation:
2092
2261
  process_name=process_name,
2093
2262
  overwrite=self.stitching_config.overwrite_results,
2094
2263
  )
2264
+ else:
2265
+ raise TypeError(f"stitching_config type ({type(self.stitching_config)}) not handled")
2266
+
2267
+
2268
+ def get_obj_width(obj: Union[NXtomoScan, VolumeBase]) -> int:
2269
+ """
2270
+ return tomo object width
2271
+ """
2272
+ if isinstance(obj, NXtomoScan):
2273
+ return obj.dim_1
2274
+ elif isinstance(obj, VolumeBase):
2275
+ return obj.get_volume_shape()[-1]
2276
+ else:
2277
+ raise TypeError(f"obj type ({type(obj)}) is not handled")