nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. doc/conf.py +1 -1
  2. doc/doc_config.py +32 -0
  3. nabu/__init__.py +2 -1
  4. nabu/app/bootstrap_stitching.py +1 -1
  5. nabu/app/cli_configs.py +122 -2
  6. nabu/app/composite_cor.py +27 -2
  7. nabu/app/correct_rot.py +70 -0
  8. nabu/app/create_distortion_map_from_poly.py +42 -18
  9. nabu/app/diag_to_pix.py +358 -0
  10. nabu/app/diag_to_rot.py +449 -0
  11. nabu/app/generate_header.py +4 -3
  12. nabu/app/histogram.py +2 -2
  13. nabu/app/multicor.py +6 -1
  14. nabu/app/parse_reconstruction_log.py +151 -0
  15. nabu/app/prepare_weights_double.py +83 -22
  16. nabu/app/reconstruct.py +5 -1
  17. nabu/app/reconstruct_helical.py +7 -0
  18. nabu/app/reduce_dark_flat.py +6 -3
  19. nabu/app/rotate.py +4 -4
  20. nabu/app/stitching.py +16 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +18 -2
  22. nabu/app/validator.py +4 -4
  23. nabu/cuda/convolution.py +8 -376
  24. nabu/cuda/fft.py +4 -0
  25. nabu/cuda/kernel.py +4 -4
  26. nabu/cuda/medfilt.py +5 -158
  27. nabu/cuda/padding.py +5 -71
  28. nabu/cuda/processing.py +23 -2
  29. nabu/cuda/src/ElementOp.cu +78 -0
  30. nabu/cuda/src/backproj.cu +28 -2
  31. nabu/cuda/src/fourier_wavelets.cu +2 -2
  32. nabu/cuda/src/normalization.cu +23 -0
  33. nabu/cuda/src/padding.cu +2 -2
  34. nabu/cuda/src/transpose.cu +16 -0
  35. nabu/cuda/utils.py +39 -0
  36. nabu/estimation/alignment.py +10 -1
  37. nabu/estimation/cor.py +808 -38
  38. nabu/estimation/cor_sino.py +7 -9
  39. nabu/estimation/tests/test_cor.py +85 -3
  40. nabu/io/reader.py +26 -18
  41. nabu/io/tests/test_cast_volume.py +3 -3
  42. nabu/io/tests/test_detector_distortion.py +3 -3
  43. nabu/io/tiffwriter_zmm.py +2 -2
  44. nabu/io/utils.py +14 -4
  45. nabu/io/writer.py +5 -3
  46. nabu/misc/fftshift.py +6 -0
  47. nabu/misc/histogram.py +5 -285
  48. nabu/misc/histogram_cuda.py +8 -104
  49. nabu/misc/kernel_base.py +3 -121
  50. nabu/misc/padding_base.py +5 -69
  51. nabu/misc/processing_base.py +3 -107
  52. nabu/misc/rotation.py +5 -62
  53. nabu/misc/rotation_cuda.py +5 -65
  54. nabu/misc/transpose.py +6 -0
  55. nabu/misc/unsharp.py +3 -78
  56. nabu/misc/unsharp_cuda.py +5 -52
  57. nabu/misc/unsharp_opencl.py +8 -85
  58. nabu/opencl/fft.py +6 -0
  59. nabu/opencl/kernel.py +21 -6
  60. nabu/opencl/padding.py +5 -72
  61. nabu/opencl/processing.py +27 -5
  62. nabu/opencl/src/backproj.cl +3 -3
  63. nabu/opencl/src/fftshift.cl +65 -12
  64. nabu/opencl/src/padding.cl +2 -2
  65. nabu/opencl/src/roll.cl +96 -0
  66. nabu/opencl/src/transpose.cl +16 -0
  67. nabu/pipeline/config_validators.py +63 -3
  68. nabu/pipeline/dataset_validator.py +2 -2
  69. nabu/pipeline/estimators.py +193 -35
  70. nabu/pipeline/fullfield/chunked.py +34 -17
  71. nabu/pipeline/fullfield/chunked_cuda.py +7 -5
  72. nabu/pipeline/fullfield/computations.py +48 -13
  73. nabu/pipeline/fullfield/nabu_config.py +13 -13
  74. nabu/pipeline/fullfield/processconfig.py +10 -5
  75. nabu/pipeline/fullfield/reconstruction.py +1 -2
  76. nabu/pipeline/helical/fbp.py +5 -0
  77. nabu/pipeline/helical/filtering.py +12 -9
  78. nabu/pipeline/helical/gridded_accumulator.py +179 -33
  79. nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
  80. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
  81. nabu/pipeline/helical/helical_reconstruction.py +56 -18
  82. nabu/pipeline/helical/span_strategy.py +1 -1
  83. nabu/pipeline/helical/tests/test_accumulator.py +4 -0
  84. nabu/pipeline/params.py +23 -2
  85. nabu/pipeline/processconfig.py +3 -8
  86. nabu/pipeline/tests/test_chunk_reader.py +78 -0
  87. nabu/pipeline/tests/test_estimators.py +120 -2
  88. nabu/pipeline/utils.py +25 -0
  89. nabu/pipeline/writer.py +2 -0
  90. nabu/preproc/ccd_cuda.py +9 -7
  91. nabu/preproc/ctf.py +21 -26
  92. nabu/preproc/ctf_cuda.py +25 -25
  93. nabu/preproc/double_flatfield.py +14 -2
  94. nabu/preproc/double_flatfield_cuda.py +7 -11
  95. nabu/preproc/flatfield_cuda.py +23 -27
  96. nabu/preproc/phase.py +19 -24
  97. nabu/preproc/phase_cuda.py +21 -21
  98. nabu/preproc/shift_cuda.py +58 -28
  99. nabu/preproc/tests/test_ctf.py +5 -5
  100. nabu/preproc/tests/test_double_flatfield.py +2 -2
  101. nabu/preproc/tests/test_vshift.py +13 -2
  102. nabu/processing/__init__.py +0 -0
  103. nabu/processing/convolution_cuda.py +375 -0
  104. nabu/processing/fft_base.py +163 -0
  105. nabu/processing/fft_cuda.py +256 -0
  106. nabu/processing/fft_opencl.py +54 -0
  107. nabu/processing/fftshift.py +134 -0
  108. nabu/processing/histogram.py +286 -0
  109. nabu/processing/histogram_cuda.py +103 -0
  110. nabu/processing/kernel_base.py +126 -0
  111. nabu/processing/medfilt_cuda.py +159 -0
  112. nabu/processing/muladd.py +29 -0
  113. nabu/processing/muladd_cuda.py +68 -0
  114. nabu/processing/padding_base.py +71 -0
  115. nabu/processing/padding_cuda.py +75 -0
  116. nabu/processing/padding_opencl.py +77 -0
  117. nabu/processing/processing_base.py +123 -0
  118. nabu/processing/roll_opencl.py +64 -0
  119. nabu/processing/rotation.py +63 -0
  120. nabu/processing/rotation_cuda.py +66 -0
  121. nabu/processing/tests/__init__.py +0 -0
  122. nabu/processing/tests/test_fft.py +268 -0
  123. nabu/processing/tests/test_fftshift.py +71 -0
  124. nabu/{misc → processing}/tests/test_histogram.py +2 -4
  125. nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
  126. nabu/processing/tests/test_muladd.py +54 -0
  127. nabu/{cuda → processing}/tests/test_padding.py +119 -75
  128. nabu/processing/tests/test_roll.py +63 -0
  129. nabu/{misc → processing}/tests/test_rotation.py +3 -2
  130. nabu/processing/tests/test_transpose.py +72 -0
  131. nabu/{misc → processing}/tests/test_unsharp.py +41 -8
  132. nabu/processing/transpose.py +126 -0
  133. nabu/processing/unsharp.py +79 -0
  134. nabu/processing/unsharp_cuda.py +53 -0
  135. nabu/processing/unsharp_opencl.py +75 -0
  136. nabu/reconstruction/fbp.py +34 -10
  137. nabu/reconstruction/fbp_base.py +35 -16
  138. nabu/reconstruction/fbp_opencl.py +7 -12
  139. nabu/reconstruction/filtering.py +2 -2
  140. nabu/reconstruction/filtering_cuda.py +13 -14
  141. nabu/reconstruction/filtering_opencl.py +3 -4
  142. nabu/reconstruction/projection.py +2 -0
  143. nabu/reconstruction/rings.py +158 -1
  144. nabu/reconstruction/rings_cuda.py +218 -58
  145. nabu/reconstruction/sinogram_cuda.py +16 -12
  146. nabu/reconstruction/tests/test_deringer.py +116 -14
  147. nabu/reconstruction/tests/test_fbp.py +22 -31
  148. nabu/reconstruction/tests/test_filtering.py +11 -2
  149. nabu/resources/dataset_analyzer.py +89 -26
  150. nabu/resources/nxflatfield.py +2 -2
  151. nabu/resources/tests/test_nxflatfield.py +1 -1
  152. nabu/resources/utils.py +9 -2
  153. nabu/stitching/alignment.py +184 -0
  154. nabu/stitching/config.py +241 -39
  155. nabu/stitching/definitions.py +6 -0
  156. nabu/stitching/frame_composition.py +4 -2
  157. nabu/stitching/overlap.py +99 -3
  158. nabu/stitching/sample_normalization.py +60 -0
  159. nabu/stitching/slurm_utils.py +10 -10
  160. nabu/stitching/tests/test_alignment.py +99 -0
  161. nabu/stitching/tests/test_config.py +16 -1
  162. nabu/stitching/tests/test_overlap.py +68 -2
  163. nabu/stitching/tests/test_sample_normalization.py +49 -0
  164. nabu/stitching/tests/test_slurm_utils.py +5 -5
  165. nabu/stitching/tests/test_utils.py +3 -33
  166. nabu/stitching/tests/test_z_stitching.py +391 -22
  167. nabu/stitching/utils.py +144 -202
  168. nabu/stitching/z_stitching.py +309 -126
  169. nabu/testutils.py +18 -0
  170. nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
  171. nabu/utils.py +32 -6
  172. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
  173. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
  174. nabu-2024.1.0rc3.dist-info/RECORD +296 -0
  175. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
  176. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
  177. nabu/conftest.py +0 -14
  178. nabu/opencl/fftshift.py +0 -92
  179. nabu/opencl/tests/test_fftshift.py +0 -55
  180. nabu/opencl/tests/test_padding.py +0 -84
  181. nabu-2023.2.1.dist-info/RECORD +0 -252
  182. /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
  183. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,163 @@
