nabu 2025.1.0.dev5__py3-none-any.whl → 2025.1.0.dev13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/double_flatfield.py +18 -5
  3. nabu/app/multicor.py +25 -10
  4. nabu/app/reconstruct_helical.py +4 -4
  5. nabu/app/stitching.py +7 -2
  6. nabu/cuda/src/backproj.cu +10 -10
  7. nabu/cuda/src/cone.cu +4 -0
  8. nabu/cuda/utils.py +1 -1
  9. nabu/estimation/cor.py +3 -3
  10. nabu/io/cast_volume.py +16 -0
  11. nabu/io/reader.py +3 -2
  12. nabu/opencl/src/backproj.cl +10 -10
  13. nabu/pipeline/estimators.py +6 -6
  14. nabu/pipeline/fullfield/chunked.py +13 -13
  15. nabu/pipeline/fullfield/computations.py +4 -1
  16. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  17. nabu/pipeline/fullfield/nabu_config.py +16 -4
  18. nabu/pipeline/fullfield/processconfig.py +25 -4
  19. nabu/pipeline/fullfield/reconstruction.py +9 -4
  20. nabu/pipeline/helical/gridded_accumulator.py +1 -1
  21. nabu/pipeline/helical/helical_reconstruction.py +2 -2
  22. nabu/pipeline/helical/nabu_config.py +1 -1
  23. nabu/pipeline/helical/weight_balancer.py +1 -1
  24. nabu/pipeline/params.py +8 -3
  25. nabu/preproc/shift.py +1 -1
  26. nabu/preproc/tests/test_ctf.py +1 -1
  27. nabu/preproc/tests/test_paganin.py +1 -3
  28. nabu/processing/fft_base.py +6 -2
  29. nabu/processing/fft_cuda.py +17 -167
  30. nabu/processing/fft_opencl.py +19 -2
  31. nabu/processing/padding_cuda.py +0 -1
  32. nabu/processing/processing_base.py +11 -5
  33. nabu/processing/tests/test_fft.py +1 -63
  34. nabu/reconstruction/cone.py +39 -9
  35. nabu/reconstruction/fbp.py +7 -0
  36. nabu/reconstruction/fbp_base.py +8 -0
  37. nabu/reconstruction/filtering.py +59 -25
  38. nabu/reconstruction/filtering_cuda.py +21 -20
  39. nabu/reconstruction/filtering_opencl.py +8 -14
  40. nabu/reconstruction/hbp.py +10 -10
  41. nabu/reconstruction/mlem.py +3 -0
  42. nabu/reconstruction/rings_cuda.py +41 -13
  43. nabu/reconstruction/tests/test_cone.py +35 -0
  44. nabu/reconstruction/tests/test_deringer.py +2 -2
  45. nabu/reconstruction/tests/test_fbp.py +35 -14
  46. nabu/reconstruction/tests/test_filtering.py +14 -5
  47. nabu/reconstruction/tests/test_halftomo.py +1 -1
  48. nabu/reconstruction/tests/test_reconstructor.py +1 -1
  49. nabu/resources/dataset_analyzer.py +34 -2
  50. nabu/resources/tests/test_extract.py +4 -2
  51. nabu/stitching/config.py +6 -1
  52. nabu/stitching/stitcher/dumper/__init__.py +1 -0
  53. nabu/stitching/stitcher/dumper/postprocessing.py +105 -1
  54. nabu/stitching/stitcher/post_processing.py +14 -4
  55. nabu/stitching/stitcher/pre_processing.py +1 -1
  56. nabu/stitching/stitcher/single_axis.py +8 -7
  57. nabu/stitching/stitcher/z_stitcher.py +8 -4
  58. nabu/stitching/utils/utils.py +2 -2
  59. nabu/testutils.py +2 -2
  60. nabu/utils.py +9 -2
  61. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/METADATA +9 -28
  62. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/RECORD +66 -65
  63. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/WHEEL +1 -1
  64. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/entry_points.txt +0 -0
  65. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info/licenses}/LICENSE +0 -0
  66. {nabu-2025.1.0.dev5.dist-info → nabu-2025.1.0.dev13.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import posixpath
3
3
  import numpy as np
4
+ from .get_double_flatfield import get_double_flatfield
4
5
  from silx.io import get_data
5
6
  from silx.io.url import DataUrl
6
7
  from ...utils import copy_dict_items, compare_dicts
@@ -32,6 +33,7 @@ class ProcessConfig(ProcessConfigBase):
32
33
 
33
34
  (2) update_dataset_info_with_user_config
34
35
  - Update flats/darks
36
+ - Double-flat-field
35
37
  - CoR (value or estimation method) # no estimation yet
36
38
  - rotation angles
37
39
  - translations files
@@ -89,12 +91,29 @@ class ProcessConfig(ProcessConfigBase):
89
91
  self.subsampling_factor = subsampling_factor or 1
90
92
  self.subsampling_start = subsampling_start or 0
91
93
 
94
+ self._get_double_flatfield()
92
95
  self._update_dataset_with_user_overwrites()
93
96
  self._get_rotation_axis_position()
94
97
  self._update_rotation_angles()
95
98
  self._get_translation_file("reconstruction", "translation_movements_file", "translations")
96
99
  self._get_user_sino_normalization()
97
100
 
101
+ def _get_double_flatfield(self):
102
+ self._dff_file = None
103
+ dff_mode = self.nabu_config["preproc"]["double_flatfield"]
104
+ if not (dff_mode):
105
+ return
106
+ self._dff_file = get_double_flatfield(
107
+ self.dataset_info,
108
+ dff_mode,
109
+ output_dir=self.nabu_config["output"]["location"],
110
+ darks_flats_dir=self.nabu_config["dataset"]["darks_flats_dir"],
111
+ dff_options={
112
+ "dff_sigma": self.nabu_config["preproc"]["dff_sigma"],
113
+ "do_flatfield": (self.nabu_config["preproc"]["flatfield"] is not False),
114
+ },
115
+ )
116
+
98
117
  def _update_dataset_with_user_overwrites(self):
99
118
  user_overwrites = self.nabu_config["dataset"]["overwrite_metadata"].strip()
100
119
  if user_overwrites in ("", None):
@@ -451,11 +470,11 @@ class ProcessConfig(ProcessConfigBase):
451
470
  #
452
471
  # Double flat field
453
472
  #
454
- if nabu_config["preproc"]["double_flatfield_enabled"]:
473
+ if nabu_config["preproc"]["double_flatfield"]:
455
474
  tasks.append("double_flatfield")
456
475
  options["double_flatfield"] = {
457
476
  "sigma": nabu_config["preproc"]["dff_sigma"],
458
- "processes_file": nabu_config["preproc"]["processes_file"],
477
+ "processes_file": self._dff_file or nabu_config["preproc"]["processes_file"],
459
478
  "log_min_clip": nabu_config["preproc"]["log_min_clip"],
460
479
  "log_max_clip": nabu_config["preproc"]["log_max_clip"],
461
480
  }
@@ -558,6 +577,7 @@ class ProcessConfig(ProcessConfigBase):
558
577
  self.rec_params,
559
578
  [
560
579
  "method",
580
+ "iterations",
561
581
  "implementation",
562
582
  "fbp_filter_type",
563
583
  "fbp_filter_cutoff",
@@ -575,6 +595,7 @@ class ProcessConfig(ProcessConfigBase):
575
595
  "sample_detector_dist",
576
596
  "hbp_legs",
577
597
  "hbp_reduction_steps",
598
+ "crop_filtered_data",
578
599
  ],
579
600
  )
