nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/bootstrap.py +2 -3
  3. nabu/app/cast_volume.py +4 -2
  4. nabu/app/cli_configs.py +5 -0
  5. nabu/app/composite_cor.py +1 -1
  6. nabu/app/create_distortion_map_from_poly.py +5 -6
  7. nabu/app/diag_to_pix.py +7 -19
  8. nabu/app/diag_to_rot.py +14 -29
  9. nabu/app/double_flatfield.py +32 -44
  10. nabu/app/parse_reconstruction_log.py +3 -0
  11. nabu/app/reconstruct.py +53 -15
  12. nabu/app/reconstruct_helical.py +2 -2
  13. nabu/app/stitching.py +27 -13
  14. nabu/app/tests/test_reduce_dark_flat.py +4 -1
  15. nabu/cuda/kernel.py +11 -2
  16. nabu/cuda/processing.py +2 -2
  17. nabu/cuda/src/cone.cu +77 -0
  18. nabu/cuda/src/hierarchical_backproj.cu +271 -0
  19. nabu/cuda/utils.py +0 -6
  20. nabu/estimation/alignment.py +5 -19
  21. nabu/estimation/cor.py +173 -599
  22. nabu/estimation/cor_sino.py +356 -26
  23. nabu/estimation/focus.py +63 -11
  24. nabu/estimation/tests/test_cor.py +124 -58
  25. nabu/estimation/tests/test_focus.py +6 -6
  26. nabu/estimation/tilt.py +2 -1
  27. nabu/estimation/utils.py +5 -33
  28. nabu/io/__init__.py +1 -1
  29. nabu/io/cast_volume.py +1 -1
  30. nabu/io/reader.py +416 -21
  31. nabu/io/tests/test_readers.py +422 -0
  32. nabu/io/tests/test_writers.py +1 -102
  33. nabu/io/writer.py +4 -433
  34. nabu/opencl/kernel.py +14 -3
  35. nabu/opencl/processing.py +8 -0
  36. nabu/pipeline/config_validators.py +5 -2
  37. nabu/pipeline/datadump.py +12 -5
  38. nabu/pipeline/estimators.py +162 -188
  39. nabu/pipeline/fullfield/chunked.py +168 -92
  40. nabu/pipeline/fullfield/chunked_cuda.py +7 -3
  41. nabu/pipeline/fullfield/computations.py +2 -7
  42. nabu/pipeline/fullfield/dataset_validator.py +0 -4
  43. nabu/pipeline/fullfield/nabu_config.py +37 -13
  44. nabu/pipeline/fullfield/processconfig.py +22 -13
  45. nabu/pipeline/fullfield/reconstruction.py +13 -9
  46. nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
  47. nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
  48. nabu/pipeline/helical/helical_reconstruction.py +1 -1
  49. nabu/pipeline/params.py +21 -1
  50. nabu/pipeline/processconfig.py +1 -12
  51. nabu/pipeline/reader.py +146 -0
  52. nabu/pipeline/tests/test_estimators.py +44 -72
  53. nabu/pipeline/utils.py +4 -2
  54. nabu/pipeline/writer.py +10 -2
  55. nabu/preproc/ccd_cuda.py +1 -1
  56. nabu/preproc/ctf.py +14 -7
  57. nabu/preproc/ctf_cuda.py +2 -3
  58. nabu/preproc/double_flatfield.py +5 -12
  59. nabu/preproc/double_flatfield_cuda.py +2 -2
  60. nabu/preproc/flatfield.py +5 -1
  61. nabu/preproc/flatfield_cuda.py +5 -1
  62. nabu/preproc/phase.py +24 -73
  63. nabu/preproc/phase_cuda.py +5 -8
  64. nabu/preproc/tests/test_ctf.py +11 -7
  65. nabu/preproc/tests/test_flatfield.py +67 -122
  66. nabu/preproc/tests/test_paganin.py +54 -30
  67. nabu/processing/azim.py +206 -0
  68. nabu/processing/convolution_cuda.py +1 -1
  69. nabu/processing/fft_cuda.py +15 -17
  70. nabu/processing/histogram.py +2 -0
  71. nabu/processing/histogram_cuda.py +2 -1
  72. nabu/processing/kernel_base.py +3 -0
  73. nabu/processing/muladd_cuda.py +1 -0
  74. nabu/processing/padding_opencl.py +1 -1
  75. nabu/processing/roll_opencl.py +1 -0
  76. nabu/processing/rotation_cuda.py +2 -2
  77. nabu/processing/tests/test_fft.py +17 -10
  78. nabu/processing/unsharp_cuda.py +1 -1
  79. nabu/reconstruction/cone.py +104 -40
  80. nabu/reconstruction/fbp.py +3 -0
  81. nabu/reconstruction/fbp_base.py +7 -2
  82. nabu/reconstruction/filtering.py +20 -7
  83. nabu/reconstruction/filtering_cuda.py +7 -1
  84. nabu/reconstruction/hbp.py +424 -0
  85. nabu/reconstruction/mlem.py +99 -0
  86. nabu/reconstruction/reconstructor.py +2 -0
  87. nabu/reconstruction/rings_cuda.py +19 -19
  88. nabu/reconstruction/sinogram_cuda.py +1 -0
  89. nabu/reconstruction/sinogram_opencl.py +3 -1
  90. nabu/reconstruction/tests/test_cone.py +10 -5
  91. nabu/reconstruction/tests/test_deringer.py +7 -6
  92. nabu/reconstruction/tests/test_fbp.py +124 -10
  93. nabu/reconstruction/tests/test_filtering.py +13 -11
  94. nabu/reconstruction/tests/test_halftomo.py +30 -4
  95. nabu/reconstruction/tests/test_mlem.py +91 -0
  96. nabu/reconstruction/tests/test_reconstructor.py +8 -3
  97. nabu/resources/dataset_analyzer.py +142 -92
  98. nabu/resources/gpu.py +1 -0
  99. nabu/resources/nxflatfield.py +134 -125
  100. nabu/resources/templates/id16a_fluo.conf +42 -0
  101. nabu/resources/tests/test_extract.py +10 -0
  102. nabu/resources/tests/test_nxflatfield.py +2 -2
  103. nabu/stitching/alignment.py +80 -24
  104. nabu/stitching/config.py +105 -68
  105. nabu/stitching/definitions.py +1 -0
  106. nabu/stitching/frame_composition.py +68 -60
  107. nabu/stitching/overlap.py +91 -51
  108. nabu/stitching/single_axis_stitching.py +32 -0
  109. nabu/stitching/slurm_utils.py +6 -6
  110. nabu/stitching/stitcher/__init__.py +0 -0
  111. nabu/stitching/stitcher/base.py +124 -0
  112. nabu/stitching/stitcher/dumper/__init__.py +3 -0
  113. nabu/stitching/stitcher/dumper/base.py +94 -0
  114. nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
  115. nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
  116. nabu/stitching/stitcher/post_processing.py +555 -0
  117. nabu/stitching/stitcher/pre_processing.py +1068 -0
  118. nabu/stitching/stitcher/single_axis.py +484 -0
  119. nabu/stitching/stitcher/stitcher.py +0 -0
  120. nabu/stitching/stitcher/y_stitcher.py +13 -0
  121. nabu/stitching/stitcher/z_stitcher.py +45 -0
  122. nabu/stitching/stitcher_2D.py +278 -0
  123. nabu/stitching/tests/test_config.py +12 -37
  124. nabu/stitching/tests/test_frame_composition.py +33 -59
  125. nabu/stitching/tests/test_overlap.py +149 -7
  126. nabu/stitching/tests/test_utils.py +1 -1
  127. nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
  128. nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
  129. nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
  130. nabu/stitching/utils/__init__.py +1 -0
  131. nabu/stitching/utils/post_processing.py +281 -0
  132. nabu/stitching/utils/tests/test_post-processing.py +21 -0
  133. nabu/stitching/{utils.py → utils/utils.py} +79 -52
  134. nabu/stitching/y_stitching.py +27 -0
  135. nabu/stitching/z_stitching.py +32 -2263
  136. nabu/testutils.py +1 -152
  137. nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
  138. nabu/utils.py +158 -61
  139. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
  140. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
  141. nabu/io/tiffwriter_zmm.py +0 -99
  142. nabu/pipeline/fallback_utils.py +0 -149
  143. nabu/pipeline/helical/tests/test_accumulator.py +0 -158
  144. nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
  145. nabu/pipeline/helical/tests/test_strategy.py +0 -61
  146. nabu/pipeline/helical/utils.py +0 -51
  147. nabu/pipeline/tests/test_chunk_reader.py +0 -74
  148. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
  149. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
  150. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
  151. {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
nabu/pipeline/writer.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from os import path
2
2
  from tomoscan.esrf import TIFFVolume, MultiTIFFVolume, EDFVolume, JP2KVolume
3
+ from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase
3
4
  from ..utils import check_supported, get_num_threads
4
5
  from ..resources.logger import LoggerOrPrint
5
6
  from ..io.writer import NXProcessWriter, HSTVolVolume, NXVolVolume
@@ -8,7 +9,6 @@ from .params import files_formats
8
9
 
9
10
 
10
11
  class WriterManager:
11
-
12
12
  """
13
13
  This class is a wrapper on top of all "writers".
14
14
  It will create the right "writer" with all the necessary options, and the histogram writer.
@@ -114,6 +114,7 @@ class WriterManager:
114
114
  return vol_writer.data_url.file_path()
115
115
 
116
116
  def _init_writer(self):
117
+ self._writer_was_already_initialized = self.extra_options.get("writer_initialized", False)
117
118
  if self.file_format in ["tiff", "edf", "jp2", "hdf5"]:
118
119
  writer_kwargs = {
119
120
  "folder": self.output_dir,
@@ -144,6 +145,11 @@ class WriterManager:
144
145
  self._h5_entry = self.metadata.get("entry", "entry")
145
146
  self.writer = self._writer_classes[self.file_format](**writer_kwargs)
146
147
  self.fname = self.get_fname(self.writer)
148
+ # In certain cases, tomoscan needs to remove any previous existing volume filess
149
+ # and avoid calling 'clean_output_data' when writing downstream (for chunk processing)
150
+ if isinstance(self.writer, VolumeSingleFrameBase):
151
+ self.writer.skip_existing_data_files_removal = self._writer_was_already_initialized
152
+ # ---
147
153
  if path.exists(self.fname):
148
154
  err = "File already exists: %s" % self.fname
149
155
  if self.overwrite:
@@ -188,7 +194,9 @@ class WriterManager:
188
194
  self.writer.metadata = self.metadata
189
195
  self.writer.save_metadata()
190
196
 
191
- def write_data(self, data):
197
+ def write_data(self, data, metadata=None):
192
198
  self.writer.data = data
199
+ if metadata is not None:
200
+ self.writer.metadata = metadata
193
201
  self.writer.save()
194
202
  # self._write_metadata()
nabu/preproc/ccd_cuda.py CHANGED
@@ -118,7 +118,7 @@ class CudaLog(Log):
118
118
  self._nthreadsperblock = (16, 16, 4) # TODO tune ?
119
119
  self._nblocks = tuple([updiv(n, p) for n, p in zip([nx, ny, nz], self._nthreadsperblock)])
120
120
 
121
- self.nlog_kernel = CudaKernel(
121
+ self.nlog_kernel = CudaKernel( # pylint: disable=E0606
122
122
  "nlog",
123
123
  filename=self._nlog_srcfile,
124
124
  signature="Piiiff",
nabu/preproc/ctf.py CHANGED
@@ -18,7 +18,7 @@ class GeoPars:
18
18
  self,
19
19
  z1_vh=None,
20
20
  z2=None,
21
- pix_size_det=None,
21
+ pix_size_det=1e-6,
22
22
  wavelength=None,
23
23
  magnification=True,
24
24
  length_scale=10.0e-6,
@@ -33,8 +33,9 @@ class GeoPars:
33
33
  and the horizontaly focused source (vertical line) for KB mirrors.
34
34
  z2 : float
35
35
  the sample detector distance (meters).
36
- pix_size_det: float
37
- pixel size (meters)
36
+ pix_size_det: float or tuple
37
+ pixel size in meters.
38
+ If a tuple is passed, it is interpreted as (horizontal_size, vertical_size)
38
39
  wavelength: float
39
40
  beam wave length (meters).
40
41
  magnification: boolean defaults to True
@@ -55,7 +56,11 @@ class GeoPars:
55
56
  self.z1_vh = np.array([z1_vh, z1_vh])
56
57
  self.z2 = z2
57
58
  self.magnification = magnification
58
- self.pix_size_det = pix_size_det
59
+ if np.isscalar(pix_size_det):
60
+ self.pix_size_det_xy = (pix_size_det, pix_size_det)
61
+ else:
62
+ self.pix_size_det_xy = pix_size_det
63
+ self.pix_size_det = self.pix_size_det_xy[0] # COMPAT
59
64
 
60
65
  if self.magnification and self.z1_vh is not None:
61
66
  self.M_vh = (self.z1_vh + self.z2) / self.z1_vh
@@ -69,7 +74,9 @@ class GeoPars:
69
74
 
70
75
  self.maxM = self.M_vh.max()
71
76
 
72
- self.pix_size_rec = self.pix_size_det / self.maxM # we bring everything to highest magnification
77
+ # we bring everything to highest magnification
78
+ self.pix_size_rec_xy = [p / self.maxM for p in self.pix_size_det_xy]
79
+ self.pix_size_rec = self.pix_size_rec_xy[0] # COMPAT
73
80
 
74
81
  which_unit = int(np.sum(np.array([self.pix_size_rec > small for small in [1.0e-6, 1.0e-7]]).astype(np.int32)))
75
82
  self.pixelsize_string = [
@@ -208,8 +215,8 @@ class CTFPhaseRetrieval:
208
215
  padded_img_shape = self.shape_padded
209
216
  fsample_vh = np.array(
210
217
  [
211
- self.geo_pars.length_scale / self.geo_pars.pix_size_rec,
212
- self.geo_pars.length_scale / self.geo_pars.pix_size_rec,
218
+ self.geo_pars.length_scale / self.geo_pars.pix_size_rec_xy[1],
219
+ self.geo_pars.length_scale / self.geo_pars.pix_size_rec_xy[0],
213
220
  ]
214
221
  )
215
222
 
nabu/preproc/ctf_cuda.py CHANGED
@@ -15,7 +15,6 @@ if __has_pycuda__:
15
15
  # - better padding scheme (for now 2*shape)
16
16
  # - rework inheritance scheme ? (base class SingleDistancePhaseRetrieval and its cuda counterpart)
17
17
  class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
18
-
19
18
  """
20
19
  Cuda back-end of CTFPhaseRetrieval
21
20
  """
@@ -37,7 +36,7 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
37
36
  fft_num_threads=None,
38
37
  logger=None,
39
38
  cuda_options=None,
40
- fft_backend="skcuda",
39
+ fft_backend="vkfft",
41
40
  ):
42
41
  """
43
42
  Initialize a CudaCTFPhaseRetrieval.
@@ -130,7 +129,7 @@ class CudaCTFPhaseRetrieval(CTFPhaseRetrieval):
130
129
  self.set_input(image)
131
130
  self.cuda_padding.pad(image, output=self.d_radio_padded)
132
131
  if self.normalize_by_mean:
133
- m = garray.sum(self.d_radio_padded).get() / np.prod(self.shape_padded)
132
+ m = garray.sum(self.d_radio_padded).get() / np.prod(self.shape_padded) # pylint: disable=E0606
134
133
  self.d_radio_padded /= m
135
134
  self.cufft.fft(self.d_radio_padded, output=self.d_radio_f)
136
135
  self.cpxmult_kernel(*self._cpxmult_kernel_args, **self._cpxmult_kernel_kwargs)
@@ -2,9 +2,9 @@ from os import path
2
2
  import numpy as np
3
3
  from scipy.ndimage import gaussian_filter
4
4
  from silx.io.url import DataUrl
5
- from ..utils import check_supported, check_shape, get_2D_3D_shape
6
- from ..io.reader import Readers
7
- from ..io.writer import Writers
5
+ from ..utils import check_shape, get_2D_3D_shape
6
+ from ..io.reader import HDF5Reader
7
+ from ..io.writer import NXProcessWriter
8
8
  from .ccd import Log
9
9
 
10
10
 
@@ -64,12 +64,6 @@ class DoubleFlatField:
64
64
  self._init_processing(input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode)
65
65
  self._computed = False
66
66
 
67
- def _get_reader_writer_class(self):
68
- ext = path.splitext(self.result_url.file_path())[-1].replace(".", "")
69
- check_supported(ext, list(Writers.keys()), "file format")
70
- self._writer_cls = Writers[ext]
71
- self._reader_cls = Readers[ext]
72
-
73
67
  def _load_dff_dump(self):
74
68
  res = self.reader.get_data(self.result_url)
75
69
  if self.detector_corrector is not None:
@@ -98,15 +92,14 @@ class DoubleFlatField:
98
92
  self.reader = None
99
93
  if self.result_url is None:
100
94
  return
101
- self._get_reader_writer_class()
102
95
  if path.exists(result_url.file_path()):
103
96
  if detector_corrector is None:
104
97
  adapted_subregion = sub_region
105
98
  else:
106
99
  adapted_subregion = self.detector_corrector.get_adapted_subregion(sub_region)
107
- self.reader = self._reader_cls(sub_region=adapted_subregion)
100
+ self.reader = HDF5Reader(sub_region=adapted_subregion)
108
101
  else:
109
- self.writer = self._writer_cls(self.result_url.file_path())
102
+ self.writer = NXProcessWriter(self.result_url.file_path())
110
103
 
111
104
  def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode):
