nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. doc/conf.py +1 -1
  2. doc/doc_config.py +32 -0
  3. nabu/__init__.py +2 -1
  4. nabu/app/bootstrap_stitching.py +1 -1
  5. nabu/app/cli_configs.py +122 -2
  6. nabu/app/composite_cor.py +27 -2
  7. nabu/app/correct_rot.py +70 -0
  8. nabu/app/create_distortion_map_from_poly.py +42 -18
  9. nabu/app/diag_to_pix.py +358 -0
  10. nabu/app/diag_to_rot.py +449 -0
  11. nabu/app/generate_header.py +4 -3
  12. nabu/app/histogram.py +2 -2
  13. nabu/app/multicor.py +6 -1
  14. nabu/app/parse_reconstruction_log.py +151 -0
  15. nabu/app/prepare_weights_double.py +83 -22
  16. nabu/app/reconstruct.py +5 -1
  17. nabu/app/reconstruct_helical.py +7 -0
  18. nabu/app/reduce_dark_flat.py +6 -3
  19. nabu/app/rotate.py +4 -4
  20. nabu/app/stitching.py +16 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +18 -2
  22. nabu/app/validator.py +4 -4
  23. nabu/cuda/convolution.py +8 -376
  24. nabu/cuda/fft.py +4 -0
  25. nabu/cuda/kernel.py +4 -4
  26. nabu/cuda/medfilt.py +5 -158
  27. nabu/cuda/padding.py +5 -71
  28. nabu/cuda/processing.py +23 -2
  29. nabu/cuda/src/ElementOp.cu +78 -0
  30. nabu/cuda/src/backproj.cu +28 -2
  31. nabu/cuda/src/fourier_wavelets.cu +2 -2
  32. nabu/cuda/src/normalization.cu +23 -0
  33. nabu/cuda/src/padding.cu +2 -2
  34. nabu/cuda/src/transpose.cu +16 -0
  35. nabu/cuda/utils.py +39 -0
  36. nabu/estimation/alignment.py +10 -1
  37. nabu/estimation/cor.py +808 -38
  38. nabu/estimation/cor_sino.py +7 -9
  39. nabu/estimation/tests/test_cor.py +85 -3
  40. nabu/io/reader.py +26 -18
  41. nabu/io/tests/test_cast_volume.py +3 -3
  42. nabu/io/tests/test_detector_distortion.py +3 -3
  43. nabu/io/tiffwriter_zmm.py +2 -2
  44. nabu/io/utils.py +14 -4
  45. nabu/io/writer.py +5 -3
  46. nabu/misc/fftshift.py +6 -0
  47. nabu/misc/histogram.py +5 -285
  48. nabu/misc/histogram_cuda.py +8 -104
  49. nabu/misc/kernel_base.py +3 -121
  50. nabu/misc/padding_base.py +5 -69
  51. nabu/misc/processing_base.py +3 -107
  52. nabu/misc/rotation.py +5 -62
  53. nabu/misc/rotation_cuda.py +5 -65
  54. nabu/misc/transpose.py +6 -0
  55. nabu/misc/unsharp.py +3 -78
  56. nabu/misc/unsharp_cuda.py +5 -52
  57. nabu/misc/unsharp_opencl.py +8 -85
  58. nabu/opencl/fft.py +6 -0
  59. nabu/opencl/kernel.py +21 -6
  60. nabu/opencl/padding.py +5 -72
  61. nabu/opencl/processing.py +27 -5
  62. nabu/opencl/src/backproj.cl +3 -3
  63. nabu/opencl/src/fftshift.cl +65 -12
  64. nabu/opencl/src/padding.cl +2 -2
  65. nabu/opencl/src/roll.cl +96 -0
  66. nabu/opencl/src/transpose.cl +16 -0
  67. nabu/pipeline/config_validators.py +63 -3
  68. nabu/pipeline/dataset_validator.py +2 -2
  69. nabu/pipeline/estimators.py +193 -35
  70. nabu/pipeline/fullfield/chunked.py +34 -17
  71. nabu/pipeline/fullfield/chunked_cuda.py +7 -5
  72. nabu/pipeline/fullfield/computations.py +48 -13
  73. nabu/pipeline/fullfield/nabu_config.py +13 -13
  74. nabu/pipeline/fullfield/processconfig.py +10 -5
  75. nabu/pipeline/fullfield/reconstruction.py +1 -2
  76. nabu/pipeline/helical/fbp.py +5 -0
  77. nabu/pipeline/helical/filtering.py +12 -9
  78. nabu/pipeline/helical/gridded_accumulator.py +179 -33
  79. nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
  80. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
  81. nabu/pipeline/helical/helical_reconstruction.py +56 -18
  82. nabu/pipeline/helical/span_strategy.py +1 -1
  83. nabu/pipeline/helical/tests/test_accumulator.py +4 -0
  84. nabu/pipeline/params.py +23 -2
  85. nabu/pipeline/processconfig.py +3 -8
  86. nabu/pipeline/tests/test_chunk_reader.py +78 -0
  87. nabu/pipeline/tests/test_estimators.py +120 -2
  88. nabu/pipeline/utils.py +25 -0
  89. nabu/pipeline/writer.py +2 -0
  90. nabu/preproc/ccd_cuda.py +9 -7
  91. nabu/preproc/ctf.py +21 -26
  92. nabu/preproc/ctf_cuda.py +25 -25
  93. nabu/preproc/double_flatfield.py +14 -2
  94. nabu/preproc/double_flatfield_cuda.py +7 -11
  95. nabu/preproc/flatfield_cuda.py +23 -27
  96. nabu/preproc/phase.py +19 -24
  97. nabu/preproc/phase_cuda.py +21 -21
  98. nabu/preproc/shift_cuda.py +58 -28
  99. nabu/preproc/tests/test_ctf.py +5 -5
  100. nabu/preproc/tests/test_double_flatfield.py +2 -2
  101. nabu/preproc/tests/test_vshift.py +13 -2
  102. nabu/processing/__init__.py +0 -0
  103. nabu/processing/convolution_cuda.py +375 -0
  104. nabu/processing/fft_base.py +163 -0
  105. nabu/processing/fft_cuda.py +256 -0
  106. nabu/processing/fft_opencl.py +54 -0
  107. nabu/processing/fftshift.py +134 -0
  108. nabu/processing/histogram.py +286 -0
  109. nabu/processing/histogram_cuda.py +103 -0
  110. nabu/processing/kernel_base.py +126 -0
  111. nabu/processing/medfilt_cuda.py +159 -0
  112. nabu/processing/muladd.py +29 -0
  113. nabu/processing/muladd_cuda.py +68 -0
  114. nabu/processing/padding_base.py +71 -0
  115. nabu/processing/padding_cuda.py +75 -0
  116. nabu/processing/padding_opencl.py +77 -0
  117. nabu/processing/processing_base.py +123 -0
  118. nabu/processing/roll_opencl.py +64 -0
  119. nabu/processing/rotation.py +63 -0
  120. nabu/processing/rotation_cuda.py +66 -0
  121. nabu/processing/tests/__init__.py +0 -0
  122. nabu/processing/tests/test_fft.py +268 -0
  123. nabu/processing/tests/test_fftshift.py +71 -0
  124. nabu/{misc → processing}/tests/test_histogram.py +2 -4
  125. nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
  126. nabu/processing/tests/test_muladd.py +54 -0
  127. nabu/{cuda → processing}/tests/test_padding.py +119 -75
  128. nabu/processing/tests/test_roll.py +63 -0
  129. nabu/{misc → processing}/tests/test_rotation.py +3 -2
  130. nabu/processing/tests/test_transpose.py +72 -0
  131. nabu/{misc → processing}/tests/test_unsharp.py +41 -8
  132. nabu/processing/transpose.py +126 -0
  133. nabu/processing/unsharp.py +79 -0
  134. nabu/processing/unsharp_cuda.py +53 -0
  135. nabu/processing/unsharp_opencl.py +75 -0
  136. nabu/reconstruction/fbp.py +34 -10
  137. nabu/reconstruction/fbp_base.py +35 -16
  138. nabu/reconstruction/fbp_opencl.py +7 -12
  139. nabu/reconstruction/filtering.py +2 -2
  140. nabu/reconstruction/filtering_cuda.py +13 -14
  141. nabu/reconstruction/filtering_opencl.py +3 -4
  142. nabu/reconstruction/projection.py +2 -0
  143. nabu/reconstruction/rings.py +158 -1
  144. nabu/reconstruction/rings_cuda.py +218 -58
  145. nabu/reconstruction/sinogram_cuda.py +16 -12
  146. nabu/reconstruction/tests/test_deringer.py +116 -14
  147. nabu/reconstruction/tests/test_fbp.py +22 -31
  148. nabu/reconstruction/tests/test_filtering.py +11 -2
  149. nabu/resources/dataset_analyzer.py +89 -26
  150. nabu/resources/nxflatfield.py +2 -2
  151. nabu/resources/tests/test_nxflatfield.py +1 -1
  152. nabu/resources/utils.py +9 -2
  153. nabu/stitching/alignment.py +184 -0
  154. nabu/stitching/config.py +241 -39
  155. nabu/stitching/definitions.py +6 -0
  156. nabu/stitching/frame_composition.py +4 -2
  157. nabu/stitching/overlap.py +99 -3
  158. nabu/stitching/sample_normalization.py +60 -0
  159. nabu/stitching/slurm_utils.py +10 -10
  160. nabu/stitching/tests/test_alignment.py +99 -0
  161. nabu/stitching/tests/test_config.py +16 -1
  162. nabu/stitching/tests/test_overlap.py +68 -2
  163. nabu/stitching/tests/test_sample_normalization.py +49 -0
  164. nabu/stitching/tests/test_slurm_utils.py +5 -5
  165. nabu/stitching/tests/test_utils.py +3 -33
  166. nabu/stitching/tests/test_z_stitching.py +391 -22
  167. nabu/stitching/utils.py +144 -202
  168. nabu/stitching/z_stitching.py +309 -126
  169. nabu/testutils.py +18 -0
  170. nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
  171. nabu/utils.py +32 -6
  172. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
  173. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
  174. nabu-2024.1.0rc3.dist-info/RECORD +296 -0
  175. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
  176. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
  177. nabu/conftest.py +0 -14
  178. nabu/opencl/fftshift.py +0 -92
  179. nabu/opencl/tests/test_fftshift.py +0 -55
  180. nabu/opencl/tests/test_padding.py +0 -84
  181. nabu-2023.2.1.dist-info/RECORD +0 -252
  182. /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
  183. {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/preproc/ctf.py CHANGED
@@ -1,9 +1,10 @@
1
1
  import math
2
2
  import numpy as np
3
+ from scipy.fft import rfft2, irfft2, fft2, ifft2
3
4
  from ..resources.logger import LoggerOrPrint
4
5
  from ..misc import fourier_filters
5
6
  from ..misc.padding import pad_interpolate, recut
6
- from ..utils import get_num_threads
7
+ from ..utils import get_num_threads, deprecation_warning
7
8
 
8
9
 
9
10
  class GeoPars:
@@ -111,6 +112,7 @@ class CTFPhaseRetrieval:
111
112
  lim2=0.2,
112
113
  use_rfft=False,
113
114
  fftw_num_threads=None,
115
+ fft_num_threads=None,
114
116
  logger=None,
115
117
  ):
