nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/cast_volume.py +9 -1
  4. nabu/app/cli_configs.py +80 -3
  5. nabu/app/estimate_motion.py +54 -0
  6. nabu/app/multicor.py +2 -4
  7. nabu/app/pcaflats.py +116 -0
  8. nabu/app/reconstruct.py +1 -7
  9. nabu/app/reduce_dark_flat.py +5 -2
  10. nabu/estimation/cor.py +1 -1
  11. nabu/estimation/motion.py +557 -0
  12. nabu/estimation/tests/test_motion_estimation.py +471 -0
  13. nabu/estimation/tilt.py +1 -1
  14. nabu/estimation/translation.py +47 -1
  15. nabu/io/cast_volume.py +100 -13
  16. nabu/io/reader.py +32 -1
  17. nabu/io/tests/test_remove_volume.py +152 -0
  18. nabu/pipeline/config_validators.py +42 -43
  19. nabu/pipeline/estimators.py +255 -0
  20. nabu/pipeline/fullfield/chunked.py +67 -43
  21. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  22. nabu/pipeline/fullfield/nabu_config.py +20 -14
  23. nabu/pipeline/fullfield/processconfig.py +17 -3
  24. nabu/pipeline/fullfield/reconstruction.py +4 -1
  25. nabu/pipeline/params.py +12 -0
  26. nabu/pipeline/tests/test_estimators.py +240 -3
  27. nabu/preproc/ccd.py +53 -3
  28. nabu/preproc/flatfield.py +306 -1
  29. nabu/preproc/shift.py +3 -1
  30. nabu/preproc/tests/test_pcaflats.py +154 -0
  31. nabu/processing/rotation_cuda.py +3 -1
  32. nabu/processing/tests/test_rotation.py +4 -2
  33. nabu/reconstruction/astra.py +245 -0
  34. nabu/reconstruction/fbp.py +7 -0
  35. nabu/reconstruction/fbp_base.py +31 -7
  36. nabu/reconstruction/fbp_opencl.py +8 -0
  37. nabu/reconstruction/filtering_opencl.py +2 -0
  38. nabu/reconstruction/mlem.py +47 -13
  39. nabu/reconstruction/tests/test_filtering.py +13 -2
  40. nabu/reconstruction/tests/test_mlem.py +91 -62
  41. nabu/resources/dataset_analyzer.py +144 -20
  42. nabu/resources/nxflatfield.py +101 -35
  43. nabu/resources/tests/test_nxflatfield.py +1 -1
  44. nabu/resources/utils.py +16 -10
  45. nabu/stitching/alignment.py +7 -7
  46. nabu/stitching/config.py +22 -20
  47. nabu/stitching/definitions.py +2 -2
  48. nabu/stitching/overlap.py +4 -4
  49. nabu/stitching/sample_normalization.py +5 -5
  50. nabu/stitching/stitcher/post_processing.py +5 -3
  51. nabu/stitching/stitcher/pre_processing.py +24 -20
  52. nabu/stitching/tests/test_config.py +3 -3
  53. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  54. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  55. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  56. nabu/stitching/utils/utils.py +7 -7
  57. nabu/testutils.py +1 -4
  58. nabu/utils.py +13 -0
  59. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/METADATA +3 -4
  60. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/RECORD +64 -57
  61. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/WHEEL +1 -1
  62. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/entry_points.txt +2 -1
  63. nabu/app/correct_rot.py +0 -62
  64. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/licenses/LICENSE +0 -0
  65. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc2.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import numpy
2
2
  import logging
3
3
  import h5py
4
4
  import os
5
+ import pint
5
6
  from collections.abc import Iterable
6
7
  from silx.io.url import DataUrl
7
8
  from silx.io.utils import get_data
