nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/__init__.py +0 -0
  15. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  16. nabu/cuda/kernel.py +11 -2
  17. nabu/cuda/processing.py +2 -2
  18. nabu/cuda/src/cone.cu +77 -0
  19. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  20. nabu/cuda/utils.py +0 -6
  21. nabu/estimation/alignment.py +5 -19
  22. nabu/estimation/cor.py +173 -599
  23. nabu/estimation/cor_sino.py +356 -26
  24. nabu/estimation/focus.py +63 -11
  25. nabu/estimation/tests/test_cor.py +124 -58
  26. nabu/estimation/tests/test_focus.py +6 -6
  27. nabu/estimation/tilt.py +2 -1
  28. nabu/estimation/utils.py +5 -33
  29. nabu/io/__init__.py +1 -1
  30. nabu/io/cast_volume.py +1 -1
  31. nabu/io/reader.py +416 -21
  32. nabu/io/tests/test_readers.py +422 -0
  33. nabu/io/tests/test_writers.py +1 -102
  34. nabu/io/writer.py +4 -433
  35. nabu/opencl/kernel.py +14 -3
  36. nabu/opencl/processing.py +8 -0
  37. nabu/pipeline/config_validators.py +5 -2
  38. nabu/pipeline/datadump.py +12 -5
  39. nabu/pipeline/estimators.py +162 -188
  40. nabu/pipeline/fullfield/chunked.py +168 -92
  41. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  42. nabu/pipeline/fullfield/computations.py +2 -7
  43. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  44. nabu/pipeline/fullfield/nabu_config.py +37 -13
  45. nabu/pipeline/fullfield/processconfig.py +22 -13
  46. nabu/pipeline/fullfield/reconstruction.py +13 -9
  47. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  48. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  49. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  50. nabu/pipeline/params.py +21 -1
  51. nabu/pipeline/processconfig.py +1 -12
  52. nabu/pipeline/reader.py +146 -0
  53. nabu/pipeline/tests/test_estimators.py +44 -72
  54. nabu/pipeline/utils.py +4 -2
  55. nabu/pipeline/writer.py +10 -2
  56. nabu/preproc/ccd_cuda.py +1 -1
  57. nabu/preproc/ctf.py +14 -7
  58. nabu/preproc/ctf_cuda.py +2 -3
  59. nabu/preproc/double_flatfield.py +5 -12
  60. nabu/preproc/double_flatfield_cuda.py +2 -2
  61. nabu/preproc/flatfield.py +5 -1
  62. nabu/preproc/flatfield_cuda.py +5 -1
  63. nabu/preproc/phase.py +24 -73
  64. nabu/preproc/phase_cuda.py +5 -8
  65. nabu/preproc/tests/test_ctf.py +11 -7
  66. nabu/preproc/tests/test_flatfield.py +67 -122
  67. nabu/preproc/tests/test_paganin.py +54 -30
  68. nabu/processing/azim.py +206 -0
  69. nabu/processing/convolution_cuda.py +1 -1
  70. nabu/processing/fft_cuda.py +15 -17
  71. nabu/processing/histogram.py +2 -0
  72. nabu/processing/histogram_cuda.py +2 -1
  73. nabu/processing/kernel_base.py +3 -0
  74. nabu/processing/muladd_cuda.py +1 -0
  75. nabu/processing/padding_opencl.py +1 -1
  76. nabu/processing/roll_opencl.py +1 -0
  77. nabu/processing/rotation_cuda.py +2 -2
  78. nabu/processing/tests/test_fft.py +17 -10
  79. nabu/processing/unsharp_cuda.py +1 -1
  80. nabu/reconstruction/cone.py +104 -40
  81. nabu/reconstruction/fbp.py +3 -0
  82. nabu/reconstruction/fbp_base.py +7 -2
  83. nabu/reconstruction/filtering.py +20 -7
  84. nabu/reconstruction/filtering_cuda.py +7 -1
  85. nabu/reconstruction/hbp.py +424 -0
  86. nabu/reconstruction/mlem.py +99 -0
  87. nabu/reconstruction/reconstructor.py +2 -0
  88. nabu/reconstruction/rings_cuda.py +19 -19
  89. nabu/reconstruction/sinogram_cuda.py +1 -0
  90. nabu/reconstruction/sinogram_opencl.py +3 -1
  91. nabu/reconstruction/tests/test_cone.py +10 -5
  92. nabu/reconstruction/tests/test_deringer.py +7 -6
  93. nabu/reconstruction/tests/test_fbp.py +124 -10
  94. nabu/reconstruction/tests/test_filtering.py +13 -11
  95. nabu/reconstruction/tests/test_halftomo.py +30 -4
  96. nabu/reconstruction/tests/test_mlem.py +91 -0
  97. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  98. nabu/resources/dataset_analyzer.py +142 -92
  99. nabu/resources/gpu.py +1 -0
  100. nabu/resources/nxflatfield.py +134 -125
  101. nabu/resources/templates/id16a_fluo.conf +42 -0
  102. nabu/resources/tests/test_extract.py +10 -0
  103. nabu/resources/tests/test_nxflatfield.py +2 -2
  104. nabu/stitching/alignment.py +80 -24
  105. nabu/stitching/config.py +105 -68
  106. nabu/stitching/definitions.py +1 -0
  107. nabu/stitching/frame_composition.py +68 -60
  108. nabu/stitching/overlap.py +91 -51
  109. nabu/stitching/single_axis_stitching.py +32 -0
  110. nabu/stitching/slurm_utils.py +6 -6
  111. nabu/stitching/stitcher/__init__.py +0 -0
  112. nabu/stitching/stitcher/base.py +124 -0
  113. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  114. nabu/stitching/stitcher/dumper/base.py +94 -0
  115. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  116. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  117. nabu/stitching/stitcher/post_processing.py +555 -0
  118. nabu/stitching/stitcher/pre_processing.py +1068 -0
  119. nabu/stitching/stitcher/single_axis.py +484 -0
  120. nabu/stitching/stitcher/stitcher.py +0 -0
  121. nabu/stitching/stitcher/y_stitcher.py +13 -0
  122. nabu/stitching/stitcher/z_stitcher.py +45 -0
  123. nabu/stitching/stitcher_2D.py +278 -0
  124. nabu/stitching/tests/test_config.py +12 -37
  125. nabu/stitching/tests/test_frame_composition.py +33 -59
  126. nabu/stitching/tests/test_overlap.py +149 -7
  127. nabu/stitching/tests/test_utils.py +1 -1
  128. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  129. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  130. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  131. nabu/stitching/utils/__init__.py +1 -0
  132. nabu/stitching/utils/post_processing.py +281 -0
  133. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  134. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  135. nabu/stitching/y_stitching.py +27 -0
  136. nabu/stitching/z_stitching.py +32 -2281
  137. nabu/testutils.py +1 -152
  138. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  139. nabu/utils.py +158 -61
  140. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
  141. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
  142. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
  143. nabu/io/tiffwriter_zmm.py +0 -99
  144. nabu/pipeline/fallback_utils.py +0 -149
  145. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  146. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  147. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  148. nabu/pipeline/helical/utils.py +0 -51
  149. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  150. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  151. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  152. {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,206 @@
1
+ from multiprocessing.pool import ThreadPool
2
+ import numpy as np
3
+
4
+ try:
5
+ from skimage.transform import warp_polar
6
+
7
+ __have_skimage__ = True
8
+ except ImportError:
9
+ __have_skimage__ = False
10
+
11
+
12
+ def azimuthal_integration(img, axes=(-2, -1), domain="direct"):
13
+ """
14
+ Computes azimuthal integration of an image or a stack of images.
15
+
16
+ Parameters
17
+ ----------
18
+ img : `numpy.array_like`
19
+ The image or stack of images.
20
+ axes : tuple(int, int), optional
21
+ Axes of that need to be azimuthally integrated. The default is (-2, -1).
22
+ domain : string, optional
23
+ Domain of the integration. Options are: "direct" | "fourier". Default is "direct".
24
+
25
+ Raises
26
+ ------
27
+ ValueError
28
+ Error returned when not passing images or wrong axes.
29
+ NotImplementedError
30
+ In case of tack of images for the moment.
31
+
32
+ Returns
33
+ -------
34
+ `numpy.array_like`
35
+ The azimuthally integrated profile.
36
+ """
37
+ if not len(img.shape) >= 2:
38
+ raise ValueError("Input image should be at least 2-dimensional.")
39
+ if not len(axes) == 2:
40
+ raise ValueError("Input axes should be 2.")
41
+
42
+ img_axes_dims = np.array((img.shape[axes[0]], img.shape[axes[1]]))
43
+ if domain.lower() == "direct":
44
+ half_dims = (img_axes_dims - 1) / 2
45
+ xx = np.linspace(-half_dims[0], half_dims[0], img_axes_dims[0])
46
+ yy = np.linspace(-half_dims[1], half_dims[1], img_axes_dims[1])
47
+ else:
48
+ xx = np.fft.fftfreq(img_axes_dims[0], 1 / img_axes_dims[0])
49
+ yy = np.fft.fftfreq(img_axes_dims[1], 1 / img_axes_dims[1])
50
+ xy = np.stack(np.meshgrid(xx, yy, indexing="ij"))
51
+ r = np.sqrt(np.sum(xy**2, axis=0))
52
+
53
+ img_tr_op = [*range(len(img.shape))]
54
+ for a in axes:
55
+ img_tr_op.append(img_tr_op.pop(a))
56
+ img = np.transpose(img, img_tr_op)
57
+ if len(img.shape) > 2:
58
+ img_old_shape = img.shape[:-2]
59
+ img = np.reshape(img, [-1, *img_axes_dims])
60
+
61
+ r_l = np.floor(r)
62
+ r_u = r_l + 1
63
+ w_l = (r_u - r) * img
64
+ w_u = (r - r_l) * img
65
+
66
+ r_all = np.concatenate((r_l.flatten(), r_u.flatten())).astype(np.int64)
67
+ if len(img.shape) == 2:
68
+ w_all = np.concatenate((w_l.flatten(), w_u.flatten()))
69
+ return np.bincount(r_all, weights=w_all)
70
+ else:
71
+ num_imgs = img.shape[0]
72
+ az_img = [None] * num_imgs
73
+ for ii in range(num_imgs):
74
+ w_all = np.concatenate((w_l[ii, :].flatten(), w_u[ii, :].flatten()))
75
+ az_img[ii] = np.bincount(r_all, weights=w_all)
76
+ az_img = np.array(az_img)
77
+ return np.reshape(az_img, (*img_old_shape, az_img.shape[-1]))
78
+
79
+
80
+ def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, cal=None, return_radii=False):
81
+ """
82
+ Translates the Java method `doRadialDistribution` (from imagej) into Python using NumPy.
83
+ Done by chatgpt-4o on 2024-11-08
84
+
85
+ Args:
86
+ - ip: A 2D numpy array representing the image.
87
+ - X0, Y0: Coordinates of the center.
88
+ - mR: Maximum radius.
89
+ - nBins: Number of bins (optional, defaults to 3*mR/4).
90
+ - use_calibration: Boolean indicating if calibration should be applied.
91
+ - cal: Calibration object with attributes `pixel_width` and `units` (optional).
92
+ """
93
+ if nBins is None:
94
+ nBins = int(3 * mR / 4)
95
+
96
+ Accumulator = np.zeros((2, nBins))
97
+
98
+ # Define the bounding box
99
+ xmin, xmax = X0 - mR, X0 + mR
100
+ ymin, ymax = Y0 - mR, Y0 + mR
101
+
102
+ # Create grid of coordinates
103
+ x = np.arange(xmin, xmax)
104
+ y = np.arange(ymin, ymax)
105
+ xv, yv = np.meshgrid(x, y, indexing="ij")
106
+
107
+ # Calculate the radius for each point
108
+ R = np.sqrt((xv - X0) ** 2 + (yv - Y0) ** 2)
109
+
110
+ # Bin calculation
111
+ bins = np.floor((R / mR) * nBins).astype(int)
112
+ bins = np.clip(bins - 1, 0, nBins - 1) # Adjust bins to be in range [0, nBins-1]
113
+
114
+ # Accumulate values
115
+ for b in range(nBins):
116
+ mask = bins == b
117
+ Accumulator[0, b] = np.sum(mask)
118
+ Accumulator[1, b] = np.sum(ip[mask])
119
+
120
+ # Normalize integrated intensity
121
+ Accumulator[1] /= Accumulator[0]
122
+
123
+ if use_calibration and cal is not None:
124
+ # Apply calibration if units are provided
125
+ radii = cal.pixel_width * mR * (np.arange(1, nBins + 1) / nBins)
126
+ units = cal.units
127
+ else:
128
+ # Use pixel units
129
+ radii = mR * (np.arange(1, nBins + 1) / nBins)
130
+ units = "pixels"
131
+
132
+ if return_radii:
133
+ return radii, Accumulator[1]
134
+ else:
135
+ return Accumulator[1]
136
+
137
+
138
+ # OK-ish, but small discrepancy with do_radial_distribution.
139
+ # 20-40X faster than above methods for (2048, 2048) images
140
+ # Also it assumes a uniform sampling
141
+ # No idea why there is this "offset=1", to be investigated - perhaps radius=0 is also calculated ?
142
+ def azimuthal_integration_skimage(img, center=None, offset=1):
143
+ shape2 = [int(s // 2 * 1.4142) for s in img.shape]
144
+ s = min(img.shape) // 2
145
+ img_polar = warp_polar(img, output_shape=shape2, center=center)
146
+ return img_polar.mean(axis=0)[offset : offset + s]
147
+
148
+
149
+ def _apply_on_images_stack(func, images_stack, n_threads=4, func_args=None, func_kwargs=None):
150
+ func_args = func_args or []
151
+ func_kwargs = func_kwargs or {}
152
+
153
+ def _process_image(img):
154
+ return func(img, *func_args, **func_kwargs)
155
+
156
+ with ThreadPool(n_threads) as tp:
157
+ res = tp.map(_process_image, images_stack)
158
+ return np.array(res)
159
+
160
+
161
+ def _apply_on_patches_stack(func, images_stack, n_threads=4, func_args=None, func_kwargs=None):
162
+ (n_images, n_patchs_y, img_shape_y, n_patchs_x, img_shape_x) = images_stack.shape
163
+ func_args = func_args or []
164
+ func_kwargs = func_kwargs or {}
165
+ out_sample = func(images_stack[0, 0, :, 0, :], *func_args, **func_kwargs)
166
+ out_shape = out_sample.shape
167
+ out_dtype = out_sample.dtype
168
+
169
+ def _process_image(img):
170
+ res = np.zeros((n_patchs_y, n_patchs_x) + out_shape, dtype=out_dtype)
171
+ for i in range(n_patchs_y):
172
+ for j in range(n_patchs_x):
173
+ res[i, j] = func(img[i, :, j, :], *func_args, **func_kwargs)
174
+ return res
175
+
176
+ with ThreadPool(n_threads) as tp:
177
+ res = tp.map(_process_image, images_stack)
178
+ return np.array(res)
179
+
180
+
181
+ def azimuthal_integration_imagej_stack(images_stack, n_threads=4):
182
+ if images_stack.ndim == 3:
183
+ img_shape = images_stack.shape[-2:]
184
+ _apply = _apply_on_images_stack
185
+ elif images_stack.ndim == 5:
186
+ img_shape = np.array(images_stack.shape)[[-3, -1]]
187
+ _apply = _apply_on_patches_stack
188
+ else:
189
+ raise ValueError
190
+ s = min(img_shape)
191
+ return _apply(
192
+ do_radial_distribution,
193
+ images_stack,
194
+ n_threads=n_threads,
195
+ func_args=[s // 2, s // 2, s // 2],
196
+ func_kwargs={"nBins": s // 2, "return_radii": False},
197
+ )
198
+
199
+
200
+ def azimuthal_integration_skimage_stack(images_stack, n_threads=4):
201
+ if images_stack.ndim == 3:
202
+ return _apply_on_images_stack(azimuthal_integration_skimage, images_stack, n_threads=n_threads)
203
+ elif images_stack.ndim == 5:
204
+ return _apply_on_patches_stack(azimuthal_integration_skimage, images_stack, n_threads=n_threads)
205
+ else:
206
+ raise ValueError
@@ -197,7 +197,7 @@ class Convolution:
197
197
  self.sourcemodule_kwargs["include_dirs"] = include_dirs
198
198
  with open(fname) as fid:
199
199
  cuda_src = fid.read()
200
- self._module = SourceModule(cuda_src, **self.sourcemodule_kwargs)
200
+ self._module = SourceModule(cuda_src, **self.sourcemodule_kwargs) # pylint: disable=E0606
201
201
  # Blocks, grid
202
202
  self._block_size = {1: (32, 1, 1), 2: (32, 32, 1), 3: (16, 8, 8)}[self.data_ndim] # TODO tune
203
203
  self._n_blocks = tuple([int(updiv(a, b)) for a, b in zip(self.shape[::-1], self._block_size)])
@@ -169,7 +169,7 @@ def _has_vkfft(x):
169
169
  return False
170
170
  vk = VKCUFFT((16,), "f")
171
171
  avail = True
172
- except (RuntimeError, OSError):
172
+ except (ImportError, RuntimeError, OSError, NameError):
173
173
  avail = False
174
174
  return avail
175
175
 
@@ -197,7 +197,7 @@ def _has_skfft(x):
197
197
 
198
198
  sk = SKCUFFT((16,), "f")
199
199
  avail = True
200
- except (ImportError, RuntimeError, OSError):
200
+ except (ImportError, RuntimeError, OSError, NameError):
201
201
  avail = False
202
202
  return avail
203
203
 
@@ -216,7 +216,7 @@ def has_skcuda(safe=True):
216
216
  return v
217
217
 
218
218
 
219
- def get_fft_class(backend="skcuda"):
219
+ def get_fft_class(backend="vkfft"):
220
220
  backends = {
221
221
  "scikit-cuda": SKCUFFT,
222
222
  "skcuda": SKCUFFT,
@@ -226,31 +226,29 @@ def get_fft_class(backend="skcuda"):
226
226
  "pyvkfft": VKCUFFT,
227
227
  }
228
228
 
229
- def check_vkfft(asked_fft_cls):
230
- if asked_fft_cls is VKCUFFT:
231
- if has_vkfft(safe=True) is False:
232
- warnings.warn("Could not get VKFFT backend. Falling-back to scikit-cuda/CUFFT instead.", RuntimeWarning)
233
- return SKCUFFT
234
- return VKCUFFT
235
- return SKCUFFT
236
-
237
229
  def get_fft_cls(asked_fft_backend):
238
230
  asked_fft_backend = asked_fft_backend.lower()
239
- check_supported(asked_fft_backend, list(backends.keys()), "FFT backend name")
240
- asked_fft_cls = backends[asked_fft_backend]
241
- fft_cls = check_vkfft(asked_fft_cls)
242
- return fft_cls
231
+ check_supported(asked_fft_backend, list(backends.keys()), "Cuda FFT backend name")
232
+ return backends[asked_fft_backend]
243
233
 
244
234
  asked_fft_backend_env = os.environ.get("NABU_FFT_BACKEND", "")
245
235
  if asked_fft_backend_env != "":
246
236
  return get_fft_cls(asked_fft_backend_env)
237
+
238
+ avail_fft_implems = get_available_fft_implems()
239
+ if len(avail_fft_implems) == 0:
240
+ raise RuntimeError("Could not any Cuda FFT implementation. Please install either scikit-cuda or pyvkfft")
241
+ if backend not in avail_fft_implems:
242
+ warnings.warn("Could not get FFT backend '%s'" % backend, RuntimeWarning)
243
+ backend = avail_fft_implems[0]
244
+
247
245
  return get_fft_cls(backend)
248
246
 
249
247
 
250
248
  def get_available_fft_implems():
251
249
  avail_implems = []
252
- if has_skcuda(safe=True):
253
- avail_implems.append("skcuda")
254
250
  if has_vkfft(safe=True):
255
251
  avail_implems.append("vkfft")
252
+ if has_skcuda(safe=True):
253
+ avail_implems.append("skcuda")
256
254
  return avail_implems
@@ -117,6 +117,8 @@ class PartialHistogram:
117
117
  elif self.backend == "silx":
118
118
  histogrammer = Histogramnd(data, n_bins=self.num_bins, histo_range=(dmin, dmax), last_bin_closed=True)
119
119
  res = histogrammer.histo, histogrammer.edges[0] # pylint: disable=E1136
120
+ else:
121
+ raise ValueError("Unknown backend")
120
122
  return res
121
123
 
122
124
  def _merge_histograms_fixed_nbins(self, histograms, dont_truncate_bins=False):
@@ -25,7 +25,7 @@ class CudaPartialHistogram(PartialHistogram):
25
25
  num_bins=num_bins,
26
26
  min_bins=min_bins,
27
27
  )
28
- self.cuda_processing = CudaProcessing(**(cuda_options or {}))
28
+ self.cuda_processing = CudaProcessing(**(cuda_options or {})) # pylint: disable=E0606
29
29
  self._init_cuda_histogram()
30
30
 
31
31
  def _init_cuda_histogram(self):
@@ -43,6 +43,7 @@ class CudaPartialHistogram(PartialHistogram):
43
43
  # Should be possible to do both in one single pass with ReductionKernel
44
44
  # and garray.vec.float2, but the last step in volatile shared memory
45
45
  # still gives errors. To be investigated...
46
+ # pylint: disable=E0606
46
47
  data_min = garray.min(data).get()[()]
47
48
  data_max = garray.max(data).get()[()]
48
49
  else:
@@ -2,6 +2,7 @@
2
2
  Base class for CudaKernel and OpenCLKernel
3
3
  Should not be used directly
4
4
  """
5
+
5
6
  from ..utils import updiv
6
7
 
7
8
 
@@ -47,9 +48,11 @@ class KernelBase:
47
48
  filename=None,
48
49
  src=None,
49
50
  automation_params=None,
51
+ silent_compilation_warnings=False,
50
52
  ):
51
53
  self.check_filename_src(filename, src)
52
54
  self.set_automation_params(automation_params)
55
+ self.silent_compilation_warnings = silent_compilation_warnings
53
56
 
54
57
  def check_filename_src(self, filename, src):
55
58
  err_msg = "Please provide either filename or src"
@@ -42,6 +42,7 @@ class CudaMulAdd(MulAdd):
42
42
  raise ValueError("delta_x or delta_y is 0")
43
43
 
44
44
  # can't use "int4" in pycuda ? int2 seems fine. Go figure
45
+ # pylint: disable=E0606
45
46
  dst_x_range = np.array(dst_coords[:2], dtype=garray.vec.int2)
46
47
  dst_y_range = np.array(dst_coords[2:], dtype=garray.vec.int2)
47
48
  other_x_range = np.array(other_coords[:2], dtype=garray.vec.int2)
@@ -27,7 +27,7 @@ class OpenCLPadding(PaddingBase):
27
27
  self.d_padded_array_constant = self.processing.to_device(
28
28
  "d_padded_array_constant", self.padded_array_constant
29
29
  )
30
- self.memcpy2D = OpenCLMemcpy2D(ctx=self.processing.ctx, queue=self.queue)
30
+ self.memcpy2D = OpenCLMemcpy2D(ctx=self.processing.ctx, queue=self.queue) # pylint: disable=E0606
31
31
  return
32
32
  self._coords_transform_kernel = self.processing.kernel(
33
33
  "coordinate_transform",
@@ -1,6 +1,7 @@
1
1
  #
2
2
  # WIP !
3
3
  #
4
+ # pylint: skip-file
4
5
  import numpy as np
5
6
  from ..opencl.utils import __has_pyopencl__
6
7
  from ..utils import get_opencl_srcfile
@@ -23,11 +23,11 @@ class CudaRotation(Rotation):
23
23
  self._init_rotation_kernel()
24
24
 
25
25
  def _allocate_arrays(self):
26
- self._d_image_cua = cuda.np_to_array(np.zeros(self.shape, "f"), "C")
26
+ self._d_image_cua = cuda.np_to_array(np.zeros(self.shape, "f"), "C") # pylint: disable=E0606
27
27
  self.cuda_processing.init_arrays_to_none(["d_output"])
28
28
 
29
29
  def _init_rotation_kernel(self):
30
- self.cuda_rotation_kernel = CudaKernel("rotate", get_cuda_srcfile("rotation.cu"))
30
+ self.cuda_rotation_kernel = CudaKernel("rotate", get_cuda_srcfile("rotation.cu")) # pylint: disable=E0606
31
31
  self.texref_image = self.cuda_rotation_kernel.module.get_texref("tex_image")
32
32
  self.texref_image.set_filter_mode(cuda.filter_mode.LINEAR) # bilinear
33
33
  self.texref_image.set_address_mode(0, cuda.address_mode.CLAMP) # TODO tune
@@ -4,17 +4,14 @@ import numpy as np
4
4
  from scipy.fft import fftn, ifftn, rfftn, irfftn
5
5
  from nabu.testutils import generate_tests_scenarios, get_data, get_array_of_given_shape, __do_long_tests__
6
6
  from nabu.cuda.utils import get_cuda_context, __has_pycuda__
7
- from nabu.processing.fft_cuda import SKCUFFT, VKCUFFT, has_vkfft as has_cuda_vkfft
7
+ from nabu.processing.fft_cuda import SKCUFFT, VKCUFFT, get_available_fft_implems
8
8
  from nabu.opencl.utils import __has_pyopencl__, get_opencl_context
9
9
  from nabu.processing.fft_opencl import VKCLFFT, has_vkfft as has_cl_vkfft
10
10
  from nabu.processing.fft_base import is_fast_axes
11
11
 
12
- try:
13
- import skcuda
14
-
15
- __has_skcuda__ = True
16
- except ImportError:
17
- __has_skcuda__ = False
12
+ available_cuda_fft = get_available_fft_implems()
13
+ __has_vkfft__ = "vkfft" in available_cuda_fft
14
+ __has_skcuda__ = "skcuda" in available_cuda_fft
18
15
 
19
16
 
20
17
  scenarios = {
@@ -60,7 +57,7 @@ def bootstrap(request):
60
57
  def _get_fft_cls(backend):
61
58
  fft_cls = None
62
59
  if backend == "cuda":
63
- if not (has_cuda_vkfft() and __has_pycuda__):
60
+ if not (__has_vkfft__ and __has_pycuda__):
64
61
  pytest.skip("Need vkfft and pycuda to use VKCUFFT")
65
62
  fft_cls = VKCUFFT
66
63
  if backend == "opencl":
@@ -116,7 +113,9 @@ class TestFFT:
116
113
  ref = ref_ifft_func(data, axes=axes)
117
114
  return ref
118
115
 
119
- @pytest.mark.skipif(not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda scikit-cuda for this test")
116
+ @pytest.mark.skipif(
117
+ not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
118
+ )
120
119
  @pytest.mark.parametrize("config", scenarios)
121
120
  def test_sckcuda(self, config):
122
121
  r2c = config["r2c"]
@@ -146,7 +145,9 @@ class TestFFT:
146
145
  # Perhaps we should also check against numpy/scipy ifft,
147
146
  # but it does not yield the good shape for R2C on odd-sized data
148
147
 
149
- @pytest.mark.skipif(not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda scikit-cuda for this test")
148
+ @pytest.mark.skipif(
149
+ not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
150
+ )
150
151
  @pytest.mark.parametrize("config", scenarios)
151
152
  def test_skcuda_batched(self, config):
152
153
  shape = config["shape"]
@@ -188,6 +189,12 @@ class TestFFT:
188
189
  if ndim >= 2 and r2c and shape[-1] & 1:
189
190
  pytest.skip("R2C with odd-sized fast dimension is not supported in VKFFT")
190
191
 
192
+ # FIXME - vkfft + POCL fail for R2C in one dimension
193
+ if config["backend"] == "opencl" and r2c and ndim == 1:
194
+ if self.cl_ctx.devices[0].platform.name.strip().lower() == "portable computing language":
195
+ pytest.skip("Something wrong with vkfft + pocl for R2C 1D")
196
+ # ---
197
+
191
198
  data = self._get_data_array(config)
192
199
 
193
200
  res, fft_obj = self._do_fft(data, r2c, return_fft_obj=True, backend_cls=fft_cls)
@@ -32,7 +32,7 @@ class CudaUnsharpMask(UnsharpMask):
32
32
 
33
33
  def _init_mad_kernel(self):
34
34
  # garray.GPUArray.mul_add is out of place...
35
- self.mad_kernel = ElementwiseKernel(
35
+ self.mad_kernel = ElementwiseKernel( # pylint: disable=E0606
36
36
  "float* array, float fac, float* other, float otherfac",
37
37
  "array[i] = fac * array[i] + otherfac * other[i]",
38
38
  name="mul_add",