116
118
  """
@@ -138,10 +140,11 @@ class CTFPhaseRetrieval:
138
140
  use_rfft: bool, optional
139
141
  Whether to use real-to-complex (R2C) FFT instead of usual complex-to-complex (C2C).
140
142
  fftw_num_threads: bool or None or int, optional
141
- If False is passed: don't use FFTW.
142
- If None is passed: use all available threads.
143
- If a number is provided: number of threads to use for FFTW.
144
- You can pass a negative number to use N - fftw_num_threads cores.
143
+ DEPRECATED - please use fft_num_threads instead.
144
+ fft_num_threads: bool or None or int, optional
145
+ Number of threads to use for FFT.
146
+ If a number is provided: number of threads to use for FFT.
147
+ You can pass a negative number to use N - fft_num_threads cores.
145
148
  logger: optional
146
149
  a logger object
147
150
  """
@@ -152,12 +155,18 @@ class CTFPhaseRetrieval:
152
155
  self._calc_shape(shape, padded_shape, padding_mode)
153
156
  self.delta_beta = delta_beta
154
157
 
158
+ # COMPAT.
159
+ if fftw_num_threads is not None:
160
+ deprecation_warning("'fftw_num_threads' is replaced with 'fft_num_threads'", func_name="ctf_fftw")
161
+ fft_num_threads = fftw_num_threads
162
+ # ---
163
+
155
164
  self.lim = None