@@ -26,8 +27,8 @@ from nabu.stitching.config import (
26
27
  from nabu.stitching.utils import find_projections_relative_shifts
27
28
  from functools import lru_cache as cache
28
29
  from .single_axis import SingleAxisStitcher
29
- from pyunitsystem.metricsystem import MetricSystem
30
30
 
31
+ _ureg = pint.get_application_registry()
31
32
 
32
33
  _logger = logging.getLogger(__name__)
33
34
 
@@ -238,18 +239,18 @@ class PreProcessingStitching(SingleAxisStitcher):
238
239
  if not scan_0.field_of_view == scan_1.field_of_view:
239
240
  raise ValueError(f"{scan_0} and {scan_1} have different field of view")
240
241
  # check distance
241
- if scan_0.distance is None:
242
+ if scan_0.sample_detector_distance is None:
242
243
  _logger.warning(f"no distance found for {scan_0}")
243
- elif not numpy.isclose(scan_0.distance, scan_1.distance, rtol=10e-3):
244
+ elif not numpy.isclose(scan_0.sample_detector_distance, scan_1.sample_detector_distance, rtol=10e-3):
244
245
  raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance")
245
246
  # check pixel size
246
- if not numpy.isclose(scan_0.x_pixel_size, scan_1.x_pixel_size):
247
+ if not numpy.isclose(scan_0.sample_x_pixel_size, scan_1.sample_x_pixel_size):
247
248
  raise ValueError(
248
- f"{scan_0} and {scan_1} have different x pixel size. {scan_0.x_pixel_size} vs {scan_1.x_pixel_size}"
249
+ f"{scan_0} and {scan_1} have different x pixel size. {scan_0.sample_x_pixel_size} vs {scan_1.sample_x_pixel_size}"
249
250
  )
250
- if not numpy.isclose(scan_0.y_pixel_size, scan_1.y_pixel_size):
251
+ if not numpy.isclose(scan_0.sample_y_pixel_size, scan_1.sample_y_pixel_size):
251
252
  raise ValueError(
252
- f"{scan_0} and {scan_1} have different y pixel size. {scan_0.y_pixel_size} vs {scan_1.y_pixel_size}"
253
+ f"{scan_0} and {scan_1} have different y pixel size. {scan_0.sample_y_pixel_size} vs {scan_1.sample_y_pixel_size}"
253
254
  )
254
255
 
255
256
  for scan in self.series:
@@ -291,7 +292,7 @@ class PreProcessingStitching(SingleAxisStitcher):
291
292
  axis_N_pos_px = []
292
293
  for scan, pos_in_mm in zip(self.series, pos_as_mm):
293
294
  pixel_size_m = self.configuration.pixel_size or scan.pixel_size
294
- axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / pixel_size_m)
295
+ axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / pixel_size_m)
295
296
  return axis_N_pos_px
296
297
  else:
297
298
  # deduce from motor position and pixel size
