nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  15. nabu/cuda/kernel.py +11 -2
  16. nabu/cuda/processing.py +2 -2
  17. nabu/cuda/src/cone.cu +77 -0
  18. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  19. nabu/cuda/utils.py +0 -6
  20. nabu/estimation/alignment.py +5 -19
  21. nabu/estimation/cor.py +173 -599
  22. nabu/estimation/cor_sino.py +356 -26
  23. nabu/estimation/focus.py +63 -11
  24. nabu/estimation/tests/test_cor.py +124 -58
  25. nabu/estimation/tests/test_focus.py +6 -6
  26. nabu/estimation/tilt.py +2 -1
  27. nabu/estimation/utils.py +5 -33
  28. nabu/io/__init__.py +1 -1
  29. nabu/io/cast_volume.py +1 -1
  30. nabu/io/reader.py +416 -21
  31. nabu/io/tests/test_readers.py +422 -0
  32. nabu/io/tests/test_writers.py +1 -102
  33. nabu/io/writer.py +4 -433
  34. nabu/opencl/kernel.py +14 -3
  35. nabu/opencl/processing.py +8 -0
  36. nabu/pipeline/config_validators.py +5 -2
  37. nabu/pipeline/datadump.py +12 -5
  38. nabu/pipeline/estimators.py +162 -188
  39. nabu/pipeline/fullfield/chunked.py +168 -92
  40. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  41. nabu/pipeline/fullfield/computations.py +2 -7
  42. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  43. nabu/pipeline/fullfield/nabu_config.py +37 -13
  44. nabu/pipeline/fullfield/processconfig.py +22 -13
  45. nabu/pipeline/fullfield/reconstruction.py +13 -9
  46. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  47. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  48. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  49. nabu/pipeline/params.py +21 -1
  50. nabu/pipeline/processconfig.py +1 -12
  51. nabu/pipeline/reader.py +146 -0
  52. nabu/pipeline/tests/test_estimators.py +44 -72
  53. nabu/pipeline/utils.py +4 -2
  54. nabu/pipeline/writer.py +10 -2
  55. nabu/preproc/ccd_cuda.py +1 -1
  56. nabu/preproc/ctf.py +14 -7
  57. nabu/preproc/ctf_cuda.py +2 -3
  58. nabu/preproc/double_flatfield.py +5 -12
  59. nabu/preproc/double_flatfield_cuda.py +2 -2
  60. nabu/preproc/flatfield.py +5 -1
  61. nabu/preproc/flatfield_cuda.py +5 -1
  62. nabu/preproc/phase.py +24 -73
  63. nabu/preproc/phase_cuda.py +5 -8
  64. nabu/preproc/tests/test_ctf.py +11 -7
  65. nabu/preproc/tests/test_flatfield.py +67 -122
  66. nabu/preproc/tests/test_paganin.py +54 -30
  67. nabu/processing/azim.py +206 -0
  68. nabu/processing/convolution_cuda.py +1 -1
  69. nabu/processing/fft_cuda.py +15 -17
  70. nabu/processing/histogram.py +2 -0
  71. nabu/processing/histogram_cuda.py +2 -1
  72. nabu/processing/kernel_base.py +3 -0
  73. nabu/processing/muladd_cuda.py +1 -0
  74. nabu/processing/padding_opencl.py +1 -1
  75. nabu/processing/roll_opencl.py +1 -0
  76. nabu/processing/rotation_cuda.py +2 -2
  77. nabu/processing/tests/test_fft.py +17 -10
  78. nabu/processing/unsharp_cuda.py +1 -1
  79. nabu/reconstruction/cone.py +104 -40
  80. nabu/reconstruction/fbp.py +3 -0
  81. nabu/reconstruction/fbp_base.py +7 -2
  82. nabu/reconstruction/filtering.py +20 -7
  83. nabu/reconstruction/filtering_cuda.py +7 -1
  84. nabu/reconstruction/hbp.py +424 -0
  85. nabu/reconstruction/mlem.py +99 -0
  86. nabu/reconstruction/reconstructor.py +2 -0
  87. nabu/reconstruction/rings_cuda.py +19 -19
  88. nabu/reconstruction/sinogram_cuda.py +1 -0
  89. nabu/reconstruction/sinogram_opencl.py +3 -1
  90. nabu/reconstruction/tests/test_cone.py +10 -5
  91. nabu/reconstruction/tests/test_deringer.py +7 -6
  92. nabu/reconstruction/tests/test_fbp.py +124 -10
  93. nabu/reconstruction/tests/test_filtering.py +13 -11
  94. nabu/reconstruction/tests/test_halftomo.py +30 -4
  95. nabu/reconstruction/tests/test_mlem.py +91 -0
  96. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  97. nabu/resources/dataset_analyzer.py +142 -92
  98. nabu/resources/gpu.py +1 -0
  99. nabu/resources/nxflatfield.py +134 -125
  100. nabu/resources/templates/id16a_fluo.conf +42 -0
  101. nabu/resources/tests/test_extract.py +10 -0
  102. nabu/resources/tests/test_nxflatfield.py +2 -2
  103. nabu/stitching/alignment.py +80 -24
  104. nabu/stitching/config.py +105 -68
  105. nabu/stitching/definitions.py +1 -0
  106. nabu/stitching/frame_composition.py +68 -60
  107. nabu/stitching/overlap.py +91 -51
  108. nabu/stitching/single_axis_stitching.py +32 -0
  109. nabu/stitching/slurm_utils.py +6 -6
  110. nabu/stitching/stitcher/__init__.py +0 -0
  111. nabu/stitching/stitcher/base.py +124 -0
  112. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  113. nabu/stitching/stitcher/dumper/base.py +94 -0
  114. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  115. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  116. nabu/stitching/stitcher/post_processing.py +555 -0
  117. nabu/stitching/stitcher/pre_processing.py +1068 -0
  118. nabu/stitching/stitcher/single_axis.py +484 -0
  119. nabu/stitching/stitcher/stitcher.py +0 -0
  120. nabu/stitching/stitcher/y_stitcher.py +13 -0
  121. nabu/stitching/stitcher/z_stitcher.py +45 -0
  122. nabu/stitching/stitcher_2D.py +278 -0
  123. nabu/stitching/tests/test_config.py +12 -37
  124. nabu/stitching/tests/test_frame_composition.py +33 -59
  125. nabu/stitching/tests/test_overlap.py +149 -7
  126. nabu/stitching/tests/test_utils.py +1 -1
  127. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  128. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  129. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  130. nabu/stitching/utils/__init__.py +1 -0
  131. nabu/stitching/utils/post_processing.py +281 -0
  132. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  133. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  134. nabu/stitching/y_stitching.py +27 -0
  135. nabu/stitching/z_stitching.py +32 -2263
  136. nabu/testutils.py +1 -152
  137. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  138. nabu/utils.py +158 -61
  139. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
  140. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
  141. nabu/io/tiffwriter_zmm.py +0 -99
  142. nabu/pipeline/fallback_utils.py +0 -149
  143. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  144. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  145. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  146. nabu/pipeline/helical/utils.py +0 -51
  147. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  148. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  149. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
  150. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  151. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,42 @@
