nabu 2024.2.13__py3-none-any.whl → 2025.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/bootstrap_stitching.py +4 -2
  4. nabu/app/cast_volume.py +16 -14
  5. nabu/app/cli_configs.py +102 -9
  6. nabu/app/compare_volumes.py +1 -1
  7. nabu/app/composite_cor.py +2 -4
  8. nabu/app/diag_to_pix.py +5 -6
  9. nabu/app/diag_to_rot.py +10 -11
  10. nabu/app/double_flatfield.py +18 -5
  11. nabu/app/estimate_motion.py +75 -0
  12. nabu/app/multicor.py +28 -15
  13. nabu/app/parse_reconstruction_log.py +1 -0
  14. nabu/app/pcaflats.py +122 -0
  15. nabu/app/prepare_weights_double.py +1 -2
  16. nabu/app/reconstruct.py +1 -7
  17. nabu/app/reconstruct_helical.py +5 -9
  18. nabu/app/reduce_dark_flat.py +5 -4
  19. nabu/app/rotate.py +3 -1
  20. nabu/app/stitching.py +7 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +2 -2
  22. nabu/app/validator.py +1 -4
  23. nabu/cuda/convolution.py +1 -1
  24. nabu/cuda/fft.py +1 -1
  25. nabu/cuda/medfilt.py +1 -1
  26. nabu/cuda/padding.py +1 -1
  27. nabu/cuda/src/backproj.cu +6 -6
  28. nabu/cuda/src/cone.cu +4 -0
  29. nabu/cuda/src/hierarchical_backproj.cu +14 -0
  30. nabu/cuda/utils.py +2 -2
  31. nabu/estimation/alignment.py +17 -31
  32. nabu/estimation/cor.py +27 -33
  33. nabu/estimation/cor_sino.py +2 -8
  34. nabu/estimation/focus.py +4 -8
  35. nabu/estimation/motion.py +557 -0
  36. nabu/estimation/tests/test_alignment.py +2 -0
  37. nabu/estimation/tests/test_motion_estimation.py +471 -0
  38. nabu/estimation/tests/test_tilt.py +1 -1
  39. nabu/estimation/tilt.py +6 -5
  40. nabu/estimation/translation.py +47 -1
  41. nabu/io/cast_volume.py +108 -18
  42. nabu/io/detector_distortion.py +5 -6
  43. nabu/io/reader.py +45 -6
  44. nabu/io/reader_helical.py +5 -4
  45. nabu/io/tests/test_cast_volume.py +2 -2
  46. nabu/io/tests/test_readers.py +41 -38
  47. nabu/io/tests/test_remove_volume.py +152 -0
  48. nabu/io/tests/test_writers.py +2 -2
  49. nabu/io/utils.py +8 -4
  50. nabu/io/writer.py +1 -2
  51. nabu/misc/fftshift.py +1 -1
  52. nabu/misc/fourier_filters.py +1 -1
  53. nabu/misc/histogram.py +1 -1
  54. nabu/misc/histogram_cuda.py +1 -1
  55. nabu/misc/padding_base.py +1 -1
  56. nabu/misc/rotation.py +1 -1
  57. nabu/misc/rotation_cuda.py +1 -1
  58. nabu/misc/tests/test_binning.py +1 -1
  59. nabu/misc/transpose.py +1 -1
  60. nabu/misc/unsharp.py +1 -1
  61. nabu/misc/unsharp_cuda.py +1 -1
  62. nabu/misc/unsharp_opencl.py +1 -1
  63. nabu/misc/utils.py +1 -1
  64. nabu/opencl/fft.py +1 -1
  65. nabu/opencl/padding.py +1 -1
  66. nabu/opencl/src/backproj.cl +6 -6
  67. nabu/opencl/utils.py +8 -8
  68. nabu/pipeline/config.py +2 -2
  69. nabu/pipeline/config_validators.py +46 -46
  70. nabu/pipeline/datadump.py +3 -3
  71. nabu/pipeline/estimators.py +271 -11
  72. nabu/pipeline/fullfield/chunked.py +103 -67
  73. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  74. nabu/pipeline/fullfield/computations.py +4 -1
  75. nabu/pipeline/fullfield/dataset_validator.py +0 -1
  76. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  77. nabu/pipeline/fullfield/nabu_config.py +36 -17
  78. nabu/pipeline/fullfield/processconfig.py +41 -7
  79. nabu/pipeline/fullfield/reconstruction.py +14 -10
  80. nabu/pipeline/helical/dataset_validator.py +3 -4
  81. nabu/pipeline/helical/fbp.py +4 -4
  82. nabu/pipeline/helical/filtering.py +5 -4
  83. nabu/pipeline/helical/gridded_accumulator.py +10 -11
  84. nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
  85. nabu/pipeline/helical/helical_reconstruction.py +12 -9
  86. nabu/pipeline/helical/helical_utils.py +1 -2
  87. nabu/pipeline/helical/nabu_config.py +2 -1
  88. nabu/pipeline/helical/span_strategy.py +1 -0
  89. nabu/pipeline/helical/weight_balancer.py +2 -3
  90. nabu/pipeline/params.py +20 -3
  91. nabu/pipeline/tests/__init__.py +0 -0
  92. nabu/pipeline/tests/test_estimators.py +240 -3
  93. nabu/pipeline/utils.py +1 -1
  94. nabu/pipeline/writer.py +1 -1
  95. nabu/preproc/alignment.py +0 -10
  96. nabu/preproc/ccd.py +53 -3
  97. nabu/preproc/ctf.py +8 -8
  98. nabu/preproc/ctf_cuda.py +1 -1
  99. nabu/preproc/double_flatfield_cuda.py +2 -2
  100. nabu/preproc/double_flatfield_variable_region.py +0 -1
  101. nabu/preproc/flatfield.py +307 -2
  102. nabu/preproc/flatfield_cuda.py +1 -2
  103. nabu/preproc/flatfield_variable_region.py +3 -3
  104. nabu/preproc/phase.py +2 -4
  105. nabu/preproc/phase_cuda.py +2 -2
  106. nabu/preproc/shift.py +4 -2
  107. nabu/preproc/shift_cuda.py +0 -1
  108. nabu/preproc/tests/test_ctf.py +4 -4
  109. nabu/preproc/tests/test_double_flatfield.py +1 -1
  110. nabu/preproc/tests/test_flatfield.py +1 -1
  111. nabu/preproc/tests/test_paganin.py +1 -3
  112. nabu/preproc/tests/test_pcaflats.py +154 -0
  113. nabu/preproc/tests/test_vshift.py +4 -1
  114. nabu/processing/azim.py +9 -5
  115. nabu/processing/convolution_cuda.py +6 -4
  116. nabu/processing/fft_base.py +7 -3
  117. nabu/processing/fft_cuda.py +25 -164
  118. nabu/processing/fft_opencl.py +28 -6
  119. nabu/processing/fftshift.py +1 -1
  120. nabu/processing/histogram.py +1 -1
  121. nabu/processing/muladd.py +0 -1
  122. nabu/processing/padding_base.py +1 -1
  123. nabu/processing/padding_cuda.py +0 -2
  124. nabu/processing/processing_base.py +12 -6
  125. nabu/processing/rotation_cuda.py +3 -1
  126. nabu/processing/tests/test_fft.py +2 -64
  127. nabu/processing/tests/test_fftshift.py +1 -1
  128. nabu/processing/tests/test_medfilt.py +1 -3
  129. nabu/processing/tests/test_padding.py +1 -1
  130. nabu/processing/tests/test_roll.py +1 -1
  131. nabu/processing/tests/test_rotation.py +4 -2
  132. nabu/processing/unsharp_opencl.py +1 -1
  133. nabu/reconstruction/astra.py +245 -0
  134. nabu/reconstruction/cone.py +39 -9
  135. nabu/reconstruction/fbp.py +14 -0
  136. nabu/reconstruction/fbp_base.py +40 -8
  137. nabu/reconstruction/fbp_opencl.py +8 -0
  138. nabu/reconstruction/filtering.py +59 -25
  139. nabu/reconstruction/filtering_cuda.py +22 -21
  140. nabu/reconstruction/filtering_opencl.py +10 -14
  141. nabu/reconstruction/hbp.py +26 -13
  142. nabu/reconstruction/mlem.py +55 -16
  143. nabu/reconstruction/projection.py +3 -5
  144. nabu/reconstruction/sinogram.py +1 -1
  145. nabu/reconstruction/sinogram_cuda.py +0 -1
  146. nabu/reconstruction/tests/test_cone.py +37 -2
  147. nabu/reconstruction/tests/test_deringer.py +4 -4
  148. nabu/reconstruction/tests/test_fbp.py +36 -15
  149. nabu/reconstruction/tests/test_filtering.py +27 -7
  150. nabu/reconstruction/tests/test_halftomo.py +28 -2
  151. nabu/reconstruction/tests/test_mlem.py +94 -64
  152. nabu/reconstruction/tests/test_projector.py +7 -2
  153. nabu/reconstruction/tests/test_reconstructor.py +1 -1
  154. nabu/reconstruction/tests/test_sino_normalization.py +0 -1
  155. nabu/resources/dataset_analyzer.py +210 -24
  156. nabu/resources/gpu.py +4 -4
  157. nabu/resources/logger.py +4 -4
  158. nabu/resources/nxflatfield.py +103 -37
  159. nabu/resources/tests/test_dataset_analyzer.py +37 -0
  160. nabu/resources/tests/test_extract.py +11 -0
  161. nabu/resources/tests/test_nxflatfield.py +5 -5
  162. nabu/resources/utils.py +16 -10
  163. nabu/stitching/alignment.py +8 -11
  164. nabu/stitching/config.py +44 -35
  165. nabu/stitching/definitions.py +2 -2
  166. nabu/stitching/frame_composition.py +8 -10
  167. nabu/stitching/overlap.py +4 -4
  168. nabu/stitching/sample_normalization.py +5 -5
  169. nabu/stitching/slurm_utils.py +2 -2
  170. nabu/stitching/stitcher/base.py +2 -0
  171. nabu/stitching/stitcher/dumper/base.py +0 -1
  172. nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
  173. nabu/stitching/stitcher/post_processing.py +11 -9
  174. nabu/stitching/stitcher/pre_processing.py +37 -31
  175. nabu/stitching/stitcher/single_axis.py +2 -3
  176. nabu/stitching/stitcher_2D.py +2 -1
  177. nabu/stitching/tests/test_config.py +10 -11
  178. nabu/stitching/tests/test_sample_normalization.py +1 -1
  179. nabu/stitching/tests/test_slurm_utils.py +1 -2
  180. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  181. nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
  182. nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
  183. nabu/stitching/utils/tests/__init__.py +0 -0
  184. nabu/stitching/utils/tests/test_post-processing.py +1 -0
  185. nabu/stitching/utils/utils.py +16 -18
  186. nabu/tests.py +0 -3
  187. nabu/testutils.py +62 -9
  188. nabu/utils.py +50 -20
  189. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
  190. nabu-2025.1.0.dist-info/RECORD +328 -0
  191. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
  192. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
  193. nabu/app/correct_rot.py +0 -70
  194. nabu/io/tests/test_detector_distortion.py +0 -178
  195. nabu-2024.2.13.dist-info/RECORD +0 -317
  196. /nabu/{stitching → app}/tests/__init__.py +0 -0
  197. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
  198. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,152 @@