156
165
  self.lim1 = lim1
157
166
  self.lim2 = lim2
158
167
  self.normalize_by_mean = normalize_by_mean
159
168
  self.translation_vh = translation_vh
160
- self._setup_fft(use_rfft, fftw_num_threads)
169
+ self._setup_fft(use_rfft, fft_num_threads)
161
170
  self._get_ctf_filter()
162
171
 
163
172
  def _calc_shape(self, shape, padded_shape, padding_mode):
@@ -175,25 +184,11 @@ class CTFPhaseRetrieval:
175
184
  self.shape_padded = tuple(padded_shape)
176
185
  self.padding_mode = padding_mode
177
186
 
178
- def _setup_fft(self, use_rfft, fftw_num_threads):
187
+ def _setup_fft(self, use_rfft, fft_num_threads):
179
188
  self.use_rfft = use_rfft
180
- self._fft_func = np.fft.rfft2 if use_rfft else np.fft.fft2
181
- self._ifft_func = np.fft.irfft2 if use_rfft else np.fft.ifft2
182
- self.use_fftw = False
183
- if fftw_num_threads is False:
184
- return
185
- fftw_num_threads = get_num_threads(fftw_num_threads)
186
- if self.use_rfft and (fftw_num_threads > 0):
187
- # importing silx.math.fft creates opencl contexts all over the place
188
- # because of the silx.opencl.ocl singleton.
189
- # So, import silx as late as possible
190
- from silx.math.fft.fftw import FFTW, __have_fftw__
191
-
192
- if __have_fftw__:
193
- self.use_fftw = True
194
- self.fftw = FFTW(shape=self.shape_padded, dtype="f", num_threads=fftw_num_threads)
195
- self._fft_func = self.fftw.fft
196
- self._ifft_func = self.fftw.ifft
189
+ self._fft_func = rfft2 if use_rfft else fft2
190
+ self._ifft_func = irfft2 if use_rfft else ifft2
191
+ self.fft_num_threads = get_num_threads(fft_num_threads)
197
192
 
198
193
  def _get_ctf_filter(self):