1
+ #
2
+ # ESRF ID16a fluo-tomography
3
+ #
4
+
5
+ #
6
+ # Write here your custom configuration as a python dictionary.
7
+ # Any parameter not present will take the default value.
8
+ #
9
+
10
+ [dataset]
11
+ hdf5_entry = all
12
+
13
+ [preproc]
14
+
15
+ flatfield = 0
16
+ flat_distortion_correction_enabled = 0
17
+ take_logarithm = 0
18
+
19
+
20
+
21
+ [phase]
22
+
23
+ method = none
24
+
25
+
26
+ [reconstruction]
27
+
28
+ method = mlem
29
+ rotation_axis_position = 0.
30
+ cor_options =
31
+ translation_movements_file =
32
+ enable_halftomo = 0
33
+ clip_outer_circle = 1
34
+ centered_axis = 1
35
+ iterations = 200
36
+
37
+
38
+ [output]
39
+
40
+ location =
41
+ file_format = tiff
42
+ tiff_single_file = 1
@@ -0,0 +1,10 @@
1
+ import pytest
2
+ from nabu.utils import list_match_queries
3
+
4
+
5
+ def test_list_match_queries():
6
+
7
+ # entry0000 .... entry0099
8
+ avail = ["entry%04d" % i for i in range(100)]
9
+ query = "entry0000"
10
+ list_match_queries()
@@ -96,9 +96,9 @@ class TestNXFlatField:
96
96
  for s in self.params["darks_pos"]:
97
97
  expected_darks[s.start] = self._reduction_func["darks"](data_volume[s.start : s.stop], axis=0)
