nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/conf.py +1 -1
- doc/doc_config.py +32 -0
- nabu/__init__.py +2 -1
- nabu/app/bootstrap_stitching.py +1 -1
- nabu/app/cli_configs.py +122 -2
- nabu/app/composite_cor.py +27 -2
- nabu/app/correct_rot.py +70 -0
- nabu/app/create_distortion_map_from_poly.py +42 -18
- nabu/app/diag_to_pix.py +358 -0
- nabu/app/diag_to_rot.py +449 -0
- nabu/app/generate_header.py +4 -3
- nabu/app/histogram.py +2 -2
- nabu/app/multicor.py +6 -1
- nabu/app/parse_reconstruction_log.py +151 -0
- nabu/app/prepare_weights_double.py +83 -22
- nabu/app/reconstruct.py +5 -1
- nabu/app/reconstruct_helical.py +7 -0
- nabu/app/reduce_dark_flat.py +6 -3
- nabu/app/rotate.py +4 -4
- nabu/app/stitching.py +16 -2
- nabu/app/tests/test_reduce_dark_flat.py +18 -2
- nabu/app/validator.py +4 -4
- nabu/cuda/convolution.py +8 -376
- nabu/cuda/fft.py +4 -0
- nabu/cuda/kernel.py +4 -4
- nabu/cuda/medfilt.py +5 -158
- nabu/cuda/padding.py +5 -71
- nabu/cuda/processing.py +23 -2
- nabu/cuda/src/ElementOp.cu +78 -0
- nabu/cuda/src/backproj.cu +28 -2
- nabu/cuda/src/fourier_wavelets.cu +2 -2
- nabu/cuda/src/normalization.cu +23 -0
- nabu/cuda/src/padding.cu +2 -2
- nabu/cuda/src/transpose.cu +16 -0
- nabu/cuda/utils.py +39 -0
- nabu/estimation/alignment.py +10 -1
- nabu/estimation/cor.py +808 -38
- nabu/estimation/cor_sino.py +7 -9
- nabu/estimation/tests/test_cor.py +85 -3
- nabu/io/reader.py +26 -18
- nabu/io/tests/test_cast_volume.py +3 -3
- nabu/io/tests/test_detector_distortion.py +3 -3
- nabu/io/tiffwriter_zmm.py +2 -2
- nabu/io/utils.py +14 -4
- nabu/io/writer.py +5 -3
- nabu/misc/fftshift.py +6 -0
- nabu/misc/histogram.py +5 -285
- nabu/misc/histogram_cuda.py +8 -104
- nabu/misc/kernel_base.py +3 -121
- nabu/misc/padding_base.py +5 -69
- nabu/misc/processing_base.py +3 -107
- nabu/misc/rotation.py +5 -62
- nabu/misc/rotation_cuda.py +5 -65
- nabu/misc/transpose.py +6 -0
- nabu/misc/unsharp.py +3 -78
- nabu/misc/unsharp_cuda.py +5 -52
- nabu/misc/unsharp_opencl.py +8 -85
- nabu/opencl/fft.py +6 -0
- nabu/opencl/kernel.py +21 -6
- nabu/opencl/padding.py +5 -72
- nabu/opencl/processing.py +27 -5
- nabu/opencl/src/backproj.cl +3 -3
- nabu/opencl/src/fftshift.cl +65 -12
- nabu/opencl/src/padding.cl +2 -2
- nabu/opencl/src/roll.cl +96 -0
- nabu/opencl/src/transpose.cl +16 -0
- nabu/pipeline/config_validators.py +63 -3
- nabu/pipeline/dataset_validator.py +2 -2
- nabu/pipeline/estimators.py +193 -35
- nabu/pipeline/fullfield/chunked.py +34 -17
- nabu/pipeline/fullfield/chunked_cuda.py +7 -5
- nabu/pipeline/fullfield/computations.py +48 -13
- nabu/pipeline/fullfield/nabu_config.py +13 -13
- nabu/pipeline/fullfield/processconfig.py +10 -5
- nabu/pipeline/fullfield/reconstruction.py +1 -2
- nabu/pipeline/helical/fbp.py +5 -0
- nabu/pipeline/helical/filtering.py +12 -9
- nabu/pipeline/helical/gridded_accumulator.py +179 -33
- nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
- nabu/pipeline/helical/helical_reconstruction.py +56 -18
- nabu/pipeline/helical/span_strategy.py +1 -1
- nabu/pipeline/helical/tests/test_accumulator.py +4 -0
- nabu/pipeline/params.py +23 -2
- nabu/pipeline/processconfig.py +3 -8
- nabu/pipeline/tests/test_chunk_reader.py +78 -0
- nabu/pipeline/tests/test_estimators.py +120 -2
- nabu/pipeline/utils.py +25 -0
- nabu/pipeline/writer.py +2 -0
- nabu/preproc/ccd_cuda.py +9 -7
- nabu/preproc/ctf.py +21 -26
- nabu/preproc/ctf_cuda.py +25 -25
- nabu/preproc/double_flatfield.py +14 -2
- nabu/preproc/double_flatfield_cuda.py +7 -11
- nabu/preproc/flatfield_cuda.py +23 -27
- nabu/preproc/phase.py +19 -24
- nabu/preproc/phase_cuda.py +21 -21
- nabu/preproc/shift_cuda.py +58 -28
- nabu/preproc/tests/test_ctf.py +5 -5
- nabu/preproc/tests/test_double_flatfield.py +2 -2
- nabu/preproc/tests/test_vshift.py +13 -2
- nabu/processing/__init__.py +0 -0
- nabu/processing/convolution_cuda.py +375 -0
- nabu/processing/fft_base.py +163 -0
- nabu/processing/fft_cuda.py +256 -0
- nabu/processing/fft_opencl.py +54 -0
- nabu/processing/fftshift.py +134 -0
- nabu/processing/histogram.py +286 -0
- nabu/processing/histogram_cuda.py +103 -0
- nabu/processing/kernel_base.py +126 -0
- nabu/processing/medfilt_cuda.py +159 -0
- nabu/processing/muladd.py +29 -0
- nabu/processing/muladd_cuda.py +68 -0
- nabu/processing/padding_base.py +71 -0
- nabu/processing/padding_cuda.py +75 -0
- nabu/processing/padding_opencl.py +77 -0
- nabu/processing/processing_base.py +123 -0
- nabu/processing/roll_opencl.py +64 -0
- nabu/processing/rotation.py +63 -0
- nabu/processing/rotation_cuda.py +66 -0
- nabu/processing/tests/__init__.py +0 -0
- nabu/processing/tests/test_fft.py +268 -0
- nabu/processing/tests/test_fftshift.py +71 -0
- nabu/{misc → processing}/tests/test_histogram.py +2 -4
- nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
- nabu/processing/tests/test_muladd.py +54 -0
- nabu/{cuda → processing}/tests/test_padding.py +119 -75
- nabu/processing/tests/test_roll.py +63 -0
- nabu/{misc → processing}/tests/test_rotation.py +3 -2
- nabu/processing/tests/test_transpose.py +72 -0
- nabu/{misc → processing}/tests/test_unsharp.py +41 -8
- nabu/processing/transpose.py +126 -0
- nabu/processing/unsharp.py +79 -0
- nabu/processing/unsharp_cuda.py +53 -0
- nabu/processing/unsharp_opencl.py +75 -0
- nabu/reconstruction/fbp.py +34 -10
- nabu/reconstruction/fbp_base.py +35 -16
- nabu/reconstruction/fbp_opencl.py +7 -12
- nabu/reconstruction/filtering.py +2 -2
- nabu/reconstruction/filtering_cuda.py +13 -14
- nabu/reconstruction/filtering_opencl.py +3 -4
- nabu/reconstruction/projection.py +2 -0
- nabu/reconstruction/rings.py +158 -1
- nabu/reconstruction/rings_cuda.py +218 -58
- nabu/reconstruction/sinogram_cuda.py +16 -12
- nabu/reconstruction/tests/test_deringer.py +116 -14
- nabu/reconstruction/tests/test_fbp.py +22 -31
- nabu/reconstruction/tests/test_filtering.py +11 -2
- nabu/resources/dataset_analyzer.py +89 -26
- nabu/resources/nxflatfield.py +2 -2
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +9 -2
- nabu/stitching/alignment.py +184 -0
- nabu/stitching/config.py +241 -39
- nabu/stitching/definitions.py +6 -0
- nabu/stitching/frame_composition.py +4 -2
- nabu/stitching/overlap.py +99 -3
- nabu/stitching/sample_normalization.py +60 -0
- nabu/stitching/slurm_utils.py +10 -10
- nabu/stitching/tests/test_alignment.py +99 -0
- nabu/stitching/tests/test_config.py +16 -1
- nabu/stitching/tests/test_overlap.py +68 -2
- nabu/stitching/tests/test_sample_normalization.py +49 -0
- nabu/stitching/tests/test_slurm_utils.py +5 -5
- nabu/stitching/tests/test_utils.py +3 -33
- nabu/stitching/tests/test_z_stitching.py +391 -22
- nabu/stitching/utils.py +144 -202
- nabu/stitching/z_stitching.py +309 -126
- nabu/testutils.py +18 -0
- nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
- nabu/utils.py +32 -6
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
- nabu-2024.1.0rc3.dist-info/RECORD +296 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
- nabu/conftest.py +0 -14
- nabu/opencl/fftshift.py +0 -92
- nabu/opencl/tests/test_fftshift.py +0 -55
- nabu/opencl/tests/test_padding.py +0 -84
- nabu-2023.2.1.dist-info/RECORD +0 -252
- /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/preproc/ctf.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
import math
|
2
2
|
import numpy as np
|
3
|
+
from scipy.fft import rfft2, irfft2, fft2, ifft2
|
3
4
|
from ..resources.logger import LoggerOrPrint
|
4
5
|
from ..misc import fourier_filters
|
5
6
|
from ..misc.padding import pad_interpolate, recut
|
6
|
-
from ..utils import get_num_threads
|
7
|
+
from ..utils import get_num_threads, deprecation_warning
|
7
8
|
|
8
9
|
|
9
10
|
class GeoPars:
|
@@ -111,6 +112,7 @@ class CTFPhaseRetrieval:
|
|
111
112
|
lim2=0.2,
|
112
113
|
use_rfft=False,
|
113
114
|
fftw_num_threads=None,
|
115
|
+
fft_num_threads=None,
|
114
116
|
logger=None,
|
115
117
|
):
|
116
118
|
"""
|
@@ -138,10 +140,11 @@ class CTFPhaseRetrieval:
|
|
138
140
|
use_rfft: bool, optional
|
139
141
|
Whether to use real-to-complex (R2C) FFT instead of usual complex-to-complex (C2C).
|
140
142
|
fftw_num_threads: bool or None or int, optional
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
143
|
+
DEPRECATED - please use fft_num_threads instead.
|
144
|
+
fft_num_threads: bool or None or int, optional
|
145
|
+
Number of threads to use for FFT.
|
146
|
+
If a number is provided: number of threads to use for FFT.
|
147
|
+
You can pass a negative number to use N - fft_num_threads cores.
|
145
148
|
logger: optional
|
146
149
|
a logger object
|
147
150
|
"""
|
@@ -152,12 +155,18 @@ class CTFPhaseRetrieval:
|
|
152
155
|
self._calc_shape(shape, padded_shape, padding_mode)
|
153
156
|
self.delta_beta = delta_beta
|
154
157
|
|
158
|
+
# COMPAT.
|
159
|
+
if fftw_num_threads is not None:
|
160
|
+
deprecation_warning("'fftw_num_threads' is replaced with 'fft_num_threads'", func_name="ctf_fftw")
|
161
|
+
fft_num_threads = fftw_num_threads
|
162
|
+
# ---
|
163
|
+
|
155
164
|
self.lim = None
|
156
165
|
self.lim1 = lim1
|
157
166
|
self.lim2 = lim2
|
158
167
|
self.normalize_by_mean = normalize_by_mean
|
159
168
|
self.translation_vh = translation_vh
|
160
|
-
self._setup_fft(use_rfft,
|
169
|
+
self._setup_fft(use_rfft, fft_num_threads)
|
161
170
|
self._get_ctf_filter()
|
162
171
|
|
163
172
|
def _calc_shape(self, shape, padded_shape, padding_mode):
|
@@ -175,25 +184,11 @@ class CTFPhaseRetrieval:
|
|
175
184
|
self.shape_padded = tuple(padded_shape)
|
176
185
|
self.padding_mode = padding_mode
|
177
186
|
|
178
|
-
def _setup_fft(self, use_rfft,
|
187
|
+
def _setup_fft(self, use_rfft, fft_num_threads):
|
179
188
|
self.use_rfft = use_rfft
|
180
|
-
self._fft_func =
|
181
|
-
self._ifft_func =
|
182
|
-
self.
|
183
|
-
if fftw_num_threads is False:
|
184
|
-
return
|
185
|
-
fftw_num_threads = get_num_threads(fftw_num_threads)
|
186
|
-
if self.use_rfft and (fftw_num_threads > 0):
|
187
|
-
# importing silx.math.fft creates opencl contexts all over the place
|
188
|
-
# because of the silx.opencl.ocl singleton.
|
189
|
-
# So, import silx as late as possible
|
190
|
-
from silx.math.fft.fftw import FFTW, __have_fftw__
|
191
|
-
|
192
|
-
if __have_fftw__:
|
193
|
-
self.use_fftw = True
|
194
|
-
self.fftw = FFTW(shape=self.shape_padded, dtype="f", num_threads=fftw_num_threads)
|
195
|
-
self._fft_func = self.fftw.fft
|
196
|
-
self._ifft_func = self.fftw.ifft
|
189
|
+
self._fft_func = rfft2 if use_rfft else fft2
|
190
|
+
self._ifft_func = irfft2 if use_rfft else ifft2
|
191
|
+
self.fft_num_threads = get_num_threads(fft_num_threads)
|
197
192
|
|
198
193
|
def _get_ctf_filter(self):
|
199
194
|
"""
|
@@ -320,7 +315,7 @@ class CTFPhaseRetrieval:
|
|
320
315
|
self._ctf_filter_denom = (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim).astype(np.complex64)
|
321
316
|
|
322
317
|
def _apply_filter(self, img):
|
323
|
-
img_f = self._fft_func(img)
|
318
|
+
img_f = self._fft_func(img, workers=self.fft_num_threads)
|
324
319
|
img_f *= self.unreg_filter_denom
|
325
320
|
|
326
321
|
unreg_filter_denom_0_mean = self.unreg_filter_denom[0, 0]
|
@@ -331,7 +326,7 @@ class CTFPhaseRetrieval:
|
|
331
326
|
|
332
327
|
## formula 8, with regularisation to stay at a safe distance from the poles
|
333
328
|
img_f /= self._ctf_filter_denom
|
334
|
-
ph = self._ifft_func(img_f).real
|
329
|
+
ph = self._ifft_func(img_f, workers=self.fft_num_threads).real
|
335
330
|
return ph
|
336
331
|
|
337
332
|
def retrieve_phase(self, img, output=None):
|
nabu/preproc/ctf_cuda.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1
1
|
import numpy as np
|
2
|
-
from
|
3
|
-
from ..utils import calc_padding_lengths, updiv, get_cuda_srcfile
|
2
|
+
from ..utils import calc_padding_lengths, updiv, get_cuda_srcfile, docstring
|
4
3
|
from ..cuda.processing import CudaProcessing
|
5
|
-
from ..cuda.
|
6
|
-
from ..
|
4
|
+
from ..cuda.utils import __has_pycuda__
|
5
|
+
from ..processing.padding_cuda import CudaPadding
|
6
|
+
from ..processing.fft_cuda import get_fft_class
|
7
7
|
from .phase_cuda import CudaPaganinPhaseRetrieval
|
8
8
|
from .ctf import CTFPhaseRetrieval
|
9
9
|
|
10
|
+
if __has_pycuda__:
|
11
|
+
from pycuda import gpuarray as garray
|
12
|
+
|
10
13
|
|
11
14
|
# TODO:
|
12
15
|
# - better padding scheme (for now 2*shape)
|
@@ -17,6 +20,7 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
|
|
17
20
|
Cuda back-end of CTFPhaseRetrieval
|
18
21
|
"""
|
19
22
|
|
23
|
+
@docstring(CTFPhaseRetrieval)
|
20
24
|
def __init__(
|
21
25
|
self,
|
22
26
|
shape,
|
@@ -29,9 +33,11 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
|
|
29
33
|
lim1=1.0e-5,
|
30
34
|
lim2=0.2,
|
31
35
|
use_rfft=True,
|
32
|
-
fftw_num_threads=None,
|
36
|
+
fftw_num_threads=None, # COMPAT.
|
37
|
+
fft_num_threads=None,
|
33
38
|
logger=None,
|
34
39
|
cuda_options=None,
|
40
|
+
fft_backend="skcuda",
|
35
41
|
):
|
36
42
|
"""
|
37
43
|
Initialize a CudaCTFPhaseRetrieval.
|
@@ -62,30 +68,26 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
|
|
62
68
|
lim2=lim2,
|
63
69
|
logger=logger,
|
64
70
|
use_rfft=True,
|
65
|
-
|
71
|
+
fft_num_threads=False,
|
66
72
|
)
|
67
73
|
self._init_ctf_filter()
|
68
74
|
self._init_cuda_padding()
|
69
|
-
self._init_fft()
|
75
|
+
self._init_fft(fft_backend)
|
70
76
|
self._init_mult_kernel()
|
71
77
|
|
72
78
|
def _init_ctf_filter(self):
|
73
79
|
self._mean_scale_factor = self.unreg_filter_denom[0, 0] * np.prod(self.shape_padded)
|
74
|
-
self._d_filter_num =
|
75
|
-
self._d_filter_denom =
|
76
|
-
(1.0 / (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim)).astype("f")
|
80
|
+
self._d_filter_num = self.cuda_processing.to_device("_d_filter_num", self.unreg_filter_denom).astype("f")
|
81
|
+
self._d_filter_denom = self.cuda_processing.to_device(
|
82
|
+
"_d_filter_denom", (1.0 / (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim)).astype("f")
|
77
83
|
)
|
78
84
|
|
79
85
|
def _init_cuda_padding(self):
|
80
86
|
pad_width = calc_padding_lengths(self.shape, self.shape_padded)
|
81
87
|
# Custom coordinate transform to get directly FFT layout
|
82
|
-
R, C = np.indices(self.shape, dtype=np.int32)
|
83
|
-
coords_R = np.roll(
|
84
|
-
|
85
|
-
)
|
86
|
-
coords_C = np.roll(
|
87
|
-
np.pad(C, pad_width, mode=self.padding_mode), (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1)
|
88
|
-
)
|
88
|
+
R, C = np.indices(self.shape, dtype=np.int32, sparse=True)
|
89
|
+
coords_R = np.roll(np.pad(R.ravel(), pad_width[0], mode=self.padding_mode), -pad_width[0][0])
|
90
|
+
coords_C = np.roll(np.pad(C.ravel(), pad_width[1], mode=self.padding_mode), -pad_width[1][0])
|
89
91
|
self.cuda_padding = CudaPadding(
|
90
92
|
self.shape,
|
91
93
|
(coords_R, coords_C),
|
@@ -93,16 +95,14 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
|
|
93
95
|
# propagate cuda options ?
|
94
96
|
)
|
95
97
|
|
96
|
-
def _init_fft(self):
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
self.
|
101
|
-
self.d_radio_padded = self.cufft.data_in
|
102
|
-
self.d_radio_f = self.cufft.data_out
|
98
|
+
def _init_fft(self, fft_backend):
|
99
|
+
fft_cls = get_fft_class(backend=fft_backend)
|
100
|
+
self.cufft = fft_cls(shape=self.shape_padded, dtype=np.float32, r2c=True)
|
101
|
+
self.d_radio_padded = self.cuda_processing.allocate_array("d_radio_padded", self.shape_padded, "f")
|
102
|
+
self.d_radio_f = self.cuda_processing.allocate_array("d_radio_f", self.cufft.shape_out, np.complex64)
|
103
103
|
|
104
104
|
def _init_mult_kernel(self):
|
105
|
-
self.cpxmult_kernel =
|
105
|
+
self.cpxmult_kernel = self.cuda_processing.kernel(
|
106
106
|
"CTF_kernel",
|
107
107
|
filename=get_cuda_srcfile("ElementOp.cu"),
|
108
108
|
signature="PPPfii",
|
nabu/preproc/double_flatfield.py
CHANGED
@@ -5,6 +5,7 @@ from silx.io.url import DataUrl
|
|
5
5
|
from ..utils import check_supported, check_shape, get_2D_3D_shape
|
6
6
|
from ..io.reader import Readers
|
7
7
|
from ..io.writer import Writers
|
8
|
+
from .ccd import Log
|
8
9
|
|
9
10
|
|
10
11
|
class DoubleFlatField:
|
@@ -22,6 +23,8 @@ class DoubleFlatField:
|
|
22
23
|
average_is_on_log=False,
|
23
24
|
sigma_filter=None,
|
24
25
|
filter_mode="reflect",
|
26
|
+
log_clip_min=None,
|
27
|
+
log_clip_max=None,
|
25
28
|
):
|
26
29
|
"""
|
27
30
|
Init double flat field by summing a series of urls and considering the same subregion of them.
|
@@ -55,6 +58,8 @@ class DoubleFlatField:
|
|
55
58
|
self.radios_shape = get_2D_3D_shape(shape)
|
56
59
|
self.n_angles = self.radios_shape[0]
|
57
60
|
self.shape = self.radios_shape[1:]
|
61
|
+
self._log_clip_min = log_clip_min
|
62
|
+
self._log_clip_max = log_clip_max
|
58
63
|
self._init_filedump(result_url, sub_region, detector_corrector)
|
59
64
|
self._init_processing(input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode)
|
60
65
|
self._computed = False
|
@@ -112,17 +117,19 @@ class DoubleFlatField:
|
|
112
117
|
self.sigma_filter = None
|
113
118
|
self.filter_mode = filter_mode
|
114
119
|
proc = lambda x, o: np.copyto(o, x)
|
120
|
+
self._mlog = Log((1,) + self.shape, clip_min=self._log_clip_min, clip_max=self._log_clip_max)
|
121
|
+
|
115
122
|
if self.input_is_mlog:
|
116
123
|
if not self.average_is_on_log:
|
117
124
|
proc = lambda x, o: np.exp(-x, out=o)
|
118
125
|
else:
|
119
126
|
if self.average_is_on_log:
|
120
|
-
proc =
|
127
|
+
proc = self._proc_mlog
|
121
128
|
|
122
129
|
postproc = lambda x: x
|
123
130
|
if self.output_is_mlog:
|
124
131
|
if not self.average_is_on_log:
|
125
|
-
postproc =
|
132
|
+
postproc = self._proc_mlog
|
126
133
|
else:
|
127
134
|
if self.average_is_on_log:
|
128
135
|
postproc = lambda x: np.exp(-x)
|
@@ -130,6 +137,11 @@ class DoubleFlatField:
|
|
130
137
|
self.proc = proc
|
131
138
|
self.postproc = postproc
|
132
139
|
|
140
|
+
def _proc_mlog(self, x, o):
|
141
|
+
o[:] = x[:]
|
142
|
+
self._mlog.take_logarithm(o)
|
143
|
+
return o
|
144
|
+
|
133
145
|
def compute_double_flatfield(self, radios, recompute=False):
|
134
146
|
"""
|
135
147
|
Read the radios and generate the "double flat field" by averaging
|
@@ -2,7 +2,8 @@ from .double_flatfield import DoubleFlatField
|
|
2
2
|
from ..utils import check_shape
|
3
3
|
from ..cuda.utils import __has_pycuda__
|
4
4
|
from ..cuda.processing import CudaProcessing
|
5
|
-
from ..
|
5
|
+
from ..processing.unsharp_cuda import CudaUnsharpMask
|
6
|
+
from .ccd_cuda import CudaLog
|
6
7
|
|
7
8
|
if __has_pycuda__:
|
8
9
|
import pycuda.gpuarray as garray
|
@@ -21,6 +22,8 @@ class CudaDoubleFlatField(DoubleFlatField):
|
|
21
22
|
average_is_on_log=False,
|
22
23
|
sigma_filter=None,
|
23
24
|
filter_mode="reflect",
|
25
|
+
log_clip_min=None,
|
26
|
+
log_clip_max=None,
|
24
27
|
cuda_options=None,
|
25
28
|
):
|
26
29
|
"""
|
@@ -37,6 +40,8 @@ class CudaDoubleFlatField(DoubleFlatField):
|
|
37
40
|
average_is_on_log=average_is_on_log,
|
38
41
|
sigma_filter=sigma_filter,
|
39
42
|
filter_mode=filter_mode,
|
43
|
+
log_clip_min=log_clip_min,
|
44
|
+
log_clip_max=log_clip_max,
|
40
45
|
)
|
41
46
|
self._init_gaussian_filter()
|
42
47
|
|
@@ -57,16 +62,6 @@ class CudaDoubleFlatField(DoubleFlatField):
|
|
57
62
|
cumath.exp(o, out=o)
|
58
63
|
return o
|
59
64
|
|
60
|
-
@staticmethod
|
61
|
-
def _proc_mlog(x, o, min_clip=None):
|
62
|
-
if min_clip is not None:
|
63
|
-
garray.maximum(x, min_clip, out=o)
|
64
|
-
cumath.log(o, out=o)
|
65
|
-
else:
|
66
|
-
cumath.log(x, out=o)
|
67
|
-
o *= -1
|
68
|
-
return o
|
69
|
-
|
70
65
|
def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode):
|
71
66
|
self.input_is_mlog = input_is_mlog
|
72
67
|
self.output_is_mlog = output_is_mlog
|
@@ -77,6 +72,7 @@ class CudaDoubleFlatField(DoubleFlatField):
|
|
77
72
|
self.filter_mode = filter_mode
|
78
73
|
# proc = lambda x,o: np.copyto(o, x)
|
79
74
|
proc = self._proc_copy
|
75
|
+
self._mlog = CudaLog((1,) + self.shape, clip_min=self._log_clip_min, clip_max=self._log_clip_max)
|
80
76
|
if self.input_is_mlog:
|
81
77
|
if not self.average_is_on_log:
|
82
78
|
# proc = lambda x,o: np.exp(-x, out=o)
|
nabu/preproc/flatfield_cuda.py
CHANGED
@@ -1,25 +1,25 @@
|
|
1
|
-
from typing import Union
|
2
1
|
import numpy as np
|
3
|
-
|
2
|
+
|
3
|
+
from nabu.cuda.processing import CudaProcessing
|
4
4
|
from ..preproc.flatfield import FlatFieldArrays
|
5
|
-
from ..cuda.kernel import CudaKernel
|
6
5
|
from ..utils import get_cuda_srcfile
|
7
6
|
from ..io.reader import load_images_from_dataurl_dict
|
7
|
+
from ..cuda.utils import __has_pycuda__
|
8
8
|
|
9
9
|
|
10
10
|
class CudaFlatFieldArrays(FlatFieldArrays):
|
11
11
|
def __init__(
|
12
12
|
self,
|
13
|
-
radios_shape
|
14
|
-
flats
|
15
|
-
darks
|
13
|
+
radios_shape,
|
14
|
+
flats,
|
15
|
+
darks,
|
16
16
|
radios_indices=None,
|
17
|
-
interpolation
|
17
|
+
interpolation="linear",
|
18
18
|
distortion_correction=None,
|
19
19
|
nan_value=1.0,
|
20
20
|
radios_srcurrent=None,
|
21
21
|
flats_srcurrent=None,
|
22
|
-
cuda_options
|
22
|
+
cuda_options=None,
|
23
23
|
):
|
24
24
|
"""
|
25
25
|
Initialize a flat-field normalization CUDA process.
|
@@ -41,16 +41,10 @@ class CudaFlatFieldArrays(FlatFieldArrays):
|
|
41
41
|
flats_srcurrent=flats_srcurrent,
|
42
42
|
nan_value=nan_value,
|
43
43
|
)
|
44
|
-
self.
|
44
|
+
self.cuda_processing = CudaProcessing(**(cuda_options or {}))
|
45
45
|
self._init_cuda_kernels()
|
46
46
|
self._load_flats_and_darks_on_gpu()
|
47
47
|
|
48
|
-
def _set_cuda_options(self, user_cuda_options):
|
49
|
-
self.cuda_options = {"device_id": None, "ctx": None, "cleanup_at_exit": None}
|
50
|
-
if user_cuda_options is None:
|
51
|
-
user_cuda_options = {}
|
52
|
-
self.cuda_options.update(user_cuda_options)
|
53
|
-
|
54
48
|
def _init_cuda_kernels(self):
|
55
49
|
# TODO
|
56
50
|
if self.interpolation != "linear":
|
@@ -63,7 +57,7 @@ class CudaFlatFieldArrays(FlatFieldArrays):
|
|
63
57
|
]
|
64
58
|
if self.nan_value is not None:
|
65
59
|
options.append("-DNAN_VALUE=%f" % self.nan_value)
|
66
|
-
self.cuda_kernel =
|
60
|
+
self.cuda_kernel = self.cuda_processing.kernel(
|
67
61
|
"flatfield_normalization", self._cuda_fname, signature="PPPiiiPP", options=options
|
68
62
|
)
|
69
63
|
self._nx = np.int32(self.shape[1])
|
@@ -71,17 +65,19 @@ class CudaFlatFieldArrays(FlatFieldArrays):
|
|
71
65
|
|
72
66
|
def _load_flats_and_darks_on_gpu(self):
|
73
67
|
# Flats
|
74
|
-
self.d_flats =
|
68
|
+
self.d_flats = self.cuda_processing.allocate_array("d_flats", (self.n_flats,) + self.shape, np.float32)
|
75
69
|
for i, flat_idx in enumerate(self._sorted_flat_indices):
|
76
70
|
self.d_flats[i].set(np.ascontiguousarray(self.flats[flat_idx], dtype=np.float32))
|
77
71
|
# Darks
|
78
|
-
self.d_darks =
|
72
|
+
self.d_darks = self.cuda_processing.allocate_array("d_darks", (self.n_darks,) + self.shape, np.float32)
|
79
73
|
for i, dark_idx in enumerate(self._sorted_dark_indices):
|
80
74
|
self.d_darks[i].set(np.ascontiguousarray(self.darks[dark_idx], dtype=np.float32))
|
81
|
-
self.d_darks_indices =
|
75
|
+
self.d_darks_indices = self.cuda_processing.to_device(
|
76
|
+
"d_darks_indices", np.array(self._sorted_dark_indices, dtype=np.int32)
|
77
|
+
)
|
82
78
|
# Indices
|
83
|
-
self.d_flats_indices =
|
84
|
-
self.d_flats_weights =
|
79
|
+
self.d_flats_indices = self.cuda_processing.to_device("d_flats_indices", self.flats_idx)
|
80
|
+
self.d_flats_weights = self.cuda_processing.to_device("d_flats_weights", self.flats_weights)
|
85
81
|
|
86
82
|
def normalize_radios(self, radios):
|
87
83
|
"""
|
@@ -93,7 +89,7 @@ class CudaFlatFieldArrays(FlatFieldArrays):
|
|
93
89
|
radios_shape: `pycuda.gpuarray.GPUArray`
|
94
90
|
Radios chunk.
|
95
91
|
"""
|
96
|
-
if not (isinstance(radios,
|
92
|
+
if not (isinstance(radios, self.cuda_processing.array_class)):
|
97
93
|
raise ValueError("Expected a pycuda.gpuarray (got %s)" % str(type(radios)))
|
98
94
|
if radios.dtype != np.float32:
|
99
95
|
raise ValueError("radios must be in float32 dtype (got %s)" % str(radios.dtype))
|
@@ -121,16 +117,16 @@ CudaFlatField = CudaFlatFieldArrays
|
|
121
117
|
class CudaFlatFieldDataUrls(CudaFlatField):
|
122
118
|
def __init__(
|
123
119
|
self,
|
124
|
-
radios_shape
|
125
|
-
flats
|
126
|
-
darks
|
120
|
+
radios_shape,
|
121
|
+
flats,
|
122
|
+
darks,
|
127
123
|
radios_indices=None,
|
128
|
-
interpolation
|
124
|
+
interpolation="linear",
|
129
125
|
distortion_correction=None,
|
130
126
|
nan_value=1.0,
|
131
127
|
radios_srcurrent=None,
|
132
128
|
flats_srcurrent=None,
|
133
|
-
cuda_options
|
129
|
+
cuda_options=None,
|
134
130
|
**chunk_reader_kwargs,
|
135
131
|
):
|
136
132
|
flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs)
|
nabu/preproc/phase.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from math import pi
|
2
2
|
from bisect import bisect
|
3
3
|
import numpy as np
|
4
|
+
from scipy.fft import rfft2, irfft2, fft2, ifft2
|
4
5
|
from ..utils import generate_powers, get_decay, check_supported, get_num_threads, deprecation_warning
|
5
6
|
|
6
7
|
#
|
@@ -53,6 +54,7 @@ class PaganinPhaseRetrieval:
|
|
53
54
|
use_rfft=True,
|
54
55
|
use_R2C=None,
|
55
56
|
fftw_num_threads=None,
|
57
|
+
fft_num_threads=None,
|
56
58
|
):
|
57
59
|
"""
|
58
60
|
Paganin Phase Retrieval for an infinitely distant point source.
|
@@ -113,9 +115,11 @@ class PaganinPhaseRetrieval:
|
|
113
115
|
use_R2C: bool, optional
|
114
116
|
DEPRECATED, use use_rfft instead
|
115
117
|
fftw_num_threads: bool or None or int, optional
|
116
|
-
|
118
|
+
DEPRECATED - please use fft_num_threads
|
119
|
+
fft_num_threads: bool or None or int, optional
|
120
|
+
Number of threads for FFT.
|
117
121
|
Default is to use all available threads. You can pass a negative number
|
118
|
-
to use N -
|
122
|
+
to use N - fft_num_threads cores.
|
119
123
|
|
120
124
|
Important
|
121
125
|
----------
|
@@ -171,8 +175,11 @@ class PaganinPhaseRetrieval:
|
|
171
175
|
# COMPAT.
|
172
176
|
if use_R2C is not None:
|
173
177
|
deprecation_warning("'use_R2C' is replaced with 'use_rfft'", func_name="pag_r2c")
|
174
|
-
|
175
|
-
|
178
|
+
if fftw_num_threads is not None:
|
179
|
+
deprecation_warning("'fftw_num_threads' is replaced with 'fft_num_threads'", func_name="pag_fftw")
|
180
|
+
fft_num_threads = fftw_num_threads
|
181
|
+
# ---
|
182
|
+
self._get_fft(use_rfft, fft_num_threads)
|
176
183
|
self.compute_filter()
|
177
184
|
|
178
185
|
def _init_parameters(self, distance, energy, pixel_size, delta_beta, padding):
|
@@ -191,28 +198,16 @@ class PaganinPhaseRetrieval:
|
|
191
198
|
"reflect": self._pad_reflect,
|
192
199
|
}
|
193
200
|
|
194
|
-
def _get_fft(self, use_rfft,
|
201
|
+
def _get_fft(self, use_rfft, fft_num_threads):
|
195
202
|
self.use_rfft = use_rfft
|
196
203
|
self.use_R2C = use_rfft # Compat.
|
197
|
-
|
204
|
+
self.fft_num_threads = get_num_threads(fft_num_threads)
|
198
205
|
if self.use_rfft:
|
199
|
-
self.fft_func =
|
200
|
-
self.ifft_func =
|
206
|
+
self.fft_func = rfft2
|
207
|
+
self.ifft_func = irfft2
|
201
208
|
else:
|
202
|
-
self.fft_func =
|
203
|
-
self.ifft_func =
|
204
|
-
self.use_fftw = False
|
205
|
-
if self.use_rfft and (fftw_num_threads > 0):
|
206
|
-
# importing silx.math.fft creates opencl contexts all over the place
|
207
|
-
# because of the silx.opencl.ocl singleton.
|
208
|
-
# So, import silx as late as possible
|
209
|
-
from silx.math.fft.fftw import FFTW, __have_fftw__
|
210
|
-
|
211
|
-
if __have_fftw__:
|
212
|
-
self.use_fftw = True
|
213
|
-
self.fftw = FFTW(shape=self.shape_padded, dtype="f", num_threads=fftw_num_threads)
|
214
|
-
self.fft_func = self.fftw.fft
|
215
|
-
self.ifft_func = self.fftw.ifft
|
209
|
+
self.fft_func = fft2
|
210
|
+
self.ifft_func = ifft2
|
216
211
|
|
217
212
|
def _calc_shape(self, shape, margin):
|
218
213
|
if np.isscalar(shape):
|
@@ -378,9 +373,9 @@ class PaganinPhaseRetrieval:
|
|
378
373
|
|
379
374
|
def apply_filter(self, radio, padding_method=None, output=None):
|
380
375
|
self.pad_data(radio, padding_method=padding_method)
|
381
|
-
radio_f = self.fft_func(self.data_padded)
|
376
|
+
radio_f = self.fft_func(self.data_padded, workers=self.fft_num_threads)
|
382
377
|
radio_f *= self.paganin_filter
|
383
|
-
radio_filtered = self.ifft_func(radio_f).real
|
378
|
+
radio_filtered = self.ifft_func(radio_f, workers=self.fft_num_threads).real
|
384
379
|
s0, s1 = self.shape_inner
|
385
380
|
((U, _), (L, _)) = self.margin
|
386
381
|
if output is None:
|
nabu/preproc/phase_cuda.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
import numpy as np
|
2
2
|
import pycuda.driver as cuda
|
3
|
-
from
|
4
|
-
from ..utils import get_cuda_srcfile, check_supported
|
5
|
-
from .phase import PaganinPhaseRetrieval
|
3
|
+
from ..utils import get_cuda_srcfile, check_supported, docstring
|
6
4
|
from ..cuda.processing import CudaProcessing
|
7
|
-
from ..
|
5
|
+
from ..processing.fft_cuda import get_fft_class
|
6
|
+
from .phase import PaganinPhaseRetrieval
|
8
7
|
|
9
8
|
|
10
9
|
class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
11
10
|
supported_paddings = ["zeros", "constant", "edge"]
|
12
11
|
|
12
|
+
@docstring(PaganinPhaseRetrieval)
|
13
13
|
def __init__(
|
14
14
|
self,
|
15
15
|
shape,
|
@@ -20,7 +20,9 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
20
20
|
padding="edge",
|
21
21
|
margin=None,
|
22
22
|
cuda_options=None,
|
23
|
-
fftw_num_threads=None,
|
23
|
+
fftw_num_threads=None, # COMPAT.
|
24
|
+
fft_num_threads=None,
|
25
|
+
fft_backend="skcuda",
|
24
26
|
):
|
25
27
|
"""
|
26
28
|
Please refer to the documentation of
|
@@ -37,10 +39,10 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
37
39
|
padding=padding,
|
38
40
|
margin=margin,
|
39
41
|
use_rfft=True,
|
40
|
-
|
42
|
+
fft_num_threads=False,
|
41
43
|
)
|
42
44
|
self._init_gpu_arrays()
|
43
|
-
self._init_fft()
|
45
|
+
self._init_fft(fft_backend)
|
44
46
|
self._init_padding_kernel()
|
45
47
|
self._init_mult_kernel()
|
46
48
|
|
@@ -51,25 +53,23 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
51
53
|
return padding
|
52
54
|
|
53
55
|
def _init_gpu_arrays(self):
|
54
|
-
self.d_paganin_filter =
|
56
|
+
self.d_paganin_filter = self.cuda_processing.to_device(
|
57
|
+
"d_paganin_filter", np.ascontiguousarray(self.paganin_filter, dtype=np.float32)
|
58
|
+
)
|
55
59
|
|
56
60
|
# overwrite parent method, don't initialize any FFT plan
|
57
|
-
def _get_fft(self, use_rfft,
|
61
|
+
def _get_fft(self, use_rfft, fft_num_threads):
|
58
62
|
self.use_rfft = use_rfft
|
59
|
-
self.use_fftw = False
|
60
|
-
|
61
|
-
def _init_fft(self):
|
62
|
-
# Import has to be done here, otherwise scikit-cuda creates a cuda/cublas context at import
|
63
|
-
from silx.math.fft.cufft import CUFFT
|
64
63
|
|
65
|
-
|
66
|
-
|
67
|
-
self.
|
68
|
-
self.
|
64
|
+
def _init_fft(self, fft_backend):
|
65
|
+
fft_cls = get_fft_class(backend=fft_backend)
|
66
|
+
self.cufft = fft_cls(shape=self.data_padded.shape, dtype=np.float32, r2c=True)
|
67
|
+
self.d_radio_padded = self.cuda_processing.allocate_array("d_radio_padded", self.cufft.shape, "f")
|
68
|
+
self.d_radio_f = self.cuda_processing.allocate_array("d_radio_f", self.cufft.shape_out, np.complex64)
|
69
69
|
|
70
70
|
def _init_padding_kernel(self):
|
71
71
|
kern_signature = {"constant": "Piiiiiiiiffff", "edge": "Piiiiiiii"}
|
72
|
-
self.padding_kernel =
|
72
|
+
self.padding_kernel = self.cuda_processing.kernel(
|
73
73
|
"padding_%s" % self.padding,
|
74
74
|
filename=get_cuda_srcfile("padding.cu"),
|
75
75
|
signature=kern_signature[self.padding],
|
@@ -92,7 +92,7 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
92
92
|
self.padding_kernel_args.extend([0, 0, 0, 0])
|
93
93
|
|
94
94
|
def _init_mult_kernel(self):
|
95
|
-
self.cpxmult_kernel =
|
95
|
+
self.cpxmult_kernel = self.cuda_processing.kernel(
|
96
96
|
"inplace_complexreal_mul_2Dby2D",
|
97
97
|
filename=get_cuda_srcfile("ElementOp.cu"),
|
98
98
|
signature="PPii",
|
@@ -109,7 +109,7 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
|
|
109
109
|
assert data.dtype == np.float32
|
110
110
|
# Rectangular memcopy
|
111
111
|
# TODO profile, and if needed include this copy in the padding kernel
|
112
|
-
if isinstance(data, np.ndarray) or isinstance(data,
|
112
|
+
if isinstance(data, np.ndarray) or isinstance(data, self.cuda_processing.array_class):
|
113
113
|
self.d_radio_padded[: self.shape[0], : self.shape[1]] = data[:, :]
|
114
114
|
elif isinstance(data, cuda.DeviceAllocation):
|
115
115
|
# TODO manual memcpy2D
|