112
105
  self.input_is_mlog = input_is_mlog
@@ -59,7 +59,7 @@ class CudaDoubleFlatField(DoubleFlatField):
59
59
  def _proc_expm(x, o):
60
60
  o[:] = x[:]
61
61
  o[:] *= -1
62
- cumath.exp(o, out=o)
62
+ cumath.exp(o, out=o) # pylint: disable=E0606
63
63
  return o
64
64
 
65
65
  def _init_processing(self, input_is_mlog, output_is_mlog, average_is_on_log, sigma_filter, filter_mode):
@@ -108,7 +108,7 @@ class CudaDoubleFlatField(DoubleFlatField):
108
108
  recompute: bool, optional
109
109
  Whether to recompute the double flatfield if already computed.
110
110
  """
111
- if not (isinstance(radios, garray.GPUArray)):
111
+ if not (isinstance(radios, garray.GPUArray)): # pylint: disable=E0606
112
112
  raise ValueError("Expected pycuda.gpuarray.GPUArray for radios")
113
113
  if self._computed and not (recompute):
114
114
  return self.doubleflatfield
nabu/preproc/flatfield.py CHANGED
@@ -2,7 +2,7 @@ from multiprocessing.pool import ThreadPool
2
2
  from bisect import bisect_left
3
3
  import numpy as np
4
4
  from ..io.reader import load_images_from_dataurl_dict
5
- from ..utils import check_supported, get_num_threads
5
+ from ..utils import check_supported, deprecated_class, get_num_threads
6
6
 
7
7
 
8
8
  class FlatFieldArrays:
@@ -228,6 +228,7 @@ class FlatFieldArrays:
228
228
  f_idx, weights = _interp_nearest(idx, prev_next)
229
229
  elif self.interpolation == "linear":
230
230
  f_idx, weights = _interp_linear(idx, prev_next)
231
+ # pylint: disable=E0606
231
232
  self.flats_idx[i] = f_idx
232
233
  self.flats_weights[i] = weights
233
234
 
@@ -376,6 +377,9 @@ class FlatFieldArrays:
376
377
  FlatField = FlatFieldArrays
377
378
 
378
379
 
380
+ @deprecated_class(
381
+ "FlatFieldDataUrls is deprecated since 2024.2.0 and will be removed in a future version", do_print=True
382
+ )
379
383
  class FlatFieldDataUrls(FlatField):
380
384
  def __init__(
381
385
  self,
@@ -2,7 +2,7 @@ import numpy as np
2
2
 
3
3
  from nabu.cuda.processing import CudaProcessing
4
4
  from ..preproc.flatfield import FlatFieldArrays
5
- from ..utils import get_cuda_srcfile
5
+ from ..utils import deprecated_class, get_cuda_srcfile
6
6
  from ..io.reader import load_images_from_dataurl_dict
7
7
  from ..cuda.utils import __has_pycuda__
8
8
 
@@ -114,6 +114,9 @@ class CudaFlatFieldArrays(FlatFieldArrays):
114
114
  CudaFlatField = CudaFlatFieldArrays
115
115
 
116
116
 
117
+ @deprecated_class(
118
+ "CudaFlatFieldDataUrls is deprecated since version 2024.2.0 and will be removed in a future version", do_print=True
119
+ )
117
120
  class CudaFlatFieldDataUrls(CudaFlatField):
118
121
  def __init__(
119
122
  self,
@@ -138,6 +141,7 @@ class CudaFlatFieldDataUrls(CudaFlatField):
138
141
  radios_indices=radios_indices,
139
142
  interpolation=interpolation,
140
143
  distortion_correction=distortion_correction,
144
+ nan_value=nan_value,
141
145
  radios_srcurrent=radios_srcurrent,
142
146
  flats_srcurrent=flats_srcurrent,
143
147
  cuda_options=cuda_options,
nabu/preproc/phase.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
4
  from scipy.fft import rfft2, irfft2, fft2, ifft2
5
5
  from ..utils import generate_powers, get_decay, check_supported, get_num_threads, deprecation_warning
6
6
 
7
- #
7
+ # COMPAT.
8
8
  from .ctf import CTFPhaseRetrieval
9
9
 
10
10
  #
@@ -50,7 +50,6 @@ class PaganinPhaseRetrieval:
50
50
  delta_beta=250.0,
51
51
  pixel_size=1e-6,
52
52
  padding="edge",
53
- margin=None,
54
53
  use_rfft=True,
55
54
  use_R2C=None,
56
55
  fftw_num_threads=None,
@@ -73,42 +72,13 @@ class PaganinPhaseRetrieval:
73
72
  delta_beta: float, optional
74
73
  delta/beta ratio, where n = (1 - delta) + i*beta is the complex
75
74
  refractive index of the sample.
76
- pixel_size : float, optional
75
+ pixel_size : float or tuple, optional
77
76
  Detector pixel size in meters. Default is 1e-6 (one micron)
77
+ If a tuple is passed, the pixel size is set as (horizontal_size, vertical_size).
78
78
  padding : str, optional
79
79
  Padding method. Available are "zeros", "mean", "edge", "sym",
80
80
  "reflect". Default is "edge".
81
81
  Please refer to the "Padding" section below for more details.
82
- margin: tuple, optional
83
- The user may provide integers values U, D, L, R as a tuple under the
84
- form ((U, D), (L, R)) (same syntax as numpy.pad()).
85
- The resulting filtered radio will have a size equal to
86
- (size_vertic - U - D, size_horiz - L - R).
87
- These values serve to create a "margin" for the filtering process,
88
- where U, D, L R are the margin of the Up, Down, Left and Right part,
89
- respectively.
90
- The filtering is done on a subset of the input radio. The subset
91
- size is (Nrows - U - D, Ncols - R - L).
92
- The margins is used to do the padding for the rest of the padded
93
- array.
94
-
95
- For example in one dimension, where ``padding="edge"``::
96
-
97
- <------------------------------ padded_size --------------------------->
98
- [padding=edge | padding=data | radio data | padding=data | padding=edge]
99
- <------ N2 ---><----- L -----><- (N-L-R)--><----- R -----><----- N2 --->
100
-
101
- Some or all the values U, D, L, R can be 0. In this case,
102
- the padding of the parts related to the zero values will
103
- fall back to the one of "padding" parameter.
104
- For example, if padding="edge" and L, R are 0, then
105
- the left and right parts will be padded with the edges, while
106
- the Up and Down parts will be padded using the the user-provided
107
- margins of the radio, and the final data will have shape
108
- (Nrows - U - D, Ncols).
109
- Some or all the values U, D, L, R can be the string "auto".
110
- In this case, the values of U, D, L, R are automatically computed
111
- as a function of the Paganin filter width.
112
82
  use_rfft: bool, optional
113
83
  Whether to use Real-to-Complex (R2C) transform instead of
114
84
  standard Complex-to-Complex transform, providing better performances
@@ -171,7 +141,7 @@ class PaganinPhaseRetrieval:
171
141
  Journal of Microscopy, Vol 206, Part 1, 2002
172
142
  """
