nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/bootstrap_stitching.py +4 -2
  4. nabu/app/cast_volume.py +16 -14
  5. nabu/app/cli_configs.py +102 -9
  6. nabu/app/compare_volumes.py +1 -1
  7. nabu/app/composite_cor.py +2 -4
  8. nabu/app/diag_to_pix.py +5 -6
  9. nabu/app/diag_to_rot.py +10 -11
  10. nabu/app/double_flatfield.py +18 -5
  11. nabu/app/estimate_motion.py +75 -0
  12. nabu/app/multicor.py +28 -15
  13. nabu/app/parse_reconstruction_log.py +1 -0
  14. nabu/app/pcaflats.py +122 -0
  15. nabu/app/prepare_weights_double.py +1 -2
  16. nabu/app/reconstruct.py +1 -7
  17. nabu/app/reconstruct_helical.py +5 -9
  18. nabu/app/reduce_dark_flat.py +5 -4
  19. nabu/app/rotate.py +3 -1
  20. nabu/app/stitching.py +7 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +2 -2
  22. nabu/app/validator.py +1 -4
  23. nabu/cuda/convolution.py +1 -1
  24. nabu/cuda/fft.py +1 -1
  25. nabu/cuda/medfilt.py +1 -1
  26. nabu/cuda/padding.py +1 -1
  27. nabu/cuda/src/backproj.cu +6 -6
  28. nabu/cuda/src/cone.cu +4 -0
  29. nabu/cuda/src/hierarchical_backproj.cu +14 -0
  30. nabu/cuda/utils.py +2 -2
  31. nabu/estimation/alignment.py +17 -31
  32. nabu/estimation/cor.py +27 -33
  33. nabu/estimation/cor_sino.py +2 -8
  34. nabu/estimation/focus.py +4 -8
  35. nabu/estimation/motion.py +557 -0
  36. nabu/estimation/tests/test_alignment.py +2 -0
  37. nabu/estimation/tests/test_motion_estimation.py +471 -0
  38. nabu/estimation/tests/test_tilt.py +1 -1
  39. nabu/estimation/tilt.py +6 -5
  40. nabu/estimation/translation.py +47 -1
  41. nabu/io/cast_volume.py +108 -18
  42. nabu/io/detector_distortion.py +5 -6
  43. nabu/io/reader.py +45 -6
  44. nabu/io/reader_helical.py +5 -4
  45. nabu/io/tests/test_cast_volume.py +2 -2
  46. nabu/io/tests/test_readers.py +41 -38
  47. nabu/io/tests/test_remove_volume.py +152 -0
  48. nabu/io/tests/test_writers.py +2 -2
  49. nabu/io/utils.py +8 -4
  50. nabu/io/writer.py +1 -2
  51. nabu/misc/fftshift.py +1 -1
  52. nabu/misc/fourier_filters.py +1 -1
  53. nabu/misc/histogram.py +1 -1
  54. nabu/misc/histogram_cuda.py +1 -1
  55. nabu/misc/padding_base.py +1 -1
  56. nabu/misc/rotation.py +1 -1
  57. nabu/misc/rotation_cuda.py +1 -1
  58. nabu/misc/tests/test_binning.py +1 -1
  59. nabu/misc/transpose.py +1 -1
  60. nabu/misc/unsharp.py +1 -1
  61. nabu/misc/unsharp_cuda.py +1 -1
  62. nabu/misc/unsharp_opencl.py +1 -1
  63. nabu/misc/utils.py +1 -1
  64. nabu/opencl/fft.py +1 -1
  65. nabu/opencl/padding.py +1 -1
  66. nabu/opencl/src/backproj.cl +6 -6
  67. nabu/opencl/utils.py +8 -8
  68. nabu/pipeline/config.py +2 -2
  69. nabu/pipeline/config_validators.py +46 -46
  70. nabu/pipeline/datadump.py +3 -3
  71. nabu/pipeline/estimators.py +271 -11
  72. nabu/pipeline/fullfield/chunked.py +103 -67
  73. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  74. nabu/pipeline/fullfield/computations.py +4 -1
  75. nabu/pipeline/fullfield/dataset_validator.py +0 -1
  76. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  77. nabu/pipeline/fullfield/nabu_config.py +36 -17
  78. nabu/pipeline/fullfield/processconfig.py +41 -7
  79. nabu/pipeline/fullfield/reconstruction.py +14 -10
  80. nabu/pipeline/helical/dataset_validator.py +3 -4
  81. nabu/pipeline/helical/fbp.py +4 -4
  82. nabu/pipeline/helical/filtering.py +5 -4
  83. nabu/pipeline/helical/gridded_accumulator.py +10 -11
  84. nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
  85. nabu/pipeline/helical/helical_reconstruction.py +12 -9
  86. nabu/pipeline/helical/helical_utils.py +1 -2
  87. nabu/pipeline/helical/nabu_config.py +2 -1
  88. nabu/pipeline/helical/span_strategy.py +1 -0
  89. nabu/pipeline/helical/weight_balancer.py +2 -3
  90. nabu/pipeline/params.py +20 -3
  91. nabu/pipeline/tests/__init__.py +0 -0
  92. nabu/pipeline/tests/test_estimators.py +240 -3
  93. nabu/pipeline/utils.py +1 -1
  94. nabu/pipeline/writer.py +1 -1
  95. nabu/preproc/alignment.py +0 -10
  96. nabu/preproc/ccd.py +53 -3
  97. nabu/preproc/ctf.py +8 -8
  98. nabu/preproc/ctf_cuda.py +1 -1
  99. nabu/preproc/double_flatfield_cuda.py +2 -2
  100. nabu/preproc/double_flatfield_variable_region.py +0 -1
  101. nabu/preproc/flatfield.py +307 -2
  102. nabu/preproc/flatfield_cuda.py +1 -2
  103. nabu/preproc/flatfield_variable_region.py +3 -3
  104. nabu/preproc/phase.py +2 -4
  105. nabu/preproc/phase_cuda.py +2 -2
  106. nabu/preproc/shift.py +4 -2
  107. nabu/preproc/shift_cuda.py +0 -1
  108. nabu/preproc/tests/test_ctf.py +4 -4
  109. nabu/preproc/tests/test_double_flatfield.py +1 -1
  110. nabu/preproc/tests/test_flatfield.py +1 -1
  111. nabu/preproc/tests/test_paganin.py +1 -3
  112. nabu/preproc/tests/test_pcaflats.py +154 -0
  113. nabu/preproc/tests/test_vshift.py +4 -1
  114. nabu/processing/azim.py +9 -5
  115. nabu/processing/convolution_cuda.py +6 -4
  116. nabu/processing/fft_base.py +7 -3
  117. nabu/processing/fft_cuda.py +25 -164
  118. nabu/processing/fft_opencl.py +28 -6
  119. nabu/processing/fftshift.py +1 -1
  120. nabu/processing/histogram.py +1 -1
  121. nabu/processing/muladd.py +0 -1
  122. nabu/processing/padding_base.py +1 -1
  123. nabu/processing/padding_cuda.py +0 -2
  124. nabu/processing/processing_base.py +12 -6
  125. nabu/processing/rotation_cuda.py +3 -1
  126. nabu/processing/tests/test_fft.py +2 -64
  127. nabu/processing/tests/test_fftshift.py +1 -1
  128. nabu/processing/tests/test_medfilt.py +1 -3
  129. nabu/processing/tests/test_padding.py +1 -1
  130. nabu/processing/tests/test_roll.py +1 -1
  131. nabu/processing/tests/test_rotation.py +4 -2
  132. nabu/processing/unsharp_opencl.py +1 -1
  133. nabu/reconstruction/astra.py +245 -0
  134. nabu/reconstruction/cone.py +39 -9
  135. nabu/reconstruction/fbp.py +7 -0
  136. nabu/reconstruction/fbp_base.py +36 -5
  137. nabu/reconstruction/filtering.py +59 -25
  138. nabu/reconstruction/filtering_cuda.py +22 -21
  139. nabu/reconstruction/filtering_opencl.py +10 -14
  140. nabu/reconstruction/hbp.py +26 -13
  141. nabu/reconstruction/mlem.py +55 -16
  142. nabu/reconstruction/projection.py +3 -5
  143. nabu/reconstruction/sinogram.py +1 -1
  144. nabu/reconstruction/sinogram_cuda.py +0 -1
  145. nabu/reconstruction/tests/test_cone.py +37 -2
  146. nabu/reconstruction/tests/test_deringer.py +4 -4
  147. nabu/reconstruction/tests/test_fbp.py +36 -15
  148. nabu/reconstruction/tests/test_filtering.py +27 -7
  149. nabu/reconstruction/tests/test_halftomo.py +28 -2
  150. nabu/reconstruction/tests/test_mlem.py +94 -64
  151. nabu/reconstruction/tests/test_projector.py +7 -2
  152. nabu/reconstruction/tests/test_reconstructor.py +1 -1
  153. nabu/reconstruction/tests/test_sino_normalization.py +0 -1
  154. nabu/resources/dataset_analyzer.py +210 -24
  155. nabu/resources/gpu.py +4 -4
  156. nabu/resources/logger.py +4 -4
  157. nabu/resources/nxflatfield.py +103 -37
  158. nabu/resources/tests/test_dataset_analyzer.py +37 -0
  159. nabu/resources/tests/test_extract.py +11 -0
  160. nabu/resources/tests/test_nxflatfield.py +5 -5
  161. nabu/resources/utils.py +16 -10
  162. nabu/stitching/alignment.py +8 -11
  163. nabu/stitching/config.py +44 -35
  164. nabu/stitching/definitions.py +2 -2
  165. nabu/stitching/frame_composition.py +8 -10
  166. nabu/stitching/overlap.py +4 -4
  167. nabu/stitching/sample_normalization.py +5 -5
  168. nabu/stitching/slurm_utils.py +2 -2
  169. nabu/stitching/stitcher/base.py +2 -0
  170. nabu/stitching/stitcher/dumper/base.py +0 -1
  171. nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
  172. nabu/stitching/stitcher/post_processing.py +11 -9
  173. nabu/stitching/stitcher/pre_processing.py +37 -31
  174. nabu/stitching/stitcher/single_axis.py +2 -3
  175. nabu/stitching/stitcher_2D.py +2 -1
  176. nabu/stitching/tests/test_config.py +10 -11
  177. nabu/stitching/tests/test_sample_normalization.py +1 -1
  178. nabu/stitching/tests/test_slurm_utils.py +1 -2
  179. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  180. nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
  181. nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
  182. nabu/stitching/utils/tests/__init__.py +0 -0
  183. nabu/stitching/utils/tests/test_post-processing.py +1 -0
  184. nabu/stitching/utils/utils.py +16 -18
  185. nabu/tests.py +0 -3
  186. nabu/testutils.py +62 -9
  187. nabu/utils.py +50 -20
  188. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
  189. nabu-2025.1.0.dist-info/RECORD +328 -0
  190. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
  191. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
  192. nabu/app/correct_rot.py +0 -70
  193. nabu/io/tests/test_detector_distortion.py +0 -178
  194. nabu-2024.2.14.dist-info/RECORD +0 -317
  195. /nabu/{stitching → app}/tests/__init__.py +0 -0
  196. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
  197. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
