nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/__init__.py +0 -0
  15. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  16. nabu/cuda/kernel.py +11 -2
  17. nabu/cuda/processing.py +2 -2
  18. nabu/cuda/src/cone.cu +77 -0
  19. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  20. nabu/cuda/utils.py +0 -6
  21. nabu/estimation/alignment.py +5 -19
  22. nabu/estimation/cor.py +173 -599
  23. nabu/estimation/cor_sino.py +356 -26
  24. nabu/estimation/focus.py +63 -11
  25. nabu/estimation/tests/test_cor.py +124 -58
  26. nabu/estimation/tests/test_focus.py +6 -6
  27. nabu/estimation/tilt.py +2 -1
  28. nabu/estimation/utils.py +5 -33
  29. nabu/io/__init__.py +1 -1
  30. nabu/io/cast_volume.py +1 -1
  31. nabu/io/reader.py +416 -21
  32. nabu/io/tests/test_readers.py +422 -0
  33. nabu/io/tests/test_writers.py +1 -102
  34. nabu/io/writer.py +4 -433
  35. nabu/opencl/kernel.py +14 -3
  36. nabu/opencl/processing.py +8 -0
  37. nabu/pipeline/config_validators.py +5 -2
  38. nabu/pipeline/datadump.py +12 -5
  39. nabu/pipeline/estimators.py +162 -188
  40. nabu/pipeline/fullfield/chunked.py +168 -92
  41. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  42. nabu/pipeline/fullfield/computations.py +2 -7
  43. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  44. nabu/pipeline/fullfield/nabu_config.py +37 -13
  45. nabu/pipeline/fullfield/processconfig.py +22 -13
  46. nabu/pipeline/fullfield/reconstruction.py +13 -9
  47. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  48. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  49. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  50. nabu/pipeline/params.py +21 -1
  51. nabu/pipeline/processconfig.py +1 -12
  52. nabu/pipeline/reader.py +146 -0
  53. nabu/pipeline/tests/test_estimators.py +44 -72
  54. nabu/pipeline/utils.py +4 -2
  55. nabu/pipeline/writer.py +10 -2
  56. nabu/preproc/ccd_cuda.py +1 -1
  57. nabu/preproc/ctf.py +14 -7
  58. nabu/preproc/ctf_cuda.py +2 -3
  59. nabu/preproc/double_flatfield.py +5 -12
  60. nabu/preproc/double_flatfield_cuda.py +2 -2
  61. nabu/preproc/flatfield.py +5 -1
  62. nabu/preproc/flatfield_cuda.py +5 -1
  63. nabu/preproc/phase.py +24 -73
  64. nabu/preproc/phase_cuda.py +5 -8
  65. nabu/preproc/tests/test_ctf.py +11 -7
  66. nabu/preproc/tests/test_flatfield.py +67 -122
  67. nabu/preproc/tests/test_paganin.py +54 -30
  68. nabu/processing/azim.py +206 -0
  69. nabu/processing/convolution_cuda.py +1 -1
  70. nabu/processing/fft_cuda.py +15 -17
  71. nabu/processing/histogram.py +2 -0
  72. nabu/processing/histogram_cuda.py +2 -1
  73. nabu/processing/kernel_base.py +3 -0
  74. nabu/processing/muladd_cuda.py +1 -0
  75. nabu/processing/padding_opencl.py +1 -1
  76. nabu/processing/roll_opencl.py +1 -0
  77. nabu/processing/rotation_cuda.py +2 -2
  78. nabu/processing/tests/test_fft.py +17 -10
  79. nabu/processing/unsharp_cuda.py +1 -1
  80. nabu/reconstruction/cone.py +104 -40
  81. nabu/reconstruction/fbp.py +3 -0
  82. nabu/reconstruction/fbp_base.py +7 -2
  83. nabu/reconstruction/filtering.py +20 -7
  84. nabu/reconstruction/filtering_cuda.py +7 -1
  85. nabu/reconstruction/hbp.py +424 -0
  86. nabu/reconstruction/mlem.py +99 -0
  87. nabu/reconstruction/reconstructor.py +2 -0
  88. nabu/reconstruction/rings_cuda.py +19 -19
  89. nabu/reconstruction/sinogram_cuda.py +1 -0
  90. nabu/reconstruction/sinogram_opencl.py +3 -1
  91. nabu/reconstruction/tests/test_cone.py +10 -5
  92. nabu/reconstruction/tests/test_deringer.py +7 -6
  93. nabu/reconstruction/tests/test_fbp.py +124 -10
  94. nabu/reconstruction/tests/test_filtering.py +13 -11
  95. nabu/reconstruction/tests/test_halftomo.py +30 -4
  96. nabu/reconstruction/tests/test_mlem.py +91 -0
  97. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  98. nabu/resources/dataset_analyzer.py +142 -92
  99. nabu/resources/gpu.py +1 -0
  100. nabu/resources/nxflatfield.py +134 -125
  101. nabu/resources/templates/id16a_fluo.conf +42 -0
  102. nabu/resources/tests/test_extract.py +10 -0
  103. nabu/resources/tests/test_nxflatfield.py +2 -2
  104. nabu/stitching/alignment.py +80 -24
  105. nabu/stitching/config.py +105 -68
  106. nabu/stitching/definitions.py +1 -0
  107. nabu/stitching/frame_composition.py +68 -60
  108. nabu/stitching/overlap.py +91 -51
  109. nabu/stitching/single_axis_stitching.py +32 -0
  110. nabu/stitching/slurm_utils.py +6 -6
  111. nabu/stitching/stitcher/__init__.py +0 -0
  112. nabu/stitching/stitcher/base.py +124 -0
  113. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  114. nabu/stitching/stitcher/dumper/base.py +94 -0
  115. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  116. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  117. nabu/stitching/stitcher/post_processing.py +555 -0
  118. nabu/stitching/stitcher/pre_processing.py +1068 -0
  119. nabu/stitching/stitcher/single_axis.py +484 -0
  120. nabu/stitching/stitcher/stitcher.py +0 -0
  121. nabu/stitching/stitcher/y_stitcher.py +13 -0
  122. nabu/stitching/stitcher/z_stitcher.py +45 -0
  123. nabu/stitching/stitcher_2D.py +278 -0
  124. nabu/stitching/tests/test_config.py +12 -37
  125. nabu/stitching/tests/test_frame_composition.py +33 -59
  126. nabu/stitching/tests/test_overlap.py +149 -7
  127. nabu/stitching/tests/test_utils.py +1 -1
  128. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  129. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  130. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  131. nabu/stitching/utils/__init__.py +1 -0
  132. nabu/stitching/utils/post_processing.py +281 -0
  133. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  134. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  135. nabu/stitching/y_stitching.py +27 -0
  136. nabu/stitching/z_stitching.py +32 -2281
  137. nabu/testutils.py +1 -152
  138. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  139. nabu/utils.py +158 -61
  140. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
  141. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
  142. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
  143. nabu/io/tiffwriter_zmm.py +0 -99
  144. nabu/pipeline/fallback_utils.py +0 -149
  145. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  146. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  147. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  148. nabu/pipeline/helical/utils.py +0 -51
  149. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  150. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  151. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  152. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,11 @@