@@ -471,7 +472,7 @@ class PreProcessingStitching(SingleAxisStitcher):
471
472
  """
472
473
  nx_tomo = NXtomo()
473
474
 
474
- nx_tomo.energy = self.series[0].energy
475
+ nx_tomo.energy = self.series[0].energy * _ureg.keV
475
476
  start_times = list(filter(None, [scan.start_time for scan in self.series]))
476
477
  end_times = list(filter(None, [scan.end_time for scan in self.series]))
477
478
 
@@ -495,9 +496,9 @@ class PreProcessingStitching(SingleAxisStitcher):
495
496
 
496
497
  # handle detector (without frames)
497
498
  nx_tomo.instrument.detector.field_of_view = self.series[0].field_of_view
498
- nx_tomo.instrument.detector.distance = self.series[0].distance
499
- nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size
500
- nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size
499
+ nx_tomo.instrument.detector.distance = self.series[0].sample_detector_distance * _ureg.meter
500
+ nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size * _ureg.meter
501
+ nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size * _ureg.meter
501
502
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
502
503
  nx_tomo.instrument.detector.tomo_n = n_proj
503
504
  # note: stitching process insure un-flipping of frames. So make sure transformations is defined as an empty set
@@ -505,13 +506,13 @@ class PreProcessingStitching(SingleAxisStitcher):
505
506
 
506
507
  if isinstance(self.series[0], NXtomoScan):
507
508
  # note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
508
- rotation_angle = numpy.asarray(self.series[0].rotation_angle)
509
+ rotation_angle = numpy.asarray(self.series[0].rotation_angle) * _ureg.degree
509
510
  nx_tomo.sample.rotation_angle = rotation_angle[
510
511
  numpy.asarray(self.series[0].image_key_control) == ImageKey.PROJECTION.value
511
512
  ]
512
513
  elif isinstance(self.series[0], EDFTomoScan):
513
- nx_tomo.sample.rotation_angle = numpy.linspace(
514
- start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n
514
+ nx_tomo.sample.rotation_angle = (
515
+ numpy.linspace(start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n) * _ureg.degree
515
516
  )
516
517
  else:
517
518
  raise NotImplementedError(
@@ -529,8 +530,11 @@ class PreProcessingStitching(SingleAxisStitcher):
529
530
  else:
530
531
  raise RuntimeError("slices must be instance of a slice or of an iterable") # noqa: TRY004
531
532
 
532
- nx_tomo.sample.rotation_angle = apply_slices_selection(
533
- array=nx_tomo.sample.rotation_angle, slices=self._slices_to_stitch
533
+ nx_tomo.sample.rotation_angle = (
534
+ apply_slices_selection(
535
+ array=nx_tomo.sample.rotation_angle.to_base_units().magnitude, slices=self._slices_to_stitch
536
+ )
537
+ * _ureg.degree
534
538
  )
535
539
 
536
540
  # handle sample
@@ -559,7 +563,7 @@ class PreProcessingStitching(SingleAxisStitcher):
559
563
  # note: if at least one has missing values the numpy.Array(x_translation) with create an error as well
560
564
  x_translation = [0.0] * n_proj
561
565
  _logger.warning("Unable to fin input nxtomo x_translation values. Set it to 0.0")
562
- nx_tomo.sample.x_translation = x_translation
566
+ nx_tomo.sample.x_translation = x_translation * _ureg.meter
563
567
 
564
568
  y_translation = [
565
569
  get_sample_translation_for_projs(scan, "y_translation")
@@ -574,7 +578,7 @@ class PreProcessingStitching(SingleAxisStitcher):
574
578
  else:
575
579
  y_translation = [0.0] * n_proj
576
580
  _logger.warning("Unable to fin input nxtomo y_translation values. Set it to 0.0")
577
- nx_tomo.sample.y_translation = y_translation
581
+ nx_tomo.sample.y_translation = y_translation * _ureg.meter
578
582
  z_translation = [
579
583
  get_sample_translation_for_projs(scan, "z_translation")
580
584
  for scan in self.series
@@ -588,7 +592,7 @@ class PreProcessingStitching(SingleAxisStitcher):
588
592
  else:
589
593
  z_translation = [0.0] * n_proj
590
594
  _logger.warning("Unable to fin input nxtomo z_translation values. Set it to 0.0")
591
- nx_tomo.sample.z_translation = z_translation
595
+ nx_tomo.sample.z_translation = z_translation * _ureg.meter
592
596
 
593
597
  nx_tomo.sample.name = self.series[0].sample_name
594
598
 
@@ -10,7 +10,7 @@ from nabu.stitching.overlap import OverlapStitchingStrategy
10
10
  from nabu.stitching import config as stiching_config
11
11
 
12
12
 
13
- _stitching_types = list(stiching_config.StitchingType.values())
13
+ _stitching_types = [st.value for st in list(stiching_config.StitchingType)]
14
14
  _stitching_types.append(None)
15
15
 
16
16
 
@@ -37,7 +37,7 @@ def test_stitching_config(stitching_type, option_level):
37
37
 
38
38
  assert "stitching" in config
39
39
  assert "type" in config["stitching"]
40
- stitching_type = stiching_config.StitchingType.from_value(config["stitching"]["type"])
40
+ stitching_type = stiching_config.StitchingType(config["stitching"]["type"])
41
41
  if stitching_type is stiching_config.StitchingType.Z_POSTPROC:
42
42
  assert isinstance(
43
43
  stiching_config.dict_to_config_obj(config),
@@ -82,7 +82,7 @@ def test_stitching_config(stitching_type, option_level):
82
82
  assert isinstance(config_class_instance.to_dict(), dict)
83
83
 
84
84
 
85
- @pytest.mark.parametrize("stitching_strategy", OverlapStitchingStrategy.values())
85
+ @pytest.mark.parametrize("stitching_strategy", [oss for oss in OverlapStitchingStrategy])
86
86
  @pytest.mark.parametrize("overwrite_results", (True, "False", 0, "1"))
87
87
  @pytest.mark.parametrize(
88
88
  "axis_shifts",
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import pytest
3
3
  import numpy
4
+ import pint
4
5
  from tqdm import tqdm
5
6
 
6
7
  from nabu.stitching.y_stitching import y_stitching
@@ -9,6 +10,8 @@ from nxtomo.application.nxtomo import NXtomo
9
10
  from nxtomo.nxobject.nxdetector import ImageKey
10
11
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
11
12
 
13
+ _ureg = pint.UnitRegistry()
14
+
12
15
 
13
16
  def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
14
17
  r"""