199
194
  """
@@ -320,7 +315,7 @@ class CTFPhaseRetrieval:
320
315
  self._ctf_filter_denom = (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim).astype(np.complex64)
321
316
 
322
317
  def _apply_filter(self, img):
323
- img_f = self._fft_func(img)
318
+ img_f = self._fft_func(img, workers=self.fft_num_threads)
324
319
  img_f *= self.unreg_filter_denom
325
320
 
326
321
  unreg_filter_denom_0_mean = self.unreg_filter_denom[0, 0]
@@ -331,7 +326,7 @@ class CTFPhaseRetrieval:
331
326
 
332
327
  ## formula 8, with regularisation to stay at a safe distance from the poles
333
328
  img_f /= self._ctf_filter_denom
334
- ph = self._ifft_func(img_f).real
329
+ ph = self._ifft_func(img_f, workers=self.fft_num_threads).real
335
330
  return ph
336
331
 
337
332
  def retrieve_phase(self, img, output=None):
nabu/preproc/ctf_cuda.py CHANGED
@@ -1,12 +1,15 @@
1
1
  import numpy as np
2
- from pycuda import gpuarray as garray
3
- from ..utils import calc_padding_lengths, updiv, get_cuda_srcfile
2
+ from ..utils import calc_padding_lengths, updiv, get_cuda_srcfile, docstring
4
3
  from ..cuda.processing import CudaProcessing
5
- from ..cuda.kernel import CudaKernel
6
- from ..cuda.padding import CudaPadding
4
+ from ..cuda.utils import __has_pycuda__
5
+ from ..processing.padding_cuda import CudaPadding
6
+ from ..processing.fft_cuda import get_fft_class
7
7
  from .phase_cuda import CudaPaganinPhaseRetrieval
8
8
  from .ctf import CTFPhaseRetrieval
9
9
 
10
+ if __has_pycuda__:
11
+ from pycuda import gpuarray as garray
12
+
10
13
 
11
14
  # TODO:
12
15
  # - better padding scheme (for now 2*shape)
@@ -17,6 +20,7 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
17
20
  Cuda back-end of CTFPhaseRetrieval
18
21
  """
19
22
 
23
+ @docstring(CTFPhaseRetrieval)
20
24
  def __init__(
21
25
  self,
22
26
  shape,
@@ -29,9 +33,11 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
29
33
  lim1=1.0e-5,
30
34
  lim2=0.2,
31
35
  use_rfft=True,
32
- fftw_num_threads=None,
36
+ fftw_num_threads=None, # COMPAT.
37
+ fft_num_threads=None,
33
38
  logger=None,
34
39
  cuda_options=None,
40
+ fft_backend="skcuda",
35
41
  ):
36
42
  """
37
43
  Initialize a CudaCTFPhaseRetrieval.
@@ -62,30 +68,26 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
62
68
  lim2=lim2,
63
69
  logger=logger,
64
70
  use_rfft=True,
65
- fftw_num_threads=False,
71
+ fft_num_threads=False,
66
72
  )
67
73
  self._init_ctf_filter()
68
74
  self._init_cuda_padding()
69
- self._init_fft()
75
+ self._init_fft(fft_backend)
70
76
  self._init_mult_kernel()
71
77
 
72
78
  def _init_ctf_filter(self):
73
79
  self._mean_scale_factor = self.unreg_filter_denom[0, 0] * np.prod(self.shape_padded)
74
- self._d_filter_num = garray.to_gpu(self.unreg_filter_denom).astype("f")
75
- self._d_filter_denom = garray.to_gpu(
76
- (1.0 / (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim)).astype("f")
80
+ self._d_filter_num = self.cuda_processing.to_device("_d_filter_num", self.unreg_filter_denom).astype("f")
81
+ self._d_filter_denom = self.cuda_processing.to_device(
82
+ "_d_filter_denom", (1.0 / (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim)).astype("f")
77
83
  )
78
84
 
79
85
  def _init_cuda_padding(self):
80
86
  pad_width = calc_padding_lengths(self.shape, self.shape_padded)
81
87
  # Custom coordinate transform to get directly FFT layout
82
- R, C = np.indices(self.shape, dtype=np.int32)
83
- coords_R = np.roll(
84
- np.pad(R, pad_width, mode=self.padding_mode), (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1)
85
- )
86
- coords_C = np.roll(
87
- np.pad(C, pad_width, mode=self.padding_mode), (-pad_width[0][0], -pad_width[1][0]), axis=(0, 1)
88
- )
88
+ R, C = np.indices(self.shape, dtype=np.int32, sparse=True)
89
+ coords_R = np.roll(np.pad(R.ravel(), pad_width[0], mode=self.padding_mode), -pad_width[0][0])
90
+ coords_C = np.roll(np.pad(C.ravel(), pad_width[1], mode=self.padding_mode), -pad_width[1][0])
89
91
  self.cuda_padding = CudaPadding(
90
92
  self.shape,
91
93
  (coords_R, coords_C),
@@ -93,16 +95,14 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
93
95
  # propagate cuda options ?
94
96
  )
95
97
 
96
- def _init_fft(self):
97
- # Import has to be done here, otherwise scikit-cuda creates a cuda/cublas context at import
98
- from silx.math.fft.cufft import CUFFT
99
-
100
- self.cufft = CUFFT(template=np.zeros(self.shape_padded, dtype="f"))
101
- self.d_radio_padded = self.cufft.data_in
102
- self.d_radio_f = self.cufft.data_out
98
+ def _init_fft(self, fft_backend):
99
+ fft_cls = get_fft_class(backend=fft_backend)
100
+ self.cufft = fft_cls(shape=self.shape_padded, dtype=np.float32, r2c=True)
101
+ self.d_radio_padded = self.cuda_processing.allocate_array("d_radio_padded", self.shape_padded, "f")
102
+ self.d_radio_f = self.cuda_processing.allocate_array("d_radio_f", self.cufft.shape_out, np.complex64)
103
103
 
104
104
  def _init_mult_kernel(self):
105
- self.cpxmult_kernel = CudaKernel(
105
+ self.cpxmult_kernel = self.cuda_processing.kernel(
106
106
  "CTF_kernel",
107
107
  filename=get_cuda_srcfile("ElementOp.cu"),
108
108
  signature="PPPfii",
@@ -5,6 +5,7 @@ from silx.io.url import DataUrl
5
5
  from ..utils import check_supported, check_shape, get_2D_3D_shape
6
6
  from ..io.reader import Readers
7
7
  from ..io.writer import Writers
8
+ from .ccd import Log
8
9
 
9
10
 
10
11
  class DoubleFlatField:
@@ -22,6 +23,8 @@ class DoubleFlatField:
22
23
  average_is_on_log=False,
23
24
  sigma_filter=None,
24
25
  filter_mode="reflect",
26
+ log_clip_min=None,
27
+ log_clip_max=None,
25
28
  ):
26
29
  """