1
- from tempfile import mkdtemp
2
1
  import os
3
2
  import numpy as np
4
3
  import pytest
5
- from silx.io.url import DataUrl
6
- from silx.io import get_data
7
- from silx.io.dictdump import dicttoh5
8
4
  from nabu.cuda.utils import get_cuda_context, __has_pycuda__
9
- from nabu.preproc.flatfield import FlatField, FlatFieldDataUrls
5
+ from nabu.preproc.flatfield import FlatField
10
6
 
11
7
  if __has_pycuda__:
12
- import pycuda.gpuarray as garray
13
- from nabu.preproc.flatfield_cuda import CudaFlatFieldDataUrls, CudaFlatField
8
+ from nabu.preproc.flatfield_cuda import CudaFlatField
14
9
 
15
10
 
16
11
  # Flats values should be O(k) so that linear interpolation between flats gives exact results
@@ -84,7 +79,6 @@ def generate_test_flatfield_generalized(
84
79
  flats_values,
85
80
  darks_indices,
86
81
  darks_values,
87
- h5_fname,
88
82
  dtype=np.uint16,
89
83
  ):
90
84
  """
@@ -112,14 +106,11 @@ def generate_test_flatfield_generalized(
112
106
  -------
113
107
  radios: numpy.ndarray
114
108
  3D array with raw radios
115
- darks: dict of DataUrls
116
- Dictionary where each key is the dark indice, and value is a DataUrl
117
- flats: dict of DataUrls
118
- Dictionary where each key is the flat indice, and value is a DataUrl
109
+ darks: dict of arrays
110
+ Dictionary where each key is the dark indice, and value is an array
111
+ flats: dict of arrays
112
+ Dictionary where each key is the flat indice, and value is an array
119
113
  """
