nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  15. nabu/cuda/kernel.py +11 -2
  16. nabu/cuda/processing.py +2 -2
  17. nabu/cuda/src/cone.cu +77 -0
  18. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  19. nabu/cuda/utils.py +0 -6
  20. nabu/estimation/alignment.py +5 -19
  21. nabu/estimation/cor.py +173 -599
  22. nabu/estimation/cor_sino.py +356 -26
  23. nabu/estimation/focus.py +63 -11
  24. nabu/estimation/tests/test_cor.py +124 -58
  25. nabu/estimation/tests/test_focus.py +6 -6
  26. nabu/estimation/tilt.py +2 -1
  27. nabu/estimation/utils.py +5 -33
  28. nabu/io/__init__.py +1 -1
  29. nabu/io/cast_volume.py +1 -1
  30. nabu/io/reader.py +416 -21
  31. nabu/io/tests/test_readers.py +422 -0
  32. nabu/io/tests/test_writers.py +1 -102
  33. nabu/io/writer.py +4 -433
  34. nabu/opencl/kernel.py +14 -3
  35. nabu/opencl/processing.py +8 -0
  36. nabu/pipeline/config_validators.py +5 -2
  37. nabu/pipeline/datadump.py +12 -5
  38. nabu/pipeline/estimators.py +162 -188
  39. nabu/pipeline/fullfield/chunked.py +168 -92
  40. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  41. nabu/pipeline/fullfield/computations.py +2 -7
  42. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  43. nabu/pipeline/fullfield/nabu_config.py +37 -13
  44. nabu/pipeline/fullfield/processconfig.py +22 -13
  45. nabu/pipeline/fullfield/reconstruction.py +13 -9
  46. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  47. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  48. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  49. nabu/pipeline/params.py +21 -1
  50. nabu/pipeline/processconfig.py +1 -12
  51. nabu/pipeline/reader.py +146 -0
  52. nabu/pipeline/tests/test_estimators.py +44 -72
  53. nabu/pipeline/utils.py +4 -2
  54. nabu/pipeline/writer.py +10 -2
  55. nabu/preproc/ccd_cuda.py +1 -1
  56. nabu/preproc/ctf.py +14 -7
  57. nabu/preproc/ctf_cuda.py +2 -3
  58. nabu/preproc/double_flatfield.py +5 -12
  59. nabu/preproc/double_flatfield_cuda.py +2 -2
  60. nabu/preproc/flatfield.py +5 -1
  61. nabu/preproc/flatfield_cuda.py +5 -1
  62. nabu/preproc/phase.py +24 -73
  63. nabu/preproc/phase_cuda.py +5 -8
  64. nabu/preproc/tests/test_ctf.py +11 -7
  65. nabu/preproc/tests/test_flatfield.py +67 -122
  66. nabu/preproc/tests/test_paganin.py +54 -30
  67. nabu/processing/azim.py +206 -0
  68. nabu/processing/convolution_cuda.py +1 -1
  69. nabu/processing/fft_cuda.py +15 -17
  70. nabu/processing/histogram.py +2 -0
  71. nabu/processing/histogram_cuda.py +2 -1
  72. nabu/processing/kernel_base.py +3 -0
  73. nabu/processing/muladd_cuda.py +1 -0
  74. nabu/processing/padding_opencl.py +1 -1
  75. nabu/processing/roll_opencl.py +1 -0
  76. nabu/processing/rotation_cuda.py +2 -2
  77. nabu/processing/tests/test_fft.py +17 -10
  78. nabu/processing/unsharp_cuda.py +1 -1
  79. nabu/reconstruction/cone.py +104 -40
  80. nabu/reconstruction/fbp.py +3 -0
  81. nabu/reconstruction/fbp_base.py +7 -2
  82. nabu/reconstruction/filtering.py +20 -7
  83. nabu/reconstruction/filtering_cuda.py +7 -1
  84. nabu/reconstruction/hbp.py +424 -0
  85. nabu/reconstruction/mlem.py +99 -0
  86. nabu/reconstruction/reconstructor.py +2 -0
  87. nabu/reconstruction/rings_cuda.py +19 -19
  88. nabu/reconstruction/sinogram_cuda.py +1 -0
  89. nabu/reconstruction/sinogram_opencl.py +3 -1
  90. nabu/reconstruction/tests/test_cone.py +10 -5
  91. nabu/reconstruction/tests/test_deringer.py +7 -6
  92. nabu/reconstruction/tests/test_fbp.py +124 -10
  93. nabu/reconstruction/tests/test_filtering.py +13 -11
  94. nabu/reconstruction/tests/test_halftomo.py +30 -4
  95. nabu/reconstruction/tests/test_mlem.py +91 -0
  96. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  97. nabu/resources/dataset_analyzer.py +142 -92
  98. nabu/resources/gpu.py +1 -0
  99. nabu/resources/nxflatfield.py +134 -125
  100. nabu/resources/templates/id16a_fluo.conf +42 -0
  101. nabu/resources/tests/test_extract.py +10 -0
  102. nabu/resources/tests/test_nxflatfield.py +2 -2
  103. nabu/stitching/alignment.py +80 -24
  104. nabu/stitching/config.py +105 -68
  105. nabu/stitching/definitions.py +1 -0
  106. nabu/stitching/frame_composition.py +68 -60
  107. nabu/stitching/overlap.py +91 -51
  108. nabu/stitching/single_axis_stitching.py +32 -0
  109. nabu/stitching/slurm_utils.py +6 -6
  110. nabu/stitching/stitcher/__init__.py +0 -0
  111. nabu/stitching/stitcher/base.py +124 -0
  112. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  113. nabu/stitching/stitcher/dumper/base.py +94 -0
  114. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  115. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  116. nabu/stitching/stitcher/post_processing.py +555 -0
  117. nabu/stitching/stitcher/pre_processing.py +1068 -0
  118. nabu/stitching/stitcher/single_axis.py +484 -0
  119. nabu/stitching/stitcher/stitcher.py +0 -0
  120. nabu/stitching/stitcher/y_stitcher.py +13 -0
  121. nabu/stitching/stitcher/z_stitcher.py +45 -0
  122. nabu/stitching/stitcher_2D.py +278 -0
  123. nabu/stitching/tests/test_config.py +12 -37
  124. nabu/stitching/tests/test_frame_composition.py +33 -59
  125. nabu/stitching/tests/test_overlap.py +149 -7
  126. nabu/stitching/tests/test_utils.py +1 -1
  127. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  128. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  129. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  130. nabu/stitching/utils/__init__.py +1 -0
  131. nabu/stitching/utils/post_processing.py +281 -0
  132. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  133. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  134. nabu/stitching/y_stitching.py +27 -0
  135. nabu/stitching/z_stitching.py +32 -2263
  136. nabu/testutils.py +1 -152
  137. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  138. nabu/utils.py +158 -61
  139. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
  140. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
  141. nabu/io/tiffwriter_zmm.py +0 -99
  142. nabu/pipeline/fallback_utils.py +0 -149
  143. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  144. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  145. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  146. nabu/pipeline/helical/utils.py +0 -51
  147. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  148. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  149. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
  150. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  151. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,281 @@