173
143
  self._init_parameters(distance, energy, pixel_size, delta_beta, padding)
174
- self._calc_shape(shape, margin)
144
+ self._calc_shape(shape)
175
145
  # COMPAT.
176
146
  if use_R2C is not None:
177
147
  deprecation_warning("'use_R2C' is replaced with 'use_rfft'", func_name="pag_r2c")
@@ -186,7 +156,13 @@ class PaganinPhaseRetrieval:
186
156
  self.distance_cm = distance * 1e2
187
157
  self.distance_micron = distance * 1e6
188
158
  self.energy_kev = energy
189
- self.pixel_size_micron = pixel_size * 1e6
159
+ if np.isscalar(pixel_size):
160
+ self.pixel_size_xy_micron = (pixel_size * 1e6, pixel_size * 1e6)
161
+ else:
162
+ self.pixel_size_xy_micron = pixel_size * 1e6
163
+ # COMPAT.
164
+ self.pixel_size_micron = self.pixel_size_xy_micron[0]
165
+ #
190
166
  self.delta_beta = delta_beta
191
167
  self.wavelength_micron = 1.23984199e-3 / self.energy_kev
192
168
  self.padding = padding
@@ -209,34 +185,14 @@ class PaganinPhaseRetrieval:
209
185
  self.fft_func = fft2