98
98
 
99
- flats = {idx: get_data(dataset_info.flats[idx]) for idx in dataset_info.flats.keys()}
99
+ flats = dataset_info.flats
100
100
  for idx in flats.keys():
101
101
  assert np.allclose(flats[idx], expected_flats[idx])
102
- darks = {idx: get_data(dataset_info.darks[idx]) for idx in dataset_info.darks.keys()}
102
+ darks = dataset_info.darks
103
103
  for idx in darks.keys():
104
104
  assert np.allclose(darks[idx], expected_darks[idx])
@@ -8,52 +8,108 @@ from nabu.io.utils import DatasetReader
8
8
 
9
9
 
10
10
  class AlignmentAxis2(_Enum):
11
+ """Specific alignment named to help users orienting themself with specific name"""
12
+
11
13
  CENTER = "center"
12
14
  LEFT = "left"
13
15
  RIGTH = "right"
14
16
 
15
17
 
16
18
  class AlignmentAxis1(_Enum):
19
+ """Specific alignment named to help users orienting themself with specific name"""
20
+
17
21
  FRONT = "front"
18
22
  CENTER = "center"
19
23
  BACK = "back"
20
24
 
21
25
 
22
- def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width: int, pad_mode="constant"):
26
+ class _Alignment(_Enum):
27
+ """Internal alignment to be used for 2D alignment"""
28
+
29
+ LOWER_BOUNDARY = "lower boundary"
30
+ HIGHER_BOUNDARY = "higher boundary"
31
+ CENTER = "center"
32
+
33
+ @classmethod
34
+ def from_value(cls, value):
35
+ # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition
36
+ if value in ("front", "left", AlignmentAxis1.FRONT, AlignmentAxis2.LEFT):
37
+ return _Alignment.LOWER_BOUNDARY
38
+ elif value in ("back", "right", AlignmentAxis1.BACK, AlignmentAxis2.RIGTH):
39
+ return _Alignment.HIGHER_BOUNDARY
40
+ elif value in (AlignmentAxis1.CENTER, AlignmentAxis2.CENTER):
41
+ return _Alignment.CENTER
42
+ else:
43
+ return super().from_value(value)
44
+
45
+
46
+ def align_frame(
47
+ data: numpy.ndarray, alignment: _Alignment, alignment_axis: int, new_aligned_axis_size: int, pad_mode="constant"
48
+ ):
23
49
  """
24
- Align data horizontally to make sure new data width will ne `new_width`.
50
+ Align 2D array to extend if size along `alignment_axis` to `new_aligned_axis_size`.
25
51
 
26
- :param numpy.ndarray data: data to align
52
+ :param numpy.ndarray data: data (frame) to align (2D numpy array)
53
+ :param alignment_axis: axis along which we want to align the frame. Must be in (0, 1)
27
54
  :param HAlignment alignment: alignment strategy
28
55
  :param int new_width: output data width
29
56
  """
30
- current_width = data.shape[-1]
31
- alignment = AlignmentAxis2.from_value(alignment)
57
+ if alignment_axis not in (0, 1):
58
+ raise ValueError(f"alignment_axis should be in (0, 1). Get {alignment_axis}")
59
+ alignment = _Alignment.from_value(alignment)
32
60
 
33
- if current_width > new_width:
34
- raise ValueError(f"data.shape[-1] ({data.shape[-1]}) > new_width ({new_width}). Unable to crop data")
35
- elif current_width == new_width:
61
+ aligned_axis_size = data.shape[alignment_axis]
62
+
63
+ if aligned_axis_size > new_aligned_axis_size:
64
+ raise ValueError(
65
+ f"data.shape[alignment_axis] ({data.shape[alignment_axis]}) > new_aligned_axis_size ({new_aligned_axis_size}). Unable to crop data"
66
+ )
67
+ elif aligned_axis_size == new_aligned_axis_size:
36
68
  return data
37
69
  else:
