nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/bootstrap_stitching.py +4 -2
  4. nabu/app/cast_volume.py +16 -14
  5. nabu/app/cli_configs.py +102 -9
  6. nabu/app/compare_volumes.py +1 -1
  7. nabu/app/composite_cor.py +2 -4
  8. nabu/app/diag_to_pix.py +5 -6
  9. nabu/app/diag_to_rot.py +10 -11
  10. nabu/app/double_flatfield.py +18 -5
  11. nabu/app/estimate_motion.py +75 -0
  12. nabu/app/multicor.py +28 -15
  13. nabu/app/parse_reconstruction_log.py +1 -0
  14. nabu/app/pcaflats.py +122 -0
  15. nabu/app/prepare_weights_double.py +1 -2
  16. nabu/app/reconstruct.py +1 -7
  17. nabu/app/reconstruct_helical.py +5 -9
  18. nabu/app/reduce_dark_flat.py +5 -4
  19. nabu/app/rotate.py +3 -1
  20. nabu/app/stitching.py +7 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +2 -2
  22. nabu/app/validator.py +1 -4
  23. nabu/cuda/convolution.py +1 -1
  24. nabu/cuda/fft.py +1 -1
  25. nabu/cuda/medfilt.py +1 -1
  26. nabu/cuda/padding.py +1 -1
  27. nabu/cuda/src/backproj.cu +6 -6
  28. nabu/cuda/src/cone.cu +4 -0
  29. nabu/cuda/src/hierarchical_backproj.cu +14 -0
  30. nabu/cuda/utils.py +2 -2
  31. nabu/estimation/alignment.py +17 -31
  32. nabu/estimation/cor.py +27 -33
  33. nabu/estimation/cor_sino.py +2 -8
  34. nabu/estimation/focus.py +4 -8
  35. nabu/estimation/motion.py +557 -0
  36. nabu/estimation/tests/test_alignment.py +2 -0
  37. nabu/estimation/tests/test_motion_estimation.py +471 -0
  38. nabu/estimation/tests/test_tilt.py +1 -1
  39. nabu/estimation/tilt.py +6 -5
  40. nabu/estimation/translation.py +47 -1
  41. nabu/io/cast_volume.py +108 -18
  42. nabu/io/detector_distortion.py +5 -6
  43. nabu/io/reader.py +45 -6
  44. nabu/io/reader_helical.py +5 -4
  45. nabu/io/tests/test_cast_volume.py +2 -2
  46. nabu/io/tests/test_readers.py +41 -38
  47. nabu/io/tests/test_remove_volume.py +152 -0
  48. nabu/io/tests/test_writers.py +2 -2
  49. nabu/io/utils.py +8 -4
  50. nabu/io/writer.py +1 -2
  51. nabu/misc/fftshift.py +1 -1
  52. nabu/misc/fourier_filters.py +1 -1
  53. nabu/misc/histogram.py +1 -1
  54. nabu/misc/histogram_cuda.py +1 -1
  55. nabu/misc/padding_base.py +1 -1
  56. nabu/misc/rotation.py +1 -1
  57. nabu/misc/rotation_cuda.py +1 -1
  58. nabu/misc/tests/test_binning.py +1 -1
  59. nabu/misc/transpose.py +1 -1
  60. nabu/misc/unsharp.py +1 -1
  61. nabu/misc/unsharp_cuda.py +1 -1
  62. nabu/misc/unsharp_opencl.py +1 -1
  63. nabu/misc/utils.py +1 -1
  64. nabu/opencl/fft.py +1 -1
  65. nabu/opencl/padding.py +1 -1
  66. nabu/opencl/src/backproj.cl +6 -6
  67. nabu/opencl/utils.py +8 -8
  68. nabu/pipeline/config.py +2 -2
  69. nabu/pipeline/config_validators.py +46 -46
  70. nabu/pipeline/datadump.py +3 -3
  71. nabu/pipeline/estimators.py +271 -11
  72. nabu/pipeline/fullfield/chunked.py +103 -67
  73. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  74. nabu/pipeline/fullfield/computations.py +4 -1
  75. nabu/pipeline/fullfield/dataset_validator.py +0 -1
  76. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  77. nabu/pipeline/fullfield/nabu_config.py +36 -17
  78. nabu/pipeline/fullfield/processconfig.py +41 -7
  79. nabu/pipeline/fullfield/reconstruction.py +14 -10
  80. nabu/pipeline/helical/dataset_validator.py +3 -4
  81. nabu/pipeline/helical/fbp.py +4 -4
  82. nabu/pipeline/helical/filtering.py +5 -4
  83. nabu/pipeline/helical/gridded_accumulator.py +10 -11
  84. nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
  85. nabu/pipeline/helical/helical_reconstruction.py +12 -9
  86. nabu/pipeline/helical/helical_utils.py +1 -2
  87. nabu/pipeline/helical/nabu_config.py +2 -1
  88. nabu/pipeline/helical/span_strategy.py +1 -0
  89. nabu/pipeline/helical/weight_balancer.py +2 -3
  90. nabu/pipeline/params.py +20 -3
  91. nabu/pipeline/tests/__init__.py +0 -0
  92. nabu/pipeline/tests/test_estimators.py +240 -3
  93. nabu/pipeline/utils.py +1 -1
  94. nabu/pipeline/writer.py +1 -1
  95. nabu/preproc/alignment.py +0 -10
  96. nabu/preproc/ccd.py +53 -3
  97. nabu/preproc/ctf.py +8 -8
  98. nabu/preproc/ctf_cuda.py +1 -1
  99. nabu/preproc/double_flatfield_cuda.py +2 -2
  100. nabu/preproc/double_flatfield_variable_region.py +0 -1
  101. nabu/preproc/flatfield.py +307 -2
  102. nabu/preproc/flatfield_cuda.py +1 -2
  103. nabu/preproc/flatfield_variable_region.py +3 -3
  104. nabu/preproc/phase.py +2 -4
  105. nabu/preproc/phase_cuda.py +2 -2
  106. nabu/preproc/shift.py +4 -2
  107. nabu/preproc/shift_cuda.py +0 -1
  108. nabu/preproc/tests/test_ctf.py +4 -4
  109. nabu/preproc/tests/test_double_flatfield.py +1 -1
  110. nabu/preproc/tests/test_flatfield.py +1 -1
  111. nabu/preproc/tests/test_paganin.py +1 -3
  112. nabu/preproc/tests/test_pcaflats.py +154 -0
  113. nabu/preproc/tests/test_vshift.py +4 -1
  114. nabu/processing/azim.py +9 -5
  115. nabu/processing/convolution_cuda.py +6 -4
  116. nabu/processing/fft_base.py +7 -3
  117. nabu/processing/fft_cuda.py +25 -164
  118. nabu/processing/fft_opencl.py +28 -6
  119. nabu/processing/fftshift.py +1 -1
  120. nabu/processing/histogram.py +1 -1
  121. nabu/processing/muladd.py +0 -1
  122. nabu/processing/padding_base.py +1 -1
  123. nabu/processing/padding_cuda.py +0 -2
  124. nabu/processing/processing_base.py +12 -6
  125. nabu/processing/rotation_cuda.py +3 -1
  126. nabu/processing/tests/test_fft.py +2 -64
  127. nabu/processing/tests/test_fftshift.py +1 -1
  128. nabu/processing/tests/test_medfilt.py +1 -3
  129. nabu/processing/tests/test_padding.py +1 -1
  130. nabu/processing/tests/test_roll.py +1 -1
  131. nabu/processing/tests/test_rotation.py +4 -2
  132. nabu/processing/unsharp_opencl.py +1 -1
  133. nabu/reconstruction/astra.py +245 -0
  134. nabu/reconstruction/cone.py +39 -9
  135. nabu/reconstruction/fbp.py +7 -0
  136. nabu/reconstruction/fbp_base.py +36 -5
  137. nabu/reconstruction/filtering.py +59 -25
  138. nabu/reconstruction/filtering_cuda.py +22 -21
  139. nabu/reconstruction/filtering_opencl.py +10 -14
  140. nabu/reconstruction/hbp.py +26 -13
  141. nabu/reconstruction/mlem.py +55 -16
  142. nabu/reconstruction/projection.py +3 -5
  143. nabu/reconstruction/sinogram.py +1 -1
  144. nabu/reconstruction/sinogram_cuda.py +0 -1
  145. nabu/reconstruction/tests/test_cone.py +37 -2
  146. nabu/reconstruction/tests/test_deringer.py +4 -4
  147. nabu/reconstruction/tests/test_fbp.py +36 -15
  148. nabu/reconstruction/tests/test_filtering.py +27 -7
  149. nabu/reconstruction/tests/test_halftomo.py +28 -2
  150. nabu/reconstruction/tests/test_mlem.py +94 -64
  151. nabu/reconstruction/tests/test_projector.py +7 -2
  152. nabu/reconstruction/tests/test_reconstructor.py +1 -1
  153. nabu/reconstruction/tests/test_sino_normalization.py +0 -1
  154. nabu/resources/dataset_analyzer.py +210 -24
  155. nabu/resources/gpu.py +4 -4
  156. nabu/resources/logger.py +4 -4
  157. nabu/resources/nxflatfield.py +103 -37
  158. nabu/resources/tests/test_dataset_analyzer.py +37 -0
  159. nabu/resources/tests/test_extract.py +11 -0
  160. nabu/resources/tests/test_nxflatfield.py +5 -5
  161. nabu/resources/utils.py +16 -10
  162. nabu/stitching/alignment.py +8 -11
  163. nabu/stitching/config.py +44 -35
  164. nabu/stitching/definitions.py +2 -2
  165. nabu/stitching/frame_composition.py +8 -10
  166. nabu/stitching/overlap.py +4 -4
  167. nabu/stitching/sample_normalization.py +5 -5
  168. nabu/stitching/slurm_utils.py +2 -2
  169. nabu/stitching/stitcher/base.py +2 -0
  170. nabu/stitching/stitcher/dumper/base.py +0 -1
  171. nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
  172. nabu/stitching/stitcher/post_processing.py +11 -9
  173. nabu/stitching/stitcher/pre_processing.py +37 -31
  174. nabu/stitching/stitcher/single_axis.py +2 -3
  175. nabu/stitching/stitcher_2D.py +2 -1
  176. nabu/stitching/tests/test_config.py +10 -11
  177. nabu/stitching/tests/test_sample_normalization.py +1 -1
  178. nabu/stitching/tests/test_slurm_utils.py +1 -2
  179. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  180. nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
  181. nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
  182. nabu/stitching/utils/tests/__init__.py +0 -0
  183. nabu/stitching/utils/tests/test_post-processing.py +1 -0
  184. nabu/stitching/utils/utils.py +16 -18
  185. nabu/tests.py +0 -3
  186. nabu/testutils.py +62 -9
  187. nabu/utils.py +50 -20
  188. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
  189. nabu-2025.1.0.dist-info/RECORD +328 -0
  190. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
  191. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
  192. nabu/app/correct_rot.py +0 -70
  193. nabu/io/tests/test_detector_distortion.py +0 -178
  194. nabu-2024.2.14.dist-info/RECORD +0 -317
  195. /nabu/{stitching → app}/tests/__init__.py +0 -0
  196. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
  197. {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,8 @@ import numpy
2
2
  import logging
3
3
  import h5py
4
4
  import os
5
- from typing import Iterable
5
+ import pint
6
+ from collections.abc import Iterable
6
7
  from silx.io.url import DataUrl
7
8
  from silx.io.utils import get_data
8
9
  from datetime import datetime
@@ -26,8 +27,8 @@ from nabu.stitching.config import (
26
27
  from nabu.stitching.utils import find_projections_relative_shifts
27
28
  from functools import lru_cache as cache
28
29
  from .single_axis import SingleAxisStitcher
29
- from pyunitsystem.metricsystem import MetricSystem
30
30
 
31
+ _ureg = pint.get_application_registry()
31
32
 
32
33
  _logger = logging.getLogger(__name__)
33
34
 
@@ -40,7 +41,6 @@ class PreProcessingStitching(SingleAxisStitcher):
40
41
  """
41
42
 
42
43
  def __init__(self, configuration, progress=None) -> None:
43
- """ """
44
44
  if not isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration):
45
45
  raise TypeError(
46
46
  f"configuration is expected to be an instance of {PreProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead"
@@ -239,18 +239,18 @@ class PreProcessingStitching(SingleAxisStitcher):
239
239
  if not scan_0.field_of_view == scan_1.field_of_view:
240
240
  raise ValueError(f"{scan_0} and {scan_1} have different field of view")
241
241
  # check distance
242
- if scan_0.distance is None:
242
+ if scan_0.sample_detector_distance is None:
243
243
  _logger.warning(f"no distance found for {scan_0}")
244
- elif not numpy.isclose(scan_0.distance, scan_1.distance, rtol=10e-3):
244
+ elif not numpy.isclose(scan_0.sample_detector_distance, scan_1.sample_detector_distance, rtol=10e-3):
245
245
  raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance")
246
246
  # check pixel size
247
- if not numpy.isclose(scan_0.x_pixel_size, scan_1.x_pixel_size):
247
+ if not numpy.isclose(scan_0.sample_x_pixel_size, scan_1.sample_x_pixel_size):
248
248
  raise ValueError(
249
- f"{scan_0} and {scan_1} have different x pixel size. {scan_0.x_pixel_size} vs {scan_1.x_pixel_size}"
249
+ f"{scan_0} and {scan_1} have different x pixel size. {scan_0.sample_x_pixel_size} vs {scan_1.sample_x_pixel_size}"
250
250
  )
251
- if not numpy.isclose(scan_0.y_pixel_size, scan_1.y_pixel_size):
251
+ if not numpy.isclose(scan_0.sample_y_pixel_size, scan_1.sample_y_pixel_size):
252
252
  raise ValueError(
253
- f"{scan_0} and {scan_1} have different y pixel size. {scan_0.y_pixel_size} vs {scan_1.y_pixel_size}"
253
+ f"{scan_0} and {scan_1} have different y pixel size. {scan_0.sample_y_pixel_size} vs {scan_1.sample_y_pixel_size}"
254
254
  )
255
255
 
256
256
  for scan in self.series:
@@ -292,7 +292,7 @@ class PreProcessingStitching(SingleAxisStitcher):
292
292
  axis_N_pos_px = []
293
293
  for scan, pos_in_mm in zip(self.series, pos_as_mm):
294
294
  pixel_size_m = self.configuration.pixel_size or scan.pixel_size
295
- axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / pixel_size_m)
295
+ axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / pixel_size_m)
296
296
  return axis_N_pos_px
297
297
  else:
298
298
  # deduce from motor position and pixel size
@@ -472,7 +472,7 @@ class PreProcessingStitching(SingleAxisStitcher):
472
472
  """
473
473
  nx_tomo = NXtomo()
474
474
 
475
- nx_tomo.energy = self.series[0].energy
475
+ nx_tomo.energy = self.series[0].energy * _ureg.keV
476
476
  start_times = list(filter(None, [scan.start_time for scan in self.series]))
477
477
  end_times = list(filter(None, [scan.end_time for scan in self.series]))
478
478
 
@@ -496,9 +496,9 @@ class PreProcessingStitching(SingleAxisStitcher):
496
496
 
497
497
  # handle detector (without frames)
498
498
  nx_tomo.instrument.detector.field_of_view = self.series[0].field_of_view
499
- nx_tomo.instrument.detector.distance = self.series[0].distance
500
- nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size
501
- nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size
499
+ nx_tomo.instrument.detector.distance = self.series[0].sample_detector_distance * _ureg.meter
500
+ nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size * _ureg.meter
501
+ nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size * _ureg.meter
502
502
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
503
503
  nx_tomo.instrument.detector.tomo_n = n_proj
504
504
  # note: stitching process insure un-flipping of frames. So make sure transformations is defined as an empty set
@@ -506,13 +506,13 @@ class PreProcessingStitching(SingleAxisStitcher):
506
506
 
507
507
  if isinstance(self.series[0], NXtomoScan):
508
508
  # note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
509
- rotation_angle = numpy.asarray(self.series[0].rotation_angle)
509
+ rotation_angle = numpy.asarray(self.series[0].rotation_angle) * _ureg.degree
510
510
  nx_tomo.sample.rotation_angle = rotation_angle[
511
511
  numpy.asarray(self.series[0].image_key_control) == ImageKey.PROJECTION.value
512
512
  ]
513
513
  elif isinstance(self.series[0], EDFTomoScan):
514
- nx_tomo.sample.rotation_angle = numpy.linspace(
515
- start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n
514
+ nx_tomo.sample.rotation_angle = (
515
+ numpy.linspace(start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n) * _ureg.degree
516
516
  )
517
517
  else:
518
518
  raise NotImplementedError(
@@ -526,12 +526,15 @@ class PreProcessingStitching(SingleAxisStitcher):
526
526
  if isinstance(slices, slice):
527
527
  return array[slices.start : slices.stop : 1]
528
528
  elif isinstance(slices, Iterable):
529
- return list([array[index] for index in slices])
529
+ return [array[index] for index in slices]
530
530
  else:
531
- raise RuntimeError("slices must be instance of a slice or of an iterable")
531
+ raise RuntimeError("slices must be instance of a slice or of an iterable") # noqa: TRY004
532
532
 
533
- nx_tomo.sample.rotation_angle = apply_slices_selection(
534
- array=nx_tomo.sample.rotation_angle, slices=self._slices_to_stitch
533
+ nx_tomo.sample.rotation_angle = (
534
+ apply_slices_selection(
535
+ array=nx_tomo.sample.rotation_angle.to_base_units().magnitude, slices=self._slices_to_stitch
536
+ )
537
+ * _ureg.degree
535
538
  )
536
539
 
537
540
  # handle sample
@@ -560,7 +563,7 @@ class PreProcessingStitching(SingleAxisStitcher):
560
563
  # note: if at least one has missing values the numpy.Array(x_translation) with create an error as well
561
564
  x_translation = [0.0] * n_proj
562
565
  _logger.warning("Unable to fin input nxtomo x_translation values. Set it to 0.0")
563
- nx_tomo.sample.x_translation = x_translation
566
+ nx_tomo.sample.x_translation = x_translation * _ureg.meter
564
567
 
565
568
  y_translation = [
566
569
  get_sample_translation_for_projs(scan, "y_translation")
@@ -575,7 +578,7 @@ class PreProcessingStitching(SingleAxisStitcher):
575
578
  else:
576
579
  y_translation = [0.0] * n_proj
577
580
  _logger.warning("Unable to fin input nxtomo y_translation values. Set it to 0.0")
578
- nx_tomo.sample.y_translation = y_translation
581
+ nx_tomo.sample.y_translation = y_translation * _ureg.meter
579
582
  z_translation = [
580
583
  get_sample_translation_for_projs(scan, "z_translation")
581
584
  for scan in self.series
@@ -589,7 +592,7 @@ class PreProcessingStitching(SingleAxisStitcher):
589
592
  else:
590
593
  z_translation = [0.0] * n_proj
591
594
  _logger.warning("Unable to fin input nxtomo z_translation values. Set it to 0.0")
592
- nx_tomo.sample.z_translation = z_translation
595
+ nx_tomo.sample.z_translation = z_translation * _ureg.meter
593
596
 
594
597
  nx_tomo.sample.name = self.series[0].sample_name
595
598
 
@@ -794,7 +797,7 @@ class PreProcessingStitching(SingleAxisStitcher):
794
797
  ):
795
798
  i_frame = 0
796
799
  _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
797
- for _, url in set_of_compacted_slices.items():
800
+ for url in set_of_compacted_slices.values():
798
801
  scan = scans[i_scan]
799
802
  url = DataUrl(
800
803
  file_path=url.file_path(),
@@ -886,6 +889,9 @@ class PreProcessingStitching(SingleAxisStitcher):
886
889
  """
887
890
  make sure reduced dark and flats are existing otherwise compute them
888
891
  """
892
+ # TODO
893
+ # ruff: noqa: SIM105, S110
894
+ # --
889
895
  for scan in self.series:
890
896
  try:
891
897
  reduced_darks, darks_infos = scan.load_reduced_darks(return_info=True)
@@ -896,7 +902,7 @@ class PreProcessingStitching(SingleAxisStitcher):
896
902
  try:
897
903
  # if we don't have write in the folder containing the .nx for example
898
904
  scan.save_reduced_darks(reduced_darks, darks_infos=darks_infos)
899
- except Exception as e:
905
+ except Exception:
900
906
  pass
901
907
  scan.set_reduced_darks(reduced_darks, darks_infos=darks_infos)
902
908
 
@@ -909,7 +915,7 @@ class PreProcessingStitching(SingleAxisStitcher):
909
915
  try:
910
916
  # if we don't have write in the folder containing the .nx for example
911
917
  scan.save_reduced_flats(reduced_flats, flats_infos=flats_infos)
912
- except Exception as e:
918
+ except Exception:
913
919
  pass
914
920
  scan.set_reduced_flats(reduced_flats, flats_infos=flats_infos)
915
921
 
@@ -979,7 +985,7 @@ class PreProcessingStitching(SingleAxisStitcher):
979
985
  ):
980
986
  i_frame = 0
981
987
  _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
982
- for _, url in set_of_compacted_slices.items():
988
+ for url in set_of_compacted_slices.values():
983
989
  scan = scans[i_scan]
984
990
  url = DataUrl(
985
991
  file_path=url.file_path(),
@@ -1000,12 +1006,12 @@ class PreProcessingStitching(SingleAxisStitcher):
1000
1006
 
1001
1007
  missing = []
1002
1008
  if len(scan.reduced_flats) == 0:
1003
- missing = "flats"
1009
+ missing.append("flats")
1004
1010
  if len(scan.reduced_darks) == 0:
1005
- missing = "darks"
1011
+ missing.append("darks")
1006
1012
 
1007
1013
  if len(missing) > 0:
1008
- _logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
1014
+ _logger.warning(f"missing {' and '.join(missing)}. Unable to do flat field correction")
1009
1015
  ff_arrays = None
1010
1016
  data = raw_radios
1011
1017
  else:
@@ -1,8 +1,7 @@
1
- import h5py
2
1
  import numpy
3
2
  import logging
4
- from math import ceil
5
- from typing import Optional, Iterable, Union
3
+ from typing import Optional, Union
4
+ from collections.abc import Iterable
6
5
  from tomoscan.series import Series
7
6
  from tomoscan.identifier import BaseIdentifier
8
7
  from nabu.stitching.stitcher.base import _StitcherBase, get_obj_constant_side_length
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: N999
1
2
  import numpy
2
3
  from math import ceil
3
4
  from typing import Union, Optional
@@ -19,7 +20,7 @@ def stitch_raw_frames(
19
20
  pad_mode="constant",
20
21
  new_unstitched_axis_size: Optional[int] = None,
21
22
  ) -> numpy.ndarray:
22
- """
23
+ r"""
23
24
  stitches raw frames (already shifted and flat fielded !!!) together using
24
25
  raw stitching (no pixel interpolation, y_overlap_in_px is expected to be a int).
25
26
  Sttiching depends on the kernel used.
@@ -10,7 +10,7 @@ from nabu.stitching.overlap import OverlapStitchingStrategy
10
10
  from nabu.stitching import config as stiching_config
11
11
 
12
12
 
13
- _stitching_types = list(stiching_config.StitchingType.values())
13
+ _stitching_types = [st.value for st in list(stiching_config.StitchingType)]
14
14
  _stitching_types.append(None)
15
15
 
16
16
 
@@ -37,18 +37,16 @@ def test_stitching_config(stitching_type, option_level):
37
37
 
38
38
  assert "stitching" in config
39
39
  assert "type" in config["stitching"]
40
- stitching_type = stiching_config.StitchingType.from_value(config["stitching"]["type"])
40
+ stitching_type = stiching_config.StitchingType(config["stitching"]["type"])
41
41
  if stitching_type is stiching_config.StitchingType.Z_POSTPROC:
42
42
  assert isinstance(
43
43
  stiching_config.dict_to_config_obj(config),
44
44
  stiching_config.PostProcessedSingleAxisStitchingConfiguration,
45
45
  )
46
- elif stitching_type is stiching_config.StitchingType.Z_PREPROC:
47
- assert isinstance(
48
- stiching_config.dict_to_config_obj(config),
49
- stiching_config.PreProcessedSingleAxisStitchingConfiguration,
50
- )
51
- elif stitching_type is stiching_config.StitchingType.Y_PREPROC:
46
+ elif (
47
+ stitching_type is stiching_config.StitchingType.Z_PREPROC
48
+ or stitching_type is stiching_config.StitchingType.Y_PREPROC
49
+ ):
52
50
  assert isinstance(
53
51
  stiching_config.dict_to_config_obj(config),
54
52
  stiching_config.PreProcessedSingleAxisStitchingConfiguration,
@@ -84,7 +82,7 @@ def test_stitching_config(stitching_type, option_level):
84
82
  assert isinstance(config_class_instance.to_dict(), dict)
85
83
 
86
84
 
87
- @pytest.mark.parametrize("stitching_strategy", OverlapStitchingStrategy.values())
85
+ @pytest.mark.parametrize("stitching_strategy", [oss for oss in OverlapStitchingStrategy])
88
86
  @pytest.mark.parametrize("overwrite_results", (True, "False", 0, "1"))
89
87
  @pytest.mark.parametrize(
90
88
  "axis_shifts",
@@ -92,7 +90,7 @@ def test_stitching_config(stitching_type, option_level):
92
90
  "",
93
91
  None,
94
92
  "None",
95
- "",
93
+ "", # noqa: PT014
96
94
  "skimage",
97
95
  "nabu-fft",
98
96
  ),
@@ -176,7 +174,8 @@ def test_PreProcessedZStitchingConfiguration(
176
174
 
177
175
  from_dict = stiching_config.PreProcessedZStitchingConfiguration.from_dict(pre_process_config.to_dict())
178
176
  # workaround for scans because a new object is created each time
179
- pre_process_config.settle_inputs
177
+ # ???
178
+ pre_process_config.settle_inputs # noqa: B018
180
179
  assert len(from_dict.input_scans) == len(pre_process_config.input_scans)
181
180
  from_dict.input_scans = None
182
181
  pre_process_config.input_scans = None
@@ -1,6 +1,6 @@
1
1
  import numpy
2
2
  import pytest
3
- from nabu.stitching.sample_normalization import normalize_frame, SampleSide, Method
3
+ from nabu.stitching.sample_normalization import normalize_frame
4
4
 
5
5
 
6
6
  def test_normalize_frame():
@@ -1,5 +1,4 @@
1
1
  import os
2
- import numpy
3
2
  import pytest
4
3
  from tomoscan.esrf import NXtomoScan
5
4
  from tomoscan.esrf.volume import HDF5Volume
@@ -14,7 +13,7 @@ from nabu.stitching.slurm_utils import (
14
13
  from tomoscan.esrf.mock import MockNXtomo
15
14
 
16
15
  try:
17
- import sluurp
16
+ import sluurp # noqa: F401
18
17
  except ImportError:
19
18
  has_sluurp = False
20
19
  else:
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import pytest
3
3
  import numpy
4
+ import pint
4
5
  from tqdm import tqdm
5
6
 
6
7
  from nabu.stitching.y_stitching import y_stitching
@@ -9,6 +10,8 @@ from nxtomo.application.nxtomo import NXtomo
9
10
  from nxtomo.nxobject.nxdetector import ImageKey
10
11
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
11
12
 
13
+ _ureg = pint.UnitRegistry()
14
+
12
15
 
13
16
  def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
14
17
  r"""
@@ -52,10 +55,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
52
55
 
53
56
  n_projs = 3
54
57
  nx_tomo = NXtomo()
55
- nx_tomo.sample.x_translation = [0] * (n_projs + 2)
56
- nx_tomo.sample.y_translation = [frame_y_position] * (n_projs + 2)
57
- nx_tomo.sample.z_translation = [0] * (n_projs + 2)
58
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False)
58
+ nx_tomo.sample.x_translation = ([0] * (n_projs + 2)) * _ureg.meter
59
+ nx_tomo.sample.y_translation = ([frame_y_position] * (n_projs + 2)) * _ureg.meter
60
+ nx_tomo.sample.z_translation = ([0] * (n_projs + 2)) * _ureg.meter
61
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False) * _ureg.degree
59
62
  nx_tomo.instrument.detector.image_key_control = (
60
63
  ImageKey.DARK_FIELD,
61
64
  ImageKey.FLAT_FIELD,
@@ -63,10 +66,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
63
66
  ImageKey.PROJECTION,
64
67
  ImageKey.PROJECTION,
65
68
  )
66
- nx_tomo.instrument.detector.x_pixel_size = 1.0
67
- nx_tomo.instrument.detector.y_pixel_size = 1.0
68
- nx_tomo.instrument.detector.distance = 2.3
69
- nx_tomo.energy = 19.2
69
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
70
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
71
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
72
+ nx_tomo.energy = 19.2 * _ureg.keV
70
73
  nx_tomo.instrument.detector.data = numpy.stack(
71
74
  (
72
75
  my_dark_data,
@@ -395,7 +395,7 @@ def test_vol_z_stitching_with_alignment_axis_2(tmp_path, alignment_axis_2):
395
395
  axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
396
396
  slice_for_cross_correlation="middle",
397
397
  voxel_size=None,
398
- alignment_axis_2=AlignmentAxis2.from_value(alignment_axis_2),
398
+ alignment_axis_2=AlignmentAxis2(alignment_axis_2),
399
399
  )
400
400
 
401
401
  stitcher = PostProcessZStitcher(z_stich_config, progress=None)
@@ -512,7 +512,7 @@ def test_vol_z_stitching_with_alignment_axis_1(tmp_path, alignment_axis_1):
512
512
  axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
513
513
  slice_for_cross_correlation="middle",
514
514
  voxel_size=None,
515
- alignment_axis_1=AlignmentAxis1.from_value(alignment_axis_1),
515
+ alignment_axis_1=AlignmentAxis1(alignment_axis_1),
516
516
  )
517
517
 
518
518
  stitcher = PostProcessZStitcher(z_stich_config, progress=None)
@@ -770,6 +770,6 @@ def test_data_duplication(tmp_path, data_duplication):
770
770
  if not data_duplication:
771
771
  # make sure an error is raised if we try to ask for no data duplication and if we get some flips
772
772
  z_stich_config.flip_ud = (False, True, False)
773
- with pytest.raises(ValueError):
773
+ with pytest.raises(ValueError): # noqa: PT012
774
774
  stitcher = PostProcessZStitcherNoDD(z_stich_config, progress=None)
775
775
  stitcher.stitch()
@@ -1,15 +1,16 @@
1
1
  import os
2
+ import pint
2
3
  from silx.image.phantomgenerator import PhantomGenerator
3
4
  from scipy.ndimage import shift as scipy_shift
4
5
  import numpy
5
6
  import pytest
6
7
  from nabu.stitching.config import PreProcessedZStitchingConfiguration
7
8
  from nabu.stitching.config import KEY_IMG_REG_METHOD
8
- from nabu.stitching.overlap import ImageStichOverlapKernel, OverlapStitchingStrategy
9
+ from nabu.stitching.overlap import OverlapStitchingStrategy
9
10
  from nabu.stitching.z_stitching import (
10
11
  PreProcessZStitcher,
11
12
  )
12
- from nabu.stitching.stitcher_2D import stitch_raw_frames, get_overlap_areas
13
+ from nabu.stitching.stitcher_2D import get_overlap_areas
13
14
  from nxtomo.nxobject.nxdetector import ImageKey
14
15
  from nxtomo.utils.transformation import DetYFlipTransformation, DetZFlipTransformation
15
16
  from nxtomo.application.nxtomo import NXtomo
@@ -17,6 +18,8 @@ from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
17
18
  from nabu.stitching.utils import ShiftAlgorithm
18
19
  import h5py
19
20
 
21
+ _ureg = pint.get_application_registry()
22
+
20
23
 
21
24
  _stitching_configurations = (
22
25
  # simple case where shifts are provided
@@ -82,13 +85,13 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration):
82
85
  scans = []
83
86
  for (i_frame, frame), z_pos in zip(enumerate(frames), z_position):
84
87
  nx_tomo = NXtomo()
85
- nx_tomo.sample.z_translation = [z_pos] * n_proj
86
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
88
+ nx_tomo.sample.z_translation = ([z_pos] * n_proj) * _ureg.meter
89
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
87
90
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
88
- nx_tomo.instrument.detector.x_pixel_size = 1.0
89
- nx_tomo.instrument.detector.y_pixel_size = 1.0
90
- nx_tomo.instrument.detector.distance = 2.3
91
- nx_tomo.energy = 19.2
91
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
92
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
93
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
94
+ nx_tomo.energy = 19.2 * _ureg.keV
92
95
  nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
93
96
 
94
97
  file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
@@ -160,8 +163,8 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration):
160
163
  )
161
164
 
162
165
  # check also other metadata are here
163
- assert created_nx_tomo.instrument.detector.distance.value == 2.3
164
- assert created_nx_tomo.energy.value == 19.2
166
+ assert created_nx_tomo.instrument.detector.distance == 2.3 * _ureg.meter
167
+ assert created_nx_tomo.energy == 19.2 * _ureg.keV
165
168
  numpy.testing.assert_array_equal(
166
169
  created_nx_tomo.instrument.detector.image_key_control,
167
170
  numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj),
@@ -228,13 +231,13 @@ def build_nxtomos(output_dir) -> tuple:
228
231
  scans = []
229
232
  for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions):
230
233
  nx_tomo = NXtomo()
231
- nx_tomo.sample.z_translation = [z_pos] * n_projs
232
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False)
234
+ nx_tomo.sample.z_translation = [z_pos] * n_projs * _ureg.meter
235
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False) * _ureg.degree
233
236
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs
234
- nx_tomo.instrument.detector.x_pixel_size = 1.0
235
- nx_tomo.instrument.detector.y_pixel_size = 1.0
236
- nx_tomo.instrument.detector.distance = 2.3
237
- nx_tomo.energy = 19.2
237
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
238
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
239
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
240
+ nx_tomo.energy = 19.2 * _ureg.keV
238
241
  nx_tomo.instrument.detector.data = frame
239
242
 
240
243
  file_path = os.path.join(output_dir, f"nxtomo_{i_frame}.nx")
@@ -301,11 +304,11 @@ def test_DistributePreProcessZStitcher(tmp_path, configuration_dist):
301
304
  )
302
305
 
303
306
  if complete:
304
- len(final_nx_tomo.instrument.detector.data) == 128
307
+ assert len(final_nx_tomo.instrument.detector.data) == 100
305
308
  # test middle
306
309
  numpy.testing.assert_array_almost_equal(raw_data[1], final_nx_tomo.instrument.detector.data[1, :, :])
307
310
  else:
308
- len(final_nx_tomo.instrument.detector.data) == 3
311
+ assert len(final_nx_tomo.instrument.detector.data) == 3
309
312
  # test middle
310
313
  numpy.testing.assert_array_almost_equal(raw_data[49], final_nx_tomo.instrument.detector.data[1, :, :])
311
314
  # in the case of first, middle and last frames
@@ -373,15 +376,15 @@ def test_frame_flip(tmp_path):
373
376
  scans = []
374
377
  for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips):
375
378
  nx_tomo = NXtomo()
376
- nx_tomo.sample.z_translation = [z_pos] * n_proj
377
- nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
379
+ nx_tomo.sample.z_translation = [z_pos] * n_proj * _ureg.meter
380
+ nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
378
381
  nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
379
- nx_tomo.instrument.detector.x_pixel_size = 1.0
380
- nx_tomo.instrument.detector.y_pixel_size = 1.0
381
- nx_tomo.instrument.detector.distance = 2.3
382
+ nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
383
+ nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
384
+ nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
382
385
  nx_tomo.instrument.detector.transformations.add_transformation(DetZFlipTransformation(flip=x_flip))
383
386
  nx_tomo.instrument.detector.transformations.add_transformation(DetYFlipTransformation(flip=y_flip))
384
- nx_tomo.energy = 19.2
387
+ nx_tomo.energy = 19.2 * _ureg.keV
385
388
  nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
386
389
 
387
390
  file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
File without changes
@@ -1,3 +1,4 @@
1
+ # noqa: N999
1
2
  import pytest
2
3
 
3
4
  from nabu.stitching.stitcher.single_axis import PROGRESS_BAR_STITCH_VOL_DESC
@@ -1,3 +1,4 @@
1
+ from enum import Enum
1
2
  from packaging.version import parse as parse_version
2
3
  from typing import Optional, Union
3
4
  import logging
@@ -6,10 +7,8 @@ import numpy
6
7
  from tomoscan.scanbase import TomoScanBase
7
8
  from tomoscan.volumebase import VolumeBase
8
9
  from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation
9
- from silx.utils.enum import Enum as _Enum
10
10
  from scipy.fft import rfftn as local_fftn
11
11
  from scipy.fft import irfftn as local_ifftn
12
- from ..overlap import OverlapStitchingStrategy, ImageStichOverlapKernel
13
12
  from ..alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
14
13
  from ...misc import fourier_filters
15
14
  from ...estimation.alignment import AlignmentBase
@@ -37,7 +36,7 @@ else:
37
36
  __has_sk_phase_correlation__ = True
38
37
 
39
38
 
40
- class ShiftAlgorithm(_Enum):
39
+ class ShiftAlgorithm(Enum):
41
40
  """All generic shift search algorithm"""
42
41
 
43
42
  NABU_FFT = "nabu-fft"
@@ -59,7 +58,7 @@ class ShiftAlgorithm(_Enum):
59
58
  if value in ("", None):
60
59
  return ShiftAlgorithm.NONE
61
60
  else:
62
- return super().from_value(value=value)
61
+ return super().__new__(cls, value)
63
62
 
64
63
 
65
64
  def find_frame_relative_shifts(
@@ -73,9 +72,9 @@ def find_frame_relative_shifts(
73
72
  y_shifts_params: Optional[dict] = None,
74
73
  ):
75
74
  """
76
- :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x
75
+ :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x.
77
76
  """
78
- if not overlap_axis in (0, 1):
77
+ if overlap_axis not in (0, 1):
79
78
  raise ValueError(f"overlap_axis should be in (0, 1). Get {overlap_axis}")
80
79
  from nabu.stitching.config import (
81
80
  KEY_LOW_PASS_FILTER,
@@ -146,7 +145,7 @@ def find_frame_relative_shifts(
146
145
  }
147
146
 
148
147
  res_algo = {}
149
- for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)):
148
+ for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)): # noqa: C405
150
149
  if shift_alg not in shift_methods:
151
150
  raise ValueError(f"requested image alignment function not handled ({shift_alg})")
152
151
  try:
@@ -200,8 +199,8 @@ def find_volumes_relative_shifts(
200
199
  else:
201
200
  raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {overlap_axis}")
202
201
 
203
- alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2)
204
- alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1)
202
+ alignment_axis_2 = AlignmentAxis2(alignment_axis_2)
203
+ alignment_axis_1 = AlignmentAxis1(alignment_axis_1)
205
204
  assert dim_axis_1 > 0, "dim_axis_1 <= 0"
206
205
 
207
206
  if isinstance(slice_for_shift, str):
@@ -249,7 +248,7 @@ def find_volumes_relative_shifts(
249
248
 
250
249
  w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
251
250
  start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
252
- end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_frame.shape[0], lower_frame.shape[0]))
251
+ end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, upper_frame.shape[0], lower_frame.shape[0])
253
252
 
254
253
  if start_overlap == 0:
255
254
  overlap_upper_frame = upper_frame[-end_overlap:]
@@ -385,12 +384,10 @@ def find_projections_relative_shifts(
385
384
  cor_options=cor_options,
386
385
  )
387
386
 
388
- estimated_shifts = tuple(
389
- [
390
- estimated_shifts[0],
391
- (lower_scan_pos - upper_scan_pos),
392
- ]
393
- )
387
+ estimated_shifts = [
388
+ estimated_shifts[0],
389
+ (lower_scan_pos - upper_scan_pos),
390
+ ]
394
391
  x_cross_correlation_function = ShiftAlgorithm.NONE
395
392
 
396
393
  # } else we will compute shift from the flat projections
@@ -464,7 +461,8 @@ def find_projections_relative_shifts(
464
461
  start_overlap = max(estimated_shifts[axis_proj_space] // 2 - w_window_size // 2, 0)
465
462
  end_overlap = min(
466
463
  estimated_shifts[axis_proj_space] // 2 + w_window_size // 2,
467
- min(upper_proj.shape[axis_proj_space], lower_proj.shape[axis_proj_space]),
464
+ upper_proj.shape[axis_proj_space],
465
+ lower_proj.shape[axis_proj_space],
468
466
  )
469
467
  o_upper_sel = numpy.array(range(-end_overlap, -start_overlap))
470
468
  overlap_upper_frame = numpy.take_along_axis(
@@ -502,7 +500,7 @@ def find_shift_correlate(img1, img2, padding_mode="reflect"):
502
500
  padding_mode,
503
501
  )
504
502
 
505
- img_shape = img1.shape[-2:]
503
+ img_shape = cc.shape # Because cc.shape can differ from img_2.shape (e.g. in case of odd nb of cols)
506
504
  cc_vs = numpy.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
507
505
  cc_hs = numpy.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
508
506
 
nabu/tests.py CHANGED
@@ -1,6 +1,3 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
1
  import sys
5
2
  import os
6
3
  import pytest