210
186
  self.ifft_func = ifft2
211
187
 
212
- def _calc_shape(self, shape, margin):
188
+ def _calc_shape(self, shape):
213
189
  if np.isscalar(shape):
214
190
  shape = (shape, shape)
215
191
  else:
216
192
  assert len(shape) == 2
217
193
  self.shape = shape
218
- self._set_margin_value(margin)
219
194
  self._calc_padded_shape()
220
195
 
221
- def _set_margin_value(self, margin):
222
- self.margin = margin
223
- if margin is None:
224
- self.shape_inner = self.shape
225
- self.use_margin = False
226
- self.margin = ((0, 0), (0, 0))
227
- return
228
- self.use_margin = True
229
- try:
230
- ((U, D), (L, R)) = margin
231
- except ValueError:
232
- raise ValueError("Expected margin in the format ((U, D), (L, R))")
233
- for val in [U, D, L, R]:
234
- if isinstance(val, str) and val != "auto":
235
- raise ValueError("Expected either an integer, or 'auto'")
236
- if int(val) != val or val < 0:
237
- raise ValueError("Expected positive integers for margin values")
238
- self.shape_inner = (self.shape[0] - U - D, self.shape[1] - L - R)
239
-
240
196
  def _calc_padded_shape(self):
241
197
  """
242
198
  Compute the padded shape.
@@ -257,19 +213,15 @@ class PaganinPhaseRetrieval:
257
213
  nx0 : length of original data
258
214
  nx_p : total length of padded data
259
215
  """