580
601
  rec_options = options["reconstruction"]
@@ -593,8 +614,6 @@ class ProcessConfig(ProcessConfigBase):
593
614
  voxel_size,
594
615
  ) # pix size is in microns in dataset_info
595
616
 
596
- rec_options["iterations"] = nabu_config["reconstruction"]["iterations"]
597
-
598
617
  # x/y/z position information
599
618
  def get_mean_pos(position_array):
600
619
  if position_array is None:
@@ -616,6 +635,8 @@ class ProcessConfig(ProcessConfigBase):
616
635
  rec_options["position"] = mean_positions_xyz
617
636
  if rec_options["method"] == "cone" and rec_options["sample_detector_dist"] is None:
618
637
  rec_options["sample_detector_dist"] = self.dataset_info.distance # was checked to be not None earlier
638
+ if rec_options["method"].lower() == "mlem" and rec_options["implementation"] in [None, ""]:
639
+ rec_options["implementation"] = "corrct"
619
640
 
620
641
  # New key
621
642
  rec_options["cor_estimated_auto"] = isinstance(nabu_config["reconstruction"]["rotation_axis_position"], str)
@@ -120,7 +120,7 @@ class FullFieldReconstructor:
120
120
  vm = virtual_memory()
121
121
  self.resources["mem_avail_GB"] = vm.available / 1e9