27
30
  Init double flat field by summing a series of urls and considering the same subregion of them.
@@ -55,6 +58,8 @@ class DoubleFlatField:
55
58
  self.radios_shape = get_2D_3D_shape(shape)
56
59
  self.n_angles = self.radios_shape[0]
57
60
  self.shape = self.radios_shape[1:]
61
+ self._log_clip_min = log_clip_min
62
+ self._log_clip_max = log_clip_max
58
63
  self._init_filedump(result_url, sub_region, detector_corrector)
59
64
  self._init_processing(input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode)
60
65
  self._computed = False
@@ -112,17 +117,19 @@ class DoubleFlatField:
112
117
  self.sigma_filter = None
113
118
  self.filter_mode = filter_mode
114
119
  proc = lambda x, o: np.copyto(o, x)
120
+ self._mlog = Log((1,) + self.shape, clip_min=self._log_clip_min, clip_max=self._log_clip_max)
121
+
115
122
  if self.input_is_mlog:
116
123
  if not self.average_is_on_log:
117
124
  proc = lambda x, o: np.exp(-x, out=o)
118
125
  else:
119
126
  if self.average_is_on_log:
120
- proc = lambda x, o: -np.log(x, out=o)
127
+ proc = self._proc_mlog
121
128
 
122
129
  postproc = lambda x: x
123
130
  if self.output_is_mlog:
124
131
  if not self.average_is_on_log:
125
- postproc = lambda x: -np.log(x)
132
+ postproc = self._proc_mlog
126
133
  else:
127
134
  if self.average_is_on_log:
128
135
  postproc = lambda x: np.exp(-x)
@@ -130,6 +137,11 @@ class DoubleFlatField:
130
137
  self.proc = proc
131
138
  self.postproc = postproc
132
139
 
140
+ def _proc_mlog(self, x, o):
141
+ o[:] = x[:]
142
+ self._mlog.take_logarithm(o)
143
+ return o
144
+
133
145
  def compute_double_flatfield(self, radios, recompute=False):
134
146
  """
135
147
  Read the radios and generate the "double flat field" by averaging
@@ -2,7 +2,8 @@ from .double_flatfield import DoubleFlatField
2
2
  from ..utils import check_shape
3
3
  from ..cuda.utils import __has_pycuda__
4
4
  from ..cuda.processing import CudaProcessing
5
- from ..misc.unsharp_cuda import CudaUnsharpMask
5
+ from ..processing.unsharp_cuda import CudaUnsharpMask
6
+ from .ccd_cuda import CudaLog
6
7
 
7
8
  if __has_pycuda__:
8
9
  import pycuda.gpuarray as garray
@@ -21,6 +22,8 @@ class CudaDoubleFlatField(DoubleFlatField):
21
22
  average_is_on_log=False,
22
23
  sigma_filter=None,
23
24
  filter_mode="reflect",
25
+ log_clip_min=None,
26
+ log_clip_max=None,
24
27
  cuda_options=None,
25
28
  ):
26
29
  """
@@ -37,6 +40,8 @@ class CudaDoubleFlatField(DoubleFlatField):
37
40
  average_is_on_log=average_is_on_log,
38
41
  sigma_filter=sigma_filter,
39
42
  filter_mode=filter_mode,
43
+ log_clip_min=log_clip_min,
44
+ log_clip_max=log_clip_max,
40
45
  )
41
46
  self._init_gaussian_filter()
42
47
 
@@ -57,16 +62,6 @@ class CudaDoubleFlatField(DoubleFlatField):
57
62
  cumath.exp(o, out=o)
58
63
  return o
59
64
 
60
- @staticmethod
61
- def _proc_mlog(x, o, min_clip=None):
62
- if min_clip is not None:
63
- garray.maximum(x, min_clip, out=o)
64
- cumath.log(o, out=o)
65
- else:
66
- cumath.log(x, out=o)
67
- o *= -1
68
- return o
69
-
70
65
  def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode):
71
66
  self.input_is_mlog = input_is_mlog
72
67
  self.output_is_mlog = output_is_mlog
@@ -77,6 +72,7 @@ class CudaDoubleFlatField(DoubleFlatField):
77
72
  self.filter_mode = filter_mode
78
73
  # proc = lambda x,o: np.copyto(o, x)
79
74
  proc = self._proc_copy