@@ -52,10 +55,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
52
55
 
53
56
  n_projs = 3
54
57
  nx_tomo = NXtomo()
55
- nx_tomo.sample.x_translation = [0] * (n_projs + 2)
56
- nx_tomo.sample.y_translation = [frame_y_position] * (n_projs + 2)
57
- nx_tomo.sample.z_translation = [0] * (n_projs + 2)
58
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False)
58
+ nx_tomo.sample.x_translation = ([0] * (n_projs + 2)) * _ureg.meter
59
+ nx_tomo.sample.y_translation = ([frame_y_position] * (n_projs + 2)) * _ureg.meter
60
+ nx_tomo.sample.z_translation = ([0] * (n_projs + 2)) * _ureg.meter
61
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False) * _ureg.degree
59
62
  nx_tomo.instrument.detector.image_key_control = (
60
63
  ImageKey.DARK_FIELD,
61
64
  ImageKey.FLAT_FIELD,
@@ -63,10 +66,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
63
66
  ImageKey.PROJECTION,
64
67
  ImageKey.PROJECTION,
65
68
  )
66
- nx_tomo.instrument.detector.x_pixel_size = 1.0
67
- nx_tomo.instrument.detector.y_pixel_size = 1.0
68
- nx_tomo.instrument.detector.distance = 2.3
69
- nx_tomo.energy = 19.2
69
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
70
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
71
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
72
+ nx_tomo.energy = 19.2 * _ureg.keV
70
73
  nx_tomo.instrument.detector.data = numpy.stack(
71
74
  (
72
75
  my_dark_data,
@@ -395,7 +395,7 @@ def test_vol_z_stitching_with_alignment_axis_2(tmp_path, alignment_axis_2):
395
395
  axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
396
396
  slice_for_cross_correlation="middle",
397
397
  voxel_size=None,
398
- alignment_axis_2=AlignmentAxis2.from_value(alignment_axis_2),
398
+ alignment_axis_2=AlignmentAxis2(alignment_axis_2),
399
399
  )
400
400
 
401
401
  stitcher = PostProcessZStitcher(z_stich_config, progress=None)
@@ -512,7 +512,7 @@ def test_vol_z_stitching_with_alignment_axis_1(tmp_path, alignment_axis_1):
512
512
  axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
513
513
  slice_for_cross_correlation="middle",
514
514
  voxel_size=None,
515
- alignment_axis_1=AlignmentAxis1.from_value(alignment_axis_1),
515
+ alignment_axis_1=AlignmentAxis1(alignment_axis_1),
516
516
  )
517
517
 
518
518
  stitcher = PostProcessZStitcher(z_stich_config, progress=None)
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import pint
2
3
  from silx.image.phantomgenerator import PhantomGenerator
3
4
  from scipy.ndimage import shift as scipy_shift
4
5
  import numpy
@@ -17,6 +18,8 @@ from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
17
18
  from nabu.stitching.utils import ShiftAlgorithm
18
19
  import h5py
19
20
 
21
+ _ureg = pint.get_application_registry()
22
+
20
23
 
21
24
  _stitching_configurations = (
22
25
  # simple case where shifts are provided
@@ -82,13 +85,13 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration):
82
85
  scans = []
83
86
  for (i_frame, frame), z_pos in zip(enumerate(frames), z_position):
84
87
  nx_tomo = NXtomo()
85
- nx_tomo.sample.z_translation = [z_pos] * n_proj
86
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
88
+ nx_tomo.sample.z_translation = ([z_pos] * n_proj) * _ureg.meter
89
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
87
90
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
88
- nx_tomo.instrument.detector.x_pixel_size = 1.0
89
- nx_tomo.instrument.detector.y_pixel_size = 1.0
90
- nx_tomo.instrument.detector.distance = 2.3
91
- nx_tomo.energy = 19.2
91
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
92
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
93
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
94
+ nx_tomo.energy = 19.2 * _ureg.keV
92
95
  nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
93
96
 
94
97
  file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
@@ -160,8 +163,8 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration):
160
163
  )
161
164
 
162
165
  # check also other metadata are here