nabu/resources/utils.py CHANGED
@@ -1,8 +1,9 @@
1
1
  from ast import literal_eval
2
2
  import numpy as np
3
+ import pint
3
4
  from psutil import virtual_memory, cpu_count
4
- from pyunitsystem.metricsystem import MetricSystem
5
- from pyunitsystem.energysystem import EnergySI
5
+
6
+ _ureg = pint.get_application_registry()
6
7
 
7
8
 
8
9
  def get_values_from_file(fname, n_values=None, shape=None, sep=None, any_size=False):
@@ -163,12 +164,17 @@ def get_quantities_and_units(string, sep=";"):
163
164
  value, unit = value_and_unit.split()
164
165
  val = float(value)
165
166
  # Convert to SI
166
- try:
167
- # handle metrics
168
- conversion_factor = MetricSystem.from_str(unit).value
169
- except ValueError:
170
- # handle energies
171
- conversion_factor = EnergySI.from_str(unit).value / EnergySI.KILOELECTRONVOLT.value
172
-
173
- result[quantity_name] = val * conversion_factor
167
+ if unit.lower() == "kev":
168
+ current_unit = _ureg.keV
169
+ elif unit.lower() == "ev":
170
+ current_unit = _ureg.eV
171
+ else:
172
+ current_unit = _ureg(unit)
173
+ # handle energies (to move to keV)
174
+ if _ureg.keV.dimensionality == current_unit.dimensionality:
175
+ result[quantity_name] = (val * current_unit).to(_ureg.keV).magnitude
176
+ elif _ureg.meter.dimensionality == current_unit.dimensionality:
177
+ result[quantity_name] = (val * current_unit).to_base_units().magnitude
178
+ else:
179
+ raise ValueError(f"Cannot convert: {unit}")
174
180
  return result
