nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/conf.py +1 -1
- doc/doc_config.py +32 -0
- nabu/__init__.py +2 -1
- nabu/app/bootstrap_stitching.py +1 -1
- nabu/app/cli_configs.py +122 -2
- nabu/app/composite_cor.py +27 -2
- nabu/app/correct_rot.py +70 -0
- nabu/app/create_distortion_map_from_poly.py +42 -18
- nabu/app/diag_to_pix.py +358 -0
- nabu/app/diag_to_rot.py +449 -0
- nabu/app/generate_header.py +4 -3
- nabu/app/histogram.py +2 -2
- nabu/app/multicor.py +6 -1
- nabu/app/parse_reconstruction_log.py +151 -0
- nabu/app/prepare_weights_double.py +83 -22
- nabu/app/reconstruct.py +5 -1
- nabu/app/reconstruct_helical.py +7 -0
- nabu/app/reduce_dark_flat.py +6 -3
- nabu/app/rotate.py +4 -4
- nabu/app/stitching.py +16 -2
- nabu/app/tests/test_reduce_dark_flat.py +18 -2
- nabu/app/validator.py +4 -4
- nabu/cuda/convolution.py +8 -376
- nabu/cuda/fft.py +4 -0
- nabu/cuda/kernel.py +4 -4
- nabu/cuda/medfilt.py +5 -158
- nabu/cuda/padding.py +5 -71
- nabu/cuda/processing.py +23 -2
- nabu/cuda/src/ElementOp.cu +78 -0
- nabu/cuda/src/backproj.cu +28 -2
- nabu/cuda/src/fourier_wavelets.cu +2 -2
- nabu/cuda/src/normalization.cu +23 -0
- nabu/cuda/src/padding.cu +2 -2
- nabu/cuda/src/transpose.cu +16 -0
- nabu/cuda/utils.py +39 -0
- nabu/estimation/alignment.py +10 -1
- nabu/estimation/cor.py +808 -38
- nabu/estimation/cor_sino.py +7 -9
- nabu/estimation/tests/test_cor.py +85 -3
- nabu/io/reader.py +26 -18
- nabu/io/tests/test_cast_volume.py +3 -3
- nabu/io/tests/test_detector_distortion.py +3 -3
- nabu/io/tiffwriter_zmm.py +2 -2
- nabu/io/utils.py +14 -4
- nabu/io/writer.py +5 -3
- nabu/misc/fftshift.py +6 -0
- nabu/misc/histogram.py +5 -285
- nabu/misc/histogram_cuda.py +8 -104
- nabu/misc/kernel_base.py +3 -121
- nabu/misc/padding_base.py +5 -69
- nabu/misc/processing_base.py +3 -107
- nabu/misc/rotation.py +5 -62
- nabu/misc/rotation_cuda.py +5 -65
- nabu/misc/transpose.py +6 -0
- nabu/misc/unsharp.py +3 -78
- nabu/misc/unsharp_cuda.py +5 -52
- nabu/misc/unsharp_opencl.py +8 -85
- nabu/opencl/fft.py +6 -0
- nabu/opencl/kernel.py +21 -6
- nabu/opencl/padding.py +5 -72
- nabu/opencl/processing.py +27 -5
- nabu/opencl/src/backproj.cl +3 -3
- nabu/opencl/src/fftshift.cl +65 -12
- nabu/opencl/src/padding.cl +2 -2
- nabu/opencl/src/roll.cl +96 -0
- nabu/opencl/src/transpose.cl +16 -0
- nabu/pipeline/config_validators.py +63 -3
- nabu/pipeline/dataset_validator.py +2 -2
- nabu/pipeline/estimators.py +193 -35
- nabu/pipeline/fullfield/chunked.py +34 -17
- nabu/pipeline/fullfield/chunked_cuda.py +7 -5
- nabu/pipeline/fullfield/computations.py +48 -13
- nabu/pipeline/fullfield/nabu_config.py +13 -13
- nabu/pipeline/fullfield/processconfig.py +10 -5
- nabu/pipeline/fullfield/reconstruction.py +1 -2
- nabu/pipeline/helical/fbp.py +5 -0
- nabu/pipeline/helical/filtering.py +12 -9
- nabu/pipeline/helical/gridded_accumulator.py +179 -33
- nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
- nabu/pipeline/helical/helical_reconstruction.py +56 -18
- nabu/pipeline/helical/span_strategy.py +1 -1
- nabu/pipeline/helical/tests/test_accumulator.py +4 -0
- nabu/pipeline/params.py +23 -2
- nabu/pipeline/processconfig.py +3 -8
- nabu/pipeline/tests/test_chunk_reader.py +78 -0
- nabu/pipeline/tests/test_estimators.py +120 -2
- nabu/pipeline/utils.py +25 -0
- nabu/pipeline/writer.py +2 -0
- nabu/preproc/ccd_cuda.py +9 -7
- nabu/preproc/ctf.py +21 -26
- nabu/preproc/ctf_cuda.py +25 -25
- nabu/preproc/double_flatfield.py +14 -2
- nabu/preproc/double_flatfield_cuda.py +7 -11
- nabu/preproc/flatfield_cuda.py +23 -27
- nabu/preproc/phase.py +19 -24
- nabu/preproc/phase_cuda.py +21 -21
- nabu/preproc/shift_cuda.py +58 -28
- nabu/preproc/tests/test_ctf.py +5 -5
- nabu/preproc/tests/test_double_flatfield.py +2 -2
- nabu/preproc/tests/test_vshift.py +13 -2
- nabu/processing/__init__.py +0 -0
- nabu/processing/convolution_cuda.py +375 -0
- nabu/processing/fft_base.py +163 -0
- nabu/processing/fft_cuda.py +256 -0
- nabu/processing/fft_opencl.py +54 -0
- nabu/processing/fftshift.py +134 -0
- nabu/processing/histogram.py +286 -0
- nabu/processing/histogram_cuda.py +103 -0
- nabu/processing/kernel_base.py +126 -0
- nabu/processing/medfilt_cuda.py +159 -0
- nabu/processing/muladd.py +29 -0
- nabu/processing/muladd_cuda.py +68 -0
- nabu/processing/padding_base.py +71 -0
- nabu/processing/padding_cuda.py +75 -0
- nabu/processing/padding_opencl.py +77 -0
- nabu/processing/processing_base.py +123 -0
- nabu/processing/roll_opencl.py +64 -0
- nabu/processing/rotation.py +63 -0
- nabu/processing/rotation_cuda.py +66 -0
- nabu/processing/tests/__init__.py +0 -0
- nabu/processing/tests/test_fft.py +268 -0
- nabu/processing/tests/test_fftshift.py +71 -0
- nabu/{misc → processing}/tests/test_histogram.py +2 -4
- nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
- nabu/processing/tests/test_muladd.py +54 -0
- nabu/{cuda → processing}/tests/test_padding.py +119 -75
- nabu/processing/tests/test_roll.py +63 -0
- nabu/{misc → processing}/tests/test_rotation.py +3 -2
- nabu/processing/tests/test_transpose.py +72 -0
- nabu/{misc → processing}/tests/test_unsharp.py +41 -8
- nabu/processing/transpose.py +126 -0
- nabu/processing/unsharp.py +79 -0
- nabu/processing/unsharp_cuda.py +53 -0
- nabu/processing/unsharp_opencl.py +75 -0
- nabu/reconstruction/fbp.py +34 -10
- nabu/reconstruction/fbp_base.py +35 -16
- nabu/reconstruction/fbp_opencl.py +7 -12
- nabu/reconstruction/filtering.py +2 -2
- nabu/reconstruction/filtering_cuda.py +13 -14
- nabu/reconstruction/filtering_opencl.py +3 -4
- nabu/reconstruction/projection.py +2 -0
- nabu/reconstruction/rings.py +158 -1
- nabu/reconstruction/rings_cuda.py +218 -58
- nabu/reconstruction/sinogram_cuda.py +16 -12
- nabu/reconstruction/tests/test_deringer.py +116 -14
- nabu/reconstruction/tests/test_fbp.py +22 -31
- nabu/reconstruction/tests/test_filtering.py +11 -2
- nabu/resources/dataset_analyzer.py +89 -26
- nabu/resources/nxflatfield.py +2 -2
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +9 -2
- nabu/stitching/alignment.py +184 -0
- nabu/stitching/config.py +241 -39
- nabu/stitching/definitions.py +6 -0
- nabu/stitching/frame_composition.py +4 -2
- nabu/stitching/overlap.py +99 -3
- nabu/stitching/sample_normalization.py +60 -0
- nabu/stitching/slurm_utils.py +10 -10
- nabu/stitching/tests/test_alignment.py +99 -0
- nabu/stitching/tests/test_config.py +16 -1
- nabu/stitching/tests/test_overlap.py +68 -2
- nabu/stitching/tests/test_sample_normalization.py +49 -0
- nabu/stitching/tests/test_slurm_utils.py +5 -5
- nabu/stitching/tests/test_utils.py +3 -33
- nabu/stitching/tests/test_z_stitching.py +391 -22
- nabu/stitching/utils.py +144 -202
- nabu/stitching/z_stitching.py +309 -126
- nabu/testutils.py +18 -0
- nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
- nabu/utils.py +32 -6
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
- nabu-2024.1.0rc3.dist-info/RECORD +296 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
- nabu/conftest.py +0 -14
- nabu/opencl/fftshift.py +0 -92
- nabu/opencl/tests/test_fftshift.py +0 -55
- nabu/opencl/tests/test_padding.py +0 -84
- nabu-2023.2.1.dist-info/RECORD +0 -252
- /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from ..utils import BaseClassError
|
3
|
+
|
4
|
+
|
5
|
+
class _BaseFFT:
|
6
|
+
"""
|
7
|
+
A base class for FFTs.
|
8
|
+
"""
|
9
|
+
|
10
|
+
implem = "none"
|
11
|
+
ProcessingCls = BaseClassError
|
12
|
+
|
13
|
+
def __init__(self, shape, dtype, r2c=True, axes=None, normalize="rescale", **backend_options):
|
14
|
+
"""
|
15
|
+
Base class for Fast Fourier Transform (FFT).
|
16
|
+
|
17
|
+
Parameters
|
18
|
+
----------
|
19
|
+
shape: list of int
|
20
|
+
Shape of the input data
|
21
|
+
dtype: str or numpy.dtype
|
22
|
+
Data type of the input data
|
23
|
+
r2c: bool, optional
|
24
|
+
Whether to use real-to-complex transform for real-valued input. Default is True.
|
25
|
+
axes: list of int, optional
|
26
|
+
Axes along which FFT is computed.
|
27
|
+
* For 2D transform: axes=(1,0)
|
28
|
+
* For batched 1D transform of 2D image: axes=(-1,)
|
29
|
+
normalize: str, optional
|
30
|
+
Whether to normalize FFT and IFFT. Possible values are:
|
31
|
+
* "rescale": in this case, Fourier data is divided by "N"
|
32
|
+
before IFFT, so that IFFT(FFT(data)) = data.
|
33
|
+
This corresponds to numpy norm=None i.e norm="backward".
|
34
|
+
* "ortho": in this case, FFT and IFFT are adjoint of eachother,
|
35
|
+
the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N).
|
36
|
+
* "none": no normalizatio is done : IFFT(FFT(data)) = data*N
|
37
|
+
|
38
|
+
Other parameters
|
39
|
+
-----------------
|
40
|
+
backend_options: dict, optional
|
41
|
+
Parameters to pass to CudaProcessing or OpenCLProcessing class.
|
42
|
+
"""
|
43
|
+
self._init_backend(backend_options)
|
44
|
+
self._set_dtypes(dtype, r2c)
|
45
|
+
self._set_shape_and_axes(shape, axes)
|
46
|
+
self._configure_batched_transform()
|
47
|
+
self._configure_normalization(normalize)
|
48
|
+
self._compute_fft_plans()
|
49
|
+
|
50
|
+
def _init_backend(self, backend_options):
|
51
|
+
self.processing = self.ProcessingCls(**backend_options)
|
52
|
+
|
53
|
+
def _set_dtypes(self, dtype, r2c):
|
54
|
+
self.dtype = np.dtype(dtype)
|
55
|
+
dtypes_mapping = {
|
56
|
+
np.dtype("float32"): np.complex64,
|
57
|
+
np.dtype("float64"): np.complex128,
|
58
|
+
np.dtype("complex64"): np.complex64,
|
59
|
+
np.dtype("complex128"): np.complex128,
|
60
|
+
}
|
61
|
+
if self.dtype not in dtypes_mapping:
|
62
|
+
raise ValueError("Invalid input data type: got %s" % self.dtype)
|
63
|
+
self.dtype_out = dtypes_mapping[self.dtype]
|
64
|
+
self.r2c = r2c
|
65
|
+
|
66
|
+
def _set_shape_and_axes(self, shape, axes):
|
67
|
+
# Input shape
|
68
|
+
if np.isscalar(shape):
|
69
|
+
shape = (shape,)
|
70
|
+
self.shape = shape
|
71
|
+
# Axes
|
72
|
+
default_axes = tuple(range(len(self.shape)))
|
73
|
+
if axes is None:
|
74
|
+
self.axes = default_axes
|
75
|
+
else:
|
76
|
+
self.axes = tuple(np.array(default_axes)[np.array(axes)])
|
77
|
+
# Output shape
|
78
|
+
shape_out = self.shape
|
79
|
+
if self.r2c:
|
80
|
+
reduced_dim = self.axes[-1] if self.axes is not None else -1
|
81
|
+
shape_out = list(shape_out)
|
82
|
+
shape_out[reduced_dim] = shape_out[reduced_dim] // 2 + 1
|
83
|
+
shape_out = tuple(shape_out)
|
84
|
+
self.shape_out = shape_out
|
85
|
+
|
86
|
+
def _configure_batched_transform(self):
|
87
|
+
pass
|
88
|
+
|
89
|
+
def _configure_normalization(self, normalize):
|
90
|
+
pass
|
91
|
+
|
92
|
+
def _compute_fft_plans(self):
|
93
|
+
pass
|
94
|
+
|
95
|
+
|
96
|
+
class _BaseVKFFT(_BaseFFT):
|
97
|
+
"""
|
98
|
+
FFT using VKFFT backend
|
99
|
+
"""
|
100
|
+
|
101
|
+
implem = "vkfft"
|
102
|
+
backend = "none"
|
103
|
+
ProcessingCls = BaseClassError
|
104
|
+
vkffs_cls = BaseClassError
|
105
|
+
|
106
|
+
def _configure_batched_transform(self):
|
107
|
+
if self.axes is not None and len(self.shape) == len(self.axes):
|
108
|
+
self.axes = None
|
109
|
+
return
|
110
|
+
if self.r2c:
|
111
|
+
# batched Real-to-complex transforms are supported only along fast axes
|
112
|
+
if not (is_fast_axes(len(self.shape), self.axes)):
|
113
|
+
raise ValueError("For %dD R2C, only batched transforms along fast axes are allowed" % (len(self.shape)))
|
114
|
+
self._vkfft_ndim = len(self.axes)
|
115
|
+
self.axes = None # vkfft still can do a batched transform by providing dim=XX, axes=None
|
116
|
+
|
117
|
+
def _configure_normalization(self, normalize):
|
118
|
+
self.normalize = normalize
|
119
|
+
self._vkfft_norm = {
|
120
|
+
"rescale": 1,
|
121
|
+
"backward": 1,
|
122
|
+
"ortho": "ortho",
|
123
|
+
"none": 0,
|
124
|
+
}.get(self.normalize, 1)
|
125
|
+
|
126
|
+
def _set_shape_and_axes(self, shape, axes):
|
127
|
+
super()._set_shape_and_axes(shape, axes)
|
128
|
+
self._vkfft_ndim = None
|
129
|
+
|
130
|
+
def _compute_fft_plans(self):
|
131
|
+
self._vkfft_plan = self.vkffs_cls(
|
132
|
+
self.shape,
|
133
|
+
self.dtype,
|
134
|
+
ndim=self._vkfft_ndim,
|
135
|
+
inplace=False,
|
136
|
+
norm=self._vkfft_norm,
|
137
|
+
r2c=self.r2c,
|
138
|
+
dct=False,
|
139
|
+
axes=self.axes,
|
140
|
+
strides=None,
|
141
|
+
**self._vkfft_other_init_kwargs,
|
142
|
+
)
|
143
|
+
|
144
|
+
def fft(self, array, output=None):
|
145
|
+
if output is None:
|
146
|
+
output = self.output_fft = self.processing.allocate_array(
|
147
|
+
"output_fft", self.shape_out, dtype=self.dtype_out
|
148
|
+
)
|
149
|
+
return self._vkfft_plan.fft(array, dest=output)
|
150
|
+
|
151
|
+
def ifft(self, array, output=None):
|
152
|
+
if output is None:
|
153
|
+
output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
|
154
|
+
return self._vkfft_plan.ifft(array, dest=output)
|
155
|
+
|
156
|
+
|
157
|
+
def is_fast_axes(ndim, axes):
|
158
|
+
"""
|
159
|
+
Return true if "axes" are the fast dimensions
|
160
|
+
"""
|
161
|
+
all_axes = list(range(ndim))
|
162
|
+
axes = sorted([ax + ndim if ax < 0 else ax for ax in axes]) # transform "-1" to an actual axis index (1 for 2D)
|
163
|
+
return all_axes[-len(axes) :] == axes
|
@@ -0,0 +1,256 @@
|
|
1
|
+
import os
|
2
|
+
import warnings
|
3
|
+
from multiprocessing import get_context
|
4
|
+
from multiprocessing.pool import Pool
|
5
|
+
import numpy as np
|
6
|
+
from ..utils import check_supported
|
7
|
+
from .fft_base import _BaseFFT, _BaseVKFFT
|
8
|
+
|
9
|
+
try:
|
10
|
+
from pyvkfft.cuda import VkFFTApp as vk_cufft
|
11
|
+
|
12
|
+
__has_vkfft__ = True
|
13
|
+
except (ImportError, OSError):
|
14
|
+
__has_vkfft__ = False
|
15
|
+
vk_cufft = None
|
16
|
+
from ..cuda.processing import CudaProcessing
|
17
|
+
|
18
|
+
Plan = None
|
19
|
+
cu_fft = None
|
20
|
+
cu_ifft = None
|
21
|
+
__has_skcuda__ = None
|
22
|
+
|
23
|
+
|
24
|
+
def init_skcuda():
|
25
|
+
# This needs to be done here, because scikit-cuda creates a Cuda context at import,
|
26
|
+
# which can mess things up in some cases.
|
27
|
+
# Ugly solution to an ugly problem.
|
28
|
+
global __has_skcuda__, Plan, cu_fft, cu_ifft
|
29
|
+
try:
|
30
|
+
from skcuda.fft import Plan
|
31
|
+
from skcuda.fft import fft as cu_fft
|
32
|
+
from skcuda.fft import ifft as cu_ifft
|
33
|
+
|
34
|
+
__has_skcuda__ = True
|
35
|
+
except ImportError:
|
36
|
+
__has_skcuda__ = False
|
37
|
+
|
38
|
+
|
39
|
+
class SKCUFFT(_BaseFFT):
|
40
|
+
implem = "skcuda"
|
41
|
+
backend = "cuda"
|
42
|
+
ProcessingCls = CudaProcessing
|
43
|
+
|
44
|
+
def _configure_batched_transform(self):
|
45
|
+
if __has_skcuda__ is None:
|
46
|
+
init_skcuda()
|
47
|
+
if not (__has_skcuda__):
|
48
|
+
raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end")
|
49
|
+
|
50
|
+
self.cufft_batch_size = 1
|
51
|
+
self.cufft_shape = self.shape
|
52
|
+
self._cufft_plan_kwargs = {}
|
53
|
+
if (self.axes is not None) and (len(self.axes) < len(self.shape)):
|
54
|
+
# In the easiest case, the transform is computed along the fastest dimensions:
|
55
|
+
# - 1D transforms of lines of 2D data
|
56
|
+
# - 2D transforms of images of 3D data (stacked along slow dim)
|
57
|
+
# - 1D transforms of 3D data along fastest dim
|
58
|
+
# Otherwise, we have to configure cuda "advanced memory layout".
|
59
|
+
data_ndims = len(self.shape)
|
60
|
+
|
61
|
+
if data_ndims == 2:
|
62
|
+
n_y, n_x = self.shape
|
63
|
+
along_fast_dim = self.axes[0] == 1
|
64
|
+
self.cufft_shape = n_x if along_fast_dim else n_y
|
65
|
+
self.cufft_batch_size = n_y if along_fast_dim else n_x
|
66
|
+
if not (along_fast_dim):
|
67
|
+
# Batched vertical 1D FFT on 2D data need advanced data layout
|
68
|
+
# http://docs.nvidia.com/cuda/cufft/#advanced-data-layout
|
69
|
+
self._cufft_plan_kwargs = {
|
70
|
+
"inembed": np.int32([0]),
|
71
|
+
"istride": n_x,
|
72
|
+
"idist": 1,
|
73
|
+
"onembed": np.int32([0]),
|
74
|
+
"ostride": n_x,
|
75
|
+
"odist": 1,
|
76
|
+
}
|
77
|
+
|
78
|
+
if data_ndims == 3:
|
79
|
+
# TODO/FIXME - the following work for C2C but not R2C ?!
|
80
|
+
# fast_axes = [(1, 2), (2, 1), (2,)]
|
81
|
+
fast_axes = [(2,)]
|
82
|
+
if self.axes not in fast_axes:
|
83
|
+
raise NotImplementedError(
|
84
|
+
"With the CUDA backend, batched transform on 3D data is only supported along fastest dimensions"
|
85
|
+
)
|
86
|
+
self.cufft_batch_size = self.shape[0]
|
87
|
+
self.cufft_shape = self.shape[1:]
|
88
|
+
if len(self.axes) == 1:
|
89
|
+
# 1D transform on 3D data: here only supported along fast dim, so batch_size is Nx*Ny
|
90
|
+
self.cufft_batch_size = np.prod(self.shape[:2])
|
91
|
+
self.cufft_shape = (self.shape[-1],)
|
92
|
+
if len(self.cufft_shape) == 1:
|
93
|
+
self.cufft_shape = self.cufft_shape[0]
|
94
|
+
|
95
|
+
def _configure_normalization(self, normalize):
|
96
|
+
self.normalize = normalize
|
97
|
+
if self.normalize == "ortho":
|
98
|
+
# TODO
|
99
|
+
raise NotImplementedError("Normalization mode 'ortho' is not implemented with CUDA backend yet.")
|
100
|
+
self.cufft_scale_inverse = self.normalize == "rescale"
|
101
|
+
|
102
|
+
def _compute_fft_plans(self):
|
103
|
+
self.plan_forward = Plan( # pylint: disable = E1102
|
104
|
+
self.cufft_shape,
|
105
|
+
self.dtype,
|
106
|
+
self.dtype_out,
|
107
|
+
batch=self.cufft_batch_size,
|
108
|
+
stream=self.processing.stream,
|
109
|
+
**self._cufft_plan_kwargs,
|
110
|
+
# cufft extensible plan API is only supported after 0.5.1
|
111
|
+
# (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
|
112
|
+
# but there is still no official 0.5.2
|
113
|
+
# ~ auto_allocate=True # cufft extensible plan API
|
114
|
+
)
|
115
|
+
self.plan_inverse = Plan( # pylint: disable = E1102
|
116
|
+
self.cufft_shape, # not shape_out
|
117
|
+
self.dtype_out,
|
118
|
+
self.dtype,
|
119
|
+
batch=self.cufft_batch_size,
|
120
|
+
stream=self.processing.stream,
|
121
|
+
**self._cufft_plan_kwargs,
|
122
|
+
# cufft extensible plan API is only supported after 0.5.1
|
123
|
+
# (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
|
124
|
+
# but there is still no official 0.5.2
|
125
|
+
# ~ auto_allocate=True
|
126
|
+
)
|
127
|
+
|
128
|
+
def fft(self, array, output=None):
|
129
|
+
if output is None:
|
130
|
+
output = self.output_fft = self.processing.allocate_array(
|
131
|
+
"output_fft", self.shape_out, dtype=self.dtype_out
|
132
|
+
)
|
133
|
+
cu_fft(array, output, self.plan_forward, scale=False) # pylint: disable = E1102
|
134
|
+
return output
|
135
|
+
|
136
|
+
def ifft(self, array, output=None):
|
137
|
+
if output is None:
|
138
|
+
output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
|
139
|
+
cu_ifft( # pylint: disable = E1102
|
140
|
+
array,
|
141
|
+
output,
|
142
|
+
self.plan_inverse,
|
143
|
+
scale=self.cufft_scale_inverse,
|
144
|
+
)
|
145
|
+
return output
|
146
|
+
|
147
|
+
|
148
|
+
class VKCUFFT(_BaseVKFFT):
|
149
|
+
"""
|
150
|
+
Cuda FFT, using VKFFT backend
|
151
|
+
"""
|
152
|
+
|
153
|
+
implem = "vkfft"
|
154
|
+
backend = "cuda"
|
155
|
+
ProcessingCls = CudaProcessing
|
156
|
+
vkffs_cls = vk_cufft
|
157
|
+
|
158
|
+
def _init_backend(self, backend_options):
|
159
|
+
super()._init_backend(backend_options)
|
160
|
+
self._vkfft_other_init_kwargs = {"stream": self.processing.stream}
|
161
|
+
|
162
|
+
|
163
|
+
def _has_vkfft(x):
|
164
|
+
# should be run from within a Process
|
165
|
+
try:
|
166
|
+
from nabu.processing.fft_cuda import VKCUFFT, __has_vkfft__
|
167
|
+
|
168
|
+
if not __has_vkfft__:
|
169
|
+
return False
|
170
|
+
vk = VKCUFFT((16,), "f")
|
171
|
+
avail = True
|
172
|
+
except (RuntimeError, OSError):
|
173
|
+
avail = False
|
174
|
+
return avail
|
175
|
+
|
176
|
+
|
177
|
+
def has_vkfft(safe=True):
|
178
|
+
"""
|
179
|
+
Determine whether pyvkfft is available.
|
180
|
+
For Cuda GPUs, vkfft relies on nvrtc which supports a narrow range of Cuda devices.
|
181
|
+
Unfortunately, it's not possible to determine whether vkfft is available before creating a Cuda context.
|
182
|
+
So we create a process (from scratch, i.e no fork), do the test within, and exit.
|
183
|
+
This function cannot be tested from a notebook/console, a proper entry point has to be created (if __name__ == "__main__").
|
184
|
+
"""
|
185
|
+
if not safe:
|
186
|
+
return _has_vkfft(None)
|
187
|
+
ctx = get_context("spawn")
|
188
|
+
with Pool(1, context=ctx) as p:
|
189
|
+
v = p.map(_has_vkfft, [1])[0]
|
190
|
+
return v
|
191
|
+
|
192
|
+
|
193
|
+
def _has_skfft(x):
|
194
|
+
# should be run from within a Process
|
195
|
+
try:
|
196
|
+
from nabu.processing.fft_cuda import SKCUFFT
|
197
|
+
|
198
|
+
sk = SKCUFFT((16,), "f")
|
199
|
+
avail = True
|
200
|
+
except (ImportError, RuntimeError, OSError):
|
201
|
+
avail = False
|
202
|
+
return avail
|
203
|
+
|
204
|
+
|
205
|
+
def has_skcuda(safe=True):
|
206
|
+
"""
|
207
|
+
Determine whether scikit-cuda/CUFFT is available.
|
208
|
+
Currently, scikit-cuda will create a Cuda context for Cublas, which can mess up the current execution.
|
209
|
+
Do it in a separate thread.
|
210
|
+
"""
|
211
|
+
if not safe:
|
212
|
+
return _has_skfft(None)
|
213
|
+
ctx = get_context("spawn")
|
214
|
+
with Pool(1, context=ctx) as p:
|
215
|
+
v = p.map(_has_skfft, [1])[0]
|
216
|
+
return v
|
217
|
+
|
218
|
+
|
219
|
+
def get_fft_class(backend="skcuda"):
|
220
|
+
backends = {
|
221
|
+
"scikit-cuda": SKCUFFT,
|
222
|
+
"skcuda": SKCUFFT,
|
223
|
+
"cufft": SKCUFFT,
|
224
|
+
"scikit": SKCUFFT,
|
225
|
+
"vkfft": VKCUFFT,
|
226
|
+
"pyvkfft": VKCUFFT,
|
227
|
+
}
|
228
|
+
|
229
|
+
def check_vkfft(asked_fft_cls):
|
230
|
+
if asked_fft_cls is VKCUFFT:
|
231
|
+
if has_vkfft(safe=True) is False:
|
232
|
+
warnings.warn("Could not get VKFFT backend. Falling-back to scikit-cuda/CUFFT instead.", RuntimeWarning)
|
233
|
+
return SKCUFFT
|
234
|
+
return VKCUFFT
|
235
|
+
return SKCUFFT
|
236
|
+
|
237
|
+
def get_fft_cls(asked_fft_backend):
|
238
|
+
asked_fft_backend = asked_fft_backend.lower()
|
239
|
+
check_supported(asked_fft_backend, list(backends.keys()), "FFT backend name")
|
240
|
+
asked_fft_cls = backends[asked_fft_backend]
|
241
|
+
fft_cls = check_vkfft(asked_fft_cls)
|
242
|
+
return fft_cls
|
243
|
+
|
244
|
+
asked_fft_backend_env = os.environ.get("NABU_FFT_BACKEND", "")
|
245
|
+
if asked_fft_backend_env != "":
|
246
|
+
return get_fft_cls(asked_fft_backend_env)
|
247
|
+
return get_fft_cls(backend)
|
248
|
+
|
249
|
+
|
250
|
+
def get_available_fft_implems():
|
251
|
+
avail_implems = []
|
252
|
+
if has_skcuda(safe=True):
|
253
|
+
avail_implems.append("skcuda")
|
254
|
+
if has_vkfft(safe=True):
|
255
|
+
avail_implems.append("vkfft")
|
256
|
+
return avail_implems
|
@@ -0,0 +1,54 @@
|
|
1
|
+
from multiprocessing import get_context
|
2
|
+
from multiprocessing.pool import Pool
|
3
|
+
from .fft_base import _BaseVKFFT
|
4
|
+
from ..opencl.processing import OpenCLProcessing
|
5
|
+
|
6
|
+
try:
|
7
|
+
from pyvkfft.opencl import VkFFTApp as vk_clfft
|
8
|
+
|
9
|
+
__has_vkfft__ = True
|
10
|
+
except (ImportError, OSError):
|
11
|
+
__has_vkfft__ = False
|
12
|
+
vk_clfft = None
|
13
|
+
|
14
|
+
|
15
|
+
class VKCLFFT(_BaseVKFFT):
|
16
|
+
"""
|
17
|
+
OpenCL FFT, using VKFFT backend
|
18
|
+
"""
|
19
|
+
|
20
|
+
implem = "vkfft"
|
21
|
+
backend = "opencl"
|
22
|
+
ProcessingCls = OpenCLProcessing
|
23
|
+
vkffs_cls = vk_clfft
|
24
|
+
|
25
|
+
def _init_backend(self, backend_options):
|
26
|
+
super()._init_backend(backend_options)
|
27
|
+
self._vkfft_other_init_kwargs = {"queue": self.processing.queue}
|
28
|
+
|
29
|
+
|
30
|
+
def _has_vkfft(x):
|
31
|
+
# should be run from within a Process
|
32
|
+
try:
|
33
|
+
from nabu.processing.fft_opencl import VKCLFFT, __has_vkfft__
|
34
|
+
|
35
|
+
if not __has_vkfft__:
|
36
|
+
return False
|
37
|
+
vk = VKCLFFT((16,), "f")
|
38
|
+
avail = True
|
39
|
+
except (RuntimeError, OSError):
|
40
|
+
avail = False
|
41
|
+
return avail
|
42
|
+
|
43
|
+
|
44
|
+
def has_vkfft(safe=True):
|
45
|
+
"""
|
46
|
+
Determine whether pyvkfft is available.
|
47
|
+
This function cannot be tested from a notebook/console, a proper entry point has to be created (if __name__ == "__main__").
|
48
|
+
"""
|
49
|
+
if not safe:
|
50
|
+
return _has_vkfft(None)
|
51
|
+
ctx = get_context("spawn")
|
52
|
+
with Pool(1, context=ctx) as p:
|
53
|
+
v = p.map(_has_vkfft, [1])[0]
|
54
|
+
return v
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from ..utils import BaseClassError, get_opencl_srcfile, updiv
|
3
|
+
from ..opencl.kernel import OpenCLKernel
|
4
|
+
from ..opencl.processing import OpenCLProcessing
|
5
|
+
from pyopencl.tools import dtype_to_ctype as cl_dtype_to_ctype
|
6
|
+
|
7
|
+
|
8
|
+
class FFTshiftBase:
|
9
|
+
KernelCls = BaseClassError
|
10
|
+
ProcessingCls = BaseClassError
|
11
|
+
dtype_to_ctype = BaseClassError
|
12
|
+
backend = "none"
|
13
|
+
|
14
|
+
def __init__(self, shape, dtype, dst_dtype=None, axes=None, **backend_options):
|
15
|
+
"""
|
16
|
+
|
17
|
+
Parameters
|
18
|
+
----------
|
19
|
+
shape: tuple
|
20
|
+
Array shape - can be 1D or 2D. 3D is not supported.
|
21
|
+
dtype: str or numpy.dtype
|
22
|
+
Data type, eg. "f", numpy.complex64, ...
|
23
|
+
dst_dtype: str or numpy.dtype
|
24
|
+
Output data type. If not provided (default), the shift is done in-place.
|
25
|
+
axes: tuple, optional
|
26
|
+
Axes over which to shift. Default is None, which shifts all axes.
|
27
|
+
|
28
|
+
Other parameters
|
29
|
+
----------------
|
30
|
+
backend_options:
|
31
|
+
named arguments to pass to CudaProcessing or OpenCLProcessing
|
32
|
+
"""
|
33
|
+
#
|
34
|
+
if axes not in [1, (1,), (-1,)]:
|
35
|
+
raise NotImplementedError
|
36
|
+
#
|
37
|
+
self.processing = self.ProcessingCls(**backend_options)
|
38
|
+
self.shape = shape
|
39
|
+
if len(self.shape) not in [1, 2]:
|
40
|
+
raise ValueError("Expected 1D or 2D array")
|
41
|
+
self.dtype = np.dtype(dtype)
|
42
|
+
self.dst_dtype = dst_dtype
|
43
|
+
|
44
|
+
if dst_dtype is None:
|
45
|
+
self._configure_inplace_shift()
|
46
|
+
else:
|
47
|
+
self._configure_out_of_place_shift()
|
48
|
+
self._configure_kenel_initialization()
|
49
|
+
self._fftshift_kernel = self.KernelCls(*self._kernel_init_args, **self._kernel_init_kwargs)
|
50
|
+
self._configure_kernel_call()
|
51
|
+
|
52
|
+
def _configure_inplace_shift(self):
|
53
|
+
self.inplace = True
|
54
|
+
# in-place on odd-sized array is more difficult - see fftshift.cl
|
55
|
+
if self.shape[-1] & 1:
|
56
|
+
raise NotImplementedError
|
57
|
+
#
|
58
|
+
self._kernel_init_args = [
|
59
|
+
"fftshift_x_inplace",
|
60
|
+
]
|
61
|
+
self._kernel_init_kwargs = {
|
62
|
+
"options": [
|
63
|
+
"-DDTYPE=%s" % self.dtype_to_ctype(self.dtype),
|
64
|
+
],
|
65
|
+
}
|
66
|
+
|
67
|
+
def _configure_out_of_place_shift(self):
|
68
|
+
self.inplace = False
|
69
|
+
self._kernel_init_args = [
|
70
|
+
"fftshift_x",
|
71
|
+
]
|
72
|
+
self._kernel_init_kwargs = {
|
73
|
+
"options": [
|
74
|
+
"-DDTYPE=%s" % self.dtype_to_ctype(self.dtype),
|
75
|
+
"-DDTYPE_OUT=%s" % self.dtype_to_ctype(np.dtype(self.dst_dtype)),
|
76
|
+
],
|
77
|
+
}
|
78
|
+
additional_flag = None
|
79
|
+
input_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dtype))
|
80
|
+
output_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dst_dtype))
|
81
|
+
if not (input_is_complex) and output_is_complex:
|
82
|
+
additional_flag = "-DCAST_TO_COMPLEX"
|
83
|
+
if input_is_complex and not (output_is_complex):
|
84
|
+
additional_flag = "-DCAST_TO_REAL"
|
85
|
+
if additional_flag is not None:
|
86
|
+
self._kernel_init_kwargs["options"].append(additional_flag)
|
87
|
+
|
88
|
+
def _call_fftshift_inplace(self, arr, direction):
|
89
|
+
self._fftshift_kernel( # pylint: disable=E1102
|
90
|
+
arr, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs
|
91
|
+
)
|
92
|
+
return arr
|
93
|
+
|
94
|
+
def _call_fftshift_out_of_place(self, arr, dst, direction):
|
95
|
+
if dst is None:
|
96
|
+
dst = self.processing.allocate_array("dst", arr.shape, dtype=self.dst_dtype)
|
97
|
+
self._fftshift_kernel( # pylint: disable=E1102
|
98
|
+
arr, dst, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs
|
99
|
+
)
|
100
|
+
return dst
|
101
|
+
|
102
|
+
def fftshift(self, arr, dst=None):
|
103
|
+
if self.inplace:
|
104
|
+
return self._call_fftshift_inplace(arr, 1)
|
105
|
+
else:
|
106
|
+
return self._call_fftshift_out_of_place(arr, dst, 1)
|
107
|
+
|
108
|
+
def ifftshift(self, arr, dst=None):
|
109
|
+
if self.inplace:
|
110
|
+
return self._call_fftshift_inplace(arr, -1)
|
111
|
+
else:
|
112
|
+
return self._call_fftshift_out_of_place(arr, dst, -1)
|
113
|
+
|
114
|
+
|
115
|
+
class OpenCLFFTshift(FFTshiftBase):
|
116
|
+
KernelCls = OpenCLKernel
|
117
|
+
ProcessingCls = OpenCLProcessing
|
118
|
+
dtype_to_ctype = cl_dtype_to_ctype
|
119
|
+
backend = "opencl"
|
120
|
+
|
121
|
+
def _configure_kenel_initialization(self):
|
122
|
+
self._kernel_init_args.append(self.processing.ctx)
|
123
|
+
self._kernel_init_kwargs.update(
|
124
|
+
{
|
125
|
+
"filename": get_opencl_srcfile("fftshift.cl"),
|
126
|
+
"queue": self.processing.queue,
|
127
|
+
}
|
128
|
+
)
|
129
|
+
|
130
|
+
def _configure_kernel_call(self):
|
131
|
+
# TODO in-place fftshift needs to launch only arr.size//2 threads
|
132
|
+
block = (16, 16, 1)
|
133
|
+
grid = [updiv(a, b) * b for a, b in zip(self.shape[::-1], block)]
|
134
|
+
self._kernel_kwargs = {"global_size": grid, "local_size": block}
|