260
- n_y, n_x = self.shape_inner
261
- n_y0, n_x0 = self.shape
262
- n_y_p = self._get_next_power(max(2 * n_y, n_y0))
263
- n_x_p = self._get_next_power(max(2 * n_x, n_x0))
216
+ n_y, n_x = self.shape
217
+ n_y_p = self._get_next_power(2 * n_y)
218
+ n_x_p = self._get_next_power(2 * n_x)
264
219
  self.shape_padded = (n_y_p, n_x_p)
265
220
  self.data_padded = np.zeros((n_y_p, n_x_p), dtype=np.float64)
266
-
267
- ((U, D), (L, R)) = self.margin
268
- n_y0, n_x0 = self.shape
269
- self.pad_top_len = (n_y_p - n_y0) // 2
270
- self.pad_bottom_len = n_y_p - n_y0 - self.pad_top_len
271
- self.pad_left_len = (n_x_p - n_x0) // 2
272
- self.pad_right_len = n_x_p - n_x0 - self.pad_left_len
221
+ self.pad_top_len = (n_y_p - n_y) // 2
222
+ self.pad_bottom_len = n_y_p - n_y - self.pad_top_len
223
+ self.pad_left_len = (n_x_p - n_x) // 2
224
+ self.pad_right_len = n_x_p - n_x - self.pad_left_len
273
225
 