1
+ from pathlib import Path
2
+ from os import path, mkdir, rename
3
+ import numpy as np
4
+ import pytest
5
+ from tomoscan.esrf.volume import (
6
+ EDFVolume,
7
+ HDF5Volume,
8
+ JP2KVolume,
9
+ MultiTIFFVolume,
10
+ TIFFVolume,
11
+ )
12
+ from tomoscan.esrf.volume.jp2kvolume import has_glymur
13
+ from nabu.io.writer import merge_hdf5_files
14
+ from nabu.io.cast_volume import remove_volume
15
+
16
+
17
+ def test_remove_single_frame_volume(tmpdir):
18
+ """
19
+ Test volume removal for tiff, jp2 and EDF
20
+ """
21
+ # Have to use a not-too-small size because of jp2k
22
+ data = np.arange(10 * 40 * 50, dtype="f").reshape((10, 40, 50))
23
+
24
+ volume_classes = [EDFVolume, TIFFVolume, JP2KVolume]
25
+ if not (has_glymur):
26
+ volume_classes.pop()
27
+ for volume_cls in volume_classes:
28
+ ext = volume_cls.DEFAULT_DATA_EXTENSION
29
+ folder = path.join(tmpdir, f"{ext}_vol")
30
+ volume_basename = f"{ext}_basename"
31
+
32
+ vol_writer = volume_cls(folder=folder, volume_basename=volume_basename, overwrite=True, start_index=0)
33
+ vol_writer.data = data
34
+ vol_writer.save()
35
+
36
+ vol_reader = volume_cls(folder=folder, volume_basename=volume_basename)
37
+ assert path.isdir(folder), f"Expected to find a folder f{folder}"
38
+ remove_volume(vol_reader)
39
+ assert not (path.isdir(folder)), f"Expected to have removed the folder f{folder}"
40
+
41
+ vol_writer.save()
42
+ vol_reader = volume_cls(folder=folder, volume_basename=volume_basename)
43
+ Path(path.join(folder, f"unexpected.{ext}")).touch()
44
+ with pytest.raises(RuntimeError) as exc:
45
+ remove_volume(vol_reader, check=True)
46
+ assert "Unexpected files present" in str(exc.value), "Expected check to find extraneous files"
47
+
48
+
49
+ def test_remove_multiframe_volume(tmpdir):
50
+ """
51
+ Test volume removal for "multiframe" formats (HDF5, tiff3D)
52
+ The HDF5 files considered in this test do not have virtual sources
53
+ """
54
+ data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
55
+
56
+ for ext, volume_cls in {"h5": HDF5Volume, "tiff": MultiTIFFVolume}.items():
57
+ file_path = path.join(tmpdir, f"{ext}_vol.{ext}")
58
+
59
+ init_kwargs = {"file_path": file_path}
60
+ if ext == "h5":
61
+ init_kwargs["data_path"] = "entry"
62
+ vol_writer = volume_cls(**init_kwargs)
63
+ vol_writer.data = data
64
+ vol_writer.save()
65
+
66
+ vol_reader = volume_cls(**init_kwargs)
67
+ assert path.isfile(file_path), f"Expected to find a {ext} volume at {file_path}"
68
+ remove_volume(vol_reader)
69
+ assert not (path.isfile(file_path)), f"Expected to have removed f{file_path}"
70
+
71
+
72
+ def test_remove_hdf5_multiple_entries(tmpdir):
73
+ data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
74
+ file_path = path.join(tmpdir, "h5_vol.h5")
75
+ vol_writer_1 = HDF5Volume(file_path=file_path, data_path="entry0000")
76
+ vol_writer_1.data = data
77
+ vol_writer_1.save()
78
+ vol_writer_2 = HDF5Volume(file_path=file_path, data_path="entry0001")
79
+ vol_writer_2.data = data + 10
80
+ vol_writer_2.save()
81
+ vol_reader = HDF5Volume(file_path=file_path, data_path="entry0000")
82
+ with pytest.raises(NotImplementedError) as exc:
83
+ remove_volume(vol_reader, check=True)
84
+ assert "Removing a HDF5 volume with more than one entry is not supported" in str(
85
+ exc.value
86
+ ), "Expected an error message"
87
+
88
+
89
+ def test_remove_nabu_hdf5_reconstruction(tmpdir):
90
+ """
91
+ Test removal of HDF5 reconstruction generated by nabu (i.e with virtual sources)
92
+ """
93
+
94
+ entry = "entry"
95
+ process_name = "reconstruction"
96
+
97
+ master_file_path = path.join(tmpdir, "sample_naburec.hdf5")
98
+ associated_dir = path.join(tmpdir, "sample_naburec")
99
+ if not (path.isdir(associated_dir)):
100
+ mkdir(associated_dir)
101
+
102
+ n_chunks = 5
103
+ local_files = []
104
+ for i in range(n_chunks):
105
+ fname = "sample_naburec_%06d.h5" % i
106
+ partial_rec_abspath = path.join(associated_dir, fname)
107
+ local_files.append(f"sample_naburec/{fname}")
108
+ # local_files.append(fname)
109
+ vol = HDF5Volume(partial_rec_abspath, data_path=f"{entry}/{process_name}")
110
+ vol.data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
111
+ vol.save()
112
+
113
+ h5_path = f"{entry}/{process_name}/results/data"
114
+
115
+ merge_hdf5_files(
116
+ local_files,
117
+ h5_path,
118
+ master_file_path,
119
+ process_name,
120
+ output_entry=entry,
121
+ output_filemode="a",
122
+ processing_index=0,
123
+ config=None,
124
+ base_dir=path.dirname(associated_dir),
125
+ axis=0,
126
+ overwrite=True,
127
+ )
128
+
129
+ assert path.isfile(master_file_path), f"Expected to find the master file at {master_file_path}"
130
+ assert path.isdir(associated_dir)
131
+ for local_file in local_files:
132
+ partial_rec_file = path.join(tmpdir, local_file)
133
+ assert path.isfile(partial_rec_file), f"Expected to find partial file number {i} at {partial_rec_file}"
134
+
135
+ # Check that the virtual links are handled properly
136
+ # sample_rec.hdf5 should reference sample_rec/sample_rec_{i}.h5
137
+ renamed_master_file_path = (
138
+ path.join(path.dirname(master_file_path), path.basename(master_file_path).split(".")[0]) + "_renamed" + ".h5"
139
+ )
140
+ rename(master_file_path, renamed_master_file_path)
141
+ h5_vol = HDF5Volume(file_path=renamed_master_file_path, data_path=f"{entry}/{process_name}")
142
+ with pytest.raises(ValueError) as exc:
143
+ remove_volume(h5_vol)
144
+ expected_error_message = f"The virtual sources in {renamed_master_file_path}:{process_name}/results/data reference the directory sample_naburec, but expected was sample_naburec_renamed"
145
+ assert str(exc.value) == expected_error_message
146
+
147
+ # Check removal in normal circumstances
148
+ rename(renamed_master_file_path, master_file_path)
149
+ h5_vol = HDF5Volume(file_path=master_file_path, data_path=f"{entry}/{process_name}")
150
+ remove_volume(h5_vol)
151
+ assert not (path.isfile(master_file_path)), f"Expected to find the master file at {master_file_path}"
152
+ assert not (path.isdir(associated_dir))
@@ -80,7 +80,7 @@ class TestNXWriter:
80
80
  writer.write(self.data, "test_no_overwrite")