@@ -1,13 +1,10 @@
1
+ from enum import Enum
1
2
  import h5py
2
3
  import numpy
3
4
  from typing import Union
4
- from silx.utils.enum import Enum as _Enum
5
- from tomoscan.volumebase import VolumeBase
6
- from tomoscan.esrf.volume.hdf5volume import HDF5Volume
7
- from nabu.io.utils import DatasetReader
8
5
 
9
6
 
10
- class AlignmentAxis2(_Enum):
7
+ class AlignmentAxis2(Enum):
11
8
  """Specific alignment named to help users orienting themself with specific name"""
12
9
 
13
10
  CENTER = "center"
@@ -15,7 +12,7 @@ class AlignmentAxis2(_Enum):
15
12
  RIGTH = "right"
16
13
 
17
14
 
18
- class AlignmentAxis1(_Enum):
15
+ class AlignmentAxis1(Enum):
19
16
  """Specific alignment named to help users orienting themself with specific name"""
20
17
 
21
18
  FRONT = "front"
@@ -23,7 +20,7 @@ class AlignmentAxis1(_Enum):
23
20
  BACK = "back"
24
21
 
25
22
 
26
- class _Alignment(_Enum):
23
+ class _Alignment(Enum):
27
24
  """Internal alignment to be used for 2D alignment"""
28
25
 
29
26
  LOWER_BOUNDARY = "lower boundary"
@@ -32,7 +29,7 @@ class _Alignment(_Enum):
32
29
 
33
30
  @classmethod
34
31
  def from_value(cls, value):
35
- # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition
32
+ # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition.
36
33
  if value in ("front", "left", AlignmentAxis1.FRONT, AlignmentAxis2.LEFT):
37
34
  return _Alignment.LOWER_BOUNDARY
38
35
  elif value in ("back", "right", AlignmentAxis1.BACK, AlignmentAxis2.RIGTH):
@@ -40,7 +37,7 @@ class _Alignment(_Enum):
40
37
  elif value in (AlignmentAxis1.CENTER, AlignmentAxis2.CENTER):
41
38
  return _Alignment.CENTER
42
39
  else:
43
- return super().from_value(value)
40
+ return super().__new__(cls, value)
44
41
 
45
42
 