75
+ self._mlog = CudaLog((1,) + self.shape, clip_min=self._log_clip_min, clip_max=self._log_clip_max)
80
76
  if self.input_is_mlog:
81
77
  if not self.average_is_on_log:
82
78
  # proc = lambda x,o: np.exp(-x, out=o)
@@ -1,25 +1,25 @@
1
- from typing import Union
2
1
  import numpy as np
3
- import pycuda.gpuarray as garray
2
+
3
+ from nabu.cuda.processing import CudaProcessing
4
4
  from ..preproc.flatfield import FlatFieldArrays
5
- from ..cuda.kernel import CudaKernel
6
5
  from ..utils import get_cuda_srcfile
7
6
  from ..io.reader import load_images_from_dataurl_dict
7
+ from ..cuda.utils import __has_pycuda__
8
8
 
9
9
 
10
10
  class CudaFlatFieldArrays(FlatFieldArrays):
11
11
  def __init__(
12
12
  self,
13
- radios_shape: tuple,
14
- flats: dict,
15
- darks: dict,
13
+ radios_shape,
14
+ flats,
15
+ darks,
16
16
  radios_indices=None,
17
- interpolation: str = "linear",
17
+ interpolation="linear",
18
18
  distortion_correction=None,
19
19
  nan_value=1.0,
20
20
  radios_srcurrent=None,
21
21
  flats_srcurrent=None,
22
- cuda_options: Union[dict, None] = None,
22
+ cuda_options=None,
23
23
  ):
24
24
  """
25
25
  Initialize a flat-field normalization CUDA process.
@@ -41,16 +41,10 @@ class CudaFlatFieldArrays(FlatFieldArrays):
41
41
  flats_srcurrent=flats_srcurrent,
42
42
  nan_value=nan_value,
43
43
  )
44
- self._set_cuda_options(cuda_options)
44
+ self.cuda_processing = CudaProcessing(**(cuda_options or {}))
45
45
  self._init_cuda_kernels()
46
46
  self._load_flats_and_darks_on_gpu()
47
47
 
48
- def _set_cuda_options(self, user_cuda_options):
49
- self.cuda_options = {"device_id": None, "ctx": None, "cleanup_at_exit": None}
50
- if user_cuda_options is None:
51
- user_cuda_options = {}
52
- self.cuda_options.update(user_cuda_options)
53
-
54
48
  def _init_cuda_kernels(self):
55
49
  # TODO
56
50
  if self.interpolation != "linear":
@@ -63,7 +57,7 @@ class CudaFlatFieldArrays(FlatFieldArrays):
63
57
  ]
64
58
  if self.nan_value is not None:
65
59
  options.append("-DNAN_VALUE=%f" % self.nan_value)
66
- self.cuda_kernel = CudaKernel(
60
+ self.cuda_kernel = self.cuda_processing.kernel(
67
61
  "flatfield_normalization", self._cuda_fname, signature="PPPiiiPP", options=options
68
62
  )
69
63
  self._nx = np.int32(self.shape[1])
@@ -71,17 +65,19 @@ class CudaFlatFieldArrays(FlatFieldArrays):
71
65
 
72
66
  def _load_flats_and_darks_on_gpu(self):
73
67
  # Flats
74
- self.d_flats = garray.zeros((self.n_flats,) + self.shape, np.float32)
68
+ self.d_flats = self.cuda_processing.allocate_array("d_flats", (self.n_flats,) + self.shape, np.float32)
75
69
  for i, flat_idx in enumerate(self._sorted_flat_indices):
76
70
  self.d_flats[i].set(np.ascontiguousarray(self.flats[flat_idx], dtype=np.float32))
77
71
  # Darks
78
- self.d_darks = garray.zeros((self.n_darks,) + self.shape, np.float32)
72
+ self.d_darks = self.cuda_processing.allocate_array("d_darks", (self.n_darks,) + self.shape, np.float32)
79
73
  for i, dark_idx in enumerate(self._sorted_dark_indices):
80
74
  self.d_darks[i].set(np.ascontiguousarray(self.darks[dark_idx], dtype=np.float32))
81
- self.d_darks_indices = garray.to_gpu(np.array(self._sorted_dark_indices, dtype=np.int32))
75
+ self.d_darks_indices = self.cuda_processing.to_device(
76
+ "d_darks_indices", np.array(self._sorted_dark_indices, dtype=np.int32)
77
+ )
82
78
  # Indices
83
- self.d_flats_indices = garray.to_gpu(self.flats_idx)
84
- self.d_flats_weights = garray.to_gpu(self.flats_weights)
79
+ self.d_flats_indices = self.cuda_processing.to_device("d_flats_indices", self.flats_idx)
80
+ self.d_flats_weights = self.cuda_processing.to_device("d_flats_weights", self.flats_weights)
85
81
 
86
82
  def normalize_radios(self, radios):
87
83
  """
@@ -93,7 +89,7 @@ class CudaFlatFieldArrays(FlatFieldArrays):
93
89
  radios_shape: `pycuda.gpuarray.GPUArray`
94
90
  Radios chunk.