81
81
 
82
82
  writer2 = NXProcessWriter(fname, entry="entry0000", overwrite=False)
83
- with pytest.raises((RuntimeError, OSError)) as ex:
83
+ with pytest.raises((RuntimeError, OSError)):
84
84
  writer2.write(self.data, "test_no_overwrite")
85
85
 
86
- message = "Error should have been raised for trying to overwrite, but got the following: %s" % str(ex.value)
86
+ # message = "Error should have been raised for trying to overwrite, but got the following: %s" % str(ex.value)
nabu/io/utils.py CHANGED
@@ -8,6 +8,8 @@ from tomoscan.volumebase import VolumeBase
8
8
  from tomoscan.esrf import EDFVolume, HDF5Volume, TIFFVolume, JP2KVolume, MultiTIFFVolume
9
9
  from tomoscan.io import HDF5File
10
10
 
11
+ from nabu.utils import first_generator_item
12
+
11
13
 
12
14
  # This function might be moved elsewhere
13
15
  def get_compacted_dataslices(urls, subsampling=None, begin=0):
@@ -100,7 +102,7 @@ def get_compacted_dataslices(urls, subsampling=None, begin=0):
100
102
 
101
103
  def get_first_hdf5_entry(fname):
102
104
  with HDF5File(fname, "r") as fid:
103
- entry = list(fid.keys())[0]
105
+ entry = first_generator_item(fid.keys())
104
106
  return entry
105
107
 
106
108
 
@@ -189,7 +191,7 @@ class _BaseReader(contextlib.AbstractContextManager):
189
191
  if url.scheme() not in ("silx", "h5py"):
190
192
  raise ValueError("Valid scheme are silx and h5py")
191
193
  if url.data_slice() is not None:
192
- raise ValueError("Data slices are not managed. Data path should " "point to a bliss node (h5py.Group)")
194
+ raise ValueError("Data slices are not managed. Data path should point to a bliss node (h5py.Group)")
193
195
  self._url = url
194
196
  self._file_handler = None
195
197
 
@@ -207,7 +209,7 @@ class EntryReader(_BaseReader):
207
209
  else:
208
210
  entry = self._file_handler[self._url.data_path()]
209
211
  if not isinstance(entry, h5py.Group):
210
- raise ValueError("Data path should point to a bliss node (h5py.Group)")
212
+ raise TypeError("Data path should point to a bliss node (h5py.Group)")
211
213
  return entry
212
214
 
213
215
 
@@ -218,7 +220,7 @@ class DatasetReader(_BaseReader):
218
220
  self._file_handler = HDF5File(self._url.file_path(), mode="r")