38
- if alignment is AlignmentAxis2.CENTER:
39
- left_width = (new_width - current_width) // 2
40
- right_width = (new_width - current_width) - left_width
41
- elif alignment is AlignmentAxis2.LEFT:
42
- left_width = 0
43
- right_width = new_width - current_width
44
- elif alignment is AlignmentAxis2.RIGTH:
45
- left_width = new_width - current_width
46
- right_width = 0
70
+ if alignment is _Alignment.CENTER:
71
+ lower_boundary = (new_aligned_axis_size - aligned_axis_size) // 2
72
+ higher_boundary = (new_aligned_axis_size - aligned_axis_size) - lower_boundary
73
+ elif alignment is _Alignment.LOWER_BOUNDARY:
74
+ lower_boundary = 0
75
+ higher_boundary = new_aligned_axis_size - aligned_axis_size
76
+ elif alignment is _Alignment.HIGHER_BOUNDARY:
77
+ lower_boundary = new_aligned_axis_size - aligned_axis_size
78
+ higher_boundary = 0
47
79
  else:
48
80
  raise ValueError(f"alignment {alignment.value} is not handled")
49
81
 
50
- assert left_width >= 0, f"pad width must be positive - left width isn't ({left_width})"
51
- assert right_width >= 0, f"pad width must be positive - right width isn't ({right_width})"
52
- return numpy.pad(
53
- data,
54
- pad_width=((0, 0), (left_width, right_width)),
55
- mode=pad_mode,
56
- )
82
+ assert lower_boundary >= 0, f"pad size must be positive - lower boundary isn't ({lower_boundary})"
83
+ assert higher_boundary >= 0, f"pad size must be positive - higher boundary isn't ({higher_boundary})"
84
+
85
+ if alignment_axis == 1:
86
+ return numpy.pad(
87
+ data,
88
+ pad_width=((0, 0), (lower_boundary, higher_boundary)),
89
+ mode=pad_mode,
90
+ )
91
+ elif alignment_axis == 0:
92
+ return numpy.pad(
93
+ data,
94
+ pad_width=((lower_boundary, higher_boundary), (0, 0)),
95
+ mode=pad_mode,
96
+ )
97
+ else:
98
+ raise ValueError("alignment_axis should be in (0, 1)")
99
+
100
+
101
+ def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width: int, pad_mode="constant"):
102
+ """
103
+ Align data horizontally to make sure new data width will ne `new_width`.
104
+
105
+ :param numpy.ndarray data: data to align
106
+ :param HAlignment alignment: alignment strategy
107
+ :param int new_width: output data width
108
+ """
109
+ alignment = AlignmentAxis2.from_value(alignment).value
110
+ return align_frame(
111
+ data=data, alignment=alignment, new_aligned_axis_size=new_width, pad_mode=pad_mode, alignment_axis=1
112
+ )
57
113
 
58
114
 
59
115
  class PaddedRawData:
nabu/stitching/config.py CHANGED
@@ -1,33 +1,3 @@
1
- # coding: utf-8
2
- # /*##########################################################################
3
- #
4
- # Copyright (c) 2016-2017 European Synchrotron Radiation Facility
5
- #
6
- # Permission is hereby granted, free of charge, to any person obtaining a copy
7
- # of this software and associated documentation files (the "Software"), to deal
8
- # in the Software without restriction, including without limitation the rights
9
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
- # copies of the Software, and to permit persons to whom the Software is
11
- # furnished to do so, subject to the following conditions:
12
- #
13
- # The above copyright notice and this permission notice shall be included in
14
- # all copies or substantial portions of the Software.
15
- #
16
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22
- # THE SOFTWARE.
23
- #
24
- # ###########################################################################*/
25
-
26
- __authors__ = ["H. Payno"]
27
- __license__ = "MIT"
28
- __date__ = "10/05/2022"
29
-
30
-
31
1
  from math import ceil
32
2
  from typing import Any, Iterable, Optional, Union, Sized
33
3
  from dataclasses import dataclass