122
122
  # Account for other memory constraints. There might be a better way
123
- slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0))
123
+ slurm_mem_constraint_MB = int(environ.get("SLURM_MEM_PER_NODE", 0)) # noqa: PLW1508
124
124
  if slurm_mem_constraint_MB > 0:
125
125
  self.resources["mem_avail_GB"] = slurm_mem_constraint_MB / 1e3
126
126
  #
@@ -131,8 +131,12 @@ class FullFieldReconstructor:
131
131
  self.resources["gpus"] = avail_gpus
132
132
  if len(avail_gpus) == 0:
133
133
  return
134
- # pick first GPU by default. TODO: handle user's nabu_config["resources"]["gpu_id"]
135
- self.resources["gpu_id"] = self._gpu_id = first_generator_item(avail_gpus.keys())
134
+ user_gpus = self.process_config.nabu_config.get("resources", {}).get("gpu_id", [])
135
+ if len(user_gpus) == 0:
136
+ user_gpus = [0]
137
+ # For now nabu does not support multi-GPU reconstruction. Take the first one.
138
+ user_gpu_idx = user_gpus[0]
139
+ self.resources["gpu_id"] = self._gpu_id = list(avail_gpus.keys())[user_gpu_idx]
136
140
 
137
141
  def _get_backend(self, backend, cuda_options):
138
142
  self._pipeline_cls = ChunkedPipeline
@@ -145,6 +149,7 @@ class FullFieldReconstructor:
145
149
  backend = "numpy"
146
150
  else:
147
151
  self.gpu_mem = self.resources["gpus"][self._gpu_id]["memory_GB"] * self.gpu_mem_fraction
152
+ self.cuda_options = {"device_id": self._gpu_id}
148
153
  if backend == "cuda":
149
154
  if not (__has_pycuda__):
150
155
  raise RuntimeError("pycuda not avilable")
@@ -307,7 +312,7 @@ class FullFieldReconstructor:
307
312
  sigma = opts["unsharp_sigma"]
308
313
  # nabu uses cutoff = 4
309
314
  cutoff = 4
310
- gaussian_kernel_size = int(ceil(2 * cutoff * sigma + 1))
315
+ gaussian_kernel_size = ceil(2 * cutoff * sigma + 1)
311
316
  self.logger.debug("Unsharp mask margin: %d pixels" % gaussian_kernel_size)
312
317
  return (gaussian_kernel_size, gaussian_kernel_size)
313
318
 
@@ -532,7 +532,7 @@ def get_reconstruction_space(span_info, min_scanwise_z, end_scanwise_z, phase_ma
532
532
  # regridded dataset, estimating a meaningul angular step representative
533
533
  # of the raw data
534
534
  my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean())
535
- n_gridded_angles = int(round(360.0 / my_angle_step))
535
+ n_gridded_angles = round(360.0 / my_angle_step)
536
536
 
537
537
  radios_h = phase_margin_pix + (my_z_end - my_z_min) + phase_margin_pix
538
538
 
@@ -168,8 +168,8 @@ class HelicalReconstructorRegridded:
168
168
 
169
169
  # the meaming of z_min and z_max is: position in slices units from the
170
170
  # first available slice and in the direction of the scan
171
- self.z_min = int(round(z_start * (0 - z_fract_min) + z_max * z_fract_min))
172
- self.z_max = int(round(z_start * (0 - z_fract_max) + z_max * z_fract_max)) + 1
171
+ self.z_min = round(z_start * (0 - z_fract_min) + z_max * z_fract_min)
172
+ self.z_max = round(z_start * (0 - z_fract_max) + z_max * z_fract_max) + 1
173
173
 
174
174
  def _compute_translations_margin(self):
175
175
  return 0, 0
@@ -43,7 +43,7 @@ nabu_config["preproc"]["processes_file"] = {
43
43
  "validator": optional_file_location_validator,
44
44
  "type": "required",
45
45
  }
46
- nabu_config["preproc"]["double_flatfield_enabled"]["default"] = 1
46
+ nabu_config["preproc"]["double_flatfield"]["default"] = 1
47
47
 
48
48
 
