nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/bootstrap_stitching.py +4 -2
  4. nabu/app/cast_volume.py +16 -14
  5. nabu/app/cli_configs.py +102 -9
  6. nabu/app/compare_volumes.py +1 -1
  7. nabu/app/composite_cor.py +2 -4
  8. nabu/app/diag_to_pix.py +5 -6
  9. nabu/app/diag_to_rot.py +10 -11
  10. nabu/app/double_flatfield.py +18 -5
  11. nabu/app/estimate_motion.py +75 -0
  12. nabu/app/multicor.py +28 -15
  13. nabu/app/parse_reconstruction_log.py +1 -0
  14. nabu/app/pcaflats.py +122 -0
  15. nabu/app/prepare_weights_double.py +1 -2
  16. nabu/app/reconstruct.py +1 -7
  17. nabu/app/reconstruct_helical.py +5 -9
  18. nabu/app/reduce_dark_flat.py +5 -4
  19. nabu/app/rotate.py +3 -1
  20. nabu/app/stitching.py +7 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +2 -2
  22. nabu/app/validator.py +1 -4
  23. nabu/cuda/convolution.py +1 -1
  24. nabu/cuda/fft.py +1 -1
  25. nabu/cuda/medfilt.py +1 -1
  26. nabu/cuda/padding.py +1 -1
  27. nabu/cuda/src/backproj.cu +6 -6
  28. nabu/cuda/src/cone.cu +4 -0
  29. nabu/cuda/src/hierarchical_backproj.cu +14 -0
  30. nabu/cuda/utils.py +2 -2
  31. nabu/estimation/alignment.py +17 -31
  32. nabu/estimation/cor.py +27 -33
  33. nabu/estimation/cor_sino.py +2 -8
  34. nabu/estimation/focus.py +4 -8
  35. nabu/estimation/motion.py +557 -0
  36. nabu/estimation/tests/test_alignment.py +2 -0
  37. nabu/estimation/tests/test_motion_estimation.py +471 -0
  38. nabu/estimation/tests/test_tilt.py +1 -1
  39. nabu/estimation/tilt.py +6 -5
  40. nabu/estimation/translation.py +47 -1
  41. nabu/io/cast_volume.py +108 -18
  42. nabu/io/detector_distortion.py +5 -6
  43. nabu/io/reader.py +45 -6
  44. nabu/io/reader_helical.py +5 -4
  45. nabu/io/tests/test_cast_volume.py +2 -2
  46. nabu/io/tests/test_readers.py +41 -38
  47. nabu/io/tests/test_remove_volume.py +152 -0
  48. nabu/io/tests/test_writers.py +2 -2
  49. nabu/io/utils.py +8 -4
  50. nabu/io/writer.py +1 -2
  51. nabu/misc/fftshift.py +1 -1
  52. nabu/misc/fourier_filters.py +1 -1
  53. nabu/misc/histogram.py +1 -1
  54. nabu/misc/histogram_cuda.py +1 -1
  55. nabu/misc/padding_base.py +1 -1
  56. nabu/misc/rotation.py +1 -1
  57. nabu/misc/rotation_cuda.py +1 -1
  58. nabu/misc/tests/test_binning.py +1 -1
  59. nabu/misc/transpose.py +1 -1
  60. nabu/misc/unsharp.py +1 -1
  61. nabu/misc/unsharp_cuda.py +1 -1
  62. nabu/misc/unsharp_opencl.py +1 -1
  63. nabu/misc/utils.py +1 -1
  64. nabu/opencl/fft.py +1 -1
  65. nabu/opencl/padding.py +1 -1
  66. nabu/opencl/src/backproj.cl +6 -6
  67. nabu/opencl/utils.py +8 -8
  68. nabu/pipeline/config.py +2 -2
  69. nabu/pipeline/config_validators.py +46 -46
  70. nabu/pipeline/datadump.py +3 -3
  71. nabu/pipeline/estimators.py +271 -11
  72. nabu/pipeline/fullfield/chunked.py +103 -67
  73. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  74. nabu/pipeline/fullfield/computations.py +4 -1
  75. nabu/pipeline/fullfield/dataset_validator.py +0 -1
  76. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  77. nabu/pipeline/fullfield/nabu_config.py +36 -17
  78. nabu/pipeline/fullfield/processconfig.py +41 -7
  79. nabu/pipeline/fullfield/reconstruction.py +14 -10
  80. nabu/pipeline/helical/dataset_validator.py +3 -4
  81. nabu/pipeline/helical/fbp.py +4 -4
  82. nabu/pipeline/helical/filtering.py +5 -4
  83. nabu/pipeline/helical/gridded_accumulator.py +10 -11
  84. nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
  85. nabu/pipeline/helical/helical_reconstruction.py +12 -9
  86. nabu/pipeline/helical/helical_utils.py +1 -2
  87. nabu/pipeline/helical/nabu_config.py +2 -1
  88. nabu/pipeline/helical/span_strategy.py +1 -0
  89. nabu/pipeline/helical/weight_balancer.py +2 -3
  90. nabu/pipeline/params.py +20 -3
  91. nabu/pipeline/tests/__init__.py +0 -0
  92. nabu/pipeline/tests/test_estimators.py +240 -3
  93. nabu/pipeline/utils.py +1 -1
  94. nabu/pipeline/writer.py +1 -1
  95. nabu/preproc/alignment.py +0 -10
  96. nabu/preproc/ccd.py +53 -3
  97. nabu/preproc/ctf.py +8 -8
  98. nabu/preproc/ctf_cuda.py +1 -1
  99. nabu/preproc/double_flatfield_cuda.py +2 -2
  100. nabu/preproc/double_flatfield_variable_region.py +0 -1
  101. nabu/preproc/flatfield.py +307 -2
  102. nabu/preproc/flatfield_cuda.py +1 -2
  103. nabu/preproc/flatfield_variable_region.py +3 -3
  104. nabu/preproc/phase.py +2 -4
  105. nabu/preproc/phase_cuda.py +2 -2
  106. nabu/preproc/shift.py +4 -2
  107. nabu/preproc/shift_cuda.py +0 -1
  108. nabu/preproc/tests/test_ctf.py +4 -4
  109. nabu/preproc/tests/test_double_flatfield.py +1 -1
  110. nabu/preproc/tests/test_flatfield.py +1 -1
  111. nabu/preproc/tests/test_paganin.py +1 -3
  112. nabu/preproc/tests/test_pcaflats.py +154 -0
  113. nabu/preproc/tests/test_vshift.py +4 -1
  114. nabu/processing/azim.py +9 -5
  115. nabu/processing/convolution_cuda.py +6 -4
  116. nabu/processing/fft_base.py +7 -3
  117. nabu/processing/fft_cuda.py +25 -164
  118. nabu/processing/fft_opencl.py +28 -6
  119. nabu/processing/fftshift.py +1 -1
  120. nabu/processing/histogram.py +1 -1
  121. nabu/processing/muladd.py +0 -1
  122. nabu/processing/padding_base.py +1 -1
  123. nabu/processing/padding_cuda.py +0 -2
  124. nabu/processing/processing_base.py +12 -6
  125. nabu/processing/rotation_cuda.py +3 -1
  126. nabu/processing/tests/test_fft.py +2 -64
  127. nabu/processing/tests/test_fftshift.py +1 -1
  128. nabu/processing/tests/test_medfilt.py +1 -3
  129. nabu/processing/tests/test_padding.py +1 -1
  130. nabu/processing/tests/test_roll.py +1 -1
  131. nabu/processing/tests/test_rotation.py +4 -2
  132. nabu/processing/unsharp_opencl.py +1 -1
  133. nabu/reconstruction/astra.py +245 -0
  134. nabu/reconstruction/cone.py +39 -9
  135. nabu/reconstruction/fbp.py +7 -0
  136. nabu/reconstruction/fbp_base.py +36 -5
  137. nabu/reconstruction/filtering.py +59 -25
  138. nabu/reconstruction/filtering_cuda.py +22 -21
  139. nabu/reconstruction/filtering_opencl.py +10 -14
  140. nabu/reconstruction/hbp.py +26 -13
  141. nabu/reconstruction/mlem.py +55 -16
  142. nabu/reconstruction/projection.py +3 -5
  143. nabu/reconstruction/sinogram.py +1 -1
  144. nabu/reconstruction/sinogram_cuda.py +0 -1
  145. nabu/reconstruction/tests/test_cone.py +37 -2
  146. nabu/reconstruction/tests/test_deringer.py +4 -4
  147. nabu/reconstruction/tests/test_fbp.py +36 -15
  148. nabu/reconstruction/tests/test_filtering.py +27 -7
  149. nabu/reconstruction/tests/test_halftomo.py +28 -2
  150. nabu/reconstruction/tests/test_mlem.py +94 -64
  151. nabu/reconstruction/tests/test_projector.py +7 -2
  152. nabu/reconstruction/tests/test_reconstructor.py +1 -1
  153. nabu/reconstruction/tests/test_sino_normalization.py +0 -1
  154. nabu/resources/dataset_analyzer.py +210 -24
  155. nabu/resources/gpu.py +4 -4
  156. nabu/resources/logger.py +4 -4
  157. nabu/resources/nxflatfield.py +103 -37
  158. nabu/resources/tests/test_dataset_analyzer.py +37 -0
  159. nabu/resources/tests/test_extract.py +11 -0
  160. nabu/resources/tests/test_nxflatfield.py +5 -5
  161. nabu/resources/utils.py +16 -10
  162. nabu/stitching/alignment.py +8 -11
  163. nabu/stitching/config.py +44 -35
  164. nabu/stitching/definitions.py +2 -2
  165. nabu/stitching/frame_composition.py +8 -10
  166. nabu/stitching/overlap.py +4 -4
  167. nabu/stitching/sample_normalization.py +5 -5
  168. nabu/stitching/slurm_utils.py +2 -2
  169. nabu/stitching/stitcher/base.py +2 -0
  170. nabu/stitching/stitcher/dumper/base.py +0 -1
  171. nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
  172. nabu/stitching/stitcher/post_processing.py +11 -9
  173. nabu/stitching/stitcher/pre_processing.py +37 -31
  174. nabu/stitching/stitcher/single_axis.py +2 -3
  175. nabu/stitching/stitcher_2D.py +2 -1
  176. nabu/stitching/tests/test_config.py +10 -11
  177. nabu/stitching/tests/test_sample_normalization.py +1 -1
  178. nabu/stitching/tests/test_slurm_utils.py +1 -2
  179. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  180. nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
  181. nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
  182. nabu/stitching/utils/tests/__init__.py +0 -0
  183. nabu/stitching/utils/tests/test_post-processing.py +1 -0
  184. nabu/stitching/utils/utils.py +16 -18
  185. nabu/tests.py +0 -3
  186. nabu/testutils.py +62 -9
  187. nabu/utils.py +50 -20
  188. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
  189. nabu-2025.1.0.dist-info/RECORD +328 -0
  190. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
  191. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
  192. nabu/app/correct_rot.py +0 -70
  193. nabu/io/tests/test_detector_distortion.py +0 -178
  194. nabu-2024.2.14.dist-info/RECORD +0 -317
  195. /nabu/{stitching → app}/tests/__init__.py +0 -0
  196. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
  197. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