274
226
  def _get_next_power(self, n):
275
227
  """
@@ -284,8 +236,8 @@ class PaganinPhaseRetrieval:
284
236
  def compute_filter(self):
285
237
  nyp, nxp = self.shape_padded
286
238
  fftfreq = np.fft.rfftfreq if self.use_rfft else np.fft.fftfreq
287
- fy = np.fft.fftfreq(nyp, d=self.pixel_size_micron)
288
- fx = fftfreq(nxp, d=self.pixel_size_micron)
239
+ fy = np.fft.fftfreq(nyp, d=self.pixel_size_xy_micron[1])
240
+ fx = fftfreq(nxp, d=self.pixel_size_xy_micron[0])
289
241
  self._coords_grid = np.add.outer(fy**2, fx**2)
290
242
  #
291
243
  k2 = self._coords_grid
@@ -376,12 +328,11 @@ class PaganinPhaseRetrieval:
376
328
  radio_f = self.fft_func(self.data_padded, workers=self.fft_num_threads)
377
329
  radio_f *= self.paganin_filter
378
330
  radio_filtered = self.ifft_func(radio_f, workers=self.fft_num_threads).real
379
- s0, s1 = self.shape_inner
380
- ((U, _), (L, _)) = self.margin
331
+ s0, s1 = self.shape
381
332
  if output is None:
382
- return radio_filtered[U : U + s0, L : L + s1]
333
+ return radio_filtered[:s0, :s1]
383
334
  else:
384
- output[:, :] = radio_filtered[U : U + s0, L : L + s1]
335
+ output[:, :] = radio_filtered[:s0, :s1]
385
336
  return output
386
337
 
387
338
  def lmicron_to_db(self, Lmicron):
@@ -18,11 +18,10 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
18
18
  delta_beta=250.0,
19
19
  pixel_size=1e-6,
20
20
  padding="edge",
21
- margin=None,
22
21
  cuda_options=None,
23
22
  fftw_num_threads=None, # COMPAT.
24
23
  fft_num_threads=None,
25
- fft_backend="skcuda",
24
+ fft_backend="vkfft",
26
25
  ):
27
26
  """