95
91
  """
96
- if not (isinstance(radios, garray.GPUArray)):
92
+ if not (isinstance(radios, self.cuda_processing.array_class)):
97
93
  raise ValueError("Expected a pycuda.gpuarray (got %s)" % str(type(radios)))
98
94
  if radios.dtype != np.float32:
99
95
  raise ValueError("radios must be in float32 dtype (got %s)" % str(radios.dtype))
@@ -121,16 +117,16 @@ CudaFlatField = CudaFlatFieldArrays
121
117
  class CudaFlatFieldDataUrls(CudaFlatField):
122
118
  def __init__(
123
119
  self,
124
- radios_shape: tuple,
125
- flats: dict,
126
- darks: dict,
120
+ radios_shape,
121
+ flats,
122
+ darks,
127
123
  radios_indices=None,
128
- interpolation: str = "linear",
124
+ interpolation="linear",
129
125
  distortion_correction=None,
130
126
  nan_value=1.0,
131
127
  radios_srcurrent=None,
132
128
  flats_srcurrent=None,
133
- cuda_options: Union[dict, None] = None,
129
+ cuda_options=None,
134
130
  **chunk_reader_kwargs,
135
131
  ):
136
132
  flats_arrays_dict = load_images_from_dataurl_dict(flats, **chunk_reader_kwargs)
nabu/preproc/phase.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from math import pi
2
2
  from bisect import bisect
3
3
  import numpy as np
4
+ from scipy.fft import rfft2, irfft2, fft2, ifft2
4
5
  from ..utils import generate_powers, get_decay, check_supported, get_num_threads, deprecation_warning
5
6
 
6
7
  #
@@ -53,6 +54,7 @@ class PaganinPhaseRetrieval:
53
54
  use_rfft=True,
54
55
  use_R2C=None,
55
56
  fftw_num_threads=None,
57
+ fft_num_threads=None,
56
58
  ):
57
59
  """
58
60
  Paganin Phase Retrieval for an infinitely distant point source.
@@ -113,9 +115,11 @@ class PaganinPhaseRetrieval:
113
115
  use_R2C: bool, optional
114
116
  DEPRECATED, use use_rfft instead
115
117
  fftw_num_threads: bool or None or int, optional
116
- Whether to use FFTW for speeding up FFT.
118
+ DEPRECATED - please use fft_num_threads
119
+ fft_num_threads: bool or None or int, optional
120
+ Number of threads for FFT.
117
121
  Default is to use all available threads. You can pass a negative number
118
- to use N - fftw_num_threads cores.
122
+ to use N - fft_num_threads cores.
119
123
 
120
124
  Important
121
125
  ----------
@@ -171,8 +175,11 @@ class PaganinPhaseRetrieval:
171
175
  # COMPAT.
172
176
  if use_R2C is not None:
173
177
  deprecation_warning("'use_R2C' is replaced with 'use_rfft'", func_name="pag_r2c")
174
- # -
175
- self._get_fft(use_rfft, fftw_num_threads)
178
+ if fftw_num_threads is not None:
179
+ deprecation_warning("'fftw_num_threads' is replaced with 'fft_num_threads'", func_name="pag_fftw")
180
+ fft_num_threads = fftw_num_threads
181
+ # ---
182
+ self._get_fft(use_rfft, fft_num_threads)
176
183
  self.compute_filter()
177
184
 
178
185
  def _init_parameters(self, distance, energy, pixel_size, delta_beta, padding):
@@ -191,28 +198,16 @@ class PaganinPhaseRetrieval:
191
198
  "reflect": self._pad_reflect,
192
199
  }
193
200
 
194
- def _get_fft(self, use_rfft, fftw_num_threads):
201
+ def _get_fft(self, use_rfft, fft_num_threads):
195
202
  self.use_rfft = use_rfft
196
203
  self.use_R2C = use_rfft # Compat.
197
- fftw_num_threads = get_num_threads(fftw_num_threads)
204
+ self.fft_num_threads = get_num_threads(fft_num_threads)
198
205
  if self.use_rfft:
199
- self.fft_func = np.fft.rfft2
200
- self.ifft_func = np.fft.irfft2
206
+ self.fft_func = rfft2
207
+ self.ifft_func = irfft2
201
208
  else:
202
- self.fft_func = np.fft.fft2
203
- self.ifft_func = np.fft.ifft2
204
- self.use_fftw = False
205
- if self.use_rfft and (fftw_num_threads > 0):
206
- # importing silx.math.fft creates opencl contexts all over the place
207
- # because of the silx.opencl.ocl singleton.
208
- # So, import silx as late as possible
209
- from silx.math.fft.fftw import FFTW, __have_fftw__
210
-
211
- if __have_fftw__:
212
- self.use_fftw = True
213
- self.fftw = FFTW(shape=self.shape_padded, dtype="f", num_threads=fftw_num_threads)
214
- self.fft_func = self.fftw.fft
215
- self.ifft_func = self.fftw.ifft
209
+ self.fft_func = fft2
210
+ self.ifft_func = ifft2
216
211
 
217
212
  def _calc_shape(self, shape, margin):
218
213
  if np.isscalar(shape):
@@ -378,9 +373,9 @@ class PaganinPhaseRetrieval:
378
373
 
379
374
  def apply_filter(self, radio, padding_method=None, output=None):
380
375
  self.pad_data(radio, padding_method=padding_method)
381
- radio_f = self.fft_func(self.data_padded)
376
+ radio_f = self.fft_func(self.data_padded, workers=self.fft_num_threads)
382
377
  radio_f *= self.paganin_filter