1
+ import os
2
+ import numpy as np
3
+ import pytest
4
+
5
+ import h5py
6
+ from nabu.testutils import utilstest
7
+ from nabu.preproc.flatfield import (
8
+ PCAFlatsDecomposer,
9
+ PCAFlatsNormalizer,
10
+ )
11
+
12
+
13
+ @pytest.fixture(scope="class")
14
+ def bootstrap_pcaflats(request):
15
+ cls = request.cls
16
+ # TODO: these tolerances for having the tests passed should be tighter.
17
+ # Discrepancies between id11 code and nabu code are still mysterious.
18
+ cls.mean_abs_tol = 1e-1
19
+ cls.comps_abs_tol = 1e-2
20
+ cls.projs, cls.flats, cls.darks = get_pcaflats_data("test_pcaflats.npz")
21
+ cls.raw_projs = cls.projs.copy() # Needed because flat correction is done inplace.
22
+ ref_data = get_pcaflats_refdata("ref_pcaflats.npz")
23
+ cls.mean = ref_data["mean"]
24
+ cls.components_3 = ref_data["components_3"]
25
+ cls.components_15 = ref_data["components_15"]
26
+ cls.dark = ref_data["dark"]
27
+ cls.normalized_projs_3 = ref_data["normalized_projs_3"]
28
+ cls.normalized_projs_15 = ref_data["normalized_projs_15"]
29
+ cls.normalized_projs_custom_mask = ref_data["normalized_projs_custom_mask"]
30
+ cls.test_normalize_projs_custom_prop = ref_data["normalized_projs_custom_prop"]
31
+
32
+ cls.h5_filename_3 = get_h5_pcaflats("pcaflat_3.h5")
33
+ cls.h5_filename_15 = get_h5_pcaflats("pcaflat_15.h5")
34
+
35
+
36
+ def get_pcaflats_data(*dataset_path):
37
+ """
38
+ Get a dataset file from silx.org/pub/nabu/data
39
+ dataset_args is a list describing a nested folder structures, ex.
40
+ ["path", "to", "my", "dataset.h5"]
41
+ """
42
+ dataset_relpath = os.path.join(*dataset_path)
43
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
44
+ data = np.load(dataset_downloaded_path)
45
+ projs = data["projs"].astype(np.float32)
46
+ flats = data["flats"].astype(np.float32)
47
+ darks = data["darks"].astype(np.float32)
48
+
49
+ return projs, flats, darks
50
+
51
+
52
+ def get_h5_pcaflats(*dataset_path):
53
+ """
54
+ Get a dataset file from silx.org/pub/nabu/data
55
+ dataset_args is a list describing a nested folder structures, ex.
56
+ ["path", "to", "my", "dataset.h5"]
57
+ """
58
+ dataset_relpath = os.path.join(*dataset_path)
59
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
60
+
61
+ return dataset_downloaded_path
62
+
63
+
64
+ def get_pcaflats_refdata(*dataset_path):
65
+ """
66
+ Get a dataset file from silx.org/pub/nabu/data
67
+ dataset_args is a list describing a nested folder structures, ex.
68
+ ["path", "to", "my", "dataset.h5"]
69
+ """
70
+ dataset_relpath = os.path.join(*dataset_path)
71
+ dataset_downloaded_path = utilstest.getfile(dataset_relpath)
72
+ data = np.load(dataset_downloaded_path)
73
+
74
+ return data
75
+
76
+
77
+ def get_decomposition(filename):
78
+ with h5py.File(filename, "r") as f:
79
+ # Load the dataset
80
+ p_comps = f["entry0000/p_components"][()]
81
+ p_mean = f["entry0000/p_mean"][()]
82
+ dark = f["entry0000/dark"][()]
83
+ return p_comps, p_mean, dark
84
+
85
+
86
+ @pytest.mark.usefixtures("bootstrap_pcaflats")
87
+ class TestPCAFlatsDecomposer:
88
+ def test_decompose_flats(self):
89
+ # Build 3-sigma basis
90
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
91
+ message = f"Found a discrepency between computed mean flat and reference."
92
+ assert np.allclose(self.mean, pca.mean, atol=self.mean_abs_tol), message
93
+ message = f"Found a discrepency between computed components and reference ones if nsigma=3."
94
+ assert np.allclose(self.components_3, np.array(pca.components), atol=self.comps_abs_tol), message
95
+
96
+ # Build 1.5-sigma basis
97
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=1.5)
98
+ message = f"Found a discrepency between computed components and reference ones, if nsigma=1.5."
99
+ assert np.allclose(self.components_15, np.array(pca.components), atol=self.comps_abs_tol), message
100
+
101
+ def test_save_load_decomposition(self):
102
+ pca = PCAFlatsDecomposer(self.flats, self.darks, nsigma=3)
103
+ tmp_path = os.path.join(os.path.dirname(self.h5_filename_3), "PCA_Flats.h5")
104
+ pca.save_decomposition(path=tmp_path)
105
+ p_comps, p_mean, dark = get_decomposition(tmp_path)
106
+ message = f"Found a discrepency between saved and loaded mean flat."
107
+ assert np.allclose(self.mean, p_mean, atol=self.mean_abs_tol), message
108
+ message = f"Found a discrepency between saved and loaded components if nsigma=3."
109
+ assert np.allclose(self.components_3, p_comps, atol=self.comps_abs_tol), message
110
+ message = f"Found a discrepency between saved and loaded dark."
111
+ assert np.allclose(self.dark, dark, atol=self.comps_abs_tol), message
112
+ # Clean up
113
+ if os.path.exists(tmp_path):
114
+ os.remove(tmp_path)
115
+
116
+
117
+ @pytest.mark.usefixtures("bootstrap_pcaflats")
118
+ class TestPCAFlatsNormalizer:
119
+ def test_load_pcaflats(self):
120
+ """Tests that the structure of the output PCAFlat h5 file is correct."""
121
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
122
+ # Check the shape of the loaded data
123
+ assert p_comps.shape[1:] == p_mean.shape
124
+ assert p_comps.shape[1:] == dark.shape
125
+
126
+ def test_normalize_projs(self):
127
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
128
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
129
+ projs = self.raw_projs.copy()
130
+ pca.normalize_radios(projs)
131
+ assert np.allclose(projs, self.normalized_projs_3, atol=1e-2)
132
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_15)
133
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
134
+ projs = self.raw_projs.copy()
135
+ pca.normalize_radios(projs)
136
+ assert np.allclose(projs, self.normalized_projs_15, atol=1e-2)
137
+
138
+ def test_use_custom_mask(self):
139
+ mask = np.zeros(self.mean.shape, dtype=bool)
140
+ mask[:, :10] = True
141
+ mask[:, -10:] = True
142
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
143
+
144
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
145
+ projs = self.raw_projs.copy()
146
+ pca.normalize_radios(projs, mask=mask)
147
+ assert np.allclose(projs, self.normalized_projs_custom_mask, atol=1e-2)
148
+
149
+ def test_change_mask_prop(self):
150
+ p_comps, p_mean, dark = get_decomposition(self.h5_filename_3)
151
+ pca = PCAFlatsNormalizer(p_comps, dark, p_mean)
152
+ projs = self.raw_projs.copy()
153
+ pca.normalize_radios(projs, prop=0.05)
154
+ assert np.allclose(projs, self.test_normalize_projs_custom_prop, atol=1e-2)
@@ -70,4 +70,7 @@ class TestVerticalShift:
70
70
  Shifter_neg_cuda = CudaVerticalShift(d_radios.shape, -self.shifts)