1
+ import numpy as np
2
+ from ..utils import BaseClassError
3
+
4
+
5
+ class _BaseFFT:
6
+ """
7
+ A base class for FFTs.
8
+ """
9
+
10
+ implem = "none"
11
+ ProcessingCls = BaseClassError
12
+
13
+ def __init__(self, shape, dtype, r2c=True, axes=None, normalize="rescale", **backend_options):
14
+ """
15
+ Base class for Fast Fourier Transform (FFT).
16
+
17
+ Parameters
18
+ ----------
19
+ shape: list of int
20
+ Shape of the input data
21
+ dtype: str or numpy.dtype
22
+ Data type of the input data
23
+ r2c: bool, optional
24
+ Whether to use real-to-complex transform for real-valued input. Default is True.
25
+ axes: list of int, optional
26
+ Axes along which FFT is computed.
27
+ * For 2D transform: axes=(1,0)
28
+ * For batched 1D transform of 2D image: axes=(-1,)
29
+ normalize: str, optional
30
+ Whether to normalize FFT and IFFT. Possible values are:
31
+ * "rescale": in this case, Fourier data is divided by "N"
32
+ before IFFT, so that IFFT(FFT(data)) = data.
33
+ This corresponds to numpy norm=None i.e norm="backward".
34
+ * "ortho": in this case, FFT and IFFT are adjoint of eachother,
35
+ the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N).
36
+ * "none": no normalizatio is done : IFFT(FFT(data)) = data*N
37
+
38
+ Other parameters
39
+ -----------------
40
+ backend_options: dict, optional
41
+ Parameters to pass to CudaProcessing or OpenCLProcessing class.
42
+ """
43
+ self._init_backend(backend_options)
44
+ self._set_dtypes(dtype, r2c)
45
+ self._set_shape_and_axes(shape, axes)
46
+ self._configure_batched_transform()
47
+ self._configure_normalization(normalize)
48
+ self._compute_fft_plans()
49
+
50
+ def _init_backend(self, backend_options):
51
+ self.processing = self.ProcessingCls(**backend_options)
52
+
53
+ def _set_dtypes(self, dtype, r2c):
54
+ self.dtype = np.dtype(dtype)
55
+ dtypes_mapping = {
56
+ np.dtype("float32"): np.complex64,
57
+ np.dtype("float64"): np.complex128,
58
+ np.dtype("complex64"): np.complex64,
59
+ np.dtype("complex128"): np.complex128,
60
+ }
61
+ if self.dtype not in dtypes_mapping:
62
+ raise ValueError("Invalid input data type: got %s" % self.dtype)
63
+ self.dtype_out = dtypes_mapping[self.dtype]
64
+ self.r2c = r2c
65
+
66
+ def _set_shape_and_axes(self, shape, axes):
67
+ # Input shape
68
+ if np.isscalar(shape):
69
+ shape = (shape,)
70
+ self.shape = shape
71
+ # Axes
72
+ default_axes = tuple(range(len(self.shape)))
73
+ if axes is None:
74
+ self.axes = default_axes
75
+ else:
76
+ self.axes = tuple(np.array(default_axes)[np.array(axes)])
77
+ # Output shape
78
+ shape_out = self.shape
79
+ if self.r2c:
80
+ reduced_dim = self.axes[-1] if self.axes is not None else -1
81
+ shape_out = list(shape_out)
82
+ shape_out[reduced_dim] = shape_out[reduced_dim] // 2 + 1
83
+ shape_out = tuple(shape_out)
84
+ self.shape_out = shape_out
85
+
86
+ def _configure_batched_transform(self):
87
+ pass
88
+
89
+ def _configure_normalization(self, normalize):
90
+ pass
91
+
92
+ def _compute_fft_plans(self):
93
+ pass
94
+
95
+
96
+ class _BaseVKFFT(_BaseFFT):
97
+ """
98
+ FFT using VKFFT backend
99
+ """
100
+
101
+ implem = "vkfft"
102
+ backend = "none"
103
+ ProcessingCls = BaseClassError
104
+ vkffs_cls = BaseClassError
105
+
106
+ def _configure_batched_transform(self):
107
+ if self.axes is not None and len(self.shape) == len(self.axes):
108
+ self.axes = None
109
+ return
110
+ if self.r2c:
111
+ # batched Real-to-complex transforms are supported only along fast axes
112
+ if not (is_fast_axes(len(self.shape), self.axes)):
113
+ raise ValueError("For %dD R2C, only batched transforms along fast axes are allowed" % (len(self.shape)))
114
+ self._vkfft_ndim = len(self.axes)
115
+ self.axes = None # vkfft still can do a batched transform by providing dim=XX, axes=None
116
+
117
+ def _configure_normalization(self, normalize):
118
+ self.normalize = normalize
119
+ self._vkfft_norm = {
120
+ "rescale": 1,
121
+ "backward": 1,
122
+ "ortho": "ortho",
123
+ "none": 0,
124
+ }.get(self.normalize, 1)
125
+
126
+ def _set_shape_and_axes(self, shape, axes):
127
+ super()._set_shape_and_axes(shape, axes)
128
+ self._vkfft_ndim = None
129
+
130
+ def _compute_fft_plans(self):
131
+ self._vkfft_plan = self.vkffs_cls(
132
+ self.shape,
133
+ self.dtype,
134
+ ndim=self._vkfft_ndim,
135
+ inplace=False,
136
+ norm=self._vkfft_norm,
137
+ r2c=self.r2c,
138
+ dct=False,
139
+ axes=self.axes,
140
+ strides=None,
141
+ **self._vkfft_other_init_kwargs,
142
+ )
143
+
144
+ def fft(self, array, output=None):
145
+ if output is None:
146
+ output = self.output_fft = self.processing.allocate_array(
147
+ "output_fft", self.shape_out, dtype=self.dtype_out
148
+ )
149
+ return self._vkfft_plan.fft(array, dest=output)
150
+
151
+ def ifft(self, array, output=None):
152
+ if output is None:
153
+ output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
154
+ return self._vkfft_plan.ifft(array, dest=output)
155
+
156
+
157
+ def is_fast_axes(ndim, axes):
158
+ """
159
+ Return true if "axes" are the fast dimensions
160
+ """
161
+ all_axes = list(range(ndim))
162
+ axes = sorted([ax + ndim if ax < 0 else ax for ax in axes]) # transform "-1" to an actual axis index (1 for 2D)
163
+ return all_axes[-len(axes) :] == axes
@@ -0,0 +1,256 @@
1
+ import os
2
+ import warnings
3
+ from multiprocessing import get_context
4
+ from multiprocessing.pool import Pool
5
+ import numpy as np
6
+ from ..utils import check_supported
7
+ from .fft_base import _BaseFFT, _BaseVKFFT
8
+
9
+ try:
10
+ from pyvkfft.cuda import VkFFTApp as vk_cufft
11
+
12
+ __has_vkfft__ = True
13
+ except (ImportError, OSError):
14
+ __has_vkfft__ = False
15
+ vk_cufft = None
16
+ from ..cuda.processing import CudaProcessing
17
+
18
+ Plan = None
19
+ cu_fft = None
20
+ cu_ifft = None
21
+ __has_skcuda__ = None
22
+
23
+
24
+ def init_skcuda():
25
+ # This needs to be done here, because scikit-cuda creates a Cuda context at import,
26
+ # which can mess things up in some cases.
27
+ # Ugly solution to an ugly problem.
28
+ global __has_skcuda__, Plan, cu_fft, cu_ifft
29
+ try:
30
+ from skcuda.fft import Plan
31
+ from skcuda.fft import fft as cu_fft
32
+ from skcuda.fft import ifft as cu_ifft
33
+
34
+ __has_skcuda__ = True
35
+ except ImportError:
36
+ __has_skcuda__ = False
37
+
38
+
39
+ class SKCUFFT(_BaseFFT):
40
+ implem = "skcuda"
41
+ backend = "cuda"
42
+ ProcessingCls = CudaProcessing
43
+
44
+ def _configure_batched_transform(self):
45
+ if __has_skcuda__ is None:
46
+ init_skcuda()
47
+ if not (__has_skcuda__):
48
+ raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end")
49
+
50
+ self.cufft_batch_size = 1
51
+ self.cufft_shape = self.shape
52
+ self._cufft_plan_kwargs = {}
53
+ if (self.axes is not None) and (len(self.axes) < len(self.shape)):
54
+ # In the easiest case, the transform is computed along the fastest dimensions:
55
+ # - 1D transforms of lines of 2D data
56
+ # - 2D transforms of images of 3D data (stacked along slow dim)
57
+ # - 1D transforms of 3D data along fastest dim
58
+ # Otherwise, we have to configure cuda "advanced memory layout".
59
+ data_ndims = len(self.shape)
60
+
61
+ if data_ndims == 2:
62
+ n_y, n_x = self.shape
63
+ along_fast_dim = self.axes[0] == 1
64
+ self.cufft_shape = n_x if along_fast_dim else n_y
65
+ self.cufft_batch_size = n_y if along_fast_dim else n_x
66
+ if not (along_fast_dim):
67
+ # Batched vertical 1D FFT on 2D data need advanced data layout
68
+ # http://docs.nvidia.com/cuda/cufft/#advanced-data-layout
69
+ self._cufft_plan_kwargs = {
70
+ "inembed": np.int32([0]),
71
+ "istride": n_x,
72
+ "idist": 1,
73
+ "onembed": np.int32([0]),
74
+ "ostride": n_x,
75
+ "odist": 1,
76
+ }
77
+
78
+ if data_ndims == 3:
79
+ # TODO/FIXME - the following work for C2C but not R2C ?!
80
+ # fast_axes = [(1, 2), (2, 1), (2,)]
81
+ fast_axes = [(2,)]
82
+ if self.axes not in fast_axes:
83
+ raise NotImplementedError(
84
+ "With the CUDA backend, batched transform on 3D data is only supported along fastest dimensions"
85
+ )
86
+ self.cufft_batch_size = self.shape[0]
87
+ self.cufft_shape = self.shape[1:]
88
+ if len(self.axes) == 1:
89
+ # 1D transform on 3D data: here only supported along fast dim, so batch_size is Nx*Ny
90
+ self.cufft_batch_size = np.prod(self.shape[:2])
91
+ self.cufft_shape = (self.shape[-1],)
92
+ if len(self.cufft_shape) == 1:
93
+ self.cufft_shape = self.cufft_shape[0]
94
+
95
+ def _configure_normalization(self, normalize):
96
+ self.normalize = normalize
97
+ if self.normalize == "ortho":
98
+ # TODO
99
+ raise NotImplementedError("Normalization mode 'ortho' is not implemented with CUDA backend yet.")
100
+ self.cufft_scale_inverse = self.normalize == "rescale"
101
+
102
+ def _compute_fft_plans(self):
103
+ self.plan_forward = Plan( # pylint: disable = E1102
104
+ self.cufft_shape,
105
+ self.dtype,
106
+ self.dtype_out,
107
+ batch=self.cufft_batch_size,
108
+ stream=self.processing.stream,
109
+ **self._cufft_plan_kwargs,
110
+ # cufft extensible plan API is only supported after 0.5.1
111
+ # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
112
+ # but there is still no official 0.5.2
113
+ # ~ auto_allocate=True # cufft extensible plan API
114
+ )
115
+ self.plan_inverse = Plan( # pylint: disable = E1102
116
+ self.cufft_shape, # not shape_out
117
+ self.dtype_out,
118
+ self.dtype,
119
+ batch=self.cufft_batch_size,
120
+ stream=self.processing.stream,
121
+ **self._cufft_plan_kwargs,
122
+ # cufft extensible plan API is only supported after 0.5.1
123
+ # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
124
+ # but there is still no official 0.5.2
125
+ # ~ auto_allocate=True
126
+ )
127
+
128
+ def fft(self, array, output=None):
129
+ if output is None:
130
+ output = self.output_fft = self.processing.allocate_array(
131
+ "output_fft", self.shape_out, dtype=self.dtype_out
132
+ )
133
+ cu_fft(array, output, self.plan_forward, scale=False) # pylint: disable = E1102
134
+ return output
135
+
136
+ def ifft(self, array, output=None):
137
+ if output is None:
138
+ output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
139
+ cu_ifft( # pylint: disable = E1102
140
+ array,
141
+ output,
142
+ self.plan_inverse,
143
+ scale=self.cufft_scale_inverse,
144
+ )
145
+ return output
146
+
147
+
148
+ class VKCUFFT(_BaseVKFFT):
149
+ """
150
+ Cuda FFT, using VKFFT backend
151
+ """
152
+
153
+ implem = "vkfft"
154
+ backend = "cuda"
155
+ ProcessingCls = CudaProcessing
156
+ vkffs_cls = vk_cufft
157
+
158
+ def _init_backend(self, backend_options):
159
+ super()._init_backend(backend_options)
160
+ self._vkfft_other_init_kwargs = {"stream": self.processing.stream}
161
+
162
+
163
+ def _has_vkfft(x):
164
+ # should be run from within a Process
165
+ try:
166
+ from nabu.processing.fft_cuda import VKCUFFT, __has_vkfft__
167
+
168
+ if not __has_vkfft__:
169
+ return False
170
+ vk = VKCUFFT((16,), "f")
171
+ avail = True
172
+ except (RuntimeError, OSError):
173
+ avail = False
174
+ return avail
175
+
176
+
177
+ def has_vkfft(safe=True):
178
+ """
179
+ Determine whether pyvkfft is available.
180
+ For Cuda GPUs, vkfft relies on nvrtc which supports a narrow range of Cuda devices.
181
+ Unfortunately, it's not possible to determine whether vkfft is available before creating a Cuda context.
182
+ So we create a process (from scratch, i.e no fork), do the test within, and exit.
183
+ This function cannot be tested from a notebook/console, a proper entry point has to be created (if __name__ == "__main__").
184
+ """
185
+ if not safe:
186
+ return _has_vkfft(None)
187
+ ctx = get_context("spawn")
188
+ with Pool(1, context=ctx) as p:
189
+ v = p.map(_has_vkfft, [1])[0]
190
+ return v
191
+
192
+
193
+ def _has_skfft(x):
194
+ # should be run from within a Process
195
+ try:
196
+ from nabu.processing.fft_cuda import SKCUFFT
197
+
198
+ sk = SKCUFFT((16,), "f")
199
+ avail = True
200
+ except (ImportError, RuntimeError, OSError):
201
+ avail = False
202
+ return avail
203
+
204
+
205
+ def has_skcuda(safe=True):
206
+ """
207
+ Determine whether scikit-cuda/CUFFT is available.
208
+ Currently, scikit-cuda will create a Cuda context for Cublas, which can mess up the current execution.
209
+ Do it in a separate thread.
210
+ """
211
+ if not safe:
212
+ return _has_skfft(None)
213
+ ctx = get_context("spawn")
214
+ with Pool(1, context=ctx) as p:
215
+ v = p.map(_has_skfft, [1])[0]
216
+ return v
217
+
218
+
219
+ def get_fft_class(backend="skcuda"):
220
+ backends = {
221
+ "scikit-cuda": SKCUFFT,
222
+ "skcuda": SKCUFFT,
223
+ "cufft": SKCUFFT,
224
+ "scikit": SKCUFFT,
225
+ "vkfft": VKCUFFT,
226
+ "pyvkfft": VKCUFFT,
227
+ }
228
+
229
+ def check_vkfft(asked_fft_cls):
230
+ if asked_fft_cls is VKCUFFT:
231
+ if has_vkfft(safe=True) is False:
232
+ warnings.warn("Could not get VKFFT backend. Falling-back to scikit-cuda/CUFFT instead.", RuntimeWarning)
233
+ return SKCUFFT
234
+ return VKCUFFT
235
+ return SKCUFFT
236
+
237
+ def get_fft_cls(asked_fft_backend):
238
+ asked_fft_backend = asked_fft_backend.lower()
239
+ check_supported(asked_fft_backend, list(backends.keys()), "FFT backend name")
240
+ asked_fft_cls = backends[asked_fft_backend]
241
+ fft_cls = check_vkfft(asked_fft_cls)
242
+ return fft_cls
243
+
244
+ asked_fft_backend_env = os.environ.get("NABU_FFT_BACKEND", "")
245
+ if asked_fft_backend_env != "":
246
+ return get_fft_cls(asked_fft_backend_env)
247
+ return get_fft_cls(backend)
248
+
249
+
250
+ def get_available_fft_implems():
251
+ avail_implems = []
252
+ if has_skcuda(safe=True):
253
+ avail_implems.append("skcuda")
254
+ if has_vkfft(safe=True):
255
+ avail_implems.append("vkfft")
256
+ return avail_implems
@@ -0,0 +1,54 @@
1
+ from multiprocessing import get_context
2
+ from multiprocessing.pool import Pool
3
+ from .fft_base import _BaseVKFFT
4
+ from ..opencl.processing import OpenCLProcessing
5
+
6
+ try:
7
+ from pyvkfft.opencl import VkFFTApp as vk_clfft
8
+
9
+ __has_vkfft__ = True
10
+ except (ImportError, OSError):
11
+ __has_vkfft__ = False
12
+ vk_clfft = None
13
+
14
+
15
+ class VKCLFFT(_BaseVKFFT):
16
+ """
17
+ OpenCL FFT, using VKFFT backend
18
+ """
19
+
20
+ implem = "vkfft"
21
+ backend = "opencl"
22
+ ProcessingCls = OpenCLProcessing
23
+ vkffs_cls = vk_clfft
24
+
25
+ def _init_backend(self, backend_options):
26
+ super()._init_backend(backend_options)
27
+ self._vkfft_other_init_kwargs = {"queue": self.processing.queue}
28
+
29
+
30
+ def _has_vkfft(x):
31
+ # should be run from within a Process
32
+ try:
33
+ from nabu.processing.fft_opencl import VKCLFFT, __has_vkfft__
34
+
35
+ if not __has_vkfft__:
36
+ return False
37
+ vk = VKCLFFT((16,), "f")
38
+ avail = True
39
+ except (RuntimeError, OSError):
40
+ avail = False
41
+ return avail
42
+
43
+
44
+ def has_vkfft(safe=True):
45
+ """
46
+ Determine whether pyvkfft is available.
47
+ This function cannot be tested from a notebook/console, a proper entry point has to be created (if __name__ == "__main__").
48
+ """
49
+ if not safe:
50
+ return _has_vkfft(None)
51
+ ctx = get_context("spawn")
52
+ with Pool(1, context=ctx) as p:
53
+ v = p.map(_has_vkfft, [1])[0]
54
+ return v
@@ -0,0 +1,134 @@
1
+ import numpy as np
2
+ from ..utils import BaseClassError, get_opencl_srcfile, updiv
3
+ from ..opencl.kernel import OpenCLKernel
4
+ from ..opencl.processing import OpenCLProcessing
5
+ from pyopencl.tools import dtype_to_ctype as cl_dtype_to_ctype
6
+
7
+
8
+ class FFTshiftBase:
9
+ KernelCls = BaseClassError
10
+ ProcessingCls = BaseClassError
11
+ dtype_to_ctype = BaseClassError
12
+ backend = "none"
13
+
14
+ def __init__(self, shape, dtype, dst_dtype=None, axes=None, **backend_options):
15
+ """
16
+
17
+ Parameters
18
+ ----------
19
+ shape: tuple
20
+ Array shape - can be 1D or 2D. 3D is not supported.
21
+ dtype: str or numpy.dtype
22
+ Data type, eg. "f", numpy.complex64, ...
23
+ dst_dtype: str or numpy.dtype
24
+ Output data type. If not provided (default), the shift is done in-place.
25
+ axes: tuple, optional
26
+ Axes over which to shift. Default is None, which shifts all axes.
27
+
28
+ Other parameters
29
+ ----------------
30
+ backend_options:
31
+ named arguments to pass to CudaProcessing or OpenCLProcessing
32
+ """
33
+ #
34
+ if axes not in [1, (1,), (-1,)]:
35
+ raise NotImplementedError
36
+ #
37
+ self.processing = self.ProcessingCls(**backend_options)
38
+ self.shape = shape
39
+ if len(self.shape) not in [1, 2]:
40
+ raise ValueError("Expected 1D or 2D array")
41
+ self.dtype = np.dtype(dtype)
42
+ self.dst_dtype = dst_dtype
43
+
44
+ if dst_dtype is None:
45
+ self._configure_inplace_shift()
46
+ else:
47
+ self._configure_out_of_place_shift()
48
+ self._configure_kenel_initialization()
49
+ self._fftshift_kernel = self.KernelCls(*self._kernel_init_args, **self._kernel_init_kwargs)
50
+ self._configure_kernel_call()
51
+
52
+ def _configure_inplace_shift(self):
53
+ self.inplace = True
54
+ # in-place on odd-sized array is more difficult - see fftshift.cl
55
+ if self.shape[-1] & 1:
56
+ raise NotImplementedError
57
+ #
58
+ self._kernel_init_args = [
59
+ "fftshift_x_inplace",
60
+ ]
61
+ self._kernel_init_kwargs = {
62
+ "options": [
63
+ "-DDTYPE=%s" % self.dtype_to_ctype(self.dtype),
64
+ ],
65
+ }
66
+
67
+ def _configure_out_of_place_shift(self):
68
+ self.inplace = False
69
+ self._kernel_init_args = [
70
+ "fftshift_x",
71
+ ]
72
+ self._kernel_init_kwargs = {
73
+ "options": [
74
+ "-DDTYPE=%s" % self.dtype_to_ctype(self.dtype),
75
+ "-DDTYPE_OUT=%s" % self.dtype_to_ctype(np.dtype(self.dst_dtype)),
76
+ ],
77
+ }
78
+ additional_flag = None
79
+ input_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dtype))
80
+ output_is_complex = np.iscomplexobj(np.ones(1, dtype=self.dst_dtype))
81
+ if not (input_is_complex) and output_is_complex:
82
+ additional_flag = "-DCAST_TO_COMPLEX"
83
+ if input_is_complex and not (output_is_complex):
84
+ additional_flag = "-DCAST_TO_REAL"
85
+ if additional_flag is not None:
86
+ self._kernel_init_kwargs["options"].append(additional_flag)
87
+
88
+ def _call_fftshift_inplace(self, arr, direction):
89
+ self._fftshift_kernel( # pylint: disable=E1102
90
+ arr, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs
91
+ )
92
+ return arr
93
+
94
+ def _call_fftshift_out_of_place(self, arr, dst, direction):
95
+ if dst is None:
96
+ dst = self.processing.allocate_array("dst", arr.shape, dtype=self.dst_dtype)
97
+ self._fftshift_kernel( # pylint: disable=E1102
98
+ arr, dst, np.int32(self.shape[1]), np.int32(self.shape[0]), np.int32(direction), **self._kernel_kwargs
99
+ )
100
+ return dst
101
+
102
+ def fftshift(self, arr, dst=None):
103
+ if self.inplace:
104
+ return self._call_fftshift_inplace(arr, 1)
105
+ else:
106
+ return self._call_fftshift_out_of_place(arr, dst, 1)
107
+
108
+ def ifftshift(self, arr, dst=None):
109
+ if self.inplace:
110
+ return self._call_fftshift_inplace(arr, -1)
111
+ else:
112
+ return self._call_fftshift_out_of_place(arr, dst, -1)
113
+
114
+
115
+ class OpenCLFFTshift(FFTshiftBase):
116
+ KernelCls = OpenCLKernel
117
+ ProcessingCls = OpenCLProcessing
118
+ dtype_to_ctype = cl_dtype_to_ctype
119
+ backend = "opencl"
120
+
121
+ def _configure_kenel_initialization(self):
122
+ self._kernel_init_args.append(self.processing.ctx)
123
+ self._kernel_init_kwargs.update(
124
+ {
125
+ "filename": get_opencl_srcfile("fftshift.cl"),
126
+ "queue": self.processing.queue,
127
+ }
128
+ )
129
+
130
+ def _configure_kernel_call(self):
131
+ # TODO in-place fftshift needs to launch only arr.size//2 threads
132
+ block = (16, 16, 1)
133
+ grid = [updiv(a, b) * b for a, b in zip(self.shape[::-1], block)]
134
+ self._kernel_kwargs = {"global_size": grid, "local_size": block}