28
27
  Please refer to the documentation of
@@ -37,7 +36,6 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
37
36
  delta_beta=delta_beta,
38
37
  pixel_size=pixel_size,
39
38
  padding=padding,
40
- margin=margin,
41
39
  use_rfft=True,
42
40
  fft_num_threads=False,
43
41
  )
@@ -118,14 +116,13 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval):
118
116
  raise ValueError("Expected either numpy array, pycuda array or pycuda buffer")
119
117
 
120
118
  def get_output(self, output):
121
- s0, s1 = self.shape_inner
122
- ((U, _), (L, _)) = self.margin
119
+ s0, s1 = self.shape
123
120
  if output is None:
124
121
  # copy D2H
125
- return self.d_radio_padded[U : U + s0, L : L + s1].get()
126
- assert output.shape == self.shape_inner
122
+ return self.d_radio_padded[:s0, :s1].get()
123
+ assert output.shape == self.shape
127
124
  assert output.dtype == np.float32
128
- output[:, :] = self.d_radio_padded[U : U + s0, L : L + s1]
125
+ output[:, :] = self.d_radio_padded[:s0, :s1]
129
126
  return output
130
127
 
131
128
  def apply_filter(self, radio, output=None):
@@ -1,6 +1,7 @@
1
1
  import pytest