@@ -44,7 +14,7 @@ from ..pipeline.config_validators import (
44
14
  from ..utils import concatenate_dict, convert_str_to_tuple
45
15
  from ..io.utils import get_output_volume
46
16
  from .overlap import OverlapStitchingStrategy
47
- from .utils import ShiftAlgorithm
17
+ from .utils.utils import ShiftAlgorithm
48
18
  from .definitions import StitchingType
49
19
  from .alignment import AlignmentAxis1, AlignmentAxis2
50
20
  from pyunitsystem.metricsystem import MetricSystem
@@ -66,9 +36,9 @@ OUTPUT_SECTION = "output"
66
36
 
67
37
  INPUTS_SECTION = "inputs"
68
38
 
69
- Z_PRE_PROC_SECTION = "z-preproc"
39
+ PRE_PROC_SECTION = "preproc"
70
40
 
71
- Z_POST_PROC_SECTION = "z-postproc"
41
+ POST_PROC_SECTION = "postproc"
72
42
 
73
43
  INPUT_DATASETS_FIELD = "input_datasets"
74
44
 
@@ -134,6 +104,8 @@ ALIGNMENT_AXIS_1_FIELD = "alignment_axis_1"
134
104
 
135
105
  PAD_MODE_FIELD = "pad_mode"
136
106
 
107
+ AVOID_DATA_DUPLICATION_FIELD = "avoid_data_duplication"
108
+
137
109
  # SLURM
138
110
 
139
111
  SLURM_SECTION = "slurm"
@@ -539,6 +511,10 @@ class StitchingConfiguration:
539
511
 
540
512
  normalization_by_sample: NormalizationBySample = None
541
513
 
514
+ duplicate_data: bool = True
515
+ """when possible (for HDF5) avoid duplicating data as-much-much-as-possible. Overlaping region between two frames will be duplicated. Remaining will be 'raw_data' for volume.
516
+ For projection flat field will be applied"""
517
+
542
518
  @property
543
519
  def stitching_type(self):
544
520
  raise NotImplementedError("Base class")
@@ -657,6 +633,12 @@ class StitchingConfiguration:
657
633
  "help": f"pad mode to use for frame alignment. Valid values are 'constant', 'edge', 'linear_ramp', maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', 'wrap', and 'empty'. See nupy.pad documentation for details",
658
634
  "type": "advanced",
659
635
  },
636
+ AVOID_DATA_DUPLICATION_FIELD: {
637
+ "default": "1",
638
+ "help": "When possible (stitching on reconstructed volume and HDF5 volume as input and output) create link to original data instead of duplicating it all. Warning: this will create relative link between the stiched volume and the original reconstructed volume.",
639
+ "validator": boolean_validator,
640
+ "type": "advanced",
641
+ },
660
642
  },