219
221
  entry = self._file_handler[self._url.data_path()]
220
222
  if not isinstance(entry, h5py.Dataset):
221
- raise ValueError("Data path ({}) should point to a dataset (h5py.Dataset)".format(self._url.path()))
223
+ raise TypeError(f"Data path ({self._url.path()}) should point to a dataset (h5py.Dataset)")
222
224
  return entry
223
225
 
224
226
 
@@ -261,3 +263,5 @@ def get_output_volume(location: str, file_prefix: Optional[str], file_format: st
261
263
  return MultiTIFFVolume(file_path=location)
262
264
  else:
263
265
  return TIFFVolume(folder=location, volume_basename=file_prefix)
266
+ else:
267
+ raise ValueError
nabu/io/writer.py CHANGED
@@ -13,7 +13,6 @@ try:
13
13
  except:
14
14
  from h5py import File as HDF5File
15
15
  from tomoscan.esrf import RawVolume
16
- from tomoscan.esrf.volume.jp2kvolume import has_glymur as __have_jp2k__
17
16
  from .. import version as nabu_version
18
17
  from ..utils import merged_shape
19
18
  from .utils import convert_dict_values
@@ -183,7 +182,7 @@ class NXVolVolume(NXProcessWriter):
183
182
  volume_basename = file_prefix = kwargs.get("volume_basename", None)
184
183
  start_index = kwargs.get("start_index", None)
185
184
  overwrite = kwargs.get("overwrite", False)
186
- data_path = entry = kwargs.get("data_path", None)
185
+ entry = kwargs.get("data_path", None)
187
186
  self._process_name = kwargs.get("process_name", "reconstruction")
188
187
  if any([param is None for param in [folder, volume_basename, start_index, entry]]):
189
188
  raise ValueError("Need the following parameters: folder, volume_basename, start_index, data_path")
nabu/misc/fftshift.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.fftshift import *
1
+ from ..processing.fftshift import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
@@ -1,7 +1,7 @@
1
- # -*- coding: utf-8 -*-
2
1
  """
3
2
  Fourier filters.
4
3
  """
4
+
5
5
  from functools import lru_cache
6
6
  import numpy as np
7
7
  import scipy.special as spspe
nabu/misc/histogram.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.histogram import *
1
+ from ..processing.histogram import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
@@ -1,4 +1,4 @@
1
- from ..processing.histogram_cuda import *
1
+ from ..processing.histogram_cuda import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
nabu/misc/padding_base.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.padding_base import *
1
+ from ..processing.padding_base import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
nabu/misc/rotation.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.rotation import *
1
+ from ..processing.rotation import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
@@ -1,4 +1,4 @@
1
- from ..processing.rotation_cuda import *
1
+ from ..processing.rotation_cuda import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
@@ -1,7 +1,7 @@
1
1
  from itertools import product
2
2
  import numpy as np
3
3
  import pytest
4
- from nabu.misc.binning import *
4
+ from nabu.misc.binning import binning
5
5
 
6
6
 
7
7
  @pytest.fixture(scope="class")
nabu/misc/transpose.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.transpose import *
1
+ from ..processing.transpose import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
nabu/misc/unsharp.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.unsharp import *
1
+ from ..processing.unsharp import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning("nabu.misc.unsharp has been moved to nabu.processing.unsharp", do_print=True, func_name="unsharp")
nabu/misc/unsharp_cuda.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.unsharp_cuda import *
1
+ from ..processing.unsharp_cuda import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
@@ -1,4 +1,4 @@
1
- from ..processing.unsharp_opencl import *
1
+ from ..processing.unsharp_opencl import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
nabu/misc/utils.py CHANGED
@@ -41,7 +41,7 @@ def psnr(img1, img2):
41
41
  #
42
42
 
43
43
 
44
- class ConvolutionInfos(object):
44
+ class ConvolutionInfos:
45
45
  allowed_axes = {
46
46
  "1D": [None],
47
47
  "separable_2D_1D_2D": [None, (0, 1), (1, 0)],
nabu/opencl/fft.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.fft_opencl import *
1
+ from ..processing.fft_opencl import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
nabu/opencl/padding.py CHANGED
@@ -1,4 +1,4 @@
1
- from ..processing.padding_opencl import *
1
+ from ..processing.padding_opencl import * # noqa: F403
2
2
  from ..utils import deprecation_warning
3
3
 
4
4
  deprecation_warning(
@@ -14,7 +14,7 @@ static inline int is_in_circle(float x, float y, float center_x, float center_y,
14
14
  This will return arr[y][x] where y is an int (exact access) and x is a float (linear interp horizontally)
15
15
  */
16
16
  static inline float linear_interpolation(global float* arr, int Nx, float x, int y) {
17
- if (x < 0 || x >= Nx) return 0.0f; // texture address mode CLAMP_TO_EDGE
17
+ if (x < -0.5f || x > Nx - 0.5f) return 0.0f; // texture address mode BORDER (CLAMP_TO_EDGE continues with edge)
18
18
  int xm = (int) floor(x);
19
19
  int xp = (int) ceil(x);
20
20
  if ((xm == xp) || (xp >= Nx)) return arr[y*Nx+xm];
@@ -53,7 +53,7 @@ kernel void backproj(
53
53
  uint Gy = get_global_size(1);
54
54
 
55
55
  #ifdef USE_TEXTURES
56
- const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_LINEAR;
56
+ const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_LINEAR;
57
57
  #endif
58
58
 
59
59
  // (xr, yr) (xrp, yr)
@@ -111,10 +111,10 @@ kernel void backproj(
111
111
  #endif
112
112
 
113
113
  #ifdef USE_TEXTURES
114
- if (h1 >= 0 && h1 < num_bins) sum1 += read_imagef(d_sino, sampler, (float2) (h1 +0.5f,proj +0.5f)).x;
115
- if (h2 >= 0 && h2 < num_bins) sum2 += read_imagef(d_sino, sampler, (float2) (h2 +0.5f,proj +0.5f)).x;
116
- if (h3 >= 0 && h3 < num_bins) sum3 += read_imagef(d_sino, sampler, (float2) (h3 +0.5f,proj +0.5f)).x;
117
- if (h4 >= 0 && h4 < num_bins) sum4 += read_imagef(d_sino, sampler, (float2) (h4 +0.5f,proj +0.5f)).x;
114
+ sum1 += read_imagef(d_sino, sampler, (float2) (h1 +0.5f,proj +0.5f)).x;
115
+ sum2 += read_imagef(d_sino, sampler, (float2) (h2 +0.5f,proj +0.5f)).x;
116
+ sum3 += read_imagef(d_sino, sampler, (float2) (h3 +0.5f,proj +0.5f)).x;
117
+ sum4 += read_imagef(d_sino, sampler, (float2) (h4 +0.5f,proj +0.5f)).x;
118
118
  #else
119
119
  if (h1 >= 0 && h1 < num_bins) sum1 += linear_interpolation(d_sino, num_bins, h1, proj);
120
120
  if (h2 >= 0 && h2 < num_bins) sum2 += linear_interpolation(d_sino, num_bins, h2, proj);
nabu/opencl/utils.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import numpy as np
2
- from ..utils import check_supported
2
+ from ..utils import check_supported, first_generator_item
3
3
 
4
4
  try:
5
5
  import pyopencl as cl
@@ -138,11 +138,11 @@ def collect_opencl_gpus():
138
138
  Return a dictionary of platforms and brief description of each OpenCL-compatible
139
139
  GPU with a few fields
140
140
  """
141
- gpus, error_msg = detect_opencl_gpus()
141
+ gpus_detected, error_msg = detect_opencl_gpus()
142
142
  if error_msg is not None:
143
143
  return None
144
144
  opencl_gpus = {}
145
- for platform, gpus in gpus.items():
145
+ for platform, gpus in gpus_detected.items():
146
146
  for gpu_id, gpu in gpus.items():
147
147
  if platform not in opencl_gpus:
148
148
  opencl_gpus[platform] = {}
@@ -215,22 +215,22 @@ def pick_opencl_cpu_platform(opencl_cpus):
215
215
  raise ValueError("No CPU to pick")
216
216
  name2device = {}
217
217
  for platform, devices in opencl_cpus.items():
218
- for device_id, device_desc in devices.items():
218
+ for device_id, device_desc in devices.items(): # noqa: PERF102
219
219
  name2device.setdefault(device_desc["name"], [])
220
220
  name2device[device_desc["name"]].append(platform)
221
221
  if len(name2device) > 1:
222
222
  raise ValueError("Expected at most one CPU but got %d: %s" % (len(name2device), list(name2device.keys())))
223
- cpu_name = list(name2device.keys())[0]
223
+ cpu_name = first_generator_item(name2device.keys())
224
224
  platforms = name2device[cpu_name]
225
225
  # Several platforms for the same CPU
226
226
  res = opencl_cpus[platforms[0]]
227
- if len(platforms) > 1:
227
+ if len(platforms) > 1: # noqa: SIM102
228
228
  if "intel" in cpu_name.lower():
229
229
  for platform in platforms:
230
230
  if "intel" in platform.lower():
231
231
  res = opencl_cpus[platform]
232
232
  #
233
- return res[list(res.keys())[0]]
233
+ return res[first_generator_item(res.keys())]
234
234
 
235
235
 
236
236
  def allocate_texture(ctx, shape, support_1D=False):
@@ -289,4 +289,4 @@ def copy_to_texture(queue, dst_texture, src_array, dtype=np.float32):
289
289
  src_array = np.ascontiguousarray(src_array, dtype=dtype)
290
290
  return cl.enqueue_copy(queue, dst_texture, src_array, origin=(0, 0), region=shape[::-1])
291
291
  else:
292
- raise ValueError("Unknown source array type")
292
+ raise TypeError("Unknown source array type")
nabu/pipeline/config.py CHANGED
@@ -156,7 +156,7 @@ def _extract_nabuconfig_section(section, default_config):
156
156
 
157
157
  def _extract_nabuconfig_keyvals(default_config):
158
158
  res = {}
159
- for section in default_config.keys():
159
+ for section in default_config:
160
160
  res[section] = _extract_nabuconfig_section(section, default_config)
161
161
  return res
162
162
 
@@ -257,7 +257,7 @@ def overwrite_config(conf, overwritten_params):
257
257
  if section not in conf:
258
258
  raise ValueError("Unknown section %s" % section)
259
259
  current_section = conf[section]
260
- for key in params.keys():
260
+ for key in params:
261
261
  if key not in current_section:
262
262
  raise ValueError("Unknown parameter '%s' in section '%s'" % (key, section))
263
263
  conf[section][key] = overwritten_params[section][key]
@@ -1,8 +1,9 @@
1
+ # ruff: noqa: F405
1
2
  import os
2
3
 
3
4
  path = os.path
4
- from ..utils import check_supported, is_writeable
5
- from .params import *
5
+ from ..utils import check_supported, deprecation_warning, is_writeable
6
+ from .params import * # noqa: F403
6
7
 
7
8
  """
8
9
  A validator is a function with
@@ -96,12 +97,21 @@ def convert_to_bool_noerr(val):
96
97
  return res
97
98
 
98
99
 
99
- def name_range_checker(name, valid_names, descr, replacements=None):
100
+ def name_range_checker(name, available_names, descr):
101
+ """
102
+ Check whether a parameter name is valid, against a list or dictionary of names.
103
+ """
100
104
  name = name.strip().lower()
101
- if replacements is not None and name in replacements:
102
- name = replacements[name]
103
- valid = name in valid_names
104
- assert valid, "Invalid %s '%s'. Available are %s" % (descr, name, str(valid_names))
105
+ valid = name in available_names
106
+ if isinstance(available_names, dict):
107
+ # handle replacements, eg. {"edge": "edges"}
108
+ name = available_names[name]
109
+ # we could use .keys() instead to be more permissive to the user
110
+ available_names_str = str(set(available_names.values()))
111
+ else:
112
+ # assuming list
113
+ available_names_str = str(available_names)
114
+ assert valid, "Invalid %s '%s'. Available are %s" % (descr, name, available_names_str)
105
115
  return name
106
116
 
107
117
 
@@ -310,7 +320,7 @@ def optional_nonzero_float_validator(val):
310
320
  assert error is None, "Invalid number"
311
321
  else:
312
322
  val_float = None
313
- if val_float is not None:
323
+ if val_float is not None: # noqa: SIM102
314
324
  if abs(val_float) < 1e-6:
315
325
  val_float = None
316
326
  return val_float
@@ -323,7 +333,7 @@ def optional_tuple_of_floats_validator(val):
323
333
  err_msg = "Expected a tuple of two numbers, but got %s" % val
324
334
  try:
325
335
  res = tuple(float(x) for x in val.strip("()").split(","))
326
- except Exception as exc:
336
+ except Exception:
327
337
  raise ValueError(err_msg)
328
338
  if len(res) != 2:
329
339
  raise ValueError(err_msg)
@@ -337,9 +347,7 @@ def cor_validator(val):
337
347
  return val_float
338
348
  if len(val.strip()) == 0:
339
349
  return None
340
- val = name_range_checker(
341
- val.lower(), set(cor_methods.values()), "center of rotation estimation method", replacements=cor_methods
342
- )
350
+ val = name_range_checker(val, cor_methods, "center of rotation estimation method")
343
351
  return val
344
352
 
345
353
 
@@ -350,9 +358,7 @@ def tilt_validator(val):
350
358
  return val_float
351
359
  if len(val.strip()) == 0:
352
360
  return None
353
- val = name_range_checker(
354
- val.lower(), set(tilt_methods.values()), "automatic detector tilt estimation method", replacements=tilt_methods
355
- )
361
+ val = name_range_checker(val, tilt_methods, "automatic detector tilt estimation method")
356
362
  return val
357
363
 
358
364
 
@@ -394,53 +400,55 @@ def cor_slice_validator(val):
394
400
 
395
401
 
396
402
  @validator
397
- def flatfield_enabled_validator(val):
398
- return name_range_checker(val, set(flatfield_modes.values()), "flatfield mode", replacements=flatfield_modes)
403
+ def flatfield_validator(val):
404
+ ret = name_range_checker(val, flatfield_modes, "flatfield mode")
405
+ if ret in ["force-load", "force-compute"]:
406
+ deprecation_warning(
407
+ f"Using 'flatfield = {ret}' is deprecated since version 2025.1.0. Please use the parameter 'flatfield_loading_mode'",
408
+ )
409
+ return ret
410
+
411
+
412
+ @validator
413
+ def flatfield_loading_mode_validator(val):
414
+ return name_range_checker(val, flatfield_loading_mode, "flatfield mode")
399
415
 
400
416
 
401
417
  @validator
402
418
  def phase_method_validator(val):
403
- return name_range_checker(
404
- val, set(phase_retrieval_methods.values()), "phase retrieval method", replacements=phase_retrieval_methods
405
- )
419
+ return name_range_checker(val, phase_retrieval_methods, "phase retrieval method")
406
420
 
407
421
 
408
422
  @validator
409
423
  def detector_distortion_correction_validator(val):
410
424
  return name_range_checker(
411
425
  val,
412
- set(detector_distortion_correction_methods.values()),
426
+ detector_distortion_correction_methods,
413
427
  "detector_distortion_correction_methods",
414
- replacements=detector_distortion_correction_methods,
415
428
  )
416
429
 
417
430
 
418
431
  @validator
419
432
  def unsharp_method_validator(val):
420
- return name_range_checker(
421
- val, set(unsharp_methods.values()), "unsharp mask method", replacements=phase_retrieval_methods
422
- )
433
+ return name_range_checker(val, unsharp_methods, "unsharp mask method")
423
434
 
424
435
 
425
436
  @validator
426
437
  def padding_mode_validator(val):
427
- return name_range_checker(val, set(padding_modes.values()), "padding mode", replacements=padding_modes)
438
+ return name_range_checker(val, padding_modes, "padding mode")
428
439
 
429
440
 
430
441
  @validator
431
442
  def reconstruction_method_validator(val):
432
- return name_range_checker(
433
- val, set(reconstruction_methods.values()), "reconstruction method", replacements=reconstruction_methods
434
- )
443
+ return name_range_checker(val, reconstruction_methods, "reconstruction method")
435
444
 
436
445
 
437
446
  @validator
438
447
  def fbp_filter_name_validator(val):
439
448
  return name_range_checker(
440
449
  val,
441
- set(fbp_filters.values()),
450
+ fbp_filters,
442
451
  "FBP filter",
443
- replacements=fbp_filters,
444
452
  )
445
453
 
446
454
 
@@ -448,29 +456,24 @@ def fbp_filter_name_validator(val):
448
456
  def reconstruction_implementation_validator(val):
449
457
  return name_range_checker(
450
458
  val,
451
- set(reco_implementations.values()),
459
+ reco_implementations,
452
460
  "Reconstruction method implementation",
453
- replacements=reco_implementations,
454
461
  )
455
462
 
456
463
 
457
464
  @validator
458
465
  def optimization_algorithm_name_validator(val):
459
- return name_range_checker(
460
- val, set(optim_algorithms.values()), "optimization algorithm name", replacements=iterative_methods
461
- )
466
+ return name_range_checker(val, optim_algorithms, "optimization algorithm name")
462
467
 
463
468
 
464
469
  @validator
465
470
  def output_file_format_validator(val):
466
- return name_range_checker(val, set(files_formats.values()), "output file format", replacements=files_formats)
471
+ return name_range_checker(val, files_formats, "output file format")
467
472
 
468
473
 
469
474
  @validator
470
475
  def distribution_method_validator(val):
471
- val = name_range_checker(
472
- val, set(distribution_methods.values()), "workload distribution method", replacements=distribution_methods
473
- )
476
+ val = name_range_checker(val, distribution_methods, "workload distribution method")
474
477
  # TEMP.
475
478
  if val != "local":
476
479
  raise NotImplementedError("Computation method '%s' is not implemented yet" % val)
@@ -480,9 +483,7 @@ def distribution_method_validator(val):
480
483
 
481
484
  @validator
482
485
  def sino_normalization_validator(val):
483
- val = name_range_checker(
484
- val, set(sino_normalizations.values()), "sinogram normalization method", replacements=sino_normalizations
485
- )
486
+ val = name_range_checker(val, sino_normalizations, "sinogram normalization method")
486
487
  return val
487
488
 
488
489
 
@@ -490,9 +491,8 @@ def sino_normalization_validator(val):
490
491
  def sino_deringer_methods(val):
491
492
  val = name_range_checker(
492
493
  val,
493
- set(rings_methods.values()),
494
+ rings_methods,
494
495
  "sinogram rings artefacts correction method",
495
- replacements=rings_methods,
496
496
  )
497
497
  return val
498
498
 
@@ -555,7 +555,7 @@ def nonempty_string_validator(val):
555
555
 
556
556
  @validator
557
557
  def logging_validator(val):
558
- return name_range_checker(val, set(log_levels.values()), "logging level", replacements=log_levels)
558
+ return name_range_checker(val, log_levels, "logging level")
559
559
 
560
560
 
561
561
  @validator