383
- radio_filtered = self.ifft_func(radio_f).real
378
+ radio_filtered = self.ifft_func(radio_f, workers=self.fft_num_threads).real
384
379
  s0, s1 = self.shape_inner
385
380
  ((U, _), (L, _)) = self.margin
386
381
  if output is None:
@@ -1,15 +1,15 @@
1
1
  import numpy as np
2
2
  import pycuda.driver as cuda
3
- from pycuda import gpuarray as garray
4
- from ..utils import get_cuda_srcfile, check_supported
5
- from .phase import PaganinPhaseRetrieval
3
+ from ..utils import get_cuda_srcfile, check_supported, docstring
6
4
  from ..cuda.processing import CudaProcessing
7
- from ..cuda.kernel import CudaKernel
5
+ from ..processing.fft_cuda import get_fft_class
6
+ from .phase import PaganinPhaseRetrieval
8
7
 
9
8
 
10
9
  class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
11
10
  supported_paddings = ["zeros", "constant", "edge"]
12
11
 
12
+ @docstring(PaganinPhaseRetrieval)
13
13
  def __init__(
14
14
  self,
15
15
  shape,
@@ -20,7 +20,9 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
20
20
  padding="edge",
21
21
  margin=None,
22
22
  cuda_options=None,
23
- fftw_num_threads=None,
23
+ fftw_num_threads=None, # COMPAT.
24
+ fft_num_threads=None,
25
+ fft_backend="skcuda",
24
26
  ):
25
27
  """
26
28
  Please refer to the documentation of
@@ -37,10 +39,10 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
37
39
  padding=padding,
38
40
  margin=margin,
39
41
  use_rfft=True,
40
- fftw_num_threads=None,
42
+ fft_num_threads=False,
41
43
  )
42
44
  self._init_gpu_arrays()
43
- self._init_fft()
45
+ self._init_fft(fft_backend)
44
46
  self._init_padding_kernel()
45
47
  self._init_mult_kernel()
46
48
 
@@ -51,25 +53,23 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
51
53
  return padding
52
54
 
53
55
  def _init_gpu_arrays(self):
54
- self.d_paganin_filter = garray.to_gpu(np.ascontiguousarray(self.paganin_filter, dtype=np.float32))
56
+ self.d_paganin_filter = self.cuda_processing.to_device(
57
+ "d_paganin_filter", np.ascontiguousarray(self.paganin_filter, dtype=np.float32)
58
+ )
55
59
 
56
60
  # overwrite parent method, don't initialize any FFT plan
57
- def _get_fft(self, use_rfft, fftw_num_threads):
61
+ def _get_fft(self, use_rfft, fft_num_threads):
58
62
  self.use_rfft = use_rfft
59
- self.use_fftw = False
60
-
61
- def _init_fft(self):
62
- # Import has to be done here, otherwise scikit-cuda creates a cuda/cublas context at import
63
- from silx.math.fft.cufft import CUFFT
64
63
 
65
- #
66
- self.cufft = CUFFT(template=self.data_padded.astype("f"))
67
- self.d_radio_padded = self.cufft.data_in
68
- self.d_radio_f = self.cufft.data_out
64
+ def _init_fft(self, fft_backend):
65
+ fft_cls = get_fft_class(backend=fft_backend)
66
+ self.cufft = fft_cls(shape=self.data_padded.shape, dtype=np.float32, r2c=True)
67
+ self.d_radio_padded = self.cuda_processing.allocate_array("d_radio_padded", self.cufft.shape, "f")
68
+ self.d_radio_f = self.cuda_processing.allocate_array("d_radio_f", self.cufft.shape_out, np.complex64)
69
69
 
70
70
  def _init_padding_kernel(self):
71
71
  kern_signature = {"constant": "Piiiiiiiiffff", "edge": "Piiiiiiii"}
72
- self.padding_kernel = CudaKernel(
72
+ self.padding_kernel = self.cuda_processing.kernel(
73
73
  "padding_%s" % self.padding,
74
74
  filename=get_cuda_srcfile("padding.cu"),
75
75
  signature=kern_signature[self.padding],
@@ -92,7 +92,7 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
92
92
  self.padding_kernel_args.extend([0, 0, 0, 0])
93
93
 
94
94
  def _init_mult_kernel(self):
95
- self.cpxmult_kernel = CudaKernel(
95
+ self.cpxmult_kernel = self.cuda_processing.kernel(
96
96
  "inplace_complexreal_mul_2Dby2D",
97
97
  filename=get_cuda_srcfile("ElementOp.cu"),
98
98
  signature="PPii",
@@ -109,7 +109,7 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
109
109
  assert data.dtype == np.float32
110
110
  # Rectangular memcopy
111
111
  # TODO profile, and if needed include this copy in the padding kernel
112
- if isinstance(data, np.ndarray) or isinstance(data, garray.GPUArray):
112
+ if isinstance(data, np.ndarray) or isinstance(data, self.cuda_processing.array_class):
113
113
  self.d_radio_padded[: self.shape[0], : self.shape[1]] = data[:, :]
114
114
  elif isinstance(data, cuda.DeviceAllocation):
115
115
  # TODO manual memcpy2D