163
- assert created_nx_tomo.instrument.detector.distance.value == 2.3
164
- assert created_nx_tomo.energy.value == 19.2
166
+ assert created_nx_tomo.instrument.detector.distance == 2.3 * _ureg.meter
167
+ assert created_nx_tomo.energy == 19.2 * _ureg.keV
165
168
  numpy.testing.assert_array_equal(
166
169
  created_nx_tomo.instrument.detector.image_key_control,
167
170
  numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj),
@@ -228,13 +231,13 @@ def build_nxtomos(output_dir) -> tuple:
228
231
  scans = []
229
232
  for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions):
230
233
  nx_tomo = NXtomo()
231
- nx_tomo.sample.z_translation = [z_pos] * n_projs
232
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False)
234
+ nx_tomo.sample.z_translation = [z_pos] * n_projs * _ureg.meter
235
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False) * _ureg.degree
233
236
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs
234
- nx_tomo.instrument.detector.x_pixel_size = 1.0
235
- nx_tomo.instrument.detector.y_pixel_size = 1.0
236
- nx_tomo.instrument.detector.distance = 2.3
237
- nx_tomo.energy = 19.2
237
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
238
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
239
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
240
+ nx_tomo.energy = 19.2 * _ureg.keV
238
241
  nx_tomo.instrument.detector.data = frame
239
242
 
240
243
  file_path = os.path.join(output_dir, f"nxtomo_{i_frame}.nx")
@@ -373,15 +376,15 @@ def test_frame_flip(tmp_path):
373
376
  scans = []
374
377
  for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips):
375
378
  nx_tomo = NXtomo()
376
- nx_tomo.sample.z_translation = [z_pos] * n_proj
377
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
379
+ nx_tomo.sample.z_translation = [z_pos] * n_proj * _ureg.meter
380
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
378
381
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
379
- nx_tomo.instrument.detector.x_pixel_size = 1.0
380
- nx_tomo.instrument.detector.y_pixel_size = 1.0
381
- nx_tomo.instrument.detector.distance = 2.3
382
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
383
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
384
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
382
385
  nx_tomo.instrument.detector.transformations.add_transformation(DetZFlipTransformation(flip=x_flip))
383
386
  nx_tomo.instrument.detector.transformations.add_transformation(DetYFlipTransformation(flip=y_flip))
384
- nx_tomo.energy = 19.2
387
+ nx_tomo.energy = 19.2 * _ureg.keV
385
388
  nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
386
389
 
387
390
  file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
@@ -1,3 +1,4 @@
1
+ from enum import Enum
1
2
  from packaging.version import parse as parse_version
2
3
  from typing import Optional, Union
3
4
  import logging
@@ -6,7 +7,6 @@ import numpy
6
7
  from tomoscan.scanbase import TomoScanBase
7
8
  from tomoscan.volumebase import VolumeBase
8
9
  from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation
9
- from silx.utils.enum import Enum as _Enum
10
10
  from scipy.fft import rfftn as local_fftn
11
11
  from scipy.fft import irfftn as local_ifftn
12
12
  from ..alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
@@ -36,7 +36,7 @@ else:
36
36
  __has_sk_phase_correlation__ = True
37
37
 
38
38
 
39
- class ShiftAlgorithm(_Enum):
39
+ class ShiftAlgorithm(Enum):
40
40
  """All generic shift search algorithm"""
41
41
 
42
42
  NABU_FFT = "nabu-fft"
@@ -58,7 +58,7 @@ class ShiftAlgorithm(_Enum):
58
58
  if value in ("", None):
59
59
  return ShiftAlgorithm.NONE
60
60
  else:
61
- return super().from_value(value=value)
61
+ return super().__new__(cls, value)
62
62
 
63
63
 
64
64
  def find_frame_relative_shifts(
@@ -72,7 +72,7 @@ def find_frame_relative_shifts(
72
72
  y_shifts_params: Optional[dict] = None,
73
73
  ):
74
74
  """
75
- :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x
75
+ :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x.
76
76
  """
77
77
  if overlap_axis not in (0, 1):
78
78
  raise ValueError(f"overlap_axis should be in (0, 1). Get {overlap_axis}")
@@ -199,8 +199,8 @@ def find_volumes_relative_shifts(
199
199
  else:
200
200
  raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {overlap_axis}")
201
201
 
202
- alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2)
203
- alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1)
202
+ alignment_axis_2 = AlignmentAxis2(alignment_axis_2)
203
+ alignment_axis_1 = AlignmentAxis1(alignment_axis_1)
204
204
  assert dim_axis_1 > 0, "dim_axis_1 <= 0"