1
+ import os
2
+ import logging
3
+ from typing import Optional, Union
4
+ from nabu import version as nabu_version
5
+ from nabu.stitching.config import (
6
+ PreProcessedSingleAxisStitchingConfiguration,
7
+ PostProcessedSingleAxisStitchingConfiguration,
8
+ SingleAxisStitchingConfiguration,
9
+ )
10
+ from nabu.stitching.stitcher.single_axis import PROGRESS_BAR_STITCH_VOL_DESC
11
+ from nabu.io.writer import get_datetime
12
+ from tomoscan.factory import Factory as TomoscanFactory
13
+ from silx.io.dictdump import dicttonx
14
+ from nxtomo.application.nxtomo import NXtomo
15
+ from tomoscan.utils.volume import concatenate as concatenate_volumes
16
+ from tomoscan.esrf.volume import HDF5Volume
17
+ from contextlib import AbstractContextManager
18
+ from threading import Thread
19
+ from time import sleep
20
+
21
+ _logger = logging.getLogger(__name__)
22
+
23
+
24
+ class StitchingPostProcAggregation:
25
+ """
26
+ for remote stitching each process will stitch a part of the volume or projections.
27
+ Then once all are finished we want to aggregate them all to a final volume or NXtomo.
28
+
29
+ This is the goal of this class.
30
+ Please be careful with API. This is already inheriting from a tomwer class
31
+
32
+ :param stitching_config: configuration of the stitching configuration
33
+ :param futures: futures that just run
34
+ :param existing_objs: futures that just run
35
+ :param progress_bars: tqdm progress bars for each jobs
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ stitching_config: SingleAxisStitchingConfiguration,
41
+ futures: Optional[tuple] = None,
42
+ existing_objs_ids: Optional[tuple] = None,
43
+ progress_bars: Optional[dict] = None,
44
+ ) -> None:
45
+ if not isinstance(stitching_config, (SingleAxisStitchingConfiguration)):
46
+ raise TypeError(f"stitching_config should be an instance of {SingleAxisStitchingConfiguration}")
47
+ if not ((existing_objs_ids is None) ^ (futures is None)):
48
+ raise ValueError("Either existing_objs or futures should be provided (can't provide both)")
49
+ if progress_bars is not None and not isinstance(progress_bars, dict):
50
+ raise TypeError(f"'progress_bars' should be None or an instance of a dict. Got {type(progress_bars)}")
51
+ self._futures = futures
52
+ self._stitching_config = stitching_config
53
+ self._existing_objs_ids = existing_objs_ids
54
+ self._progress_bars = progress_bars or {}
55
+
56
+ @property
57
+ def futures(self):
58
+ return self._futures
59
+
60
+ @property
61
+ def progress_bars(self) -> dict:
62
+ return self._progress_bars
63
+
64
+ def retrieve_tomo_objects(self) -> tuple:
65
+ """
66
+ Return tomo objects to be stitched together. Either from future or from existing_objs
67
+ """
68
+ if self._existing_objs_ids is not None:
69
+ scan_ids = self._existing_objs_ids
70
+ else:
71
+ results = {}
72
+ _logger.info(
73
+ f"wait for slurm job to be completed. Advancement will be created once slurm job output file will be available"
74
+ )
75
+ for obj_id, future in self.futures.items():
76
+ results[obj_id] = future.result()
77
+
78
+ failed = tuple(
79
+ filter(
80
+ lambda x: x.exception() is not None,
81
+ self.futures.values(),
82
+ )
83
+ )
84
+ if len(failed) > 0:
85
+ # if some job failed: useless to do the concatenation
86
+ exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed])
87
+ raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}")
88
+
89
+ canceled = tuple(
90
+ filter(
91
+ lambda x: x.cancelled(),
92
+ self.futures.values(),
93
+ )
94
+ )
95
+ if len(canceled) > 0:
96
+ # if some job canceled: useless to do the concatenation
97
+ raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}")
98
+ scan_ids = results.keys()
99
+ return [TomoscanFactory.create_tomo_object_from_identifier(scan_id) for scan_id in scan_ids]
100
+
101
+ def dump_stitching_config_as_nx_process(self, file_path: str, data_path: str, overwrite: bool, process_name: str):
102
+ dict_to_dump = {
103
+ process_name: {
104
+ "config": self._stitching_config.to_dict(),
105
+ "program": "nabu-stitching",
106
+ "version": nabu_version,
107
+ "date": get_datetime(),
108
+ },
109
+ f"{process_name}@NX_class": "NXprocess",
110
+ }
111
+
112
+ dicttonx(
113
+ dict_to_dump,
114
+ h5file=file_path,
115
+ h5path=data_path,
116
+ update_mode="replace" if overwrite else "add",
117
+ mode="a",
118
+ )
119
+
120
+ @property
121
+ def stitching_config(self) -> SingleAxisStitchingConfiguration:
122
+ return self._stitching_config
123
+
124
+ def process(self) -> None:
125
+ """
126
+ main function
127
+ """
128
+
129
+ # concatenate result
130
+ _logger.info("all job succeeded. Concatenate results")
131
+ if isinstance(self._stitching_config, PreProcessedSingleAxisStitchingConfiguration):
132
+ # 1: case of a pre-processing stitching
133
+ with self.follow_progress():
134
+ scans = self.retrieve_tomo_objects()
135
+ nx_tomos = []
136
+ for scan in scans:
137
+ if not os.path.exists(scan.master_file):
138
+ raise RuntimeError(
139
+ f"output file not created ({scan.master_file}). Stitching failed. "
140
+ "Please check slurm .out files to have more information. Most likely the slurm configuration is invalid. "
141
+ "(partition name not existing...)"
142
+ )
143
+ nx_tomos.append(
144
+ NXtomo().load(
145
+ file_path=scan.master_file,
146
+ data_path=scan.entry,
147
+ )
148
+ )
149
+ final_nx_tomo = NXtomo.concatenate(nx_tomos)
150
+ final_nx_tomo.save(
151
+ file_path=self.stitching_config.output_file_path,
152
+ data_path=self.stitching_config.output_data_path,
153
+ overwrite=self.stitching_config.overwrite_results,
154
+ )
155
+
156
+ # dump NXprocess if possible
157
+ parts = self.stitching_config.output_data_path.split("/")
158
+ process_name = parts[-1] + "_stitching"
159
+ if len(parts) < 2:
160
+ data_path = "/"
161
+ else:
162
+ data_path = "/".join(parts[:-1])
163
+
164
+ self.dump_stitching_config_as_nx_process(
165
+ file_path=self.stitching_config.output_file_path,
166
+ data_path=data_path,
167
+ process_name=process_name,
168
+ overwrite=self.stitching_config.overwrite_results,
169
+ )
170
+
171
+ elif isinstance(self.stitching_config, PostProcessedSingleAxisStitchingConfiguration):
172
+ # 2: case of a post-processing stitching
173
+ with self.follow_progress():
174
+ outputs_sub_volumes = self.retrieve_tomo_objects()
175
+ concatenate_volumes(
176
+ output_volume=self.stitching_config.output_volume,
177
+ volumes=tuple(outputs_sub_volumes),
178
+ axis=1,
179
+ )
180
+
181
+ if isinstance(self.stitching_config.output_volume, HDF5Volume):
182
+ parts = self.stitching_config.output_volume.metadata_url.data_path().split("/")
183
+ process_name = parts[-1] + "_stitching"
184
+ if len(parts) < 2:
185
+ data_path = "/"
186
+ else:
187
+ data_path = "/".join(parts[:-1])
188
+
189
+ self.dump_stitching_config_as_nx_process(
190
+ file_path=self.stitching_config.output_volume.metadata_url.file_path(),
191
+ data_path=data_path,
192
+ process_name=process_name,
193
+ overwrite=self.stitching_config.overwrite_results,
194
+ )
195
+ else:
196
+ raise TypeError(f"stitching_config type ({type(self.stitching_config)}) not handled")
197
+
198
+ def follow_progress(self) -> AbstractContextManager:
199
+ return SlurmStitchingFollowerContext(
200
+ output_files_to_progress_bars={
201
+ job._get_output_file_path(): progress_bar for (job, progress_bar) in self.progress_bars.items()
202
+ }
203
+ )
204
+
205
+
206
+ class SlurmStitchingFollowerContext(AbstractContextManager):
207
+ """Util class to provide user feedback from stitching done on slurm"""
208
+
209
+ def __init__(self, output_files_to_progress_bars: dict):
210
+ self._update_thread = SlurmStitchingFollowerThread(file_to_progress_bar=output_files_to_progress_bars)
211
+
212
+ def __enter__(self) -> None:
213
+ self._update_thread.start()
214
+
215
+ def __exit__(self, *args, **kwargs):
216
+ self._update_thread.join(timeout=1.5)
217
+ for progress_bar in self._update_thread.file_to_progress_bar.values():
218
+ progress_bar.close() # close to clean display as leave == False
219
+
220
+
221
+ class SlurmStitchingFollowerThread(Thread):
222
+ """
223
+ Thread to check progression of stitching slurm job(s)
224
+ Read slurm jobs .out file each 'delay time' and look for a tqdm line at the end.
225
+ If it exists then deduce progress from it.
226
+
227
+ file_to_progress_bar provide for each slurm .out file the progress bar to update
228
+ """
229
+
230
+ def __init__(self, file_to_progress_bar: dict, delay_time: float = 0.5) -> None:
231
+ super().__init__()
232
+ self._stop_run = False
233
+ self._wait_time = delay_time
234
+ self._file_to_progress_bar = file_to_progress_bar
235
+ self._first_run = True
236
+
237
+ @property
238
+ def file_to_progress_bar(self) -> dict:
239
+ return self._file_to_progress_bar
240
+
241
+ def run(self) -> None:
242
+ while not self._stop_run:
243
+ for file_path, progress_bar in self._file_to_progress_bar.items():
244
+ if self._first_run:
245
+ # make sure each progress bar have been refreshed at least one
246
+ progress_bar.refresh()
247
+
248
+ if not os.path.exists(file_path):
249
+ continue
250
+ with open(file_path, "r") as f:
251
+ try:
252
+ last_line = f.readlines()[-1]
253
+ except IndexError:
254
+ continue
255
+ advancement = self.cast_progress_line_from_log(line=last_line)
256
+ if advancement is not None:
257
+ progress_bar.n = advancement
258
+ progress_bar.refresh()
259
+
260
+ self._first_run = False
261
+
262
+ sleep(self._wait_time)
263
+
264
+ def join(self, timeout: Union[float, None] = None) -> None:
265
+ self._stop_run = True
266
+ return super().join(timeout)
267
+
268
+ @staticmethod
269
+ def cast_progress_line_from_log(line: str) -> Optional[float]:
270
+ """Try to retrieve from a line from log the advancement (in percentage)"""
271
+ if PROGRESS_BAR_STITCH_VOL_DESC not in line or "%" not in line:
272
+ return None
273
+
274
+ str_before_percentage = line.split("%")[0].split(" ")[-1]
275
+ try:
276
+ advancement = float(str_before_percentage)
277
+ except ValueError:
278
+ _logger.debug(f"Failed to retrieve advancement from log file. Value got is {str_before_percentage}")
279
+ return None
280
+ else:
281
+ return advancement
@@ -0,0 +1,21 @@
1
+ import pytest
2
+
3
+ from nabu.stitching.stitcher.single_axis import PROGRESS_BAR_STITCH_VOL_DESC
4
+ from nabu.stitching.utils.post_processing import SlurmStitchingFollowerThread
5
+
6
+
7
+ @pytest.mark.parametrize(
8
+ "test_case",
9
+ {
10
+ "dump configuration: 100%|": None,
11
+ f"stitching : 100%|": None,
12
+ f"{PROGRESS_BAR_STITCH_VOL_DESC}: 42%": 42.0,
13
+ f"{PROGRESS_BAR_STITCH_VOL_DESC}: 56% toto: 23%": 56.0,
14
+ "": None,
15
+ "my%": None,
16
+ }.items(),
17
+ )
18
+ def test_SlurmStitchingFollowerContext(test_case):
19
+ """Test that the conversion from log lines created by tqdm can be read back"""
20
+ str_to_test, expected_result = test_case
21
+ assert SlurmStitchingFollowerThread.cast_progress_line_from_log(str_to_test) == expected_result
@@ -3,22 +3,18 @@ from typing import Optional, Union
3
3
  import logging
4
4
  import functools
5
5
  import numpy
6
- from scipy.ndimage import affine_transform
7
6
  from tomoscan.scanbase import TomoScanBase
8
7
  from tomoscan.volumebase import VolumeBase
9
- from nxtomo.utils.transformation import build_matrix, UDDetTransformation
8
+ from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation
10
9
  from silx.utils.enum import Enum as _Enum
11
10
  from scipy.fft import rfftn as local_fftn
12
11
  from scipy.fft import irfftn as local_ifftn
13
- from silx.utils.enum import Enum as _Enum
14
- from nxtomo.utils.transformation import build_matrix, UDDetTransformation
15
- from tomoscan.scanbase import TomoScanBase
16
- from .overlap import OverlapStitchingStrategy, ZStichOverlapKernel
17
- from .alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
18
- from ..misc import fourier_filters
19
- from ..estimation.alignment import AlignmentBase
20
- from ..resources.dataset_analyzer import HDF5DatasetAnalyzer
21
- from ..resources.nxflatfield import update_dataset_info_flats_darks
12
+ from ..overlap import OverlapStitchingStrategy, ImageStichOverlapKernel
13
+ from ..alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
14
+ from ...misc import fourier_filters
15
+ from ...estimation.alignment import AlignmentBase
16
+ from ...resources.dataset_analyzer import HDF5DatasetAnalyzer
17
+ from ...resources.nxflatfield import update_dataset_info_flats_darks
22
18
 
23
19
  try:
24
20
  import itk
@@ -66,39 +62,22 @@ class ShiftAlgorithm(_Enum):
66
62
  return super().from_value(value=value)
67
63
 
68
64
 
69
- def test_overlap_stitching_strategy(overlap_1, overlap_2, stitching_strategies):
70
- """
71
- stitch the two ovrelap with all the requested strategies.
72
- Return a dictionary with stitching strategy as key and a result dict as value.
73
- result dict keys are: 'weights_overlap_1', 'weights_overlap_2', 'stiching'
74
- """
75
- res = {}
76
- for strategy in stitching_strategies:
77
- s = OverlapStitchingStrategy.from_value(strategy)
78
- stitcher = ZStichOverlapKernel(
79
- stitching_strategy=s,
80
- frame_width=overlap_1.shape[1],
81
- )
82
- stiched_overlap, w1, w2 = stitcher.stitch(overlap_1, overlap_2, check_input=True)
83
- res[s.value] = {
84
- "stitching": stiched_overlap,
85
- "weights_overlap_1": w1,
86
- "weights_overlap_2": w2,
87
- }
88
- return res
89
-
90
-
91
65
  def find_frame_relative_shifts(
92
66
  overlap_upper_frame: numpy.ndarray,
93
67
  overlap_lower_frame: numpy.ndarray,
94
- estimated_shifts,
68
+ estimated_shifts: tuple,
69
+ overlap_axis: int,
95
70
  x_cross_correlation_function=None,
96
71
  y_cross_correlation_function=None,
97
72
  x_shifts_params: Optional[dict] = None,
98
73
  y_shifts_params: Optional[dict] = None,
99
74
  ):
75
+ """
76
+ :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x
77
+ """
78
+ if not overlap_axis in (0, 1):
79
+ raise ValueError(f"overlap_axis should be in (0, 1). Get {overlap_axis}")
100
80
  from nabu.stitching.config import (
101
- KEY_WINDOW_SIZE,
102
81
  KEY_LOW_PASS_FILTER,
103
82
  KEY_HIGH_PASS_FILTER,
104
83
  ) # avoid cyclic import
@@ -188,6 +167,7 @@ def find_frame_relative_shifts(
188
167
  def find_volumes_relative_shifts(
189
168
  upper_volume: VolumeBase,
190
169
  lower_volume: VolumeBase,
170
+ overlap_axis: int,
191
171
  estimated_shifts,
192
172
  dim_axis_1: int,
193
173
  dtype,
@@ -210,6 +190,15 @@ def find_volumes_relative_shifts(
210
190
 
211
191
  if x_shifts_params is None:
212
192
  x_shifts_params = {}
193
+ # convert from overlap_axis (3D acquisition space) to overlap_axis_proj_space.
194
+ if overlap_axis == 1:
195
+ raise NotImplementedError("finding projection shift along axis 1 is not handled for projections")
196
+ elif overlap_axis == 0:
197
+ overlap_axis_proj_space = 0
198
+ elif overlap_axis == 2:
199
+ overlap_axis_proj_space = 1
200
+ else:
201
+ raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {overlap_axis}")
213
202
 
214
203
  alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2)
215
204
  alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1)
@@ -299,6 +288,7 @@ def find_volumes_relative_shifts(
299
288
  y_cross_correlation_function=y_cross_correlation_function,
300
289
  x_shifts_params=x_shifts_params,
301
290
  y_shifts_params=y_shifts_params,
291
+ overlap_axis=overlap_axis_proj_space,
302
292
  )
303
293
 
304
294
 
@@ -308,7 +298,8 @@ from nabu.pipeline.estimators import estimate_cor
308
298
  def find_projections_relative_shifts(
309
299
  upper_scan: TomoScanBase,
310
300
  lower_scan: TomoScanBase,
311
- estimated_shifts,
301
+ estimated_shifts: tuple,
302
+ axis: int,
312
303
  flip_ud_upper_frame: bool = False,
313
304
  flip_ud_lower_frame: bool = False,
314
305
  projection_for_shift: Union[int, str] = "middle",
@@ -326,13 +317,16 @@ def find_projections_relative_shifts(
326
317
 
327
318
  :param TomoScanBase scan_0:
328
319
  :param TomoScanBase scan_1:
329
- :param int axis_0_overlap_px: overlap between the two scans in pixel
320
+ :param tuple estimated_shifts: 'a priori' shift estimation
321
+ :param int axis: axis on which the overlap / stitching is happening. In the 3D space (sample, detector referential)
322
+ :param bool flip_ud_upper_frame: is the upper frame flipped
323
+ :param bool flip_ud_lower_frame: is the lower frame flipped
330
324
  :param Union[int,str] projection_for_shift: index fo the projection to use (in projection space or in scan space ?. For now in projection) or str. If str must be in (`middle`, `first`, `last`)
325
+ :param bool invert_order: are projections inverted between the two scans (case if rotation angle are inverted)
331
326
  :param str x_cross_correlation_function: optional method to refine x shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft")
332
327
  :param str y_cross_correlation_function: optional method to refine y shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft")
333
- :param int minimal_overlap_area_for_cross_correlation: if first approximated overlap shift found from z_translation is lower than this value will fall back on taking the full image for the cross correlation and log a warning
334
- :param bool invert_order: are projections inverted between the two scans (case if rotation angle are inverted)
335
- :param tuple estimated_shifts: 'a priori' shift estimation
328
+ :param x_shifts_params: parameters to find the shift over x
329
+ :param y_shifts_params: parameters to find the shift over y
336
330
  :return: relative shift of scan_1 with scan_0 as reference: (y_shift, x_shift)
337
331
  :rtype: tuple
338
332
 
@@ -342,8 +336,18 @@ def find_projections_relative_shifts(
342
336
  x_shifts_params = {}
343
337
  if y_shifts_params is None:
344
338
  y_shifts_params = {}
345
- if estimated_shifts[0] < 0:
346
- raise ValueError("y_overlap_px is expected to be stricktly positive")
339
+
340
+ # convert from overlap_axis (3D acquisition space) to overlap_axis_proj_space.
341
+ if axis == 1:
342
+ axis_proj_space = 1
343
+ elif axis == 0:
344
+ axis_proj_space = 0
345
+ elif axis == 2:
346
+ raise NotImplementedError(
347
+ "finding projection shift along axis 1 (x-ray direction) is not handled for projections"
348
+ )
349
+ else:
350
+ raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {axis}")
347
351
 
348
352
  x_cross_correlation_function = ShiftAlgorithm.from_value(x_cross_correlation_function)
349
353
  y_cross_correlation_function = ShiftAlgorithm.from_value(y_cross_correlation_function)
@@ -432,11 +436,11 @@ def find_projections_relative_shifts(
432
436
 
433
437
  upper_scan_transformations = list(upper_scan.get_detector_transformations(tuple()))
434
438
  if flip_ud_upper_frame:
435
- upper_scan_transformations.append(UDDetTransformation())
439
+ upper_scan_transformations.append(DetYFlipTransformation(flip=True))
436
440
  upper_scan_trans_matrix = build_matrix(upper_scan_transformations)
437
441
  lower_scan_transformations = list(lower_scan.get_detector_transformations(tuple()))
438
442
  if flip_ud_lower_frame:
439
- lower_scan_transformations.append(UDDetTransformation())
443
+ lower_scan_transformations.append(DetYFlipTransformation(flip=True))
440
444
  lower_scan_trans_matrix = build_matrix(lower_scan_transformations)
441
445
  upper_proj = get_flat_fielded_proj(
442
446
  upper_scan,
@@ -453,16 +457,30 @@ def find_projections_relative_shifts(
453
457
 
454
458
  from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
455
459
 
456
- w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
457
- start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
458
- end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_proj.shape[0], lower_proj.shape[0]))
459
- if start_overlap == 0:
460
- overlap_upper_frame = upper_proj[-end_overlap:]
460
+ if axis_proj_space == 0:
461
+ w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
461
462
  else:
462
- overlap_upper_frame = upper_proj[-end_overlap:-start_overlap]
463
- overlap_lower_frame = lower_proj[start_overlap:end_overlap]
463
+ w_window_size = int(x_shifts_params.get(KEY_WINDOW_SIZE, 400))
464
+ start_overlap = max(estimated_shifts[axis_proj_space] // 2 - w_window_size // 2, 0)
465
+ end_overlap = min(
466
+ estimated_shifts[axis_proj_space] // 2 + w_window_size // 2,
467
+ min(upper_proj.shape[axis_proj_space], lower_proj.shape[axis_proj_space]),
468
+ )
469
+ o_upper_sel = numpy.array(range(-end_overlap, -start_overlap))
470
+ overlap_upper_frame = numpy.take_along_axis(
471
+ upper_proj,
472
+ o_upper_sel[:, None] if axis_proj_space == 0 else o_upper_sel[None, :],
473
+ axis=axis_proj_space,
474
+ )
475
+ o_lower_sel = numpy.array(range(start_overlap, end_overlap))
476
+ overlap_lower_frame = numpy.take_along_axis(
477
+ lower_proj,
478
+ o_lower_sel[:, None] if axis_proj_space == 0 else o_upper_sel[None, :],
479
+ axis=axis_proj_space,
480
+ )
481
+
464
482
  if not overlap_upper_frame.shape == overlap_lower_frame.shape:
465
- raise ValueError(f"Fail to get consistant overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})")
483
+ raise ValueError(f"Fail to get consistent overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})")
466
484
 
467
485
  return find_frame_relative_shifts(
468
486
  overlap_upper_frame=overlap_upper_frame,
@@ -472,6 +490,7 @@ def find_projections_relative_shifts(
472
490
  y_cross_correlation_function=y_cross_correlation_function,
473
491
  x_shifts_params=x_shifts_params,
474
492
  y_shifts_params=y_shifts_params,
493
+ overlap_axis=axis_proj_space,
475
494
  )
476
495
 
477
496
 
@@ -561,3 +580,11 @@ def find_shift_with_itk(img1: numpy.ndarray, img2: numpy.ndarray) -> tuple:
561
580
  translation_along_y = final_parameters.GetElement(1)
562
581
 
563
582
  return numpy.round(translation_along_y), numpy.round(translation_along_x)
583
+
584
+
585
+ def from_slice_to_n_elements(slice_: Union[slice, tuple]):
586
+ """Return the number of element in a slice or in a tuple"""
587
+ if isinstance(slice_, slice):
588
+ return (slice_.stop - slice_.start) / (slice_.step or 1)
589
+ else:
590
+ return len(slice_)
@@ -0,0 +1,27 @@
1
+ from tomoscan.identifier import BaseIdentifier
2
+ from nabu.stitching.stitcher.y_stitcher import PreProcessingYStitcher as PreProcessYStitcher
3
+ from nabu.stitching.config import PreProcessedYStitchingConfiguration
4
+
5
+
6
+ def y_stitching(configuration: PreProcessedYStitchingConfiguration, progress=None) -> BaseIdentifier:
7
+ """
8
+ Apply stitching from provided configuration.
9
+ Stitching will be applied along the first axis - 1 (aka y).
10
+
11
+ like:
12
+ axis 0
13
+ ^
14
+ |
15
+ x-ray |
16
+ --------> ------> axis 2
17
+ /
18
+ /
19
+ axis 1
20
+ """
21
+ if isinstance(configuration, PreProcessedYStitchingConfiguration):
22
+ stitcher = PreProcessYStitcher(configuration=configuration, progress=progress)
23
+ else:
24
+ raise TypeError(
25
+ f"configuration is expected to be in {(PreProcessedYStitchingConfiguration, )}. {type(configuration)} provided"
26
+ )
27
+ return stitcher.stitch()