46
43
  def align_frame(
@@ -106,7 +103,7 @@ def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width
106
103
  :param HAlignment alignment: alignment strategy
107
104
  :param int new_width: output data width
108
105
  """
109
- alignment = AlignmentAxis2.from_value(alignment).value
106
+ alignment = AlignmentAxis2(alignment).value
110
107
  return align_frame(
111
108
  data=data, alignment=alignment, new_aligned_axis_size=new_width, pad_mode=pad_mode, alignment_axis=1
112
109
  )
@@ -151,7 +148,7 @@ class PaddedRawData:
151
148
  @property
152
149
  def shape(self):
153
150
  if self._shape is None:
154
- self._shape = tuple(
151
+ self._shape = tuple( # noqa: C409
155
152
  (
156
153
  self._raw_data_shape[0],
157
154
  numpy.sum(
nabu/stitching/config.py CHANGED
@@ -1,8 +1,9 @@
1
+ import pint
1
2
  from math import ceil
2
- from typing import Any, Iterable, Optional, Union, Sized
3
+ from typing import Optional, Union
4
+ from collections.abc import Iterable, Sized
3
5
  from dataclasses import dataclass
4
6
  import numpy
5
- from pyunitsystem.metricsystem import MetricSystem
6
7
  from nxtomo.paths import nxtomo
7
8
  from tomoscan.factory import Factory
8
9
  from tomoscan.identifier import VolumeIdentifier, ScanIdentifier
@@ -12,13 +13,14 @@ from ..pipeline.config_validators import (
12
13
  convert_to_bool,
13
14
  )
14
15
  from ..utils import concatenate_dict, convert_str_to_tuple
15
- from ..io.utils import get_output_volume
16
16
  from .overlap import OverlapStitchingStrategy
17
17
  from .utils.utils import ShiftAlgorithm
18
18
  from .definitions import StitchingType
19
19
  from .alignment import AlignmentAxis1, AlignmentAxis2
20
- from pyunitsystem.metricsystem import MetricSystem
21
20
 
21
+ _ureg = pint.get_application_registry()
22
+
23
+ # ruff: noqa: S105
22
24
 
23
25
  KEY_IMG_REG_METHOD = "img_reg_method"
24
26
 
@@ -128,6 +130,8 @@ SLURM_MODULES_TO_LOADS = "modules"
128
130
 
129
131
  SLURM_CLEAN_SCRIPTS = "clean_scripts"
130
132
 
133
+ SLURM_JOB_NAME = "job_name"
134
+
131
135
  # normalization by sample
132
136
 
133
137
  NORMALIZATION_BY_SAMPLE_SECTION = "normalization_by_sample"
@@ -205,7 +209,7 @@ def _str_to_dict(my_str: Union[str, dict]):
205
209
 
206
210
 
207
211
  def _dict_to_str(ddict: dict):
208
- return ";".join([f"{str(key)}={str(value)}" for key, value in ddict.items()])
212
+ return ";".join([f"{key!s}={value!s}" for key, value in ddict.items()])
209
213
 
210
214
 
211
215
  def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]:
@@ -218,7 +222,7 @@ def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]:
218
222
  if my_str == "":
219
223
  return None
220
224
  try:
221
- shift = ShiftAlgorithm.from_value(my_str)
225
+ shift = ShiftAlgorithm(my_str)
222
226
  except ValueError:
223
227
  shifts_as_str = filter(None, my_str.replace(";", ",").split(","))
224
228
  return [float(shift) for shift in shifts_as_str]
@@ -235,8 +239,8 @@ def _valid_stitching_kernels_params(my_dict: Union[dict, str]):
235
239
  my_dict = _str_to_dict(my_str=my_dict)
236
240
 
237
241
  valid_keys = (KEY_THRESHOLD_FREQUENCY, KEY_SIDE)
238
- for key in my_dict.keys():
239
- if not key in valid_keys:
242
+ for key in my_dict:
243
+ if key not in valid_keys:
240
244
  raise KeyError(f"{key} is a unrecognized key")
241
245
  return my_dict
242
246
 
@@ -253,8 +257,8 @@ def _valid_shifts_params(my_dict: Union[dict, str]):
253
257
  KEY_LOW_PASS_FILTER,
254
258
  KEY_SIDE,
255
259
  )
256
- for key in my_dict.keys():
257
- if not key in valid_keys:
260
+ for key in my_dict:
261
+ if key not in valid_keys:
258
262
  raise KeyError(f"{key} is a unrecognized key")
259
263
  return my_dict
260
264
 
@@ -334,7 +338,7 @@ class NormalizationBySample:
334
338
 
335
339
  @method.setter
336
340
  def method(self, method: Union[Method, str]) -> None:
337
- self._method = Method.from_value(method)
341
+ self._method = Method(method)
338
342
 
339
343
  @property
340
344
  def margin(self) -> int:
@@ -351,7 +355,7 @@ class NormalizationBySample:
351
355
 
352
356
  @side.setter
353
357
  def side(self, side: Union[SampleSide, str]):
354
- self._side = SampleSide.from_value(side)
358
+ self._side = SampleSide(side)
355
359
 
356
360
  @property
357
361
  def width(self) -> int:
@@ -401,16 +405,16 @@ class NormalizationBySample:
401
405
  NORMALIZATION_BY_SAMPLE_WIDTH: self.width,
402
406
  }
403
407
 
404
- def __eq__(self, __value: object) -> bool:
405
- if not isinstance(__value, NormalizationBySample):
408
+ def __eq__(self, value: object, /) -> bool:
409
+ if not isinstance(value, NormalizationBySample):
406
410
  return False
407
411
  else:
408
- return self.to_dict() == __value.to_dict()
412
+ return self.to_dict() == value.to_dict()
409
413
 
410
414
 
411
415
  @dataclass
412
416
  class SlurmConfig:
413
- "configuration for slurm jobs"
417
+ """configuration for slurm jobs"""
414
418
 
415
419
  partition: str = "" # note: must stay empty to make by default we don't use slurm (use by the configuration file)
416
420
  mem: str = "128"
@@ -421,6 +425,7 @@ class SlurmConfig:
421
425
  clean_script: bool = ""
422
426
  n_tasks: int = 1
423
427
  n_cpu_per_task: int = 4
428
+ job_name: str = ""
424
429
 
425
430
  def __post_init__(self) -> None:
426
431
  # make sure either 'modules' or 'preprocessing_command' is provided
@@ -430,7 +435,7 @@ class SlurmConfig:
430
435
  )
431
436
 
432
437
  def to_dict(self) -> dict:
433
- "dump configuration to dict"
438
+ """dump configuration to dict"""
434
439
  return {
435
440
  SLURM_PARTITION: self.partition if self.partition is not None else "",
436
441
  SLURM_MEM: self.mem,
@@ -441,6 +446,7 @@ class SlurmConfig:
441
446
  SLURM_CLEAN_SCRIPTS: self.clean_script,
442
447
  SLURM_NUMBER_OF_TASKS: self.n_tasks,
443
448
  SLURM_COR_PER_TASKS: self.n_cpu_per_task,
449
+ SLURM_JOB_NAME: self.job_name,
444
450
  }
445
451
 
446
452
  @staticmethod
@@ -457,18 +463,21 @@ class SlurmConfig:
457
463
  preprocessing_command=config.get(SLURM_PREPROCESSING_COMMAND, ""),
458
464
  modules_to_load=convert_str_to_tuple(config.get(SLURM_MODULES_TO_LOADS, "")),
459
465
  clean_script=convert_to_bool(config.get(SLURM_CLEAN_SCRIPTS, False))[0],
466
+ job_name=config.get(SLURM_JOB_NAME, ""),
460
467
  )
461
468
 
462
469
 
463
- def _cast_shift_to_str(shifts: Union[tuple, str, None]) -> str:
470
+ def _cast_shift_to_str(shifts: Union[tuple, numpy.ndarray, str, None]) -> str:
464
471
  if shifts is None:
465
472
  return ""
466
473
  elif isinstance(shifts, ShiftAlgorithm):
467
474
  return shifts.value
468
475
  elif isinstance(shifts, str):
469
476
  return shifts
470
- elif isinstance(shifts, (tuple, list)):
477
+ elif isinstance(shifts, (tuple, list, numpy.ndarray)):
471
478
  return ";".join([str(value) for value in shifts])
479
+ else:
480
+ raise TypeError(f"unexpected type: {type(shifts)}")
472
481
 
473
482
 
474
483
  @dataclass
@@ -541,12 +550,12 @@ class StitchingConfiguration:
541
550
  STITCHING_SECTION: {
542
551
  STITCHING_TYPE_FIELD: {
543
552
  "default": StitchingType.Z_PREPROC.value,
544
- "help": f"stitching to be applied. Must be in {StitchingType.values()}",
553
+ "help": f"stitching to be applied. Must be in {[st.value for st in StitchingType]}",
545
554
  "type": "required",
546
555
  },
547
556
  STITCHING_STRATEGY_FIELD: {
548
557
  "default": "cosinus weights",
549
- "help": f"Policy to apply to compute the overlap area. Must be in {OverlapStitchingStrategy.values()}.",
558
+ "help": f"Policy to apply to compute the overlap area. Must be in {[ov.value for ov in OverlapStitchingStrategy]}.",
550
559
  "type": "required",
551
560
  },
552
561
  CROSS_CORRELATION_SLICE_FIELD: {
@@ -626,7 +635,7 @@ class StitchingConfiguration:
626
635
  },
627
636
  ALIGNMENT_AXIS_2_FIELD: {
628
637
  "default": "center",
629
- "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {AlignmentAxis2.values()}",
638
+ "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {[aa.value for aa in AlignmentAxis2]}",
630
639
  "type": "advanced",
631
640
  },
632
641
  PAD_MODE_FIELD: {
@@ -748,7 +757,7 @@ class StitchingConfiguration:
748
757
  AXIS_2_POS_PX: _cast_shift_to_str(self.axis_2_pos_px),
749
758
  AXIS_2_POS_MM: _cast_shift_to_str(self.axis_2_pos_mm),
750
759
  AXIS_2_PARAMS: _dict_to_str(self.axis_2_params or {}),
751
- STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy.from_value(self.stitching_strategy).value,
760
+ STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy(self.stitching_strategy).value,
752
761
  FLIP_UD: self.flip_ud,
753
762
  FLIP_LR: self.flip_lr,
754
763
  RESCALE_FRAMES: self.rescale_frames,
@@ -927,7 +936,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
927
936
  if self.pixel_size is None:
928
937
  pixel_size_mm = ""
929
938
  else:
930
- pixel_size_mm = self.pixel_size * MetricSystem.MILLIMETER.value
939
+ pixel_size_mm = (self.pixel_size * _ureg.meter).to(_ureg.millimeter).magnitude
931
940
  return concatenate_dict(
932
941
  super().to_dict(),
933
942
  {
@@ -991,10 +1000,10 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
991
1000
  if pixel_size == "":
992
1001
  pixel_size = None
993
1002
  else:
994
- pixel_size = float(pixel_size) / MetricSystem.MM
1003
+ pixel_size = (float(pixel_size) * _ureg.millimeter).to_base_units().magnitude
995
1004
 
996
1005
  return cls(
997
- stitching_strategy=OverlapStitchingStrategy.from_value(
1006
+ stitching_strategy=OverlapStitchingStrategy(
998
1007
  config[STITCHING_SECTION].get(
999
1008
  STITCHING_STRATEGY_FIELD,
1000
1009
  OverlapStitchingStrategy.COSINUS_WEIGHTS,
@@ -1035,7 +1044,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat
1035
1044
  config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
1036
1045
  )
1037
1046
  ),
1038
- alignment_axis_2=AlignmentAxis2.from_value(
1047
+ alignment_axis_2=AlignmentAxis2(
1039
1048
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1040
1049
  ),
1041
1050
  pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
@@ -1156,11 +1165,11 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1156
1165
  if voxel_size == "":
1157
1166
  voxel_size = None
1158
1167
  else:
1159
- voxel_size = float(voxel_size) * MetricSystem.MM
1168
+ voxel_size = (float(voxel_size) * _ureg.millimeter).to_base_units().magnitude
1160
1169
 
1161
1170
  # on the next section the one with a default value qre the optional one
1162
1171
  return cls(
1163
- stitching_strategy=OverlapStitchingStrategy.from_value(
1172
+ stitching_strategy=OverlapStitchingStrategy(
1164
1173
  config[STITCHING_SECTION].get(
1165
1174
  STITCHING_STRATEGY_FIELD,
1166
1175
  OverlapStitchingStrategy.COSINUS_WEIGHTS,
@@ -1191,10 +1200,10 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1191
1200
  config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
1192
1201
  )
1193
1202
  ),
1194
- alignment_axis_1=AlignmentAxis1.from_value(
1203
+ alignment_axis_1=AlignmentAxis1(
1195
1204
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_1_FIELD, AlignmentAxis1.CENTER)
1196
1205
  ),
1197
- alignment_axis_2=AlignmentAxis2.from_value(
1206
+ alignment_axis_2=AlignmentAxis2(
1198
1207
  config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
1199
1208
  ),
1200
1209
  pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
@@ -1208,7 +1217,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1208
1217
  if self.voxel_size is None:
1209
1218
  voxel_size_mm = ""
1210
1219
  else:
1211
- voxel_size_mm = numpy.array(self.voxel_size) / MetricSystem.MM
1220
+ voxel_size_mm = numpy.array((self.voxel_size * _ureg.meter).to(_ureg.millimeter).magnitude)
1212
1221
 
1213
1222
  return concatenate_dict(
1214
1223
  super().to_dict(),
@@ -1243,7 +1252,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura
1243
1252
  STITCHING_SECTION: {
1244
1253
  ALIGNMENT_AXIS_1_FIELD: {
1245
1254
  "default": "center",
1246
- "help": f"alignment to apply over axis 1 if needed. Valid values are {AlignmentAxis1.values()}",
1255
+ "help": f"alignment to apply over axis 1 if needed. Valid values are {[aa for aa in AlignmentAxis1]}",
1247
1256
  "type": "advanced",
1248
1257
  }
1249
1258
  },
@@ -1274,7 +1283,7 @@ def dict_to_config_obj(config: dict):
1274
1283
  if stitching_type is None:
1275
1284
  raise ValueError("Unable to find stitching type from config dict")
1276
1285
  else:
1277
- stitching_type = StitchingType.from_value(stitching_type)
1286
+ stitching_type = StitchingType(stitching_type)
1278
1287
  if stitching_type is StitchingType.Z_POSTPROC:
1279
1288
  return PostProcessedZStitchingConfiguration.from_dict(config)
1280
1289
  elif stitching_type is StitchingType.Z_PREPROC:
@@ -1297,7 +1306,7 @@ def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, s
1297
1306
  if stitching_type is None:
1298
1307
  return concatenate_dict(z_postproc_stitching_config, z_preproc_stitching_config)
1299
1308
 
1300
- stitching_type = StitchingType.from_value(stitching_type)
1309
+ stitching_type = StitchingType(stitching_type)
1301
1310
  if stitching_type is StitchingType.Z_POSTPROC:
1302
1311
  return z_postproc_stitching_config
1303
1312
  elif stitching_type is StitchingType.Z_PREPROC:
@@ -1,7 +1,7 @@
1
- from silx.utils.enum import Enum as _Enum
1
+ from enum import Enum
2
2
 
3
3
 
4
- class StitchingType(_Enum):
4
+ class StitchingType(Enum):
5
5
  Y_PREPROC = "y-preproc"
6
6
  Z_PREPROC = "z-preproc"
7
7
  Z_POSTPROC = "z-postproc"
@@ -31,7 +31,7 @@ class FrameComposition:
31
31
  )
32
32
 
33
33
  def compose(self, output_frame: numpy.ndarray, input_frames: tuple):
34
- if not output_frame.ndim in (2, 3):
34
+ if output_frame.ndim not in (2, 3):
35
35
  raise TypeError(
36
36
  f"output_frame is expected to be 2D (gray scale) or 3D (RGB(A)) and not {output_frame.ndim}"
37
37
  )
@@ -74,9 +74,10 @@ class FrameComposition:
74
74
  local_start_indices.extend(
75
75
  [ceil(key_line[1] + kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)]
76
76
  )
77
- local_end_indices = list(
78
- [ceil(key_line[0] - kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)]
79
- )
77
+ local_end_indices = [
78
+ ceil(key_line[0] - kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)
79
+ ]
80
+
80
81
  local_end_indices.append(frames[-1].shape[stitching_axis])
81
82
 
82
83
  for (
@@ -155,9 +156,6 @@ class FrameComposition:
155
156
  print(
156
157
  f"stitch_frame[{stitch_global_start}:{stitch_global_end}] = stitched_frame_{i_frame}[{stitch_local_start}:{stitch_local_end}]"
157
158
  )
158
- else:
159
- i_frame += 1
160
- raw_local_start, raw_local_end, raw_global_start, raw_global_end = list(raw_composition.browse())[-1]
161
- print(
162
- f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]"
163
- )
159
+ i_frame += 1
160
+ raw_local_start, raw_local_end, raw_global_start, raw_global_end = list(raw_composition.browse())[-1]
161
+ print(f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]")
nabu/stitching/overlap.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import numpy
2
2
  import logging
3
3
  from typing import Optional, Union
4
- from silx.utils.enum import Enum as _Enum
4
+ from enum import Enum
5
5
  from nabu.misc import fourier_filters
6
6
  from scipy.fft import rfftn as local_fftn
7
7
  from scipy.fft import irfftn as local_ifftn
@@ -10,7 +10,7 @@ from tomoscan.utils.geometry import BoundingBox1D
10
10
  _logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
- class OverlapStitchingStrategy(_Enum):
13
+ class OverlapStitchingStrategy(Enum):
14
14
  MEAN = "mean"
15
15
  COSINUS_WEIGHTS = "cosinus weights"
16
16
  LINEAR_WEIGHTS = "linear weights"
@@ -64,7 +64,7 @@ class ImageStichOverlapKernel(OverlapKernelBase):
64
64
  f"frame_width is expected to be a positive int, {frame_unstitched_axis_size} - not {frame_unstitched_axis_size} ({type(frame_unstitched_axis_size)})"
65
65
  )
66
66
 
67
- if not stitching_axis in (0, 1):
67
+ if stitching_axis not in (0, 1):
68
68
  raise ValueError(
69
69
  "stitching_axis is expected to be the axis along which stitching must be done. It should be '0' or '1'"
70
70
  )
@@ -72,7 +72,7 @@ class ImageStichOverlapKernel(OverlapKernelBase):
72
72
  self._stitching_axis = stitching_axis
73
73
  self._overlap_size = abs(overlap_size)
74
74
  self._frame_unstitched_axis_size = frame_unstitched_axis_size
75
- self._stitching_strategy = OverlapStitchingStrategy.from_value(stitching_strategy)
75
+ self._stitching_strategy = OverlapStitchingStrategy(stitching_strategy)
76
76
  self._weights_img_1 = None
77
77
  self._weights_img_2 = None
78
78
  if extra_params is None:
@@ -1,13 +1,13 @@
1
+ from enum import Enum
1
2
  import numpy
2
- from silx.utils.enum import Enum as _Enum
3
3
 
4
4
 
5
- class SampleSide(_Enum):
5
+ class SampleSide(Enum):
6
6
  LEFT = "left"
7
7
  RIGHT = "right"
8
8
 
9
9
 
10
- class Method(_Enum):
10
+ class Method(Enum):
11
11
  MEAN = "mean"
12
12
  MEDIAN = "median"
13
13
 
@@ -28,8 +28,8 @@ def normalize_frame(
28
28
  raise TypeError(f"Frame is expected to be a 2D numpy array.")
29
29
  if frame.ndim != 2:
30
30
  raise TypeError(f"Frame is expected to be a 2D numpy array. Get {frame.ndim}D")
31
- side = SampleSide.from_value(side)
32
- method = Method.from_value(method)
31
+ side = SampleSide(side)
32
+ method = Method(method)
33
33
 
34
34
  if frame.shape[1] < sample_width + margin_before_sample:
35
35
  raise ValueError(
@@ -177,7 +177,7 @@ def split_slices(slices: Union[slice, tuple], n_parts: int):
177
177
  raise TypeError(f"slices type ({type(slices)}) is not handled. Must be a slice or an Iterable")
178
178
 
179
179
 
180
- def get_working_directory(obj: TomoObject) -> Optional[str]:
180
+ def get_working_directory(obj: TomoObject) -> Optional[str]: # noqa: PLR0911
181
181
  """
182
182
  return working directory for a specific TomoObject
183
183
  """
@@ -201,4 +201,4 @@ def get_working_directory(obj: TomoObject) -> Optional[str]:
201
201
  else:
202
202
  return os.path.abspath(os.path.dirname(obj.master_file))
203
203
  else:
204
- raise RuntimeError(f"obj type not handled ({type(obj)})")
204
+ raise RuntimeError(f"obj type not handled ({type(obj)})") # noqa: TRY004
@@ -21,6 +21,8 @@ def get_obj_constant_side_length(obj: Union[NXtomoScan, VolumeBase], axis: int)
21
21
  return obj.dim_1
22
22
  elif axis in (1, 2):
23
23
  return obj.dim_2
24
+ else:
25
+ raise ValueError(f"Axis ({axis}) not handled. Should be in (0, 1, 2)")
24
26
  elif isinstance(obj, VolumeBase) and axis == 0:
25
27
  return obj.get_volume_shape()[-1]
26
28
  else:
@@ -4,7 +4,6 @@ from typing import Union, Optional
4
4
  from tomoscan.identifier import BaseIdentifier
5
5
  from nabu.stitching.config import StitchingConfiguration
6
6
  from tomoscan.volumebase import VolumeBase
7
- from contextlib import AbstractContextManager
8
7
 
9
8
 
10
9
  class DumperBase:
@@ -96,7 +96,7 @@ class OutputVolumeContext(AbstractContextManager):
96
96
  if self._file_handler is not None:
97
97
  return self._file_handler.close()
98
98
  else:
99
- self._volume.save_data()
99
+ self._volume.save_data() # noqa: RET503
100
100
 
101
101
 
102
102
  class OutputVolumeNoDDContext(OutputVolumeContext):
@@ -2,6 +2,7 @@ import logging
2
2
  import numpy
3
3
  import os
4
4
  import h5py
5
+ import pint
5
6
  from typing import Union
6
7
  from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration
7
8
  from nabu.stitching.alignment import AlignmentAxis1
@@ -13,11 +14,9 @@ from tomoscan.esrf import NXtomoScan
13
14
  from tomoscan.series import Series
14
15
  from tomoscan.volumebase import VolumeBase
15
16
  from tomoscan.esrf.volume import HDF5Volume
16
- from typing import Iterable
17
+ from collections.abc import Iterable
17
18
  from contextlib import AbstractContextManager
18
- from pyunitsystem.metricsystem import MetricSystem
19
19
  from nabu.stitching.config import (
20
- PostProcessedSingleAxisStitchingConfiguration,
21
20
  KEY_IMG_REG_METHOD,
22
21
  )
23
22
  from nabu.stitching.utils.utils import find_volumes_relative_shifts
@@ -26,6 +25,8 @@ from .single_axis import SingleAxisStitcher
26
25
 
27
26
  _logger = logging.getLogger(__name__)
28
27
 
28
+ _ureg = pint.get_application_registry()
29
+
29
30
 
30
31
  class FlippingValueError(ValueError):
31
32
  pass
@@ -267,7 +268,7 @@ class PostProcessingStitching(SingleAxisStitcher):
267
268
  axis_N_pos_px = []
268
269
  for volume, pos_in_mm in zip(self.series, pos_as_mm):
269
270
  voxel_size_m = self.configuration.voxel_size or volume.voxel_size
270
- axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0])
271
+ axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / voxel_size_m[0])
271
272
  return axis_N_pos_px
272
273
  else:
273
274
  # deduce from motor position and pixel size
@@ -426,7 +427,7 @@ class PostProcessingStitching(SingleAxisStitcher):
426
427
 
427
428
  bunch_size = 50
428
429
  # how many frame to we stitch between two read from disk / save to disk
429
- with self.dumper.OutputDatasetContext(**output_dataset_args):
430
+ with self.dumper.OutputDatasetContext(**output_dataset_args): # noqa: SIM117
430
431
  # note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
431
432
  with _RawDatasetsContext(
432
433
  self._input_volumes,
@@ -528,7 +529,8 @@ class _RawDatasetsContext(AbstractContextManager):
528
529
  else:
529
530
  data = volume.load_data(store=False)
530
531
  if data is None:
531
- raise ValueError(f"No data found for volume {volume.get_identifier()}")
532
+ # TODO
533
+ raise ValueError(f"No data found for volume {volume.get_identifier()}") # noqa: TRY301
532
534
  if axis_1_need_padding:
533
535
  data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
534
536
  datasets.append(data)
@@ -536,7 +538,7 @@ class _RawDatasetsContext(AbstractContextManager):
536
538
  # if some errors happen during loading HDF5
537
539
  for file_handled in self.__file_handlers:
538
540
  file_handled.close()
539
- raise e
541
+ raise e # noqa: TRY201
540
542
 
541
543
  return datasets
542
544
 
@@ -544,11 +546,11 @@ class _RawDatasetsContext(AbstractContextManager):
544
546
  success = True
545
547
  for file_handler in self.__file_handlers:
546
548
  success = success and file_handler.close()
547
- if exc_type is None:
549
+ if exc_type is None: # noqa: RET503
548
550
  return success
549
551
 
550
552
  def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
551
- alignment = AlignmentAxis1.from_value(alignment)
553
+ alignment = AlignmentAxis1(alignment)
552
554
  if alignment is AlignmentAxis1.BACK:
553
555
  axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
554
556
  elif alignment is AlignmentAxis1.CENTER: