nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  15. nabu/cuda/kernel.py +11 -2
  16. nabu/cuda/processing.py +2 -2
  17. nabu/cuda/src/cone.cu +77 -0
  18. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  19. nabu/cuda/utils.py +0 -6
  20. nabu/estimation/alignment.py +5 -19
  21. nabu/estimation/cor.py +173 -599
  22. nabu/estimation/cor_sino.py +356 -26
  23. nabu/estimation/focus.py +63 -11
  24. nabu/estimation/tests/test_cor.py +124 -58
  25. nabu/estimation/tests/test_focus.py +6 -6
  26. nabu/estimation/tilt.py +2 -1
  27. nabu/estimation/utils.py +5 -33
  28. nabu/io/__init__.py +1 -1
  29. nabu/io/cast_volume.py +1 -1
  30. nabu/io/reader.py +416 -21
  31. nabu/io/tests/test_readers.py +422 -0
  32. nabu/io/tests/test_writers.py +1 -102
  33. nabu/io/writer.py +4 -433
  34. nabu/opencl/kernel.py +14 -3
  35. nabu/opencl/processing.py +8 -0
  36. nabu/pipeline/config_validators.py +5 -2
  37. nabu/pipeline/datadump.py +12 -5
  38. nabu/pipeline/estimators.py +162 -188
  39. nabu/pipeline/fullfield/chunked.py +168 -92
  40. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  41. nabu/pipeline/fullfield/computations.py +2 -7
  42. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  43. nabu/pipeline/fullfield/nabu_config.py +37 -13
  44. nabu/pipeline/fullfield/processconfig.py +22 -13
  45. nabu/pipeline/fullfield/reconstruction.py +13 -9
  46. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  47. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  48. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  49. nabu/pipeline/params.py +21 -1
  50. nabu/pipeline/processconfig.py +1 -12
  51. nabu/pipeline/reader.py +146 -0
  52. nabu/pipeline/tests/test_estimators.py +44 -72
  53. nabu/pipeline/utils.py +4 -2
  54. nabu/pipeline/writer.py +10 -2
  55. nabu/preproc/ccd_cuda.py +1 -1
  56. nabu/preproc/ctf.py +14 -7
  57. nabu/preproc/ctf_cuda.py +2 -3
  58. nabu/preproc/double_flatfield.py +5 -12
  59. nabu/preproc/double_flatfield_cuda.py +2 -2
  60. nabu/preproc/flatfield.py +5 -1
  61. nabu/preproc/flatfield_cuda.py +5 -1
  62. nabu/preproc/phase.py +24 -73
  63. nabu/preproc/phase_cuda.py +5 -8
  64. nabu/preproc/tests/test_ctf.py +11 -7
  65. nabu/preproc/tests/test_flatfield.py +67 -122
  66. nabu/preproc/tests/test_paganin.py +54 -30
  67. nabu/processing/azim.py +206 -0
  68. nabu/processing/convolution_cuda.py +1 -1
  69. nabu/processing/fft_cuda.py +15 -17
  70. nabu/processing/histogram.py +2 -0
  71. nabu/processing/histogram_cuda.py +2 -1
  72. nabu/processing/kernel_base.py +3 -0
  73. nabu/processing/muladd_cuda.py +1 -0
  74. nabu/processing/padding_opencl.py +1 -1
  75. nabu/processing/roll_opencl.py +1 -0
  76. nabu/processing/rotation_cuda.py +2 -2
  77. nabu/processing/tests/test_fft.py +17 -10
  78. nabu/processing/unsharp_cuda.py +1 -1
  79. nabu/reconstruction/cone.py +104 -40
  80. nabu/reconstruction/fbp.py +3 -0
  81. nabu/reconstruction/fbp_base.py +7 -2
  82. nabu/reconstruction/filtering.py +20 -7
  83. nabu/reconstruction/filtering_cuda.py +7 -1
  84. nabu/reconstruction/hbp.py +424 -0
  85. nabu/reconstruction/mlem.py +99 -0
  86. nabu/reconstruction/reconstructor.py +2 -0
  87. nabu/reconstruction/rings_cuda.py +19 -19
  88. nabu/reconstruction/sinogram_cuda.py +1 -0
  89. nabu/reconstruction/sinogram_opencl.py +3 -1
  90. nabu/reconstruction/tests/test_cone.py +10 -5
  91. nabu/reconstruction/tests/test_deringer.py +7 -6
  92. nabu/reconstruction/tests/test_fbp.py +124 -10
  93. nabu/reconstruction/tests/test_filtering.py +13 -11
  94. nabu/reconstruction/tests/test_halftomo.py +30 -4
  95. nabu/reconstruction/tests/test_mlem.py +91 -0
  96. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  97. nabu/resources/dataset_analyzer.py +142 -92
  98. nabu/resources/gpu.py +1 -0
  99. nabu/resources/nxflatfield.py +134 -125
  100. nabu/resources/templates/id16a_fluo.conf +42 -0
  101. nabu/resources/tests/test_extract.py +10 -0
  102. nabu/resources/tests/test_nxflatfield.py +2 -2
  103. nabu/stitching/alignment.py +80 -24
  104. nabu/stitching/config.py +105 -68
  105. nabu/stitching/definitions.py +1 -0
  106. nabu/stitching/frame_composition.py +68 -60
  107. nabu/stitching/overlap.py +91 -51
  108. nabu/stitching/single_axis_stitching.py +32 -0
  109. nabu/stitching/slurm_utils.py +6 -6
  110. nabu/stitching/stitcher/__init__.py +0 -0
  111. nabu/stitching/stitcher/base.py +124 -0
  112. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  113. nabu/stitching/stitcher/dumper/base.py +94 -0
  114. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  115. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  116. nabu/stitching/stitcher/post_processing.py +555 -0
  117. nabu/stitching/stitcher/pre_processing.py +1068 -0
  118. nabu/stitching/stitcher/single_axis.py +484 -0
  119. nabu/stitching/stitcher/stitcher.py +0 -0
  120. nabu/stitching/stitcher/y_stitcher.py +13 -0
  121. nabu/stitching/stitcher/z_stitcher.py +45 -0
  122. nabu/stitching/stitcher_2D.py +278 -0
  123. nabu/stitching/tests/test_config.py +12 -37
  124. nabu/stitching/tests/test_frame_composition.py +33 -59
  125. nabu/stitching/tests/test_overlap.py +149 -7
  126. nabu/stitching/tests/test_utils.py +1 -1
  127. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  128. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  129. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  130. nabu/stitching/utils/__init__.py +1 -0
  131. nabu/stitching/utils/post_processing.py +281 -0
  132. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  133. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  134. nabu/stitching/y_stitching.py +27 -0
  135. nabu/stitching/z_stitching.py +32 -2263
  136. nabu/testutils.py +1 -152
  137. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  138. nabu/utils.py +158 -61
  139. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
  140. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
  141. nabu/io/tiffwriter_zmm.py +0 -99
  142. nabu/pipeline/fallback_utils.py +0 -149
  143. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  144. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  145. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  146. nabu/pipeline/helical/utils.py +0 -51
  147. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  148. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  149. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
  150. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  151. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,431 @@