71
71
  Shifter_neg_cuda.apply_vertical_shifts(d_radios2, self.indexes)
72
72
  err_max = np.max(np.abs(d_radios2.get() - radios2))
73
- assert err_max < 1e-6, "Something wrong for negative translations: max error = %.2e" % err_max
73
+ #
74
+ # FIXME tolerance was downgraded from 1e-6 to 8e-6 when switching to numpy 2
75
+ #
76
+ assert err_max < 8e-6, "Something wrong for negative translations: max error = %.2e" % err_max
nabu/processing/azim.py CHANGED
@@ -96,8 +96,11 @@ def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, ca
96
96
  Accumulator = np.zeros((2, nBins))
97
97
 
98
98
  # Define the bounding box
99
- xmin, xmax = X0 - mR, X0 + mR
100
- ymin, ymax = Y0 - mR, Y0 + mR
99
+ height, width = ip.shape
100
+ xmin = max(int(X0 - mR), 0)
101
+ xmax = min(int(X0 + mR), width)
102
+ ymin = max(int(Y0 - mR), 0)
103
+ ymax = min(int(Y0 + mR), height)
101
104
 
102
105
  # Create grid of coordinates
103
106
  x = np.arange(xmin, xmax)
@@ -112,10 +115,11 @@ def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, ca
112
115
  bins = np.clip(bins - 1, 0, nBins - 1) # Adjust bins to be in range [0, nBins-1]