49
49
  nabu_config["reconstruction"].update(
@@ -83,7 +83,7 @@ def shift(arr, shift, fill_value=0.0):
83
83
  """
84
84
  result = np.zeros_like(arr)
85
85
 
86
- num1 = int(math.floor(shift))
86
+ num1 = math.floor(shift)
87
87
  num2 = num1 + 1
88
88
  partition = shift - num1
89
89
 
nabu/pipeline/params.py CHANGED
@@ -25,12 +25,17 @@ unsharp_methods = {
25
25
  "": None,
26
26
  }
27
27
 
28
+ # see PaddingBase.supported_modes
28
29
  padding_modes = {
29
- "edges": "edge",
30
- "edge": "edge",
31
- "mirror": "mirror",
32
30
  "zeros": "zeros",
33
31
  "zero": "zeros",
32
+ "constant": "zeros",
33
+ "edges": "edge",
34
+ "edge": "edge",
35
+ "mirror": "reflect",
36
+ "reflect": "reflect",
37
+ "symmetric": "symmetric",
38
+ "wrap": "wrap",
34
39
  }
35
40
 
36
41
  reconstruction_methods = {
nabu/preproc/shift.py CHANGED
@@ -42,7 +42,7 @@ class VerticalShift:
42
42
  def _init_interp_coefficients(self):
43
43
  self.interp_infos = []
44
44
  for s in self.shifts:
45
- s0 = int(floor(s))
45
+ s0 = floor(s)
46
46
  f = s - s0
47
47
  self.interp_infos.append([s0, f])
48
48
 
@@ -223,7 +223,7 @@ class TestCtf:
223
223
  # phase_fft = ctf_fft.retrieve_phase(img)
224
224
  self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-FFT")
225
225
 
226
- @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and (scikit-cuda or vkfft)")
226
+ @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and (cupy? or vkfft)")
227
227
  def test_cuda_ctf(self):
228
228
  data = nabu_get_data("brain_phantom.npz")["data"]
229
229
  delta_beta = 50.0
@@ -77,9 +77,7 @@ class TestPaganin:
77
77
  errmax = np.max(np.abs(res - res_tomopy) / np.max(res_tomopy))
78
78
  assert errmax < self.rtol_pag, "Max error is too high"
79
79
 
80
- @pytest.mark.skipif(
81
- not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
82
- )
80
+ @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="Need pycuda and (cupy? or vkfft) for this test")
83
81
  @pytest.mark.parametrize("config", scenarios)
84
82
  def test_gpu_paganin(self, config):
85
83
  paganin, data, pag_kwargs = self.get_paganin_instance_and_data(config, self.data)
@@ -93,6 +93,10 @@ class _BaseFFT:
93
93
  pass
94
94
 
95
95
 
96
+ def raise_base_class_error(slf, *args, **kwargs):
97
+ raise ValueError
98
+
99
+
96
100
  class _BaseVKFFT(_BaseFFT):
97
101
  """
98
102
  FFT using VKFFT backend
@@ -101,7 +105,7 @@ class _BaseVKFFT(_BaseFFT):
101
105
  implem = "vkfft"
102
106
  backend = "none"
103
107
  ProcessingCls = BaseClassError
104
- vkffs_cls = BaseClassError
108
+ get_fft_obj = raise_base_class_error
105
109
 
106
110
  def _configure_batched_transform(self):
107
111
  if self.axes is not None and len(self.shape) == len(self.axes):
@@ -128,7 +132,7 @@ class _BaseVKFFT(_BaseFFT):
128
132
  self._vkfft_ndim = None
129
133
 
130
134
  def _compute_fft_plans(self):
131
- self._vkfft_plan = self.vkffs_cls(
135
+ self._vkfft_plan = self.get_fft_obj(
132
136
  self.shape,
133
137
  self.dtype,
134
138
  ndim=self._vkfft_ndim,
@@ -1,149 +1,33 @@
1
1
  import os
2
2
  import warnings
3
+ from functools import lru_cache
3
4
  from multiprocessing import get_context
4
5
  from multiprocessing.pool import Pool
5
- import numpy as np
6
- from ..utils import check_supported
7
- from .fft_base import _BaseFFT, _BaseVKFFT
6
+ from ..utils import BaseClassError, check_supported, no_decorator
7
+ from .fft_base import _BaseVKFFT
8
8
 
9
9
  try:
10
- from pyvkfft.cuda import VkFFTApp as vk_cufft
10
+ from pyvkfft.cuda import VkFFTApp as CudaVkFFTApp
11
11
 
12
12
  __has_vkfft__ = True
13
13
  except (ImportError, OSError):
14
14
  __has_vkfft__ = False
15
- vk_cufft = None
15
+ CudaVkFFTApp = BaseClassError
16
16
  from ..cuda.processing import CudaProcessing
17
17
 
18
- Plan = None
19
- cu_fft = None
20
- cu_ifft = None
21
- __has_skcuda__ = None
18
+ n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
22
19
 
23
20
 
24
- def init_skcuda():
25
- # This needs to be done here, because scikit-cuda creates a Cuda context at import,
26
- # which can mess things up in some cases.
27
- # Ugly solution to an ugly problem.
28
- # ruff: noqa: PLW0603
29
- global __has_skcuda__, Plan, cu_fft, cu_ifft
30
- try:
31
- from skcuda.fft import Plan
32
- from skcuda.fft import fft as cu_fft
33
- from skcuda.fft import ifft as cu_ifft
34
-
35
- __has_skcuda__ = True
36
- except ImportError:
37
- __has_skcuda__ = False
38
-
39
-
40
- class SKCUFFT(_BaseFFT):
41
- implem = "skcuda"
42
- backend = "cuda"
43
- ProcessingCls = CudaProcessing
44
-
45
- def _configure_batched_transform(self):
46
- if __has_skcuda__ is None:
47
- init_skcuda()
48
- if not (__has_skcuda__):
49
- raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end")
50
-
51
- self.cufft_batch_size = 1
52
- self.cufft_shape = self.shape
53
- self._cufft_plan_kwargs = {}
54
- if (self.axes is not None) and (len(self.axes) < len(self.shape)):
55
- # In the easiest case, the transform is computed along the fastest dimensions:
56
- # - 1D transforms of lines of 2D data
57
- # - 2D transforms of images of 3D data (stacked along slow dim)
58
- # - 1D transforms of 3D data along fastest dim
59
- # Otherwise, we have to configure cuda "advanced memory layout".
60
- data_ndims = len(self.shape)
21
+ maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
61
22
 
62
- if data_ndims == 2:
63
- n_y, n_x = self.shape
64
- along_fast_dim = self.axes[0] == 1
65
- self.cufft_shape = n_x if along_fast_dim else n_y
66
- self.cufft_batch_size = n_y if along_fast_dim else n_x
67
- if not (along_fast_dim):
68
- # Batched vertical 1D FFT on 2D data need advanced data layout
69
- # http://docs.nvidia.com/cuda/cufft/#advanced-data-layout
70
- self._cufft_plan_kwargs = {
71
- "inembed": np.int32([0]),
72
- "istride": n_x,
73
- "idist": 1,
74
- "onembed": np.int32([0]),
75
- "ostride": n_x,
76
- "odist": 1,
77
- }
78
23
 
79
- if data_ndims == 3:
80
- # TODO/FIXME - the following work for C2C but not R2C ?!
81
- # fast_axes = [(1, 2), (2, 1), (2,)]
82
- fast_axes = [(2,)]
83
- if self.axes not in fast_axes:
84
- raise NotImplementedError(
85
- "With the CUDA backend, batched transform on 3D data is only supported along fastest dimensions"
86
- )
87
- self.cufft_batch_size = self.shape[0]
88
- self.cufft_shape = self.shape[1:]
89
- if len(self.axes) == 1:
90
- # 1D transform on 3D data: here only supported along fast dim, so batch_size is Nx*Ny
91
- self.cufft_batch_size = np.prod(self.shape[:2])
92
- self.cufft_shape = (self.shape[-1],)
93
- if len(self.cufft_shape) == 1:
94
- self.cufft_shape = self.cufft_shape[0]
24
+ @maybe_cached
25
+ def _get_vkfft_cuda(*args, **kwargs):
26
+ return CudaVkFFTApp(*args, **kwargs)
95
27
 
96
- def _configure_normalization(self, normalize):
97
- self.normalize = normalize
98
- if self.normalize == "ortho":
99
- # TODO
100
- raise NotImplementedError("Normalization mode 'ortho' is not implemented with CUDA backend yet.")
101
- self.cufft_scale_inverse = self.normalize == "rescale"
102
28
 
103
- def _compute_fft_plans(self):
104
- self.plan_forward = Plan( # pylint: disable = E1102
105
- self.cufft_shape,
106
- self.dtype,
107
- self.dtype_out,
108
- batch=self.cufft_batch_size,
109
- stream=self.processing.stream,
110
- **self._cufft_plan_kwargs,
111
- # cufft extensible plan API is only supported after 0.5.1
112
- # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
113
- # but there is still no official 0.5.2
114
- # ~ auto_allocate=True # cufft extensible plan API
115
- )
116
- self.plan_inverse = Plan( # pylint: disable = E1102
117
- self.cufft_shape, # not shape_out
118
- self.dtype_out,
119
- self.dtype,
120
- batch=self.cufft_batch_size,
121
- stream=self.processing.stream,
122
- **self._cufft_plan_kwargs,
123
- # cufft extensible plan API is only supported after 0.5.1
124
- # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
125
- # but there is still no official 0.5.2
126
- # ~ auto_allocate=True
127
- )
128
-
129
- def fft(self, array, output=None):
130
- if output is None:
131
- output = self.output_fft = self.processing.allocate_array(
132
- "output_fft", self.shape_out, dtype=self.dtype_out
133
- )
134
- cu_fft(array, output, self.plan_forward, scale=False) # pylint: disable = E1102
135
- return output
136
-
137
- def ifft(self, array, output=None):
138
- if output is None:
139
- output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
140
- cu_ifft( # pylint: disable = E1102
141
- array,
142
- output,
143
- self.plan_inverse,
144
- scale=self.cufft_scale_inverse,
145
- )
146
- return output
29
+ def get_vkfft_cuda(slf, *args, **kwargs):
30
+ return _get_vkfft_cuda(*args, **kwargs)
147
31
 
148
32
 
149
33
  class VKCUFFT(_BaseVKFFT):
@@ -154,7 +38,7 @@ class VKCUFFT(_BaseVKFFT):
154
38
  implem = "vkfft"
155
39
  backend = "cuda"
156
40
  ProcessingCls = CudaProcessing
157
- vkffs_cls = vk_cufft
41
+ get_fft_obj = get_vkfft_cuda
158
42
 
159
43
  def _init_backend(self, backend_options):
160
44
  super()._init_backend(backend_options)
@@ -175,6 +59,7 @@ def _has_vkfft(x):
175
59
  return avail
176
60
 
177
61
 
62
+ @lru_cache(maxsize=2)
178
63
  def has_vkfft(safe=True):
179
64
  """
180
65
  Determine whether pyvkfft is available.
@@ -196,43 +81,9 @@ def has_vkfft(safe=True):
196
81
  return v
197
82
 
198
83
 
199
- def _has_skfft(x):
200
- # should be run from within a Process
201
- try:
202
- from nabu.processing.fft_cuda import SKCUFFT
203
-
204
- _ = SKCUFFT((16,), "f")
205
- avail = True
206
- except (ImportError, RuntimeError, OSError, NameError):
207
- avail = False
208
- return avail
209
-
210
-
211
- def has_skcuda(safe=True):
212
- """
213
- Determine whether scikit-cuda/CUFFT is available.
214
- Currently, scikit-cuda will create a Cuda context for Cublas, which can mess up the current execution.
215
- Do it in a separate thread.
216
- """
217
- if not safe:
218
- return _has_skfft(None)
219
- try:
220
- ctx = get_context("spawn")
221
- with Pool(1, context=ctx) as p:
222
- v = p.map(_has_skfft, [1])[0]
223
- except AssertionError:
224
- # Can get AssertionError: daemonic processes are not allowed to have children
225
- # if the calling code is already a subprocess
226
- return _has_skfft(None)
227
- return v
228
-
229
-
84
+ @lru_cache(maxsize=2)
230
85
  def get_fft_class(backend="vkfft"):
231
86
  backends = {
232
- "scikit-cuda": SKCUFFT,
233
- "skcuda": SKCUFFT,
234
- "cufft": SKCUFFT,
235
- "scikit": SKCUFFT,
236
87
  "vkfft": VKCUFFT,
237
88
  "pyvkfft": VKCUFFT,
238
89
  }
@@ -248,7 +99,7 @@ def get_fft_class(backend="vkfft"):
248
99
 
249
100
  avail_fft_implems = get_available_fft_implems()
250
101
  if len(avail_fft_implems) == 0:
251
- raise RuntimeError("Could not any Cuda FFT implementation. Please install either scikit-cuda or pyvkfft")
102
+ raise RuntimeError("Could not any Cuda FFT implementation. Please install pyvkfft")
252
103
  if backend not in avail_fft_implems:
253
104
  warnings.warn("Could not get FFT backend '%s'" % backend, RuntimeWarning)
254
105
  backend = avail_fft_implems[0]
@@ -256,10 +107,9 @@ def get_fft_class(backend="vkfft"):
256
107
  return get_fft_cls(backend)
257
108
 
258
109
 
110
+ @lru_cache(maxsize=1)
259
111
  def get_available_fft_implems():
260
112
  avail_implems = []
261
113
  if has_vkfft(safe=True):
262
114
  avail_implems.append("vkfft")
263
- if has_skcuda(safe=True):
264
- avail_implems.append("skcuda")
265
115
  return avail_implems
@@ -1,15 +1,32 @@
1
+ from functools import lru_cache
2
+ import os
1
3
  from multiprocessing import get_context
2
4
  from multiprocessing.pool import Pool
5
+
6
+ from ..utils import BaseClassError, no_decorator
3
7
  from .fft_base import _BaseVKFFT
4
8
  from ..opencl.processing import OpenCLProcessing
5
9
 
6
10
  try:
7
- from pyvkfft.opencl import VkFFTApp as vk_clfft
11
+ from pyvkfft.opencl import VkFFTApp as OpenCLVkFFTApp
8
12
 
9
13
  __has_vkfft__ = True
10
14
  except (ImportError, OSError):
11
15
  __has_vkfft__ = False
12
16
  vk_clfft = None
17
+ OpenCLVkFFTApp = BaseClassError
18
+
19
+ n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
20
+ maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
21
+
22
+
23
+ @maybe_cached
24
+ def _get_vkfft_opencl(*args, **kwargs):
25
+ return OpenCLVkFFTApp(*args, **kwargs)
26
+
27
+
28
+ def get_vkfft_opencl(slf, *args, **kwargs):
29
+ return _get_vkfft_opencl(*args, **kwargs)
13
30
 
14
31
 
15
32
  class VKCLFFT(_BaseVKFFT):
@@ -20,7 +37,7 @@ class VKCLFFT(_BaseVKFFT):
20
37
  implem = "vkfft"
21
38
  backend = "opencl"
22
39
  ProcessingCls = OpenCLProcessing
23
- vkffs_cls = vk_clfft
40
+ get_fft_obj = get_vkfft_opencl
24
41
 
25
42
  def _init_backend(self, backend_options):
26
43
  super()._init_backend(backend_options)
@@ -11,7 +11,6 @@ class CudaPadding(PaddingBase):
11
11
 
12
12
  backend = "cuda"
13
13
 
14
- # TODO docstring from base class
15
14
  def __init__(self, shape, pad_width, mode="constant", cuda_options=None, **kwargs):
16
15
  super().__init__(shape, pad_width, mode=mode, **kwargs)
17
16
  self.cuda_processing = self.processing = CudaProcessing(**(cuda_options or {}))
@@ -99,6 +99,15 @@ class ProcessingBase:
99
99
  _recover_arrays_references = recover_arrays_references
100
100
  _allocate_array = allocate_array
101
101
  _set_array = set_array
102
+ # --
103
+
104
+ def is_contiguous(self, arr):
105
+ if isinstance(arr, self.array_class):
106
+ return arr.flags.c_contiguous
107
+ elif isinstance(arr, np.ndarray):
108
+ return arr.flags["C_CONTIGUOUS"]
109
+ else:
110
+ raise TypeError
102
111
 
103
112
  def check_array(self, arr, expected_shape, expected_dtype="f", check_contiguous=True):
104
113
  """
@@ -108,11 +117,8 @@ class ProcessingBase:
108
117
  raise ValueError("Expected shape %s but got %s" % (str(expected_shape), str(arr.shape)))
109
118
  if arr.dtype != np.dtype(expected_dtype):
110
119
  raise ValueError("Expected data type %s but got %s" % (str(expected_dtype), str(arr.dtype)))
111
- if check_contiguous:
112
- if isinstance(arr, np.ndarray) and not (arr.flags["C_CONTIGUOUS"]):
113
- raise ValueError("Expected C-contiguous array")
114
- if isinstance(arr, self.array_class) and not arr.flags.c_contiguous:
115
- raise ValueError("Expected C-contiguous array")
120
+ if check_contiguous and not (self.is_contiguous(arr)):
121
+ raise ValueError("Expected C-contiguous array")
116
122
 
117
123
  def kernel(self, *args, **kwargs):
118
124
  raise ValueError("Base class")
@@ -4,14 +4,13 @@ import numpy as np
4
4
  from scipy.fft import fftn, ifftn, rfftn, irfftn
5
5
  from nabu.testutils import generate_tests_scenarios, get_data, get_array_of_given_shape, __do_long_tests__
6
6
  from nabu.cuda.utils import get_cuda_context, __has_pycuda__
7
- from nabu.processing.fft_cuda import SKCUFFT, VKCUFFT, get_available_fft_implems
7
+ from nabu.processing.fft_cuda import VKCUFFT, get_available_fft_implems
8
8
  from nabu.opencl.utils import __has_pyopencl__, get_opencl_context
9
9
  from nabu.processing.fft_opencl import VKCLFFT, has_vkfft as has_cl_vkfft
10
10
  from nabu.processing.fft_base import is_fast_axes
11
11
 
12
12
  available_cuda_fft = get_available_fft_implems()
13
13
  __has_vkfft__ = "vkfft" in available_cuda_fft
14
- __has_skcuda__ = "skcuda" in available_cuda_fft
15
14
 
16
15
 
17
16
  scenarios = {
@@ -113,67 +112,6 @@ class TestFFT:
113
112
  ref = ref_ifft_func(data, axes=axes)
114
113
  return ref
115
114
 
116
- @pytest.mark.skipif(
117
- not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
118
- )
119
- @pytest.mark.parametrize("config", scenarios)
120
- def test_sckcuda(self, config):
121
- r2c = config["r2c"]
122
- shape = config["shape"]
123
- precision = config["precision"]
124
- ndim = len(shape)
125
- if ndim == 3 and not (__do_long_tests__):
126
- pytest.skip("3D FFTs are done only for long tests - use NABU_LONG_TESTS=1")
127
-
128
- data = self._get_data_array(config)
129
-
130
- res, cufft = self._do_fft(data, r2c, return_fft_obj=True, backend_cls=SKCUFFT)
131
- ref = self._do_reference_fft(data, r2c)
132
-
133
- tol = self.abs_tol[precision][ndim]
134
- self.check_result(res, ref, config, tol, name="skcuda")
135
-
136
- # Complex-to-complex can also be performed on real data (as in numpy.fft.fft(real_data))
137
- if not (r2c):
138
- res = self._do_fft(data, False, backend_cls=SKCUFFT)
139
- ref = self._do_reference_fft(data, False)
140
- self.check_result(res, ref, config, tol, name="skcuda")
141
-
142
- # IFFT
143
- res = cufft.ifft(cufft.output_fft).get()
144
- self.check_result(res, data, config, tol, name="skcuda")
145
- # Perhaps we should also check against numpy/scipy ifft,
146
- # but it does not yield the good shape for R2C on odd-sized data
147
-
148
- @pytest.mark.skipif(
149
- not (__has_skcuda__ and __has_pycuda__), reason="Need pycuda and (scikit-cuda or vkfft) for this test"
150
- )
151
- @pytest.mark.parametrize("config", scenarios)
152
- def test_skcuda_batched(self, config):
153
- shape = config["shape"]
154
- if len(shape) == 1:
155
- return
156
- elif len(shape) == 3 and not (__do_long_tests__):
157
- pytest.skip("3D FFTs are done only for long tests - use NABU_LONG_TESTS=1")
158
- r2c = config["r2c"]
159
- tol = self.abs_tol[config["precision"]][len(shape)]
160
-
161
- data = self._get_data_array(config)
162
-
163
- if data.ndim == 2:
164
- axes_to_test = [(0,), (1,)]
165
- elif data.ndim == 3:
166
- # axes_to_test = [(1, 2), (2, 1), (2,)] # See fft.py: works for C2C but not R2C ?
167
- axes_to_test = [(2,)]
168
-
169
- for axes in axes_to_test:
170
- res, cufft = self._do_fft(data, r2c, axes=axes, return_fft_obj=True, backend_cls=SKCUFFT)
171
- ref = self._do_reference_fft(data, r2c, axes=axes)
172
- self.check_result(res, ref, config, tol, name="skcuda batched axes=%s" % (str(axes)))
173
- # IFFT
174
- res = cufft.ifft(cufft.output_fft).get()
175
- self.check_result(res, data, config, tol, name="skcuda")
176
-
177
115
  @pytest.mark.parametrize("config", scenarios)
178
116
  def test_vkfft(self, config):
179
117
  backend = config["backend"]