1
+ import os
2
+ from silx.image.phantomgenerator import PhantomGenerator
3
+ from scipy.ndimage import shift as scipy_shift
4
+ import numpy
5
+ import pytest
6
+ from nabu.stitching.config import PreProcessedZStitchingConfiguration
7
+ from nabu.stitching.config import KEY_IMG_REG_METHOD
8
+ from nabu.stitching.overlap import ImageStichOverlapKernel, OverlapStitchingStrategy
9
+ from nabu.stitching.z_stitching import (
10
+ PreProcessZStitcher,
11
+ )
12
+ from nabu.stitching.stitcher_2D import stitch_raw_frames, get_overlap_areas
13
+ from nxtomo.nxobject.nxdetector import ImageKey
14
+ from nxtomo.utils.transformation import DetYFlipTransformation, DetZFlipTransformation
15
+ from nxtomo.application.nxtomo import NXtomo
16
+ from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
17
+ from nabu.stitching.utils import ShiftAlgorithm
18
+ import h5py
19
+
20
+
21
+ _stitching_configurations = (
22
+ # simple case where shifts are provided
23
+ {
24
+ "n_proj": 4,
25
+ "raw_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to
26
+ "input_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to
27
+ "raw_shifts": ((0, 0), (-90, 0), (-180, 0)),
28
+ },
29
+ # simple case where shift is found from z position
30
+ {
31
+ "n_proj": 4,
32
+ "raw_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)),
33
+ "input_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)),
34
+ "check_bb": ((40, 140), (-50, 50), (-140, -40)),
35
+ "axis_0_params": {
36
+ KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE,
37
+ },
38
+ "axis_2_params": {
39
+ KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE,
40
+ },
41
+ "raw_shifts": ((0, 0), (-90, 0), (-180, 0)),
42
+ },
43
+ )
44
+
45
+
46
+ @pytest.mark.parametrize("configuration", _stitching_configurations)
47
+ @pytest.mark.parametrize("dtype", (numpy.float32, numpy.int16))
48
+ def test_PreProcessZStitcher(tmp_path, dtype, configuration):
49
+ """
50
+ test PreProcessZStitcher class and insure a full stitching can be done automatically.
51
+ """
52
+ n_proj = configuration["n_proj"]
53
+ ref_frame_width = 280
54
+ raw_frame_height = 100
55
+ ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) * 256.0
56
+
57
+ # add some mark for image registration
58
+ ref_frame[:, 96] = -3.2
59
+ ref_frame[:, 125] = 9.1
60
+ ref_frame[:, 165] = 4.4
61
+ ref_frame[:, 200] = -2.5
62
+ # create raw data
63
+ frame_0_shift, frame_1_shift, frame_2_shift = configuration["raw_shifts"]
64
+ frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_height]
65
+ frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_height]
66
+ frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_height]
67
+
68
+ frames = frame_0, frame_1, frame_2
69
+ frame_0_input_pos, frame_1_input_pos, frame_2_input_pos = configuration["input_pos"]
70
+ frame_0_raw_pos, frame_1_raw_pos, frame_2_raw_pos = configuration["raw_pos"]
71
+
72
+ # create a Nxtomo for each of those raw data
73
+ raw_data_dir = tmp_path / "raw_data"
74
+ raw_data_dir.mkdir()
75
+ output_dir = tmp_path / "output_dir"
76
+ output_dir.mkdir()
77
+ z_position = (
78
+ frame_0_raw_pos[0],
79
+ frame_1_raw_pos[0],
80
+ frame_2_raw_pos[0],
81
+ )
82
+ scans = []
83
+ for (i_frame, frame), z_pos in zip(enumerate(frames), z_position):
84
+ nx_tomo = NXtomo()
85
+ nx_tomo.sample.z_translation = [z_pos] * n_proj
86
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
87
+ nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
88
+ nx_tomo.instrument.detector.x_pixel_size = 1.0
89
+ nx_tomo.instrument.detector.y_pixel_size = 1.0
90
+ nx_tomo.instrument.detector.distance = 2.3
91
+ nx_tomo.energy = 19.2
92
+ nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
93
+
94
+ file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
95
+ entry = f"entry000{i_frame}"
96
+ nx_tomo.save(file_path=file_path, data_path=entry)
97
+ scans.append(NXtomoScan(scan=file_path, entry=entry))
98
+
99
+ # if requested: check bounding box
100
+ check_bb = configuration.get("check_bb", None)
101
+ if check_bb is not None:
102
+ for scan, expected_bb in zip(scans, check_bb):
103
+ assert scan.get_bounding_box(axis="z") == expected_bb
104
+ output_file_path = os.path.join(output_dir, "stitched.nx")
105
+ output_data_path = "stitched"
106
+ z_stich_config = PreProcessedZStitchingConfiguration(
107
+ stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS,
108
+ overwrite_results=True,
109
+ axis_0_pos_px=(
110
+ frame_0_input_pos[0],
111
+ frame_1_input_pos[0],
112
+ frame_2_input_pos[0],
113
+ ),
114
+ axis_1_pos_px=(
115
+ frame_0_input_pos[1],
116
+ frame_1_input_pos[1],
117
+ frame_2_input_pos[1],
118
+ ),
119
+ axis_2_pos_px=(
120
+ frame_0_input_pos[2],
121
+ frame_1_input_pos[2],
122
+ frame_2_input_pos[2],
123
+ ),
124
+ axis_0_pos_mm=None,
125
+ axis_1_pos_mm=None,
126
+ axis_2_pos_mm=None,
127
+ input_scans=scans,
128
+ output_file_path=output_file_path,
129
+ output_data_path=output_data_path,
130
+ axis_0_params=configuration.get("axis_0_params", {}),
131
+ axis_1_params=configuration.get("axis_1_params", {}),
132
+ axis_2_params=configuration.get("axis_2_params", {}),
133
+ output_nexus_version=None,
134
+ slices=None,
135
+ slurm_config=None,
136
+ slice_for_cross_correlation="middle",
137
+ pixel_size=None,
138
+ )
139
+ stitcher = PreProcessZStitcher(z_stich_config)
140
+ output_identifier = stitcher.stitch()
141
+ assert output_identifier.file_path == output_file_path
142
+ assert output_identifier.data_path == output_data_path
143
+
144
+ created_nx_tomo = NXtomo().load(
145
+ file_path=output_identifier.file_path,
146
+ data_path=output_identifier.data_path,
147
+ detector_data_as="as_numpy_array",
148
+ )
149
+
150
+ assert created_nx_tomo.instrument.detector.data.ndim == 3
151
+ mean_abs_error = configuration.get("mean_abs_error", None)
152
+ if mean_abs_error is not None:
153
+ assert (
154
+ numpy.mean(numpy.abs(ref_frame - created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :]))
155
+ < mean_abs_error
156
+ )
157
+ else:
158
+ numpy.testing.assert_array_almost_equal(
159
+ ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :]
160
+ )
161
+
162
+ # check also other metadata are here
163
+ assert created_nx_tomo.instrument.detector.distance.value == 2.3
164
+ assert created_nx_tomo.energy.value == 19.2
165
+ numpy.testing.assert_array_equal(
166
+ created_nx_tomo.instrument.detector.image_key_control,
167
+ numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj),
168
+ )
169
+
170
+ # check configuration has been saved
171
+ with h5py.File(output_identifier.file_path, mode="r") as h5f:
172
+ assert "stitching_configuration" in h5f[output_identifier.data_path]
173
+
174
+
175
+ slices_to_test_pre = (
176
+ {
177
+ "slices": (None,),
178
+ "complete": True,
179
+ },
180
+ {
181
+ "slices": (("first",), ("middle",), ("last",)),
182
+ "complete": False,
183
+ },
184
+ {
185
+ "slices": ((0, 1, 2), slice(3, -1, 1)),
186
+ "complete": True,
187
+ },
188
+ )
189
+
190
+
191
+ def build_nxtomos(output_dir) -> tuple:
192
+ r"""
193
+ build two nxtomos in output_dir and return the list of NXtomos ready to be stitched
194
+ /\
195
+ | ______________
196
+ | | nxtomo 1 |
197
+ Z | | frame |
198
+ | |~~~~~~~~~~~~~~|
199
+ | |~~~~~~~~~~~~~~|
200
+ | |______________|
201
+ | ______________
202
+ | |~~~~~~~~~~~~~~|
203
+ | |~~~~~~~~~~~~~~|
204
+ | | nxtomo 2 |
205
+ | | frame |
206
+ | |______________|
207
+ |
208
+ <-----------------------------------------------
209
+ y (in acquisition space)
210
+
211
+ * ~: represent the overlap area
212
+ """
213
+ n_projs = 100
214
+ raw_data = numpy.arange(100 * 128 * 128).reshape((100, 128, 128))
215
+
216
+ # create raw data
217
+ frame_0 = raw_data[:, 60:]
218
+ assert frame_0.ndim == 3
219
+ frame_0_pos = 40
220
+ frame_1 = raw_data[:, 0:80]
221
+ assert frame_1.ndim == 3
222
+ frame_1_pos = 94
223
+ frames = (frame_0, frame_1)
224
+ z_positions = (frame_0_pos, frame_1_pos)
225
+
226
+ # create a Nxtomo for each of those raw data
227
+
228
+ scans = []
229
+ for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions):
230
+ nx_tomo = NXtomo()
231
+ nx_tomo.sample.z_translation = [z_pos] * n_projs
232
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False)
233
+ nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs
234
+ nx_tomo.instrument.detector.x_pixel_size = 1.0
235
+ nx_tomo.instrument.detector.y_pixel_size = 1.0
236
+ nx_tomo.instrument.detector.distance = 2.3
237
+ nx_tomo.energy = 19.2
238
+ nx_tomo.instrument.detector.data = frame
239
+
240
+ file_path = os.path.join(output_dir, f"nxtomo_{i_frame}.nx")
241
+ entry = f"entry000{i_frame}"
242
+ nx_tomo.save(file_path=file_path, data_path=entry)
243
+ scans.append(NXtomoScan(scan=file_path, entry=entry))
244
+ return scans, z_positions, raw_data
245
+
246
+
247
+ @pytest.mark.parametrize("configuration_dist", slices_to_test_pre)
248
+ def test_DistributePreProcessZStitcher(tmp_path, configuration_dist):
249
+ slices = configuration_dist["slices"]
250
+ complete = configuration_dist["complete"]
251
+
252
+ raw_data_dir = tmp_path / "raw_data"
253
+ raw_data_dir.mkdir()
254
+
255
+ output_dir = tmp_path / "output_dir"
256
+ output_dir.mkdir()
257
+
258
+ scans, z_positions, raw_data = build_nxtomos(output_dir=raw_data_dir)
259
+ stitched_nx_tomo = []
260
+ for s in slices:
261
+ output_file_path = os.path.join(output_dir, "stitched_section.nx")
262
+ output_data_path = f"stitched_{s}"
263
+ z_stich_config = PreProcessedZStitchingConfiguration(
264
+ axis_0_pos_px=z_positions,
265
+ axis_1_pos_px=(0, 0),
266
+ axis_2_pos_px=None,
267
+ axis_0_pos_mm=None,
268
+ axis_1_pos_mm=None,
269
+ axis_2_pos_mm=None,
270
+ axis_0_params={},
271
+ axis_1_params={},
272
+ axis_2_params={},
273
+ stitching_strategy=OverlapStitchingStrategy.CLOSEST,
274
+ overwrite_results=True,
275
+ input_scans=scans,
276
+ output_file_path=output_file_path,
277
+ output_data_path=output_data_path,
278
+ output_nexus_version=None,
279
+ slices=s,
280
+ slurm_config=None,
281
+ slice_for_cross_correlation="middle",
282
+ pixel_size=None,
283
+ )
284
+ stitcher = PreProcessZStitcher(z_stich_config)
285
+ output_identifier = stitcher.stitch()
286
+ assert output_identifier.file_path == output_file_path
287
+ assert output_identifier.data_path == output_data_path
288
+
289
+ created_nx_tomo = NXtomo().load(
290
+ file_path=output_identifier.file_path,
291
+ data_path=output_identifier.data_path,
292
+ detector_data_as="as_numpy_array",
293
+ )
294
+ stitched_nx_tomo.append(created_nx_tomo)
295
+ assert len(stitched_nx_tomo) == len(slices)
296
+ final_nx_tomo = NXtomo.concatenate(stitched_nx_tomo)
297
+ assert isinstance(final_nx_tomo.instrument.detector.data, numpy.ndarray)
298
+ final_nx_tomo.save(
299
+ file_path=os.path.join(output_dir, "final_stitched.nx"),
300
+ data_path="entry0000",
301
+ )
302
+
303
+ if complete:
304
+ len(final_nx_tomo.instrument.detector.data) == 128
305
+ # test middle
306
+ numpy.testing.assert_array_almost_equal(raw_data[1], final_nx_tomo.instrument.detector.data[1, :, :])
307
+ else:
308
+ len(final_nx_tomo.instrument.detector.data) == 3
309
+ # test middle
310
+ numpy.testing.assert_array_almost_equal(raw_data[49], final_nx_tomo.instrument.detector.data[1, :, :])
311
+ # in the case of first, middle and last frames
312
+ # test first
313
+ numpy.testing.assert_array_almost_equal(raw_data[0], final_nx_tomo.instrument.detector.data[0, :, :])
314
+
315
+ # test last
316
+ numpy.testing.assert_array_almost_equal(raw_data[-1], final_nx_tomo.instrument.detector.data[-1, :, :])
317
+
318
+
319
+ def test_get_overlap_areas():
320
+ """test get_overlap_areas function"""
321
+ f_upper = numpy.linspace(7, 15, num=9, endpoint=True)
322
+ f_lower = numpy.linspace(0, 12, num=13, endpoint=True)
323
+
324
+ o_1, o_2 = get_overlap_areas(
325
+ upper_frame=f_upper,
326
+ lower_frame=f_lower,
327
+ upper_frame_key_line=3,
328
+ lower_frame_key_line=10,
329
+ overlap_size=4,
330
+ stitching_axis=0,
331
+ )
332
+
333
+ numpy.testing.assert_array_equal(o_1, o_2)
334
+ numpy.testing.assert_array_equal(o_1, numpy.linspace(8, 11, num=4, endpoint=True))
335
+
336
+
337
+ def test_frame_flip(tmp_path):
338
+ """check it with some NXtomo flipped"""
339
+ ref_frame_width = 280
340
+ n_proj = 10
341
+ raw_frame_width = 100
342
+ ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(numpy.float32) * 256.0
343
+ # create raw data
344
+ frame_0_shift = (0, 0)
345
+ frame_1_shift = (-90, 0)
346
+ frame_2_shift = (-180, 0)
347
+
348
+ frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_width]
349
+ frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_width]
350
+ frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_width]
351
+ frames = frame_0, frame_1, frame_2
352
+
353
+ x_flips = [False, True, True]
354
+ y_flips = [False, False, True]
355
+
356
+ def apply_flip(args):
357
+ frame, flip_x, flip_y = args
358
+ if flip_x:
359
+ frame = numpy.fliplr(frame)
360
+ if flip_y:
361
+ frame = numpy.flipud(frame)
362
+ return frame
363
+
364
+ frames = map(apply_flip, zip(frames, x_flips, y_flips))
365
+
366
+ # create a Nxtomo for each of those raw data
367
+ raw_data_dir = tmp_path / "raw_data"
368
+ raw_data_dir.mkdir()
369
+ output_dir = tmp_path / "output_dir"
370
+ output_dir.mkdir()
371
+ z_position = (90, 0, -90)
372
+
373
+ scans = []
374
+ for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips):
375
+ nx_tomo = NXtomo()
376
+ nx_tomo.sample.z_translation = [z_pos] * n_proj
377
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
378
+ nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
379
+ nx_tomo.instrument.detector.x_pixel_size = 1.0
380
+ nx_tomo.instrument.detector.y_pixel_size = 1.0
381
+ nx_tomo.instrument.detector.distance = 2.3
382
+ nx_tomo.instrument.detector.transformations.add_transformation(DetZFlipTransformation(flip=x_flip))
383
+ nx_tomo.instrument.detector.transformations.add_transformation(DetYFlipTransformation(flip=y_flip))
384
+ nx_tomo.energy = 19.2
385
+ nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
386
+
387
+ file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
388
+ entry = f"entry000{i_frame}"
389
+ nx_tomo.save(file_path=file_path, data_path=entry)
390
+ scans.append(NXtomoScan(scan=file_path, entry=entry))
391
+
392
+ output_file_path = os.path.join(output_dir, "stitched.nx")
393
+ output_data_path = "stitched"
394
+ assert len(scans) == 3
395
+ z_stich_config = PreProcessedZStitchingConfiguration(
396
+ axis_0_pos_px=(0, -90, -180),
397
+ axis_1_pos_px=(0, 0, 0),
398
+ axis_2_pos_px=None,
399
+ axis_0_pos_mm=None,
400
+ axis_1_pos_mm=None,
401
+ axis_2_pos_mm=None,
402
+ axis_0_params={},
403
+ axis_1_params={},
404
+ axis_2_params={},
405
+ stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS,
406
+ overwrite_results=True,
407
+ input_scans=scans,
408
+ output_file_path=output_file_path,
409
+ output_data_path=output_data_path,
410
+ output_nexus_version=None,
411
+ slices=None,
412
+ slurm_config=None,
413
+ slice_for_cross_correlation="middle",
414
+ pixel_size=None,
415
+ )
416
+ stitcher = PreProcessZStitcher(z_stich_config)
417
+ output_identifier = stitcher.stitch()
418
+ assert output_identifier.file_path == output_file_path
419
+ assert output_identifier.data_path == output_data_path
420
+
421
+ created_nx_tomo = NXtomo().load(
422
+ file_path=output_identifier.file_path,
423
+ data_path=output_identifier.data_path,
424
+ detector_data_as="as_numpy_array",
425
+ )
426
+
427
+ assert created_nx_tomo.instrument.detector.data.ndim == 3
428
+ # insure flipping has been taking into account
429
+ numpy.testing.assert_array_almost_equal(ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :])
430
+
431
+ assert len(created_nx_tomo.instrument.detector.transformations) == 0
@@ -0,0 +1 @@
1
+ from .utils import *