113
116
 
114
117
  # Accumulate values
118
+ sub_image = ip[xmin:xmax, ymin:ymax] # prevent issue on non-square images
115
119
  for b in range(nBins):
116
120
  mask = bins == b
117
121
  Accumulator[0, b] = np.sum(mask)
118
- Accumulator[1, b] = np.sum(ip[mask])
122
+ Accumulator[1, b] = np.sum(sub_image[mask])
119
123
 
120
124
  # Normalize integrated intensity
121
125
  Accumulator[1] /= Accumulator[0]
@@ -123,11 +127,11 @@ def do_radial_distribution(ip, X0, Y0, mR, nBins=None, use_calibration=False, ca
123
127
  if use_calibration and cal is not None:
124
128
  # Apply calibration if units are provided
125
129
  radii = cal.pixel_width * mR * (np.arange(1, nBins + 1) / nBins)
126
- units = cal.units
130
+ # units = cal.units
127
131
  else:
128
132
  # Use pixel units
129
133
  radii = mR * (np.arange(1, nBins + 1) / nBins)
130
- units = "pixels"
134
+ # units = "pixels"
131
135
 
132
136
  if return_radii:
133
137
  return radii, Accumulator[1]
@@ -159,7 +159,7 @@ class Convolution:
159
159
  self.d_kernel = self.cuda.to_device("d_kernel", self.kernel)
160
160
  else:
161
161
  if not (isinstance(self.kernel, self.cuda.array_class)):
162
- raise ValueError("kernel must be either numpy array or pycuda array")
162
+ raise TypeError("kernel must be either numpy array or pycuda array")
163
163
  self.d_kernel = self.kernel
164
164
  self._old_input_ref = None
165
165
  self._old_output_ref = None
@@ -185,7 +185,7 @@ class Convolution:
185
185
  self._c_conv_mode = mp[self.mode]
186
186
 
187
187
  def _init_kernels(self):
188
- if self.kernel_ndim > 1:
188
+ if self.kernel_ndim > 1: # noqa: SIM102
189
189
  if np.abs(np.diff(self.kernel.shape)).max() > 0:
190
190
  raise NotImplementedError("Non-separable convolution with non-square kernels is not implemented yet")
191
191
  # Compile source module
@@ -290,7 +290,7 @@ class Convolution:
290
290
  return ndim
291
291
 
292
292
  def _check_array(self, arr):
293
- if not (isinstance(arr, self.cuda.array_class) or isinstance(arr, np.ndarray)):
293
+ if not (isinstance(arr, self.cuda.array_class) or isinstance(arr, np.ndarray)): # noqa: SIM101
294
294
  raise TypeError("Expected either pycuda.gpuarray or numpy.ndarray")
295
295
  if arr.dtype != np.float32:
296
296
  raise TypeError("Data must be float32")
@@ -305,7 +305,7 @@ class Convolution:
305
305
  self._old_input_ref = self.data_in
306
306
  self.data_in = array
307
307
  data_in_ref = self.data_in
308
- if output is not None:
308
+ if output is not None: # noqa: SIM102
309
309
  if not (isinstance(output, np.ndarray)):
310
310
  self._old_output_ref = self.data_out
311
311
  self.data_out = output
@@ -324,11 +324,13 @@ class Convolution:
324
324
  cuda_kernel = self.cuda_kernels[axis]
325
325
  cuda_kernel_args = self._configure_kernel_args(self.kernel_args, input_ref, output_ref)
326
326
  ev = cuda_kernel.prepared_call(*cuda_kernel_args)
327
+ return ev
327
328
 
328
329
  def _nd_convolution(self):
329
330
  assert len(self.use_case_kernels) == 1
330
331
  cuda_kernel = self._module.get_function(self.use_case_kernels[0])
331
332
  ev = cuda_kernel.prepared_call(*self.kernel_args)
333
+ return ev
332
334
 
333
335
  def _recover_arrays_references(self):
334
336
  if self._old_input_ref is not None:
@@ -35,7 +35,7 @@ class _BaseFFT:
35
35
  the transform is unitary. Both FFT and IFFT are scaled with 1/sqrt(N).
36
36
  * "none": no normalizatio is done : IFFT(FFT(data)) = data*N
37
37
 
38
- Other parameters
38
+ Other Parameters
39
39
  -----------------
40
40
  backend_options: dict, optional
41
41
  Parameters to pass to CudaProcessing or OpenCLProcessing class.
@@ -93,6 +93,10 @@ class _BaseFFT:
93
93
  pass
94
94
 
95
95
 
96
+ def raise_base_class_error(slf, *args, **kwargs):
97
+ raise ValueError
98
+
99
+
96
100
  class _BaseVKFFT(_BaseFFT):
97
101
  """
98
102
  FFT using VKFFT backend
@@ -101,7 +105,7 @@ class _BaseVKFFT(_BaseFFT):
101
105
  implem = "vkfft"
102
106
  backend = "none"
103
107
  ProcessingCls = BaseClassError
104
- vkffs_cls = BaseClassError
108
+ get_fft_obj = raise_base_class_error
105
109
 
106
110
  def _configure_batched_transform(self):
107
111
  if self.axes is not None and len(self.shape) == len(self.axes):
@@ -128,7 +132,7 @@ class _BaseVKFFT(_BaseFFT):
128
132
  self._vkfft_ndim = None
129
133
 
130
134
  def _compute_fft_plans(self):
131
- self._vkfft_plan = self.vkffs_cls(
135
+ self._vkfft_plan = self.get_fft_obj(
132
136
  self.shape,
133
137
  self.dtype,
134
138
  ndim=self._vkfft_ndim,
@@ -1,148 +1,33 @@
1
1
  import os
2
2
  import warnings
3
+ from functools import lru_cache
3
4
  from multiprocessing import get_context
4
5
  from multiprocessing.pool import Pool
5
- import numpy as np
6
- from ..utils import check_supported
7
- from .fft_base import _BaseFFT, _BaseVKFFT
6
+ from ..utils import BaseClassError, check_supported, no_decorator
7
+ from .fft_base import _BaseVKFFT
8
8
 
9
9
  try:
10
- from pyvkfft.cuda import VkFFTApp as vk_cufft
10
+ from pyvkfft.cuda import VkFFTApp as CudaVkFFTApp
11
11
 
12
12
  __has_vkfft__ = True
13
13
  except (ImportError, OSError):
14
14
  __has_vkfft__ = False
15
- vk_cufft = None
15
+ CudaVkFFTApp = BaseClassError
16
16
  from ..cuda.processing import CudaProcessing
17
17
 
18
- Plan = None
19
- cu_fft = None
20
- cu_ifft = None
21
- __has_skcuda__ = None
18
+ n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
22
19
 
23
20
 
24
- def init_skcuda():
25
- # This needs to be done here, because scikit-cuda creates a Cuda context at import,
26
- # which can mess things up in some cases.
27
- # Ugly solution to an ugly problem.
28
- global __has_skcuda__, Plan, cu_fft, cu_ifft
29
- try:
30
- from skcuda.fft import Plan
31
- from skcuda.fft import fft as cu_fft
32
- from skcuda.fft import ifft as cu_ifft
21
+ maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
33
22
 
34
- __has_skcuda__ = True
35
- except ImportError:
36
- __has_skcuda__ = False
37
23
 
24
+ @maybe_cached
25
+ def _get_vkfft_cuda(*args, **kwargs):
26
+ return CudaVkFFTApp(*args, **kwargs)
38
27
 
39
- class SKCUFFT(_BaseFFT):
40
- implem = "skcuda"
41
- backend = "cuda"
42
- ProcessingCls = CudaProcessing
43
28
 
44
- def _configure_batched_transform(self):
45
- if __has_skcuda__ is None:
46
- init_skcuda()
47
- if not (__has_skcuda__):
48
- raise ImportError("Please install pycuda and scikit-cuda to use the CUDA back-end")
49
-
50
- self.cufft_batch_size = 1
51
- self.cufft_shape = self.shape
52
- self._cufft_plan_kwargs = {}
53
- if (self.axes is not None) and (len(self.axes) < len(self.shape)):
54
- # In the easiest case, the transform is computed along the fastest dimensions:
55
- # - 1D transforms of lines of 2D data
56
- # - 2D transforms of images of 3D data (stacked along slow dim)
57
- # - 1D transforms of 3D data along fastest dim
58
- # Otherwise, we have to configure cuda "advanced memory layout".
59
- data_ndims = len(self.shape)
60
-
61
- if data_ndims == 2:
62
- n_y, n_x = self.shape
63
- along_fast_dim = self.axes[0] == 1
64
- self.cufft_shape = n_x if along_fast_dim else n_y
65
- self.cufft_batch_size = n_y if along_fast_dim else n_x
66
- if not (along_fast_dim):
67
- # Batched vertical 1D FFT on 2D data need advanced data layout
68
- # http://docs.nvidia.com/cuda/cufft/#advanced-data-layout
69
- self._cufft_plan_kwargs = {
70
- "inembed": np.int32([0]),
71
- "istride": n_x,
72
- "idist": 1,
73
- "onembed": np.int32([0]),
74
- "ostride": n_x,
75
- "odist": 1,
76
- }
77
-
78
- if data_ndims == 3:
79
- # TODO/FIXME - the following work for C2C but not R2C ?!
80
- # fast_axes = [(1, 2), (2, 1), (2,)]
81
- fast_axes = [(2,)]
82
- if self.axes not in fast_axes:
83
- raise NotImplementedError(
84
- "With the CUDA backend, batched transform on 3D data is only supported along fastest dimensions"
85
- )
86
- self.cufft_batch_size = self.shape[0]
87
- self.cufft_shape = self.shape[1:]
88
- if len(self.axes) == 1:
89
- # 1D transform on 3D data: here only supported along fast dim, so batch_size is Nx*Ny
90
- self.cufft_batch_size = np.prod(self.shape[:2])
91
- self.cufft_shape = (self.shape[-1],)
92
- if len(self.cufft_shape) == 1:
93
- self.cufft_shape = self.cufft_shape[0]
94
-
95
- def _configure_normalization(self, normalize):
96
- self.normalize = normalize
97
- if self.normalize == "ortho":
98
- # TODO
99
- raise NotImplementedError("Normalization mode 'ortho' is not implemented with CUDA backend yet.")
100
- self.cufft_scale_inverse = self.normalize == "rescale"
101
-
102
- def _compute_fft_plans(self):
103
- self.plan_forward = Plan( # pylint: disable = E1102
104
- self.cufft_shape,
105
- self.dtype,
106
- self.dtype_out,
107
- batch=self.cufft_batch_size,
108
- stream=self.processing.stream,
109
- **self._cufft_plan_kwargs,
110
- # cufft extensible plan API is only supported after 0.5.1
111
- # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
112
- # but there is still no official 0.5.2
113
- # ~ auto_allocate=True # cufft extensible plan API
114
- )
115
- self.plan_inverse = Plan( # pylint: disable = E1102
116
- self.cufft_shape, # not shape_out
117
- self.dtype_out,
118
- self.dtype,
119
- batch=self.cufft_batch_size,
120
- stream=self.processing.stream,
121
- **self._cufft_plan_kwargs,
122
- # cufft extensible plan API is only supported after 0.5.1
123
- # (commit 65288d28ca0b93e1234133f8d460dc6becb65121)
124
- # but there is still no official 0.5.2
125
- # ~ auto_allocate=True
126
- )
127
-
128
- def fft(self, array, output=None):
129
- if output is None:
130
- output = self.output_fft = self.processing.allocate_array(
131
- "output_fft", self.shape_out, dtype=self.dtype_out
132
- )
133
- cu_fft(array, output, self.plan_forward, scale=False) # pylint: disable = E1102
134
- return output
135
-
136
- def ifft(self, array, output=None):
137
- if output is None:
138
- output = self.output_ifft = self.processing.allocate_array("output_ifft", self.shape, dtype=self.dtype)
139
- cu_ifft( # pylint: disable = E1102
140
- array,
141
- output,
142
- self.plan_inverse,
143
- scale=self.cufft_scale_inverse,
144
- )
145
- return output
29
+ def get_vkfft_cuda(slf, *args, **kwargs):
30
+ return _get_vkfft_cuda(*args, **kwargs)
146
31
 
147
32
 
148
33
  class VKCUFFT(_BaseVKFFT):
@@ -153,7 +38,7 @@ class VKCUFFT(_BaseVKFFT):
153
38
  implem = "vkfft"
154
39
  backend = "cuda"
155
40
  ProcessingCls = CudaProcessing
156
- vkffs_cls = vk_cufft
41
+ get_fft_obj = get_vkfft_cuda
157
42
 
158
43
  def _init_backend(self, backend_options):
159
44
  super()._init_backend(backend_options)
@@ -167,13 +52,14 @@ def _has_vkfft(x):
167
52
 
168
53
  if not __has_vkfft__:
169
54
  return False
170
- vk = VKCUFFT((16,), "f")
55
+ _ = VKCUFFT((16,), "f")
171
56
  avail = True
172
57
  except (ImportError, RuntimeError, OSError, NameError):
173
58
  avail = False
174
59
  return avail
175
60
 
176
61
 
62
+ @lru_cache(maxsize=2)
177
63
  def has_vkfft(safe=True):
178
64
  """
179
65
  Determine whether pyvkfft is available.
@@ -184,44 +70,20 @@ def has_vkfft(safe=True):
184
70
  """
185
71
  if not safe:
186
72
  return _has_vkfft(None)
187
- ctx = get_context("spawn")
188
- with Pool(1, context=ctx) as p:
189
- v = p.map(_has_vkfft, [1])[0]
190
- return v
191
-
192
-
193
- def _has_skfft(x):
194
- # should be run from within a Process
195
73
  try:
196
- from nabu.processing.fft_cuda import SKCUFFT
197
-
198
- sk = SKCUFFT((16,), "f")
199
- avail = True
200
- except (ImportError, RuntimeError, OSError, NameError):
201
- avail = False
202
- return avail
203
-
204
-
205
- def has_skcuda(safe=True):
206
- """
207
- Determine whether scikit-cuda/CUFFT is available.
208
- Currently, scikit-cuda will create a Cuda context for Cublas, which can mess up the current execution.
209
- Do it in a separate thread.
210
- """
211
- if not safe:
212
- return _has_skfft(None)
213
- ctx = get_context("spawn")
214
- with Pool(1, context=ctx) as p:
215
- v = p.map(_has_skfft, [1])[0]
74
+ ctx = get_context("spawn")
75
+ with Pool(1, context=ctx) as p:
76
+ v = p.map(_has_vkfft, [1])[0]
77
+ except AssertionError:
78
+ # Can get AssertionError: daemonic processes are not allowed to have children
79
+ # if the calling code is already a subprocess
80
+ return _has_vkfft(None)
216
81
  return v
217
82
 
218
83
 
84
+ @lru_cache(maxsize=2)
219
85
  def get_fft_class(backend="vkfft"):
220
86
  backends = {
221
- "scikit-cuda": SKCUFFT,
222
- "skcuda": SKCUFFT,
223
- "cufft": SKCUFFT,
224
- "scikit": SKCUFFT,
225
87
  "vkfft": VKCUFFT,
226
88
  "pyvkfft": VKCUFFT,
227
89
  }
@@ -237,7 +99,7 @@ def get_fft_class(backend="vkfft"):
237
99
 
238
100
  avail_fft_implems = get_available_fft_implems()
239
101
  if len(avail_fft_implems) == 0:
240
- raise RuntimeError("Could not any Cuda FFT implementation. Please install either scikit-cuda or pyvkfft")
102
+ raise RuntimeError("Could not any Cuda FFT implementation. Please install pyvkfft")
241
103
  if backend not in avail_fft_implems:
242
104
  warnings.warn("Could not get FFT backend '%s'" % backend, RuntimeWarning)
243
105
  backend = avail_fft_implems[0]
@@ -245,10 +107,9 @@ def get_fft_class(backend="vkfft"):
245
107
  return get_fft_cls(backend)
246
108
 
247
109
 
110
+ @lru_cache(maxsize=1)
248
111
  def get_available_fft_implems():
249
112
  avail_implems = []
250
113
  if has_vkfft(safe=True):
251
114
  avail_implems.append("vkfft")
252
- if has_skcuda(safe=True):
253
- avail_implems.append("skcuda")
254
115
  return avail_implems
@@ -1,15 +1,32 @@
1
+ from functools import lru_cache
2
+ import os
1
3
  from multiprocessing import get_context
2
4
  from multiprocessing.pool import Pool
5
+
6
+ from ..utils import BaseClassError, no_decorator
3
7
  from .fft_base import _BaseVKFFT
4
8
  from ..opencl.processing import OpenCLProcessing
5
9
 
6
10
  try:
7
- from pyvkfft.opencl import VkFFTApp as vk_clfft
11
+ from pyvkfft.opencl import VkFFTApp as OpenCLVkFFTApp
8
12
 
9
13
  __has_vkfft__ = True
10
14
  except (ImportError, OSError):
11
15
  __has_vkfft__ = False
12
16
  vk_clfft = None
17
+ OpenCLVkFFTApp = BaseClassError
18
+
19
+ n_cached_ffts = int(os.getenv("NABU_FFT_CACHE", "0"))
20
+ maybe_cached = lru_cache(maxsize=n_cached_ffts) if n_cached_ffts > 0 else no_decorator
21
+
22
+
23
+ @maybe_cached
24
+ def _get_vkfft_opencl(*args, **kwargs):
25
+ return OpenCLVkFFTApp(*args, **kwargs)
26
+
27
+
28
+ def get_vkfft_opencl(slf, *args, **kwargs):
29
+ return _get_vkfft_opencl(*args, **kwargs)
13
30
 
14
31
 
15
32
  class VKCLFFT(_BaseVKFFT):
@@ -20,7 +37,7 @@ class VKCLFFT(_BaseVKFFT):
20
37
  implem = "vkfft"
21
38
  backend = "opencl"
22
39
  ProcessingCls = OpenCLProcessing
23
- vkffs_cls = vk_clfft
40
+ get_fft_obj = get_vkfft_opencl
24
41
 
25
42
  def _init_backend(self, backend_options):
26
43
  super()._init_backend(backend_options)
@@ -34,7 +51,7 @@ def _has_vkfft(x):
34
51
 
35
52
  if not __has_vkfft__:
36
53
  return False
37
- vk = VKCLFFT((16,), "f")
54
+ _ = VKCLFFT((16,), "f")
38
55
  avail = True
39
56
  except (RuntimeError, OSError):
40
57
  avail = False
@@ -48,7 +65,12 @@ def has_vkfft(safe=True):
48
65
  """
49
66
  if not safe:
50
67
  return _has_vkfft(None)
51
- ctx = get_context("spawn")
52
- with Pool(1, context=ctx) as p:
53
- v = p.map(_has_vkfft, [1])[0]
68
+ try:
69
+ ctx = get_context("spawn")
70
+ with Pool(1, context=ctx) as p:
71
+ v = p.map(_has_vkfft, [1])[0]
72
+ except AssertionError:
73
+ # Can get AssertionError: daemonic processes are not allowed to have children
74
+ # if the calling code is already a subprocess
75
+ return _has_vkfft(None)
54
76
  return v
@@ -25,7 +25,7 @@ class FFTshiftBase:
25
25
  axes: tuple, optional
26
26
  Axes over which to shift. Default is None, which shifts all axes.
27
27
 
28
- Other parameters
28
+ Other Parameters
29
29
  ----------------
30
30
  backend_options:
31
31
  named arguments to pass to CudaProcessing or OpenCLProcessing
@@ -146,7 +146,7 @@ class PartialHistogram:
146
146
  elif self.bin_width == "uint16":
147
147
  return self._bin_width_u16(dmin, dmax)
148
148
  else:
149
- raise ValueError()
149
+ raise ValueError
150
150
 
151
151
  def _compute_histogram_fixed_bw(self, data, data_range=None):
152
152
  dmin, dmax = data.min(), data.max() if data_range is None else data_range
nabu/processing/muladd.py CHANGED
@@ -1,4 +1,3 @@
1
- import numpy as np
2
1
  from .processing_base import ProcessingBase
3
2
 
4
3
 
@@ -23,7 +23,7 @@ class PaddingBase:
23
23
  mode: str
24
24
  Padding mode
25
25
 
26
- Other parameters
26
+ Other Parameters
27
27
  ----------------
28
28
  constant_values: tuple
29
29
  Tuple containing the values to fill when mode="constant" (as in numpy.pad)
@@ -1,7 +1,6 @@
1
1
  import numpy as np
2
2
  from ..utils import get_cuda_srcfile, updiv
3
3
  from ..cuda.processing import CudaProcessing
4
- from ..cuda.utils import __has_pycuda__
5
4
  from .padding_base import PaddingBase
6
5
 
7
6
 
@@ -12,7 +11,6 @@ class CudaPadding(PaddingBase):
12
11
 
13
12
  backend = "cuda"
14
13
 
15
- # TODO docstring from base class
16
14
  def __init__(self, shape, pad_width, mode="constant", cuda_options=None, **kwargs):
17
15
  super().__init__(shape, pad_width, mode=mode, **kwargs)
18
16
  self.cuda_processing = self.processing = CudaProcessing(**(cuda_options or {}))