2
2
  import numpy as np
3
3
  import scipy.interpolate
4
+ from nabu.processing.fft_cuda import get_available_fft_implems
4
5
  from nabu.testutils import get_data as nabu_get_data
5
6
  from nabu.testutils import __do_long_tests__
6
7
  from nabu.preproc.flatfield import FlatFieldArrays
@@ -9,11 +10,14 @@ from nabu.preproc import ctf
9
10
  from nabu.estimation.distortion import estimate_flat_distortion
10
11
  from nabu.misc.filters import correct_spikes
11
12
  from nabu.preproc.distortion import DistortionCorrection
12
- from nabu.cuda.utils import __has_pycuda__, __has_cufft__, get_cuda_context
13
+ from nabu.cuda.utils import __has_pycuda__, get_cuda_context
13
14
 
14
- if __has_pycuda__ and __has_cufft__:
15
+ __has_cufft__ = False
16
+ if __has_pycuda__:
15
17
  from nabu.preproc.ctf_cuda import CudaCTFPhaseRetrieval
16
- import pycuda.gpuarray as garray
18
+
19
+ avail_fft = get_available_fft_implems()
20
+ __has_cufft__ = len(avail_fft) > 0
17
21
 
18
22
 
19
23
  @pytest.fixture(scope="class")
@@ -39,7 +43,7 @@ def bootstrap_TestCtf(request):
39
43
  cls.padded_img_shape_vh = test_data["padded_img_shape_vh"]
40
44
  cls.z1_vh = test_data["z1_vh"]
41
45
  cls.z2 = test_data["z2"]
42
- cls.pix_size_det = test_data["pix_size_det"]
46
+ cls.pix_size_det = test_data["pix_size_det"][()]
43
47
  cls.length_scale = test_data["length_scale"]
44
48
  cls.wavelength = test_data["wave_length"]
45
49
  cls.remove_spikes_threshold = test_data["remove_spikes_threshold"]
@@ -174,7 +178,7 @@ class TestCtf:
174
178
  phase = ctf_filter.retrieve_phase(img)
175
179
 
176
180
  message = "retrieved phase and reference result differ beyond the accepted tolerance"
177
- assert np.abs(phase - self.expected_result).max() < self.abs_tol * (
181
+ assert np.abs(phase - self.expected_result).max() < 10 * self.abs_tol * (
178
182
  np.abs(self.expected_result).mean()
179
183
  ), message
180
184
 
@@ -219,7 +223,7 @@ class TestCtf:
219
223
  phase_fft = ctf_fft.retrieve_phase(img)
220
224
  self.check_result(phase_r2c, self.ref_plain, "Something wrong with CtfFilter-FFT")
221
225
 
222
- @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and scikit-cuda")
226
+ @pytest.mark.skipif(not (__has_pycuda__ and __has_cufft__), reason="pycuda and (scikit-cuda or vkfft)")
223
227
  def test_cuda_ctf(self):
224
228
  data = nabu_get_data("brain_phantom.npz")["data"]
225
229
  delta_beta = 50.0
@@ -243,7 +247,7 @@ class TestCtf:
243
247
  )
244
248
  ref = ctf_filter.retrieve_phase(data)
245
249
 
246
- d_data = garray.to_gpu(data)
250
+ d_data = cuda_ctf_filter.cuda_processing.to_device("_d_data", data)
247
251
  res = cuda_ctf_filter.retrieve_phase(d_data).get()
248
252
  err_max = np.max(np.abs(res - ref))
249
253