nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/bootstrap.py +2 -3
- nabu/app/cast_volume.py +4 -2
- nabu/app/cli_configs.py +5 -0
- nabu/app/composite_cor.py +1 -1
- nabu/app/create_distortion_map_from_poly.py +5 -6
- nabu/app/diag_to_pix.py +7 -19
- nabu/app/diag_to_rot.py +14 -29
- nabu/app/double_flatfield.py +32 -44
- nabu/app/parse_reconstruction_log.py +3 -0
- nabu/app/reconstruct.py +53 -15
- nabu/app/reconstruct_helical.py +2 -2
- nabu/app/stitching.py +27 -13
- nabu/app/tests/__init__.py +0 -0
- nabu/app/tests/test_reduce_dark_flat.py +4 -1
- nabu/cuda/kernel.py +11 -2
- nabu/cuda/processing.py +2 -2
- nabu/cuda/src/cone.cu +77 -0
- nabu/cuda/src/hierarchical_backproj.cu +271 -0
- nabu/cuda/utils.py +0 -6
- nabu/estimation/alignment.py +5 -19
- nabu/estimation/cor.py +173 -599
- nabu/estimation/cor_sino.py +356 -26
- nabu/estimation/focus.py +63 -11
- nabu/estimation/tests/test_cor.py +124 -58
- nabu/estimation/tests/test_focus.py +6 -6
- nabu/estimation/tilt.py +2 -1
- nabu/estimation/utils.py +5 -33
- nabu/io/__init__.py +1 -1
- nabu/io/cast_volume.py +1 -1
- nabu/io/reader.py +416 -21
- nabu/io/tests/test_readers.py +422 -0
- nabu/io/tests/test_writers.py +1 -102
- nabu/io/writer.py +4 -433
- nabu/opencl/kernel.py +14 -3
- nabu/opencl/processing.py +8 -0
- nabu/pipeline/config_validators.py +5 -2
- nabu/pipeline/datadump.py +12 -5
- nabu/pipeline/estimators.py +162 -188
- nabu/pipeline/fullfield/chunked.py +168 -92
- nabu/pipeline/fullfield/chunked_cuda.py +7 -3
- nabu/pipeline/fullfield/computations.py +2 -7
- nabu/pipeline/fullfield/dataset_validator.py +0 -4
- nabu/pipeline/fullfield/nabu_config.py +37 -13
- nabu/pipeline/fullfield/processconfig.py +22 -13
- nabu/pipeline/fullfield/reconstruction.py +13 -9
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +1 -1
- nabu/pipeline/params.py +21 -1
- nabu/pipeline/processconfig.py +1 -12
- nabu/pipeline/reader.py +146 -0
- nabu/pipeline/tests/test_estimators.py +44 -72
- nabu/pipeline/utils.py +4 -2
- nabu/pipeline/writer.py +10 -2
- nabu/preproc/ccd_cuda.py +1 -1
- nabu/preproc/ctf.py +14 -7
- nabu/preproc/ctf_cuda.py +2 -3
- nabu/preproc/double_flatfield.py +5 -12
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/flatfield.py +5 -1
- nabu/preproc/flatfield_cuda.py +5 -1
- nabu/preproc/phase.py +24 -73
- nabu/preproc/phase_cuda.py +5 -8
- nabu/preproc/tests/test_ctf.py +11 -7
- nabu/preproc/tests/test_flatfield.py +67 -122
- nabu/preproc/tests/test_paganin.py +54 -30
- nabu/processing/azim.py +206 -0
- nabu/processing/convolution_cuda.py +1 -1
- nabu/processing/fft_cuda.py +15 -17
- nabu/processing/histogram.py +2 -0
- nabu/processing/histogram_cuda.py +2 -1
- nabu/processing/kernel_base.py +3 -0
- nabu/processing/muladd_cuda.py +1 -0
- nabu/processing/padding_opencl.py +1 -1
- nabu/processing/roll_opencl.py +1 -0
- nabu/processing/rotation_cuda.py +2 -2
- nabu/processing/tests/test_fft.py +17 -10
- nabu/processing/unsharp_cuda.py +1 -1
- nabu/reconstruction/cone.py +104 -40
- nabu/reconstruction/fbp.py +3 -0
- nabu/reconstruction/fbp_base.py +7 -2
- nabu/reconstruction/filtering.py +20 -7
- nabu/reconstruction/filtering_cuda.py +7 -1
- nabu/reconstruction/hbp.py +424 -0
- nabu/reconstruction/mlem.py +99 -0
- nabu/reconstruction/reconstructor.py +2 -0
- nabu/reconstruction/rings_cuda.py +19 -19
- nabu/reconstruction/sinogram_cuda.py +1 -0
- nabu/reconstruction/sinogram_opencl.py +3 -1
- nabu/reconstruction/tests/test_cone.py +10 -5
- nabu/reconstruction/tests/test_deringer.py +7 -6
- nabu/reconstruction/tests/test_fbp.py +124 -10
- nabu/reconstruction/tests/test_filtering.py +13 -11
- nabu/reconstruction/tests/test_halftomo.py +30 -4
- nabu/reconstruction/tests/test_mlem.py +91 -0
- nabu/reconstruction/tests/test_reconstructor.py +8 -3
- nabu/resources/dataset_analyzer.py +142 -92
- nabu/resources/gpu.py +1 -0
- nabu/resources/nxflatfield.py +134 -125
- nabu/resources/templates/id16a_fluo.conf +42 -0
- nabu/resources/tests/test_extract.py +10 -0
- nabu/resources/tests/test_nxflatfield.py +2 -2
- nabu/stitching/alignment.py +80 -24
- nabu/stitching/config.py +105 -68
- nabu/stitching/definitions.py +1 -0
- nabu/stitching/frame_composition.py +68 -60
- nabu/stitching/overlap.py +91 -51
- nabu/stitching/single_axis_stitching.py +32 -0
- nabu/stitching/slurm_utils.py +6 -6
- nabu/stitching/stitcher/__init__.py +0 -0
- nabu/stitching/stitcher/base.py +124 -0
- nabu/stitching/stitcher/dumper/__init__.py +3 -0
- nabu/stitching/stitcher/dumper/base.py +94 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
- nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
- nabu/stitching/stitcher/post_processing.py +555 -0
- nabu/stitching/stitcher/pre_processing.py +1068 -0
- nabu/stitching/stitcher/single_axis.py +484 -0
- nabu/stitching/stitcher/stitcher.py +0 -0
- nabu/stitching/stitcher/y_stitcher.py +13 -0
- nabu/stitching/stitcher/z_stitcher.py +45 -0
- nabu/stitching/stitcher_2D.py +278 -0
- nabu/stitching/tests/test_config.py +12 -37
- nabu/stitching/tests/test_frame_composition.py +33 -59
- nabu/stitching/tests/test_overlap.py +149 -7
- nabu/stitching/tests/test_utils.py +1 -1
- nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
- nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
- nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
- nabu/stitching/utils/__init__.py +1 -0
- nabu/stitching/utils/post_processing.py +281 -0
- nabu/stitching/utils/tests/test_post-processing.py +21 -0
- nabu/stitching/{utils.py → utils/utils.py} +79 -52
- nabu/stitching/y_stitching.py +27 -0
- nabu/stitching/z_stitching.py +32 -2281
- nabu/testutils.py +1 -152
- nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
- nabu/utils.py +158 -61
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
- nabu/io/tiffwriter_zmm.py +0 -99
- nabu/pipeline/fallback_utils.py +0 -149
- nabu/pipeline/helical/tests/test_accumulator.py +0 -158
- nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
- nabu/pipeline/helical/tests/test_strategy.py +0 -61
- nabu/pipeline/helical/utils.py +0 -51
- nabu/pipeline/tests/test_chunk_reader.py +0 -74
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,11 @@
|
|
1
|
-
from tempfile import mkdtemp
|
2
1
|
import os
|
3
2
|
import numpy as np
|
4
3
|
import pytest
|
5
|
-
from silx.io.url import DataUrl
|
6
|
-
from silx.io import get_data
|
7
|
-
from silx.io.dictdump import dicttoh5
|
8
4
|
from nabu.cuda.utils import get_cuda_context, __has_pycuda__
|
9
|
-
from nabu.preproc.flatfield import FlatField
|
5
|
+
from nabu.preproc.flatfield import FlatField
|
10
6
|
|
11
7
|
if __has_pycuda__:
|
12
|
-
|
13
|
-
from nabu.preproc.flatfield_cuda import CudaFlatFieldDataUrls, CudaFlatField
|
8
|
+
from nabu.preproc.flatfield_cuda import CudaFlatField
|
14
9
|
|
15
10
|
|
16
11
|
# Flats values should be O(k) so that linear interpolation between flats gives exact results
|
@@ -84,7 +79,6 @@ def generate_test_flatfield_generalized(
|
|
84
79
|
flats_values,
|
85
80
|
darks_indices,
|
86
81
|
darks_values,
|
87
|
-
h5_fname,
|
88
82
|
dtype=np.uint16,
|
89
83
|
):
|
90
84
|
"""
|
@@ -112,14 +106,11 @@ def generate_test_flatfield_generalized(
|
|
112
106
|
-------
|
113
107
|
radios: numpy.ndarray
|
114
108
|
3D array with raw radios
|
115
|
-
darks: dict of
|
116
|
-
Dictionary where each key is the dark indice, and value is
|
117
|
-
flats: dict of
|
118
|
-
Dictionary where each key is the flat indice, and value is
|
109
|
+
darks: dict of arrays
|
110
|
+
Dictionary where each key is the dark indice, and value is an array
|
111
|
+
flats: dict of arrays
|
112
|
+
Dictionary where each key is the flat indice, and value is an array
|
119
113
|
"""
|
120
|
-
tempdir = mkdtemp(prefix="nabu_")
|
121
|
-
testffname = os.path.join(tempdir, h5_fname)
|
122
|
-
|
123
114
|
# Radios
|
124
115
|
radios = np.zeros((len(radios_values),) + image_shape, dtype="f")
|
125
116
|
n_radios = radios.shape[0]
|
@@ -129,26 +120,15 @@ def generate_test_flatfield_generalized(
|
|
129
120
|
|
130
121
|
# Flats
|
131
122
|
flats = {}
|
132
|
-
flats_urls = {}
|
133
123
|
for i, flat_idx in enumerate(flats_indices):
|
134
|
-
flats[
|
135
|
-
flats_urls[flat_idx] = DataUrl(
|
136
|
-
file_path=testffname, data_path=str("/flats/flats_%06d" % flat_idx), scheme="silx"
|
137
|
-
)
|
124
|
+
flats[flat_idx] = np.zeros(img_shape, dtype=dtype) + flats_values[i]
|
138
125
|
|
139
126
|
# Darks
|
140
127
|
darks = {}
|
141
|
-
darks_urls = {}
|
142
128
|
for i, dark_idx in enumerate(darks_indices):
|
143
|
-
darks[
|
144
|
-
darks_urls[dark_idx] = DataUrl(
|
145
|
-
file_path=testffname, data_path=str("/darks/darks_%06d" % dark_idx), scheme="silx"
|
146
|
-
)
|
129
|
+
darks[dark_idx] = np.zeros(img_shape, dtype=dtype) + darks_values[i]
|
147
130
|
|
148
|
-
|
149
|
-
dicttoh5(darks, testffname, h5path="/darks", mode="a")
|
150
|
-
|
151
|
-
return radios, flats_urls, darks_urls
|
131
|
+
return radios, flats, darks
|
152
132
|
|
153
133
|
|
154
134
|
@pytest.fixture(scope="class")
|
@@ -173,7 +153,7 @@ def bootstrap(request):
|
|
173
153
|
class TestFlatField:
|
174
154
|
def get_test_elements(self, case_name):
|
175
155
|
config = flatfield_tests_cases[case_name]
|
176
|
-
radios_stack,
|
156
|
+
radios_stack, flats, darks = generate_test_flatfield_generalized(
|
177
157
|
config["image_shape"],
|
178
158
|
config["radios_indices"],
|
179
159
|
config["radios_values"],
|
@@ -181,12 +161,11 @@ class TestFlatField:
|
|
181
161
|
config["flats_values"],
|
182
162
|
config["darks_indices"],
|
183
163
|
config["darks_values"],
|
184
|
-
"test_ff.h5",
|
185
164
|
)
|
186
|
-
fname = flats_url[list(flats_url.keys())[0]].file_path()
|
187
|
-
self.tmp_files.append(fname)
|
188
|
-
self.tmp_dirs.append(os.path.dirname(fname))
|
189
|
-
return radios_stack,
|
165
|
+
# fname = flats_url[list(flats_url.keys())[0]].file_path()
|
166
|
+
# self.tmp_files.append(fname)
|
167
|
+
# self.tmp_dirs.append(os.path.dirname(fname))
|
168
|
+
return radios_stack, flats, darks, config
|
190
169
|
|
191
170
|
@staticmethod
|
192
171
|
def check_normalized_radios(radios_corr, expected_values):
|
@@ -203,9 +182,9 @@ class TestFlatField:
|
|
203
182
|
(I - D)/(F - D) where I = (1, 2, ...), D = 1, F = 0.5
|
204
183
|
= (0, -2, -4, -6, ...)
|
205
184
|
"""
|
206
|
-
radios_stack,
|
185
|
+
radios_stack, flats, darks, config = self.get_test_elements("simple_nearest_interp")
|
207
186
|
|
208
|
-
flatfield =
|
187
|
+
flatfield = FlatField(radios_stack.shape, flats, darks)
|
209
188
|
radios_corr = flatfield.normalize_radios(np.copy(radios_stack))
|
210
189
|
self.check_normalized_radios(radios_corr, config["expected_result"])
|
211
190
|
|
@@ -213,16 +192,17 @@ class TestFlatField:
|
|
213
192
|
"""
|
214
193
|
Same as test_flatfield_simple, but in a vertical subregion of the radios.
|
215
194
|
"""
|
216
|
-
radios_stack,
|
195
|
+
radios_stack, flats, darks, config = self.get_test_elements("simple_nearest_interp")
|
217
196
|
end_z = 51
|
197
|
+
flats = {k: arr[:end_z, :] for k, arr in flats.items()}
|
198
|
+
darks = {k: arr[:end_z, :] for k, arr in darks.items()}
|
218
199
|
radios_chunk = np.copy(radios_stack[:, :end_z, :])
|
219
200
|
# we only have a chunk in memory. Instantiate the class with the
|
220
201
|
# corresponding subregion to only load the relevant part of dark/flat
|
221
|
-
flatfield =
|
202
|
+
flatfield = FlatField(
|
222
203
|
radios_chunk.shape,
|
223
|
-
|
224
|
-
|
225
|
-
sub_region=(None, None, None, end_z), # start_x, end_x, start_z, end_z
|
204
|
+
flats,
|
205
|
+
darks,
|
226
206
|
)
|
227
207
|
radios_corr = flatfield.normalize_radios(radios_chunk)
|
228
208
|
self.check_normalized_radios(radios_corr, config["expected_result"])
|
@@ -239,8 +219,8 @@ class TestFlatField:
|
|
239
219
|
= (I-D)/(F-D)
|
240
220
|
= (I-1)/I
|
241
221
|
"""
|
242
|
-
radios_stack,
|
243
|
-
flatfield =
|
222
|
+
radios_stack, flats, darks, config = self.get_test_elements("two_flats_no_radios_indices")
|
223
|
+
flatfield = FlatField(radios_stack.shape, flats, darks)
|
244
224
|
radios_corr = flatfield.normalize_radios(np.copy(radios_stack))
|
245
225
|
self.check_normalized_radios(radios_corr, config["expected_result"])
|
246
226
|
|
@@ -249,10 +229,10 @@ class TestFlatField:
|
|
249
229
|
# F = 2 11
|
250
230
|
# F_i = 2 3.8 5.6 7.4 9.2 11 11 11 11 11
|
251
231
|
# R = 0 .357 .435 .469 .488 .5 .6 .7 .8 .9
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
flatfield =
|
232
|
+
flats = {k: v.copy() for k, v in flats.items()}
|
233
|
+
flats[5] = flats[9]
|
234
|
+
flats.pop(9)
|
235
|
+
flatfield = FlatField(radios_stack.shape, flats, darks)
|
256
236
|
radios_corr = flatfield.normalize_radios(np.copy(radios_stack))
|
257
237
|
self.check_normalized_radios(
|
258
238
|
radios_corr, [0.0, 0.35714286, 0.43478261, 0.46875, 0.48780488, 0.5, 0.6, 0.7, 0.8, 0.9]
|
@@ -263,13 +243,13 @@ class TestFlatField:
|
|
263
243
|
"""
|
264
244
|
Test the flat-field with cuda back-end.
|
265
245
|
"""
|
266
|
-
radios_stack,
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
darks_url,
|
246
|
+
radios_stack, flats, darks, config = self.get_test_elements("two_flats_no_radios_indices")
|
247
|
+
cuda_flatfield = CudaFlatField(
|
248
|
+
radios_stack.shape,
|
249
|
+
flats,
|
250
|
+
darks,
|
272
251
|
)
|
252
|
+
d_radios = cuda_flatfield.cuda_processing.to_device("d_radios", radios_stack.astype("f"))
|
273
253
|
cuda_flatfield.normalize_radios(d_radios)
|
274
254
|
radios_corr = d_radios.get()
|
275
255
|
self.check_normalized_radios(radios_corr, config["expected_result"])
|
@@ -277,27 +257,27 @@ class TestFlatField:
|
|
277
257
|
# Linear interpolation, two flats, one dark
|
278
258
|
def test_twoflats_simple(self):
|
279
259
|
radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices")
|
280
|
-
FF =
|
260
|
+
FF = FlatField(radios.shape, flats, darks, radios_indices=config["radios_indices"])
|
281
261
|
FF.normalize_radios(radios)
|
282
262
|
self.check_normalized_radios(radios, config["expected_result"])
|
283
263
|
|
284
264
|
def _setup_numerical_issue(self):
|
285
265
|
radios, flats, darks, config = self.get_test_elements("two_flats_with_radios_indices")
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
for flat_idx,
|
290
|
-
|
291
|
-
|
292
|
-
for dark_idx,
|
293
|
-
|
294
|
-
|
266
|
+
flats_copy = {}
|
267
|
+
darks_copy = {}
|
268
|
+
|
269
|
+
for flat_idx, flat in flats.items():
|
270
|
+
flats_copy[flat_idx] = flat.copy()
|
271
|
+
flats_copy[flat_idx][0, 0] = 99
|
272
|
+
for dark_idx, dark in darks.items():
|
273
|
+
darks_copy[dark_idx] = dark.copy()
|
274
|
+
darks_copy[dark_idx][0, 0] = 99
|
295
275
|
radios[:, 0, 0] = 99
|
296
|
-
return radios,
|
276
|
+
return radios, flats_copy, darks_copy, config
|
297
277
|
|
298
278
|
def _check_numerical_issue(self, radios, expected_result, nan_value=None):
|
299
279
|
if nan_value is None:
|
300
|
-
assert np.
|
280
|
+
assert np.all(np.logical_not(np.isfinite(radios[:, 0, 0]))), "First pixel should be nan or inf"
|
301
281
|
radios[:, 0, 0] = radios[:, 1, 1]
|
302
282
|
self.check_normalized_radios(radios, expected_result)
|
303
283
|
else:
|
@@ -341,10 +321,10 @@ class TestFlatField:
|
|
341
321
|
"""
|
342
322
|
radios, flats, darks, config = self._setup_numerical_issue()
|
343
323
|
radios0 = radios.copy()
|
344
|
-
d_radios = garray.to_gpu(radios)
|
345
324
|
FF_no_nan_handling = CudaFlatField(
|
346
325
|
radios.shape, flats, darks, radios_indices=config["radios_indices"], nan_value=None
|
347
326
|
)
|
327
|
+
d_radios = FF_no_nan_handling.cuda_processing.to_device("radios", radios)
|
348
328
|
# In a cuda kernel, no one can hear you scream
|
349
329
|
FF_no_nan_handling.normalize_radios(d_radios)
|
350
330
|
radios = d_radios.get()
|
@@ -363,7 +343,7 @@ class TestFlatField:
|
|
363
343
|
def test_srcurrent(self):
|
364
344
|
radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent")
|
365
345
|
|
366
|
-
FF =
|
346
|
+
FF = FlatField(
|
367
347
|
radios.shape,
|
368
348
|
flats,
|
369
349
|
darks,
|
@@ -377,9 +357,8 @@ class TestFlatField:
|
|
377
357
|
@pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test")
|
378
358
|
def test_srcurrent_cuda(self):
|
379
359
|
radios, flats, darks, config = self.get_test_elements("three_flats_srcurrent")
|
380
|
-
d_radios = garray.to_gpu(radios)
|
381
360
|
|
382
|
-
FF =
|
361
|
+
FF = CudaFlatField(
|
383
362
|
radios.shape,
|
384
363
|
flats,
|
385
364
|
darks,
|
@@ -387,6 +366,7 @@ class TestFlatField:
|
|
387
366
|
radios_srcurrent=config["radios_srcurrent"],
|
388
367
|
flats_srcurrent=config["flats_srcurrent"],
|
389
368
|
)
|
369
|
+
d_radios = FF.cuda_processing.to_device("radios", radios)
|
390
370
|
FF.normalize_radios(d_radios)
|
391
371
|
radios_corr = d_radios.get()
|
392
372
|
self.check_normalized_radios(radios_corr, config["expected_result"])
|
@@ -429,8 +409,6 @@ class FlatFieldTestDataset:
|
|
429
409
|
|
430
410
|
def __init__(self):
|
431
411
|
self._generate_projections()
|
432
|
-
self._dump_to_h5()
|
433
|
-
self._generate_dataurls()
|
434
412
|
|
435
413
|
def get_flat_idx(self, proj_idx):
|
436
414
|
flats_idx = sorted(list(self.flats.keys()))
|
@@ -461,26 +439,6 @@ class FlatFieldTestDataset:
|
|
461
439
|
self.projs[str(proj_idx)] = np.zeros(self.shp, "f") + proj_val
|
462
440
|
self.projs_data[i] = self.projs[str(proj_idx)]
|
463
441
|
|
464
|
-
def _dump_to_h5(self):
|
465
|
-
self.tempdir = mkdtemp(prefix="nabu_")
|
466
|
-
self.fname = os.path.join(self.tempdir, "projs_flats.h5")
|
467
|
-
dicttoh5(
|
468
|
-
{
|
469
|
-
"projs": self.projs,
|
470
|
-
"flats": {str(k): v for k, v in self.flats.items()},
|
471
|
-
"darks": {str(k): v for k, v in self.darks.items()},
|
472
|
-
},
|
473
|
-
h5file=self.fname,
|
474
|
-
)
|
475
|
-
|
476
|
-
def _generate_dataurls(self):
|
477
|
-
self.flats_urls = {}
|
478
|
-
for idx in self.flats.keys():
|
479
|
-
self.flats_urls[int(idx)] = DataUrl(file_path=self.fname, data_path="/flats/%d" % idx)
|
480
|
-
self.darks_urls = {}
|
481
|
-
for idx in self.darks.keys():
|
482
|
-
self.darks_urls[int(idx)] = DataUrl(file_path=self.fname, data_path="/darks/0")
|
483
|
-
|
484
442
|
|
485
443
|
@pytest.fixture(scope="class")
|
486
444
|
def bootstraph5(request):
|
@@ -496,9 +454,6 @@ def bootstraph5(request):
|
|
496
454
|
cls.tol_std = 1e-3
|
497
455
|
|
498
456
|
yield
|
499
|
-
# tear-down
|
500
|
-
os.remove(cls.dataset.fname)
|
501
|
-
os.rmdir(cls.dataset.tempdir)
|
502
457
|
|
503
458
|
|
504
459
|
@pytest.mark.usefixtures("bootstraph5")
|
@@ -512,10 +467,10 @@ class TestFlatFieldH5:
|
|
512
467
|
assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization"
|
513
468
|
|
514
469
|
def test_flatfield(self):
|
515
|
-
flatfield =
|
470
|
+
flatfield = FlatField(
|
516
471
|
self.dataset.projs_data.shape,
|
517
|
-
self.dataset.
|
518
|
-
self.dataset.
|
472
|
+
self.dataset.flats,
|
473
|
+
self.dataset.darks,
|
519
474
|
radios_indices=self.dataset.projs_idx,
|
520
475
|
interpolation="linear",
|
521
476
|
)
|
@@ -525,13 +480,13 @@ class TestFlatFieldH5:
|
|
525
480
|
|
526
481
|
@pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test")
|
527
482
|
def test_cuda_flatfield(self):
|
528
|
-
|
529
|
-
cuda_flatfield = CudaFlatFieldDataUrls(
|
483
|
+
cuda_flatfield = CudaFlatField(
|
530
484
|
self.dataset.projs_data.shape,
|
531
|
-
self.dataset.
|
532
|
-
self.dataset.
|
485
|
+
self.dataset.flats,
|
486
|
+
self.dataset.darks,
|
533
487
|
radios_indices=self.dataset.projs_idx,
|
534
488
|
)
|
489
|
+
d_projs = cuda_flatfield.cuda_processing.to_device("d_projs", self.dataset.projs_data)
|
535
490
|
cuda_flatfield.normalize_radios(d_projs)
|
536
491
|
projs = d_projs.get()
|
537
492
|
self.check_normalization(projs)
|
@@ -551,24 +506,17 @@ class TestFlatFieldH5:
|
|
551
506
|
def generate_test_flatfield(n_radios, radio_shape, flat_interval, h5_fname):
|
552
507
|
radios = np.zeros((n_radios,) + radio_shape, "f")
|
553
508
|
dark_data = np.ones(radios.shape[1:], "f")
|
554
|
-
tempdir = mkdtemp(prefix="nabu_")
|
555
|
-
testffname = os.path.join(tempdir, h5_fname)
|
556
509
|
flats = {}
|
557
|
-
flats_urls = {}
|
558
510
|
# F_i = i + 2
|
559
511
|
# R_i = i*(F_i - 1) + 1
|
560
512
|
# N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i
|
561
513
|
for i in range(n_radios):
|
562
514
|
f_i = i + 2
|
563
515
|
if (i % flat_interval) == 0:
|
564
|
-
flats[
|
565
|
-
flats_urls[i] = DataUrl(file_path=testffname, data_path=str("/flats/flats_%06d" % i), scheme="silx")
|
516
|
+
flats[i] = np.zeros(radio_shape, "f") + f_i
|
566
517
|
radios[i] = i * (f_i - 1) + 1
|
567
|
-
|
568
|
-
|
569
|
-
dicttoh5(dark, testffname, h5path="/dark", mode="a")
|
570
|
-
dark_url = {0: DataUrl(file_path=testffname, data_path="/dark/dark_0000", scheme="silx")}
|
571
|
-
return radios, flats_urls, dark_url
|
518
|
+
darks = {0: dark_data}
|
519
|
+
return radios, flats, darks
|
572
520
|
|
573
521
|
|
574
522
|
@pytest.fixture(scope="class")
|
@@ -582,17 +530,14 @@ def bootstrap_multiflats(request):
|
|
582
530
|
|
583
531
|
radios, flats, dark = generate_test_flatfield(n_radios, radio_shape, cls.flat_interval, h5_fname)
|
584
532
|
cls.radios = radios
|
585
|
-
cls.
|
586
|
-
cls.
|
533
|
+
cls.flats = flats
|
534
|
+
cls.darks = dark
|
587
535
|
cls.expected_results = np.arange(n_radios)
|
588
536
|
|
589
537
|
cls.tol = 5e-4
|
590
538
|
cls.tol_std = 1e-4
|
591
539
|
|
592
540
|
yield
|
593
|
-
# tear down
|
594
|
-
os.remove(dark[0].file_path())
|
595
|
-
os.rmdir(os.path.dirname(dark[0].file_path()))
|
596
541
|
|
597
542
|
|
598
543
|
@pytest.mark.usefixtures("bootstrap_multiflats")
|
@@ -607,7 +552,7 @@ class TestFlatFieldMultiFlat:
|
|
607
552
|
assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization"
|
608
553
|
|
609
554
|
def test_flatfield(self):
|
610
|
-
flatfield =
|
555
|
+
flatfield = FlatField(self.radios.shape, self.flats, self.darks, interpolation="linear")
|
611
556
|
projs = np.copy(self.radios)
|
612
557
|
flatfield.normalize_radios(projs)
|
613
558
|
print(projs[:, 0, 0])
|
@@ -615,12 +560,12 @@ class TestFlatFieldMultiFlat:
|
|
615
560
|
|
616
561
|
@pytest.mark.skipif(not (__has_pycuda__), reason="Need cuda/pycuda for this test")
|
617
562
|
def test_cuda_flatfield(self):
|
618
|
-
|
619
|
-
cuda_flatfield = CudaFlatFieldDataUrls(
|
563
|
+
cuda_flatfield = CudaFlatField(
|
620
564
|
self.radios.shape,
|
621
|
-
self.
|
622
|
-
self.
|
565
|
+
self.flats,
|
566
|
+
self.darks,
|
623
567
|
)
|
568
|
+
d_projs = cuda_flatfield.cuda_processing.to_device("radios", self.radios)
|
624
569
|
cuda_flatfield.normalize_radios(d_projs)
|
625
570
|
projs = d_projs.get()
|
626
571
|
self.check_normalization(projs)
|
@@ -1,32 +1,35 @@
|
|
1
1
|
import pytest
|
2
2
|
import numpy as np
|
3
3
|
from nabu.preproc.phase import PaganinPhaseRetrieval
|
4
|
-
from nabu.
|
4
|
+
from nabu.processing.fft_cuda import get_available_fft_implems
|
5
|
+
from nabu.testutils import generate_tests_scenarios, get_data
|
5
6
|
from nabu.thirdparty.tomopy_phase import retrieve_phase
|
6
|
-
from nabu.cuda.utils import __has_pycuda__
|
7
|
+
from nabu.cuda.utils import __has_pycuda__
|
7
8
|
|
9
|
+
__has_cufft__ = False
|
8
10
|
if __has_pycuda__:
|
9
11
|
from nabu.preproc.phase_cuda import CudaPaganinPhaseRetrieval
|
10
12
|
|
11
|
-
|
12
|
-
|
13
|
-
"distance": 1,
|
14
|
-
"energy": 35,
|
15
|
-
"delta_beta": 1e1,
|
16
|
-
"margin": ((50, 50), (0, 0)),
|
17
|
-
}
|
18
|
-
]
|
13
|
+
avail_fft = get_available_fft_implems()
|
14
|
+
__has_cufft__ = len(avail_fft) > 0
|
19
15
|
|
16
|
+
scenarios = {
|
17
|
+
"distance": [1],
|
18
|
+
"energy": [35],
|
19
|
+
"delta_beta": [1e1],
|
20
|
+
"margin": [((50, 50), (0, 0)), None],
|
21
|
+
}
|
20
22
|
|
21
|
-
|
23
|
+
scenarios = generate_tests_scenarios(scenarios)
|
24
|
+
|
25
|
+
|
26
|
+
@pytest.fixture(scope="class")
|
22
27
|
def bootstrap(request):
|
23
28
|
cls = request.cls
|
24
|
-
cls.paganin_config = request.param
|
25
29
|
|
26
30
|
cls.data = get_data("mri_proj_astra.npz")["data"]
|
27
31
|
cls.rtol = 1.1e-6
|
28
32
|
cls.rtol_pag = 5e-3
|
29
|
-
cls.paganin = PaganinPhaseRetrieval(cls.data.shape, **cls.paganin_config)
|
30
33
|
|
31
34
|
|
32
35
|
@pytest.mark.usefixtures("bootstrap")
|
@@ -36,32 +39,53 @@ class TestPaganin:
|
|
36
39
|
The reference implementation is tomopy.
|
37
40
|
"""
|
38
41
|
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
42
|
+
@staticmethod
|
43
|
+
def get_paganin_instance_and_data(cfg, data):
|
44
|
+
pag_kwargs = cfg.copy()
|
45
|
+
margin = pag_kwargs.pop("margin")
|
46
|
+
if margin is not None:
|
47
|
+
data = np.pad(data, margin, mode="edge")
|
48
|
+
paganin = PaganinPhaseRetrieval(data.shape, **pag_kwargs)
|
49
|
+
return paganin, data, pag_kwargs
|
43
50
|
|
44
|
-
|
45
|
-
|
51
|
+
@staticmethod
|
52
|
+
def crop_to_margin(data, margin):
|
53
|
+
if margin is None:
|
54
|
+
return data
|
55
|
+
((U, D), (L, R)) = margin
|
56
|
+
D = None if D == 0 else -D
|
57
|
+
R = None if R == 0 else -R
|
58
|
+
return data[U:D, L:R]
|
46
59
|
|
60
|
+
@pytest.mark.parametrize("config", scenarios)
|
61
|
+
def test_paganin(self, config):
|
62
|
+
paganin, data, _ = self.get_paganin_instance_and_data(config, self.data)
|
63
|
+
res = paganin.apply_filter(data)
|
64
|
+
|
65
|
+
data_tomopy = np.atleast_3d(np.copy(data)).T
|
47
66
|
res_tomopy = retrieve_phase(
|
48
67
|
data_tomopy,
|
49
|
-
pixel_size=
|
50
|
-
dist=
|
51
|
-
energy=
|
52
|
-
alpha=1.0 / (4 * 3.141592**2 *
|
68
|
+
pixel_size=paganin.pixel_size_xy_micron[0] * 1e-4,
|
69
|
+
dist=paganin.distance_cm,
|
70
|
+
energy=paganin.energy_kev,
|
71
|
+
alpha=1.0 / (4 * 3.141592**2 * paganin.delta_beta),
|
53
72
|
)
|
54
|
-
res_tomopy = self.crop_to_margin(res_tomopy[0].T)
|
55
73
|
|
56
|
-
|
74
|
+
res_tomopy = self.crop_to_margin(res_tomopy[0].T, config["margin"])
|
75
|
+
res = self.crop_to_margin(res, config["margin"])
|
57
76
|
|
58
77
|
errmax = np.max(np.abs(res - res_tomopy) / np.max(res_tomopy))
|
59
78
|
assert errmax < self.rtol_pag, "Max error is too high"
|
60
79
|
|
61
|
-
@pytest.mark.skipif(
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
80
|
+
@pytest.mark.skipif(
|
81
|
+
not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
|
82
|
+
)
|
83
|
+
@pytest.mark.parametrize("config", scenarios)
|
84
|
+
def test_gpu_paganin(self, config):
|
85
|
+
paganin, data, pag_kwargs = self.get_paganin_instance_and_data(config, self.data)
|
86
|
+
|
87
|
+
gpu_paganin = CudaPaganinPhaseRetrieval(data.shape, **pag_kwargs)
|
88
|
+
ref = paganin.apply_filter(data)
|
89
|
+
res = gpu_paganin.apply_filter(data)
|
66
90
|
errmax = np.max(np.abs((res - ref) / np.max(ref)))
|
67
91
|
assert errmax < self.rtol, "Max error is too high"
|