661
643
  OUTPUT_SECTION: {
662
644
  OVERWRITE_RESULTS_FIELD: {
@@ -771,6 +753,7 @@ class StitchingConfiguration:
771
753
  RESCALE_FRAMES: self.rescale_frames,
772
754
  RESCALE_PARAMS: _dict_to_str(self.rescale_params or {}),
773
755
  STITCHING_KERNELS_EXTRA_PARAMS: _dict_to_str(self.stitching_kernels_extra_params or {}),
756
+ AVOID_DATA_DUPLICATION_FIELD: not self.duplicate_data,
774
757
  },
775
758
  OUTPUT_SECTION: {
776
759
  OVERWRITE_RESULTS_FIELD: int(
@@ -781,20 +764,39 @@ class StitchingConfiguration:
781
764
  }
782
765
 
783
766
 
767
+ class SingleAxisConfigMetaClass(type):
768
+ """
769
+ Metaclass for single axis stitcher in order to aggregate dumper class and axis
770
+
771
+ warning: this class is used by tomwer as well
772
+ """
773
+
774
+ def __new__(mcls, name, bases, attrs, axis=None):
775
+ # assert axis is not None
776
+ mcls = super().__new__(mcls, name, bases, attrs)
777
+ mcls._axis = axis
778
+ return mcls
779
+
780
+
784
781
  @dataclass
785
- class ZStitchingConfiguration(StitchingConfiguration):
782
+ class SingleAxisStitchingConfiguration(StitchingConfiguration, metaclass=SingleAxisConfigMetaClass):
786
783
  """
787
784
  base class to define z-stitching parameters
788
785
  """
789
786
 
790
- slices: Union[
791
- slice, tuple, None
792
- ] = None # slices to reconstruct. Over axis 0 for pre-processing, over axis 1 for post-processing. If None will reconstruct all
787
+ slices: Union[slice, tuple, None] = (
788
+ None # slices to reconstruct. Over axis 0 for pre-processing, over axis 1 for post-processing. If None will reconstruct all
789
+ )
793
790
 
794
791
  alignment_axis_2: AlignmentAxis2 = AlignmentAxis2.CENTER
795
792
 
796
793
  pad_mode: str = "constant" # pad mode to be used for alignment
797
794
 
795
+ @property
796
+ def axis(self) -> int:
797
+ # self._axis is defined by the metaclass
798
+ return self._axis
799
+
798
800
  def settle_inputs(self) -> None:
799
801
  self.settle_slices()
800
802
 
@@ -826,7 +828,7 @@ class ZStitchingConfiguration(StitchingConfiguration):
826
828
 
827
829
 
828
830
  @dataclass
829
- class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
831
+ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfiguration):
830
832
  """
831
833
  base class to define z-stitching parameters
832
834
  """
@@ -839,7 +841,15 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
839
841
 
840
842
  @property
841
843
  def stitching_type(self) -> StitchingType:
842
- return StitchingType.Z_PREPROC
844
+ if self.axis == 0:
845
+ return StitchingType.Z_PREPROC
846
+ elif self.axis == 1:
847
+ return StitchingType.Y_PREPROC
848
+ else:
849
+ raise ValueError(
850
+ "unexpected axis value. Only stitching over axis 0 (aka z) and 1 (aka y) are handled. Current axis value is %s",
851
+ self.axis,
852
+ )
843
853
 
844
854
  def get_output_object(self):
845
855
  return NXtomoScan(
@@ -853,9 +863,11 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
853
863
 
854
864
  def settle_input_scans(self):
855
865
  self.input_scans = [
856
- Factory.create_tomo_object_from_identifier(identifier)
857
- if isinstance(identifier, (str, ScanIdentifier))
858
- else identifier
866
+ (
867
+ Factory.create_tomo_object_from_identifier(identifier)
868
+ if isinstance(identifier, (str, ScanIdentifier))
869
+ else identifier
870
+ )
859
871
  for identifier in self.input_scans
860
872
  ]
861
873
 
@@ -918,7 +930,7 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
918
930
  return concatenate_dict(
919
931
  super().to_dict(),
920
932
  {
921
- Z_PRE_PROC_SECTION: {
933
+ PRE_PROC_SECTION: {
922
934
  DATA_FILE_FIELD: self.output_file_path,
923
935
  DATA_PATH_FIELD: self.output_data_path,
924
936
  NEXUS_VERSION_FIELD: self.output_nexus_version,
@@ -935,9 +947,9 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
935
947
  @staticmethod
936
948
  def get_description_dict() -> dict:
937
949
  return concatenate_dict(
938
- ZStitchingConfiguration.get_description_dict(),
950
+ SingleAxisStitchingConfiguration.get_description_dict(),
939
951
  {
940
- Z_PRE_PROC_SECTION: {
952
+ PRE_PROC_SECTION: {
941
953
  DATA_FILE_FIELD: {
942
954
  "default": "",
943
955
  "help": "output nxtomo file path",
@@ -957,8 +969,8 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
957
969
  },
958
970
  )
959
971
 
960
- @staticmethod
961
- def from_dict(config: dict):
972
+ @classmethod
973
+ def from_dict(cls, config: dict):
962
974
  if not isinstance(config, dict):
963
975
  raise TypeError(f"config is expected to be a dict and not {type(config)}")
964
976
  inputs_scans_str = config.get(INPUTS_SECTION, {}).get(INPUT_DATASETS_FIELD, None)
@@ -967,9 +979,9 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
967
979
  else:
968
980
  input_scans = identifiers_as_str_to_instances(inputs_scans_str)
969
981
 
970
- output_file_path = config.get(Z_PRE_PROC_SECTION, {}).get(DATA_FILE_FIELD, None)
982
+ output_file_path = config.get(PRE_PROC_SECTION, {}).get(DATA_FILE_FIELD, None)
971
983
 
972
- nexus_version = config.get(Z_PRE_PROC_SECTION, {}).get(NEXUS_VERSION_FIELD, None)
984
+ nexus_version = config.get(PRE_PROC_SECTION, {}).get(NEXUS_VERSION_FIELD, None)
973
985
  if nexus_version in (None, ""):
974
986
  nexus_version = nxtomo.LATEST_VERSION
975
987
  else:
@@ -980,7 +992,7 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
980
992
  else:
981
993
  pixel_size = float(pixel_size) / MetricSystem.MM
982
994
 
983
- return PreProcessedZStitchingConfiguration(
995
+ return cls(
984
996
  stitching_strategy=OverlapStitchingStrategy.from_value(
985
997
  config[STITCHING_SECTION].get(
986
998
  STITCHING_STRATEGY_FIELD,
@@ -1006,7 +1018,7 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
1006
1018
  ),
1007
1019
  input_scans=input_scans,
1008
1020
  output_file_path=output_file_path,
1009
- output_data_path=config.get(Z_PRE_PROC_SECTION, {}).get(DATA_PATH_FIELD, "entry_from_stitchig"),
1021
+ output_data_path=config.get(PRE_PROC_SECTION, {}).get(DATA_PATH_FIELD, "entry_from_stitchig"),
1010
1022
  overwrite_results=config[STITCHING_SECTION].get(OVERWRITE_RESULTS_FIELD, True),
1011
1023
  output_nexus_version=nexus_version,
1012
1024
  slices=_slices_to_list_or_slice(config[INPUTS_SECTION].get(STITCHING_SLICES, None)),
@@ -1026,12 +1038,13 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration):
1026
1038
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1027
1039
  ),
1028
1040
  pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
1041
+ duplicate_data=not config[STITCHING_SECTION].get(AVOID_DATA_DUPLICATION_FIELD, False),
1029
1042
  normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})),
1030
1043
  )
1031
1044
 
1032
1045
 
1033
1046
  @dataclass
1034
- class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1047
+ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfiguration):
1035
1048
  """
1036
1049
  base class to define z-stitching parameters
1037
1050
  """
@@ -1043,7 +1056,10 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1043
1056
 
1044
1057
  @property
1045
1058
  def stitching_type(self) -> StitchingType:
1046
- return StitchingType.Z_POSTPROC
1059
+ if self.axis == 0:
1060
+ return StitchingType.Z_POSTPROC
1061
+ else:
1062
+ raise ValueError(f"unexpected axis value. Only stitching over axis 0 (aka z) is handled. Not {self.axis}")
1047
1063
 
1048
1064
  def get_output_object(self):
1049
1065
  return self.output_volume
@@ -1054,9 +1070,11 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1054
1070
 
1055
1071
  def settle_input_volumes(self):
1056
1072
  self.input_volumes = [
1057
- Factory.create_tomo_object_from_identifier(identifier)
1058
- if isinstance(identifier, (str, VolumeIdentifier))
1059
- else identifier
1073
+ (
1074
+ Factory.create_tomo_object_from_identifier(identifier)
1075
+ if isinstance(identifier, (str, VolumeIdentifier))
1076
+ else identifier
1077
+ )
1060
1078
  for identifier in self.input_volumes
1061
1079
  ]
1062
1080
 
@@ -1117,8 +1135,8 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1117
1135
  self.slices = slices
1118
1136
  return slices, n_slices
1119
1137
 
1120
- @staticmethod
1121
- def from_dict(config: dict):
1138
+ @classmethod
1139
+ def from_dict(cls, config: dict):
1122
1140
  if not isinstance(config, dict):
1123
1141
  raise TypeError(f"config is expected to be a dict and not {type(config)}")
1124
1142
  inputs_volumes_str = config.get(INPUTS_SECTION, {}).get(INPUT_DATASETS_FIELD, None)
@@ -1127,7 +1145,7 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1127
1145
  else:
1128
1146
  input_volumes = identifiers_as_str_to_instances(inputs_volumes_str)
1129
1147
  overwrite_results = config[STITCHING_SECTION].get(OVERWRITE_RESULTS_FIELD, True) in ("1", True, "True", 1)
1130
- output_volume = config.get(Z_POST_PROC_SECTION, {}).get(OUTPUT_VOLUME, None)
1148
+ output_volume = config.get(POST_PROC_SECTION, {}).get(OUTPUT_VOLUME, None)
1131
1149
  if output_volume is not None:
1132
1150
  output_volume = Factory.create_tomo_object_from_identifier(output_volume)
1133
1151
  output_volume.overwrite = overwrite_results
@@ -1138,8 +1156,8 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1138
1156
  else:
1139
1157
  voxel_size = float(voxel_size) * MetricSystem.MM
1140
1158
 
1141
- # on the next section the one with a default value qre the optionnal one
1142
- return PostProcessedZStitchingConfiguration(
1159
+ # on the next section the one with a default value qre the optional one
1160
+ return cls(
1143
1161
  stitching_strategy=OverlapStitchingStrategy.from_value(
1144
1162
  config[STITCHING_SECTION].get(
1145
1163
  STITCHING_STRATEGY_FIELD,
@@ -1178,6 +1196,7 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1178
1196
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1179
1197
  ),
1180
1198
  pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
1199
+ duplicate_data=not config[STITCHING_SECTION].get(AVOID_DATA_DUPLICATION_FIELD, False),
1181
1200
  normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})),
1182
1201
  )
1183
1202
 
@@ -1194,10 +1213,10 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1194
1213
  INPUT_DATASETS_FIELD: [volume.get_identifier().to_str() for volume in self.input_volumes],
1195
1214
  INPUT_VOXEL_SIZE_MM: voxel_size_mm,
1196
1215
  },
1197
- Z_POST_PROC_SECTION: {
1198
- OUTPUT_VOLUME: self.output_volume.get_identifier().to_str()
1199
- if self.output_volume is not None
1200
- else "",
1216
+ POST_PROC_SECTION: {
1217
+ OUTPUT_VOLUME: (
1218
+ self.output_volume.get_identifier().to_str() if self.output_volume is not None else ""
1219
+ ),
1201
1220
  },
1202
1221
  STITCHING_SECTION: {
1203
1222
  ALIGNMENT_AXIS_1_FIELD: self.alignment_axis_1.value,
@@ -1208,9 +1227,9 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration):
1208
1227
  @staticmethod
1209
1228
  def get_description_dict() -> dict:
1210
1229
  return concatenate_dict(
1211
- ZStitchingConfiguration.get_description_dict(),
1230
+ SingleAxisStitchingConfiguration.get_description_dict(),
1212
1231
  {
1213
- Z_POST_PROC_SECTION: {
1232
+ POST_PROC_SECTION: {
1214
1233
  OUTPUT_VOLUME: {
1215
1234
  "default": "",
1216
1235
  "help": "identifier of the output volume. Like hdf5:volume:[file_path]?path=[data_path] for an HDF5 volume",
@@ -1249,13 +1268,15 @@ def dict_to_config_obj(config: dict):
1249
1268
  raise TypeError
1250
1269
  stitching_type = config.get(STITCHING_SECTION, {}).get(STITCHING_TYPE_FIELD, None)
1251
1270
  if stitching_type is None:
1252
- raise ValueError("Unagle to find stitching type from config dict")
1271
+ raise ValueError("Unable to find stitching type from config dict")
1253
1272
  else:
1254
1273
  stitching_type = StitchingType.from_value(stitching_type)
1255
1274
  if stitching_type is StitchingType.Z_POSTPROC:
1256
1275
  return PostProcessedZStitchingConfiguration.from_dict(config)
1257
1276
  elif stitching_type is StitchingType.Z_PREPROC:
1258
1277
  return PreProcessedZStitchingConfiguration.from_dict(config)
1278
+ elif stitching_type is StitchingType.Y_PREPROC:
1279
+ return PreProcessedYStitchingConfiguration.from_dict(config)
1259
1280
  else:
1260
1281
  raise NotImplementedError(f"stitching type {stitching_type.value} not handled yet")
1261
1282
 
@@ -1277,10 +1298,26 @@ def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, s
1277
1298
  return z_postproc_stitching_config
1278
1299
  elif stitching_type is StitchingType.Z_PREPROC:
1279
1300
  return z_preproc_stitching_config
1301
+ elif stitching_type is StitchingType.Y_PREPROC:
1302
+ return y_preproc_stitching_config
1280
1303
  else:
1281
1304
  raise NotImplementedError
1282
1305
 
1283
1306
 
1307
+ class PreProcessedYStitchingConfiguration(PreProcessedSingleAxisStitchingConfiguration, axis=1):
1308
+ pass
1309
+
1310
+
1311
+ class PreProcessedZStitchingConfiguration(PreProcessedSingleAxisStitchingConfiguration, axis=0):
1312
+ pass
1313
+
1314
+
1315
+ class PostProcessedZStitchingConfiguration(PostProcessedSingleAxisStitchingConfiguration, axis=0):
1316
+ pass
1317
+
1318
+
1319
+ y_preproc_stitching_config = PreProcessedYStitchingConfiguration.get_description_dict()
1320
+
1284
1321
  z_preproc_stitching_config = PreProcessedZStitchingConfiguration.get_description_dict()
1285
1322
 
1286
1323
  z_postproc_stitching_config = PostProcessedZStitchingConfiguration.get_description_dict()
@@ -2,5 +2,6 @@ from silx.utils.enum import Enum as _Enum
2
2
 
3
3
 
4
4
  class StitchingType(_Enum):
5
+ Y_PREPROC = "y-preproc"
5
6
  Z_PREPROC = "z-preproc"
6
7
  Z_POSTPROC = "z-postproc"