120
- tempdir = mkdtemp(prefix="nabu_")
121
- testffname = os.path.join(tempdir, h5_fname)
122
-
123
114
  # Radios
124
115
  radios = np.zeros((len(radios_values),) + image_shape, dtype="f")
125
116
  n_radios = radios.shape[0]
@@ -129,26 +120,15 @@ def generate_test_flatfield_generalized(
129
120
 
130
121
  # Flats
131
122
  flats = {}
132
- flats_urls = {}
133
123
  for i, flat_idx in enumerate(flats_indices):
134
- flats["flats_%06d" % flat_idx] = np.zeros(img_shape, dtype=dtype) + flats_values[i]
135
- flats_urls[flat_idx] = DataUrl(
136
- file_path=testffname, data_path=str("/flats/flats_%06d" % flat_idx), scheme="silx"
137
- )
124
+ flats[flat_idx] = np.zeros(img_shape, dtype=dtype) + flats_values[i]
138
125
 
139
126
  # Darks
140
127
  darks = {}
141
- darks_urls = {}
142
128
  for i, dark_idx in enumerate(darks_indices):
143
- darks["darks_%06d" % dark_idx] = np.zeros(img_shape, dtype=dtype) + darks_values[i]
144
- darks_urls[dark_idx] = DataUrl(
145
- file_path=testffname, data_path=str("/darks/darks_%06d" % dark_idx), scheme="silx"
146
- )
129
+ darks[dark_idx] = np.zeros(img_shape, dtype=dtype) + darks_values[i]
147
130
 
148
- dicttoh5(flats, testffname, h5path="/flats", mode="w")
149
- dicttoh5(darks, testffname, h5path="/darks", mode="a")
150
-
151
- return radios, flats_urls, darks_urls
131
+ return radios, flats, darks
152
132
 
153
133
 
154
134
  @pytest.fixture(scope="class")
@@ -173,7 +153,7 @@ def bootstrap(request):
173
153
  class TestFlatField:
174
154
  def get_test_elements(self, case_name):
175
155
  config = flatfield_tests_cases[case_name]
176
- radios_stack, flats_url, darks_url = generate_test_flatfield_generalized(
156
+ radios_stack, flats, darks = generate_test_flatfield_generalized(
177
157
  config["image_shape"],
178
158
  config["radios_indices"],
179
159
  config["radios_values"],
@@ -181,12 +161,11 @@ class TestFlatField:
181
161
  config["flats_values"],
182
162
  config["darks_indices"],
183
163
  config["darks_values"],
184
- "test_ff.h5",
185
164
  )
186
- fname = flats_url[list(flats_url.keys())[0]].file_path()
187
- self.tmp_files.append(fname)
188
- self.tmp_dirs.append(os.path.dirname(fname))
189
- return radios_stack, flats_url, darks_url, config
165
+ # fname = flats_url[list(flats_url.keys())[0]].file_path()
166
+ # self.tmp_files.append(fname)
167
+ # self.tmp_dirs.append(os.path.dirname(fname))
168
+ return radios_stack, flats, darks, config
190
169
 
191
170
  @staticmethod
192
171
  def check_normalized_radios(radios_corr, expected_values):
@@ -203,9 +182,9 @@ class TestFlatField:
203
182
  (I - D)/(F - D) where I = (1, 2, ...), D = 1, F = 0.5
204
183
  = (0, -2, -4, -6, ...)
205
184
  """
206
- radios_stack, flats_url, darks_url, config = self.get_test_elements("simple_nearest_interp")
185
+ radios_stack, flats, darks, config = self.get_test_elements("simple_nearest_interp")
207
186
 
208
- flatfield = FlatFieldDataUrls(radios_stack.shape, flats_url, darks_url)
187
+ flatfield = FlatField(radios_stack.shape, flats, darks)
209
188
  radios_corr = flatfield.normalize_radios(np.copy(radios_stack))
210
189
  self.check_normalized_radios(radios_corr, config["expected_result"])
211
190
 
@@ -213,16 +192,17 @@ class TestFlatField:
213
192
  """
214
193
  Same as test_flatfield_simple, but in a vertical subregion of the radios.
215
194
  """
216
- radios_stack, flats_url, darks_url, config = self.get_test_elements("simple_nearest_interp")
195
+ radios_stack, flats, darks, config = self.get_test_elements("simple_nearest_interp")
217
196
  end_z = 51
197
+ flats = {k: arr[:end_z, :] for k, arr in flats.items()}
198
+ darks = {k: arr[:end_z, :] for k, arr in darks.items()}
218
199
  radios_chunk = np.copy(radios_stack[:, :end_z, :])
219
200
  # we only have a chunk in memory. Instantiate the class with the
220
201
  # corresponding subregion to only load the relevant part of dark/flat
221
- flatfield = FlatFieldDataUrls(
202
+ flatfield = FlatField(
222
203
  radios_chunk.shape,
223
- flats_url,
224
- darks_url,
225
- sub_region=(None, None, None, end_z), # start_x, end_x, start_z, end_z
204
+ flats,
205
+ darks,
226
206
  )
227
207
  radios_corr = flatfield.normalize_radios(radios_chunk)
228
208
  self.check_normalized_radios(radios_corr, config["expected_result"])
@@ -239,8 +219,8 @@ class TestFlatField:
239
219
  = (I-D)/(F-D)
240
220
  = (I-1)/I
241
221
  """
242
- radios_stack, flats_url, darks_url, config = self.get_test_elements("two_flats_no_radios_indices")
243
- flatfield = FlatFieldDataUrls(radios_stack.shape, flats_url, darks_url)
222
+ radios_stack, flats, darks, config = self.get_test_elements("two_flats_no_radios_indices")
223
+ flatfield = FlatField(radios_stack.shape, flats, darks)
244
224
  radios_corr = flatfield.normalize_radios(np.copy(radios_stack))
245
225
  self.check_normalized_radios(radios_corr, config["expected_result"])
246
226
 
@@ -249,10 +229,10 @@ class TestFlatField:
249
229
  # F = 2 11
250
230
  # F_i = 2 3.8 5.6 7.4 9.2 11 11 11 11 11
251
231
  # R = 0 .357 .435 .469 .488 .5 .6 .7 .8 .9
252
- flats_url = flats_url.copy()
253
- flats_url[5] = flats_url[9]
254
- flats_url.pop(9)
255
- flatfield = FlatFieldDataUrls(radios_stack.shape, flats_url, darks_url)
232
+ flats = {k: v.copy() for k, v in flats.items()}
233
+ flats[5] = flats[9]
234
+ flats.pop(9)
235
+ flatfield = FlatField(radios_stack.shape, flats, darks)
256
236
  radios_corr = flatfield.normalize_radios(np.copy(radios_stack))
257
237
  self.check_normalized_radios(
258
238
  radios_corr, [0.0, 0.35714286, 0.43478261, 0.46875, 0.48780488, 0.5, 0.6, 0.7, 0.8, 0.9]
@@ -263,13 +243,13 @@ class TestFlatField:
263
243
  """
264
244
  Test the flat-field with cuda back-end.
265
245
  """
266
- radios_stack, flats_url, darks_url, config = self.get_test_elements("two_flats_no_radios_indices")
267
- d_radios = garray.to_gpu(radios_stack.astype("f"))
268
- cuda_flatfield = CudaFlatFieldDataUrls(
269
- d_radios.shape,
270
- flats_url,
271
- darks_url,
246
+ radios_stack, flats, darks, config = self.get_test_elements("two_flats_no_radios_indices")
247
+ cuda_flatfield = CudaFlatField(
248
+ radios_stack.shape,
249
+ flats,
250
+ darks,
272
251
  )
252
+ d_radios = cuda_flatfield.cuda_processing.to_device("d_radios", radios_stack.astype("f"))
273
253
  cuda_flatfield.normalize_radios(d_radios)
274
254
  radios_corr = d_radios.get()
275
255
  self.check_normalized_radios(radios_corr, config["expected_result"])
@@ -277,27 +257,27 @@ class TestFlatField:
277
257
  # Linear interpolation, two flats, one dark
278
258
  def test_twoflats_simple(self):
279
259
  radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices")
280
- FF = FlatFieldDataUrls(radios.shape, flats, darks, radios_indices=config["radios_indices"])
260
+ FF = FlatField(radios.shape, flats, darks, radios_indices=config["radios_indices"])
281
261
  FF.normalize_radios(radios)
282
262
  self.check_normalized_radios(radios, config["expected_result"])
283
263
 
284
264
  def _setup_numerical_issue(self):
285
265
  radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices")
286
-
287
- # Retrieve the actual data for radios/darks/flats to use FlatField instead of FlatFieldDataUrl.
288
- # Create a setting yielding "0/0": one pixel such that flat==dark and radio==dark
289
- for flat_idx, flat_url in flats.items():
290
- flats[flat_idx] = get_data(flat_url)
291
- flats[flat_idx][0, 0] = 99
292
- for dark_idx, dark_url in darks.items():
293
- darks[dark_idx] = get_data(dark_url)
294
- darks[dark_idx][0, 0] = 99
266
+ flats_copy = {}
267
+ darks_copy = {}
268
+
269
+ for flat_idx, flat in flats.items():
270
+ flats_copy[flat_idx] = flat.copy()
271
+ flats_copy[flat_idx][0, 0] = 99
272
+ for dark_idx, dark in darks.items():
273
+ darks_copy[dark_idx] = dark.copy()
274
+ darks_copy[dark_idx][0, 0] = 99
295
275
  radios[:, 0, 0] = 99
296
- return radios, flats, darks, config
276
+ return radios, flats_copy, darks_copy, config
297
277
 
298
278
  def _check_numerical_issue(self, radios, expected_result, nan_value=None):
299
279
  if nan_value is None:
300
- assert np.alltrue(np.logical_not(np.isfinite(radios[:, 0, 0]))), "First pixel should be nan or inf"
280
+ assert np.all(np.logical_not(np.isfinite(radios[:, 0, 0]))), "First pixel should be nan or inf"
301
281
  radios[:, 0, 0] = radios[:, 1, 1]
302
282
  self.check_normalized_radios(radios, expected_result)
303
283
  else:
@@ -341,10 +321,10 @@ class TestFlatField:
341
321
  """
342
322
  radios, flats, darks, config = self._setup_numerical_issue()
343
323
  radios0 = radios.copy()
344
- d_radios = garray.to_gpu(radios)
345
324
  FF_no_nan_handling = CudaFlatField(
346
325
  radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=None
347
326
  )
327
+ d_radios = FF_no_nan_handling.cuda_processing.to_device("radios", radios)
348
328
  # In a cuda kernel, no one can hear you scream
349
329
  FF_no_nan_handling.normalize_radios(d_radios)
350
330
  radios = d_radios.get()
@@ -363,7 +343,7 @@ class TestFlatField:
363
343
  def test_srcurrent(self):
364
344
  radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent")
365
345
 
366
- FF = FlatFieldDataUrls(
346
+ FF = FlatField(
367
347
  radios.shape,
368
348
  flats,
369
349
  darks,
@@ -377,9 +357,8 @@ class TestFlatField:
377
357
  @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test")
378
358
  def test_srcurrent_cuda(self):
379
359
  radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent")
380
- d_radios = garray.to_gpu(radios)
381
360
 
382
- FF = CudaFlatFieldDataUrls(
361
+ FF = CudaFlatField(
383
362
  radios.shape,
384
363
  flats,
385
364
  darks,
@@ -387,6 +366,7 @@ class TestFlatField:
387
366
  radios_srcurrent=config["radios_srcurrent"],
388
367
  flats_srcurrent=config["flats_srcurrent"],
389
368
  )
369
+ d_radios = FF.cuda_processing.to_device("radios", radios)
390
370
  FF.normalize_radios(d_radios)
391
371
  radios_corr = d_radios.get()
392
372
  self.check_normalized_radios(radios_corr, config["expected_result"])
@@ -429,8 +409,6 @@ class FlatFieldTestDataset:
429
409
 
430
410
  def __init__(self):
431
411
  self._generate_projections()
432
- self._dump_to_h5()
433
- self._generate_dataurls()
434
412
 
435
413
  def get_flat_idx(self, proj_idx):
436
414
  flats_idx = sorted(list(self.flats.keys()))
@@ -461,26 +439,6 @@ class FlatFieldTestDataset:
461
439
  self.projs[str(proj_idx)] = np.zeros(self.shp, "f") + proj_val
462
440
  self.projs_data[i] = self.projs[str(proj_idx)]
463
441
 
464
- def _dump_to_h5(self):
465
- self.tempdir = mkdtemp(prefix="nabu_")
466
- self.fname = os.path.join(self.tempdir, "projs_flats.h5")
467
- dicttoh5(
468
- {
469
- "projs": self.projs,
470
- "flats": {str(k): v for k, v in self.flats.items()},
471
- "darks": {str(k): v for k, v in self.darks.items()},
472
- },
473
- h5file=self.fname,
474
- )
475
-
476
- def _generate_dataurls(self):
477
- self.flats_urls = {}
478
- for idx in self.flats.keys():
479
- self.flats_urls[int(idx)] = DataUrl(file_path=self.fname, data_path="/flats/%d" % idx)
480
- self.darks_urls = {}
481
- for idx in self.darks.keys():
482
- self.darks_urls[int(idx)] = DataUrl(file_path=self.fname, data_path="/darks/0")
483
-
484
442
 
485
443
  @pytest.fixture(scope="class")
486
444
  def bootstraph5(request):
@@ -496,9 +454,6 @@ def bootstraph5(request):
496
454
  cls.tol_std = 1e-3
497
455
 
498
456
  yield
499
- # tear-down
500
- os.remove(cls.dataset.fname)
501
- os.rmdir(cls.dataset.tempdir)
502
457
 
503
458
 
504
459
  @pytest.mark.usefixtures("bootstraph5")
@@ -512,10 +467,10 @@ class TestFlatFieldH5:
512
467
  assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization"
513
468
 
514
469
  def test_flatfield(self):
515
- flatfield = FlatFieldDataUrls(
470
+ flatfield = FlatField(
516
471
  self.dataset.projs_data.shape,
517
- self.dataset.flats_urls,
518
- self.dataset.darks_urls,
472
+ self.dataset.flats,
473
+ self.dataset.darks,
519
474
  radios_indices=self.dataset.projs_idx,
520
475
  interpolation="linear",
521
476
  )
@@ -525,13 +480,13 @@ class TestFlatFieldH5:
525
480
 
526
481
  @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test")
527
482
  def test_cuda_flatfield(self):
528
- d_projs = garray.to_gpu(self.dataset.projs_data)
529
- cuda_flatfield = CudaFlatFieldDataUrls(
483
+ cuda_flatfield = CudaFlatField(
530
484
  self.dataset.projs_data.shape,
531
- self.dataset.flats_urls,
532
- self.dataset.darks_urls,
485
+ self.dataset.flats,
486
+ self.dataset.darks,
533
487
  radios_indices=self.dataset.projs_idx,
534
488
  )
489
+ d_projs = cuda_flatfield.cuda_processing.to_device("d_projs", self.dataset.projs_data)
535
490
  cuda_flatfield.normalize_radios(d_projs)
536
491
  projs = d_projs.get()
537
492
  self.check_normalization(projs)
@@ -551,24 +506,17 @@ class TestFlatFieldH5:
551
506
  def generate_test_flatfield(n_radios, radio_shape, flat_interval, h5_fname):
552
507
  radios = np.zeros((n_radios,) + radio_shape, "f")
553
508
  dark_data = np.ones(radios.shape[1:], "f")
554
- tempdir = mkdtemp(prefix="nabu_")
555
- testffname = os.path.join(tempdir, h5_fname)
556
509
  flats = {}
557
- flats_urls = {}
558
510
  # F_i = i + 2
559
511
  # R_i = i*(F_i - 1) + 1
560
512
  # N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i
561
513
  for i in range(n_radios):
562
514
  f_i = i + 2
563
515
  if (i % flat_interval) == 0:
564
- flats["flats_%06d" % i] = np.zeros(radio_shape, "f") + f_i
565
- flats_urls[i] = DataUrl(file_path=testffname, data_path=str("/flats/flats_%06d" % i), scheme="silx")
516
+ flats[i] = np.zeros(radio_shape, "f") + f_i
566
517
  radios[i] = i * (f_i - 1) + 1
567
- dark = {"dark_0000": dark_data}
568
- dicttoh5(flats, testffname, h5path="/flats", mode="w")
569
- dicttoh5(dark, testffname, h5path="/dark", mode="a")
570
- dark_url = {0: DataUrl(file_path=testffname, data_path="/dark/dark_0000", scheme="silx")}
571
- return radios, flats_urls, dark_url
518
+ darks = {0: dark_data}
519
+ return radios, flats, darks
572
520
 
573
521
 
574
522
  @pytest.fixture(scope="class")
@@ -582,17 +530,14 @@ def bootstrap_multiflats(request):
582
530
 
583
531
  radios, flats, dark = generate_test_flatfield(n_radios, radio_shape, cls.flat_interval, h5_fname)
584
532
  cls.radios = radios
585
- cls.flats_urls = flats
586
- cls.darks_urls = dark
533
+ cls.flats = flats
534
+ cls.darks = dark
587
535
  cls.expected_results = np.arange(n_radios)
588
536
 
589
537
  cls.tol = 5e-4
590
538
  cls.tol_std = 1e-4
591
539
 
592
540
  yield
593
- # tear down
594
- os.remove(dark[0].file_path())
595
- os.rmdir(os.path.dirname(dark[0].file_path()))
596
541
 
597
542
 
598
543
  @pytest.mark.usefixtures("bootstrap_multiflats")
@@ -607,7 +552,7 @@ class TestFlatFieldMultiFlat:
607
552
  assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization"
608
553
 
609
554
  def test_flatfield(self):
610
- flatfield = FlatFieldDataUrls(self.radios.shape, self.flats_urls, self.darks_urls, interpolation="linear")
555
+ flatfield = FlatField(self.radios.shape, self.flats, self.darks, interpolation="linear")
611
556
  projs = np.copy(self.radios)
612
557
  flatfield.normalize_radios(projs)
613
558
  print(projs[:, 0, 0])
@@ -615,12 +560,12 @@ class TestFlatFieldMultiFlat:
615
560
 
616
561
  @pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test")
617
562
  def test_cuda_flatfield(self):
618
- d_projs = garray.to_gpu(self.radios)
619
- cuda_flatfield = CudaFlatFieldDataUrls(
563
+ cuda_flatfield = CudaFlatField(
620
564
  self.radios.shape,
621
- self.flats_urls,
622
- self.darks_urls,
565
+ self.flats,
566
+ self.darks,
623
567
  )
568
+ d_projs = cuda_flatfield.cuda_processing.to_device("radios", self.radios)
624
569
  cuda_flatfield.normalize_radios(d_projs)
625
570
  projs = d_projs.get()
626
571
  self.check_normalization(projs)
@@ -1,32 +1,35 @@
1
1
  import pytest
2
2
  import numpy as np
3
3
  from nabu.preproc.phase import PaganinPhaseRetrieval
4
- from nabu.testutils import get_data
4
+ from nabu.processing.fft_cuda import get_available_fft_implems
5
+ from nabu.testutils import generate_tests_scenarios, get_data
5
6
  from nabu.thirdparty.tomopy_phase import retrieve_phase
6
- from nabu.cuda.utils import __has_pycuda__, __has_cufft__
7
+ from nabu.cuda.utils import __has_pycuda__
7
8
 
9
+ __has_cufft__ = False
8
10
  if __has_pycuda__:
9
11
  from nabu.preproc.phase_cuda import CudaPaganinPhaseRetrieval
10
12
 
11
- scenarios = [
12
- {
13
- "distance": 1,
14
- "energy": 35,
15
- "delta_beta": 1e1,
16
- "margin": ((50, 50), (0, 0)),
17
- }
18
- ]
13
+ avail_fft = get_available_fft_implems()
14
+ __has_cufft__ = len(avail_fft) > 0
19
15
 
16
+ scenarios = {
17
+ "distance": [1],
18
+ "energy": [35],
19
+ "delta_beta": [1e1],
20
+ "margin": [((50, 50), (0, 0)), None],
21
+ }
20
22
 
21
- @pytest.fixture(scope="class", params=scenarios)
23
+ scenarios = generate_tests_scenarios(scenarios)
24
+
25
+
26
+ @pytest.fixture(scope="class")
22
27
  def bootstrap(request):
23
28
  cls = request.cls
24
- cls.paganin_config = request.param
25
29
 
26
30
  cls.data = get_data("mri_proj_astra.npz")["data"]
27
31
  cls.rtol = 1.1e-6
28
32
  cls.rtol_pag = 5e-3
29
- cls.paganin = PaganinPhaseRetrieval(cls.data.shape, **cls.paganin_config)
30
33
 
31
34
 
32
35
  @pytest.mark.usefixtures("bootstrap")
@@ -36,32 +39,53 @@ class TestPaganin:
36
39
  The reference implementation is tomopy.
37
40
  """
38
41
 
39
- def crop_to_margin(self, data):
40
- s0, s1 = self.paganin.shape_inner
41
- ((U, _), (L, _)) = self.paganin.margin
42
- return data[U : U + s0, L : L + s1]
42
+ @staticmethod
43
+ def get_paganin_instance_and_data(cfg, data):
44
+ pag_kwargs = cfg.copy()
45
+ margin = pag_kwargs.pop("margin")
46
+ if margin is not None:
47
+ data = np.pad(data, margin, mode="edge")
48
+ paganin = PaganinPhaseRetrieval(data.shape, **pag_kwargs)
49
+ return paganin, data, pag_kwargs
43
50
 
44
- def test_paganin(self):
45
- data_tomopy = np.atleast_3d(np.copy(self.data)).T
51
+ @staticmethod
52
+ def crop_to_margin(data, margin):
53
+ if margin is None:
54
+ return data
55
+ ((U, D), (L, R)) = margin
56
+ D = None if D == 0 else -D
57
+ R = None if R == 0 else -R
58
+ return data[U:D, L:R]
46
59
 
60
+ @pytest.mark.parametrize("config", scenarios)
61
+ def test_paganin(self, config):
62
+ paganin, data, _ = self.get_paganin_instance_and_data(config, self.data)
63
+ res = paganin.apply_filter(data)
64
+
65
+ data_tomopy = np.atleast_3d(np.copy(data)).T
47
66
  res_tomopy = retrieve_phase(
48
67
  data_tomopy,
49
- pixel_size=self.paganin.pixel_size_micron * 1e-4,
50
- dist=self.paganin.distance_cm,
51
- energy=self.paganin.energy_kev,
52
- alpha=1.0 / (4 * 3.141592**2 * self.paganin.delta_beta),
68
+ pixel_size=paganin.pixel_size_xy_micron[0] * 1e-4,
69
+ dist=paganin.distance_cm,
70
+ energy=paganin.energy_kev,
71
+ alpha=1.0 / (4 * 3.141592**2 * paganin.delta_beta),
53
72
  )
54
- res_tomopy = self.crop_to_margin(res_tomopy[0].T)
55
73
 
56
- res = self.paganin.apply_filter(self.data)
74
+ res_tomopy = self.crop_to_margin(res_tomopy[0].T, config["margin"])
75
+ res = self.crop_to_margin(res, config["margin"])
57
76
 
58
77
  errmax = np.max(np.abs(res - res_tomopy) / np.max(res_tomopy))
59
78
  assert errmax < self.rtol_pag, "Max error is too high"
60
79
 
61
- @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and scikit-cuda for this test")
62
- def test_gpu_paganin(self):
63
- gpu_paganin = CudaPaganinPhaseRetrieval(self.data.shape, **self.paganin_config)
64
- ref = self.paganin.apply_filter(self.data)
65
- res = gpu_paganin.apply_filter(self.data)
80
+ @pytest.mark.skipif(
81
+ not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
82
+ )
83
+ @pytest.mark.parametrize("config", scenarios)
84
+ def test_gpu_paganin(self, config):
85
+ paganin, data, pag_kwargs = self.get_paganin_instance_and_data(config, self.data)
86
+
87
+ gpu_paganin = CudaPaganinPhaseRetrieval(data.shape, **pag_kwargs)
88
+ ref = paganin.apply_filter(data)
89
+ res = gpu_paganin.apply_filter(data)
66
90
  errmax = np.max(np.abs((res - ref) / np.max(ref)))
67
91
  assert errmax < self.rtol, "Max error is too high"