205
205
 
206
206
  if isinstance(slice_for_shift, str):
@@ -500,7 +500,7 @@ def find_shift_correlate(img1, img2, padding_mode="reflect"):
500
500
  padding_mode,
501
501
  )
502
502
 
503
- img_shape = img1.shape[-2:]
503
+ img_shape = cc.shape # Because cc.shape can differ from img_2.shape (e.g. in case of odd nb of cols)
504
504
  cc_vs = numpy.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
505
505
  cc_hs = numpy.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
506
506
 
nabu/testutils.py CHANGED
@@ -41,10 +41,7 @@ def generate_tests_scenarios(configurations):
41
41
  - the key is the name of a parameter
42
42
  - the value is one value of this parameter
43
43
  """
44
- scenarios = [
45
- {key: val for key, val in zip(configurations.keys(), p_)} # noqa: C416
46
- for p_ in product(*configurations.values())
47
- ]
44
+ scenarios = [{key: val for key, val in zip(configurations.keys(), p_)} for p_ in product(*configurations.values())]
48
45
  return scenarios
49
46
 
50
47
 
nabu/utils.py CHANGED
@@ -1,3 +1,4 @@
1
+ from bisect import bisect_left
1
2
  from fnmatch import fnmatch
2
3
  from functools import partial
3
4
  import os
@@ -538,6 +539,18 @@ def restore_items_in_list(list_, removed_items):
538
539
  list_.insert(idx, val)
539
540
 
540
541
 
542
+ def search_sorted(arr, val):
543
+ """
544
+ Binary search that returns the "nearest" index given a query,
545
+ i.e find "i" that minimizes abs(arr[i] - val)
546
+ It does not return the "insersion point" contrarily to numpy.searchsorted() or bisect_left
547
+ """
548
+ pos = bisect_left(arr, val)
549
+ if pos == len(arr):
550
+ return len(arr) - 1
551
+ return pos - 1 if abs(val - arr[pos - 1]) < abs(arr[pos] - val) else pos
552
+
553
+
541
554
  def check_supported(param_value, available, param_desc):
542
555
  if param_value not in available:
543
556
  raise ValueError("Unsupported %s '%s'. Available are: %s" % (param_desc, param_value, str(available)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nabu
3
- Version: 2025.1.0.dev14
3
+ Version: 2025.1.0rc2
4
4
  Summary: Nabu - Tomography software
5
5
  Author-email: Pierre Paleo <pierre.paleo@esrf.fr>, Henri Payno <henri.payno@esrf.fr>, Alessandro Mirone <mirone@esrf.fr>, Jérôme Lesaint <jerome.lesaint@esrf.fr>
6
6
  Maintainer-email: Pierre Paleo <pierre.paleo@esrf.fr>
@@ -14,7 +14,6 @@ Classifier: Development Status :: 5 - Production/Stable
14
14
  Classifier: Intended Audience :: Developers
15
15
  Classifier: Intended Audience :: Science/Research
16
16
  Classifier: Programming Language :: Python :: 3
17
- Classifier: Programming Language :: Python :: 3.8
18
17
  Classifier: Programming Language :: Python :: 3.9
19
18
  Classifier: Programming Language :: Python :: 3.10
20
19
  Classifier: Programming Language :: Python :: 3.11
@@ -26,14 +25,14 @@ Classifier: Operating System :: MacOS :: MacOS X
26
25
  Classifier: Operating System :: POSIX
27
26
  Classifier: Topic :: Scientific/Engineering :: Physics
28
27
  Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
29
- Requires-Python: >=3.8
28
+ Requires-Python: >=3.9
30
29
  Description-Content-Type: text/markdown
31
30
  License-File: LICENSE
32
31
  Requires-Dist: numpy>1.9.0
33
32
  Requires-Dist: scipy
34
33
  Requires-Dist: h5py>=3.0
35
34
  Requires-Dist: silx>=0.15.0
36
- Requires-Dist: tomoscan>=2.1.5
35
+ Requires-Dist: tomoscan>=2.2.0
37
36
  Requires-Dist: psutil
38
37
  Requires-Dist: pytest
39
38
  Requires-Dist: tifffile