nabu 2024.2.13__py3-none-any.whl → 2025.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. doc/doc_config.py +32 -0
  2. nabu/__init__.py +1 -1
  3. nabu/app/bootstrap_stitching.py +4 -2
  4. nabu/app/cast_volume.py +16 -14
  5. nabu/app/cli_configs.py +102 -9
  6. nabu/app/compare_volumes.py +1 -1
  7. nabu/app/composite_cor.py +2 -4
  8. nabu/app/diag_to_pix.py +5 -6
  9. nabu/app/diag_to_rot.py +10 -11
  10. nabu/app/double_flatfield.py +18 -5
  11. nabu/app/estimate_motion.py +75 -0
  12. nabu/app/multicor.py +28 -15
  13. nabu/app/parse_reconstruction_log.py +1 -0
  14. nabu/app/pcaflats.py +122 -0
  15. nabu/app/prepare_weights_double.py +1 -2
  16. nabu/app/reconstruct.py +1 -7
  17. nabu/app/reconstruct_helical.py +5 -9
  18. nabu/app/reduce_dark_flat.py +5 -4
  19. nabu/app/rotate.py +3 -1
  20. nabu/app/stitching.py +7 -2
  21. nabu/app/tests/test_reduce_dark_flat.py +2 -2
  22. nabu/app/validator.py +1 -4
  23. nabu/cuda/convolution.py +1 -1
  24. nabu/cuda/fft.py +1 -1
  25. nabu/cuda/medfilt.py +1 -1
  26. nabu/cuda/padding.py +1 -1
  27. nabu/cuda/src/backproj.cu +6 -6
  28. nabu/cuda/src/cone.cu +4 -0
  29. nabu/cuda/src/hierarchical_backproj.cu +14 -0
  30. nabu/cuda/utils.py +2 -2
  31. nabu/estimation/alignment.py +17 -31
  32. nabu/estimation/cor.py +27 -33
  33. nabu/estimation/cor_sino.py +2 -8
  34. nabu/estimation/focus.py +4 -8
  35. nabu/estimation/motion.py +557 -0
  36. nabu/estimation/tests/test_alignment.py +2 -0
  37. nabu/estimation/tests/test_motion_estimation.py +471 -0
  38. nabu/estimation/tests/test_tilt.py +1 -1
  39. nabu/estimation/tilt.py +6 -5
  40. nabu/estimation/translation.py +47 -1
  41. nabu/io/cast_volume.py +108 -18
  42. nabu/io/detector_distortion.py +5 -6
  43. nabu/io/reader.py +45 -6
  44. nabu/io/reader_helical.py +5 -4
  45. nabu/io/tests/test_cast_volume.py +2 -2
  46. nabu/io/tests/test_readers.py +41 -38
  47. nabu/io/tests/test_remove_volume.py +152 -0
  48. nabu/io/tests/test_writers.py +2 -2
  49. nabu/io/utils.py +8 -4
  50. nabu/io/writer.py +1 -2
  51. nabu/misc/fftshift.py +1 -1
  52. nabu/misc/fourier_filters.py +1 -1
  53. nabu/misc/histogram.py +1 -1
  54. nabu/misc/histogram_cuda.py +1 -1
  55. nabu/misc/padding_base.py +1 -1
  56. nabu/misc/rotation.py +1 -1
  57. nabu/misc/rotation_cuda.py +1 -1
  58. nabu/misc/tests/test_binning.py +1 -1
  59. nabu/misc/transpose.py +1 -1
  60. nabu/misc/unsharp.py +1 -1
  61. nabu/misc/unsharp_cuda.py +1 -1
  62. nabu/misc/unsharp_opencl.py +1 -1
  63. nabu/misc/utils.py +1 -1
  64. nabu/opencl/fft.py +1 -1
  65. nabu/opencl/padding.py +1 -1
  66. nabu/opencl/src/backproj.cl +6 -6
  67. nabu/opencl/utils.py +8 -8
  68. nabu/pipeline/config.py +2 -2
  69. nabu/pipeline/config_validators.py +46 -46
  70. nabu/pipeline/datadump.py +3 -3
  71. nabu/pipeline/estimators.py +271 -11
  72. nabu/pipeline/fullfield/chunked.py +103 -67
  73. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  74. nabu/pipeline/fullfield/computations.py +4 -1
  75. nabu/pipeline/fullfield/dataset_validator.py +0 -1
  76. nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
  77. nabu/pipeline/fullfield/nabu_config.py +36 -17
  78. nabu/pipeline/fullfield/processconfig.py +41 -7
  79. nabu/pipeline/fullfield/reconstruction.py +14 -10
  80. nabu/pipeline/helical/dataset_validator.py +3 -4
  81. nabu/pipeline/helical/fbp.py +4 -4
  82. nabu/pipeline/helical/filtering.py +5 -4
  83. nabu/pipeline/helical/gridded_accumulator.py +10 -11
  84. nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
  85. nabu/pipeline/helical/helical_reconstruction.py +12 -9
  86. nabu/pipeline/helical/helical_utils.py +1 -2
  87. nabu/pipeline/helical/nabu_config.py +2 -1
  88. nabu/pipeline/helical/span_strategy.py +1 -0
  89. nabu/pipeline/helical/weight_balancer.py +2 -3
  90. nabu/pipeline/params.py +20 -3
  91. nabu/pipeline/tests/__init__.py +0 -0
  92. nabu/pipeline/tests/test_estimators.py +240 -3
  93. nabu/pipeline/utils.py +1 -1
  94. nabu/pipeline/writer.py +1 -1
  95. nabu/preproc/alignment.py +0 -10
  96. nabu/preproc/ccd.py +53 -3
  97. nabu/preproc/ctf.py +8 -8
  98. nabu/preproc/ctf_cuda.py +1 -1
  99. nabu/preproc/double_flatfield_cuda.py +2 -2
  100. nabu/preproc/double_flatfield_variable_region.py +0 -1
  101. nabu/preproc/flatfield.py +307 -2
  102. nabu/preproc/flatfield_cuda.py +1 -2
  103. nabu/preproc/flatfield_variable_region.py +3 -3
  104. nabu/preproc/phase.py +2 -4
  105. nabu/preproc/phase_cuda.py +2 -2
  106. nabu/preproc/shift.py +4 -2
  107. nabu/preproc/shift_cuda.py +0 -1
  108. nabu/preproc/tests/test_ctf.py +4 -4
  109. nabu/preproc/tests/test_double_flatfield.py +1 -1
  110. nabu/preproc/tests/test_flatfield.py +1 -1
  111. nabu/preproc/tests/test_paganin.py +1 -3
  112. nabu/preproc/tests/test_pcaflats.py +154 -0
  113. nabu/preproc/tests/test_vshift.py +4 -1
  114. nabu/processing/azim.py +9 -5
  115. nabu/processing/convolution_cuda.py +6 -4
  116. nabu/processing/fft_base.py +7 -3
  117. nabu/processing/fft_cuda.py +25 -164
  118. nabu/processing/fft_opencl.py +28 -6
  119. nabu/processing/fftshift.py +1 -1
  120. nabu/processing/histogram.py +1 -1
  121. nabu/processing/muladd.py +0 -1
  122. nabu/processing/padding_base.py +1 -1
  123. nabu/processing/padding_cuda.py +0 -2
  124. nabu/processing/processing_base.py +12 -6
  125. nabu/processing/rotation_cuda.py +3 -1
  126. nabu/processing/tests/test_fft.py +2 -64
  127. nabu/processing/tests/test_fftshift.py +1 -1
  128. nabu/processing/tests/test_medfilt.py +1 -3
  129. nabu/processing/tests/test_padding.py +1 -1
  130. nabu/processing/tests/test_roll.py +1 -1
  131. nabu/processing/tests/test_rotation.py +4 -2
  132. nabu/processing/unsharp_opencl.py +1 -1
  133. nabu/reconstruction/astra.py +245 -0
  134. nabu/reconstruction/cone.py +39 -9
  135. nabu/reconstruction/fbp.py +14 -0
  136. nabu/reconstruction/fbp_base.py +40 -8
  137. nabu/reconstruction/fbp_opencl.py +8 -0
  138. nabu/reconstruction/filtering.py +59 -25
  139. nabu/reconstruction/filtering_cuda.py +22 -21
  140. nabu/reconstruction/filtering_opencl.py +10 -14
  141. nabu/reconstruction/hbp.py +26 -13
  142. nabu/reconstruction/mlem.py +55 -16
  143. nabu/reconstruction/projection.py +3 -5
  144. nabu/reconstruction/sinogram.py +1 -1
  145. nabu/reconstruction/sinogram_cuda.py +0 -1
  146. nabu/reconstruction/tests/test_cone.py +37 -2
  147. nabu/reconstruction/tests/test_deringer.py +4 -4
  148. nabu/reconstruction/tests/test_fbp.py +36 -15
  149. nabu/reconstruction/tests/test_filtering.py +27 -7
  150. nabu/reconstruction/tests/test_halftomo.py +28 -2
  151. nabu/reconstruction/tests/test_mlem.py +94 -64
  152. nabu/reconstruction/tests/test_projector.py +7 -2
  153. nabu/reconstruction/tests/test_reconstructor.py +1 -1
  154. nabu/reconstruction/tests/test_sino_normalization.py +0 -1
  155. nabu/resources/dataset_analyzer.py +210 -24
  156. nabu/resources/gpu.py +4 -4
  157. nabu/resources/logger.py +4 -4
  158. nabu/resources/nxflatfield.py +103 -37
  159. nabu/resources/tests/test_dataset_analyzer.py +37 -0
  160. nabu/resources/tests/test_extract.py +11 -0
  161. nabu/resources/tests/test_nxflatfield.py +5 -5
  162. nabu/resources/utils.py +16 -10
  163. nabu/stitching/alignment.py +8 -11
  164. nabu/stitching/config.py +44 -35
  165. nabu/stitching/definitions.py +2 -2
  166. nabu/stitching/frame_composition.py +8 -10
  167. nabu/stitching/overlap.py +4 -4
  168. nabu/stitching/sample_normalization.py +5 -5
  169. nabu/stitching/slurm_utils.py +2 -2
  170. nabu/stitching/stitcher/base.py +2 -0
  171. nabu/stitching/stitcher/dumper/base.py +0 -1
  172. nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
  173. nabu/stitching/stitcher/post_processing.py +11 -9
  174. nabu/stitching/stitcher/pre_processing.py +37 -31
  175. nabu/stitching/stitcher/single_axis.py +2 -3
  176. nabu/stitching/stitcher_2D.py +2 -1
  177. nabu/stitching/tests/test_config.py +10 -11
  178. nabu/stitching/tests/test_sample_normalization.py +1 -1
  179. nabu/stitching/tests/test_slurm_utils.py +1 -2
  180. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  181. nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
  182. nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
  183. nabu/stitching/utils/tests/__init__.py +0 -0
  184. nabu/stitching/utils/tests/test_post-processing.py +1 -0
  185. nabu/stitching/utils/utils.py +16 -18
  186. nabu/tests.py +0 -3
  187. nabu/testutils.py +62 -9
  188. nabu/utils.py +50 -20
  189. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
  190. nabu-2025.1.0.dist-info/RECORD +328 -0
  191. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
  192. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
  193. nabu/app/correct_rot.py +0 -70
  194. nabu/io/tests/test_detector_distortion.py +0 -178
  195. nabu-2024.2.13.dist-info/RECORD +0 -317
  196. /nabu/{stitching → app}/tests/__init__.py +0 -0
  197. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
  198. {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from enum import Enum
1
2
  import os
2
3
  import numpy as np
3
4
  from silx.io.url import DataUrl
@@ -7,7 +8,7 @@ from tomoscan.esrf.scan.edfscan import EDFTomoScan
7
8
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
8
9
  from packaging.version import parse as parse_version
9
10
 
10
- from ..utils import check_supported, indices_to_slices
11
+ from ..utils import BaseClassError, check_supported, indices_to_slices, is_scalar, search_sorted
11
12
  from ..io.reader import EDFStackReader, NXDarksFlats, NXTomoReader
12
13
  from ..io.utils import get_compacted_dataslices
13
14
  from .utils import get_values_from_file, is_hdf5_extension
@@ -16,6 +17,33 @@ from .logger import LoggerOrPrint
16
17
  from ..pipeline.utils import nabu_env_settings
17
18
 
18
19
 
20
+ # We could import the 1000+ LoC nxtomo.nxobject.nxdetector.ImageKey... or we can do this
21
+ class ImageKey(Enum):
22
+ ALIGNMENT = -1
23
+ PROJECTION = 0
24
+ FLAT_FIELD = 1
25
+ DARK_FIELD = 2
26
+ INVALID = 3
27
+
28
+
29
+ # ---
30
+
31
+ _image_type = {
32
+ "projections": ImageKey.PROJECTION.value,
33
+ "projection": ImageKey.PROJECTION.value,
34
+ "radios": ImageKey.PROJECTION.value,
35
+ "radio": ImageKey.PROJECTION.value,
36
+ "flats": ImageKey.FLAT_FIELD.value,
37
+ "flat": ImageKey.FLAT_FIELD.value,
38
+ "darks": ImageKey.DARK_FIELD.value,
39
+ "dark": ImageKey.DARK_FIELD.value,
40
+ "static": ImageKey.ALIGNMENT.value,
41
+ "alignment": ImageKey.ALIGNMENT.value,
42
+ "return": ImageKey.ALIGNMENT.value,
43
+ "invalid": ImageKey.INVALID.value,
44
+ }
45
+
46
+
19
47
  class DatasetAnalyzer:
20
48
  _scanner = None
21
49
  kind = "none"
@@ -277,6 +305,32 @@ class DatasetAnalyzer:
277
305
  def darks(self, val):
278
306
  self._reduced_darks = val
279
307
 
308
+ @property
309
+ def scan_basename(self):
310
+ raise BaseClassError
311
+
312
+ @property
313
+ def scan_dirname(self):
314
+ raise BaseClassError
315
+
316
+ def get_alignment_projections(self, image_sub_region=None):
317
+ raise NotImplementedError
318
+
319
+ @property
320
+ def all_angles(self):
321
+ raise NotImplementedError
322
+
323
+ def get_frame(self, idx): ...
324
+
325
+ @property
326
+ def is_360(self):
327
+ """
328
+ Return True iff the scan is 360 degrees (regardless of half-tomo mode)
329
+ """
330
+ angles = self.rotation_angles
331
+ delta_theta = abs(angles.max() - angles.min())
332
+ return abs(delta_theta - 2 * np.pi) < abs(delta_theta - np.pi)
333
+
280
334
 
281
335
  class EDFDatasetAnalyzer(DatasetAnalyzer):
282
336
  """
@@ -316,15 +370,15 @@ class EDFDatasetAnalyzer(DatasetAnalyzer):
316
370
  # (eg. subsampling, binning, distortion correction...)
317
371
  # (3) The following spawns one reader instance per file, which is not elegant,
318
372
  # but in principle there are typically 1-2 reduced flats in a scan
319
- readers = {k: EDFStackReader([self.raw_flats[k].file_path()], **reader_kwargs) for k in self.raw_flats.keys()}
320
- return {k: readers[k].load_data()[0] for k in self.raw_flats.keys()}
373
+ readers = {k: EDFStackReader([self.raw_flats[k].file_path()], **reader_kwargs) for k in self.raw_flats}
374
+ return {k: readers[k].load_data()[0] for k in self.raw_flats}
321
375
 
322
376
  def get_reduced_darks(self, **reader_kwargs):
323
377
  # See notes in get_reduced_flats() above
324
378
  if self.raw_darks in [None, {}]:
325
379
  raise FileNotFoundError("No reduced dark ('darkend.edf' or 'dark.edf') found in %s" % self.location)
326
- readers = {k: EDFStackReader([self.raw_darks[k].file_path()], **reader_kwargs) for k in self.raw_darks.keys()}
327
- return {k: readers[k].load_data()[0] for k in self.raw_darks.keys()}
380
+ readers = {k: EDFStackReader([self.raw_darks[k].file_path()], **reader_kwargs) for k in self.raw_darks}
381
+ return {k: readers[k].load_data()[0] for k in self.raw_darks}
328
382
 
329
383
  @property
330
384
  def files(self):
@@ -333,6 +387,20 @@ class EDFDatasetAnalyzer(DatasetAnalyzer):
333
387
  def get_reader(self, **kwargs):
334
388
  return EDFStackReader(self.files, **kwargs)
335
389
 
390
+ @property
391
+ def scan_basename(self):
392
+ # os.path.basename(self.dataset_scanner.path)
393
+ return self.dataset_scanner.get_dataset_basename()
394
+
395
+ @property
396
+ def scan_dirname(self):
397
+ return self.dataset_scanner.path
398
+
399
+ def get_excluded_projections_indices(self, including_other_frames_types=True):
400
+ if not (including_other_frames_types):
401
+ raise NotImplementedError
402
+ return self.dataset_scanner.get_ignored_projection_indices()
403
+
336
404
 
337
405
  class HDF5DatasetAnalyzer(DatasetAnalyzer):
338
406
  """
@@ -341,9 +409,6 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
341
409
 
342
410
  _scanner = NXtomoScan
343
411
  kind = "nx"
344
- # We could import the 1000+ LoC nxtomo.nxobject.nxdetector.ImageKey... or we can do this
345
- _image_key_value = {"flats": 1, "darks": 2, "radios": 0}
346
- #
347
412
 
348
413
  @property
349
414
  def z_translation(self):
@@ -425,7 +490,8 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
425
490
  slices = set()
426
491
  for du in get_compacted_dataslices(images).values():
427
492
  if du.data_slice() is not None:
428
- s = (du.data_slice().start, du.data_slice().stop)
493
+ # note: du.data_slice is a uint in recent tomoscan version
494
+ s = (int(du.data_slice().start), int(du.data_slice().stop))
429
495
  else:
430
496
  s = None
431
497
  slices.add(s)
@@ -435,7 +501,7 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
435
501
  def _select_according_to_frame_type(self, data, frame_type):
436
502
  if data is None:
437
503
  return None
438
- return data[self.dataset_scanner.image_key_control == self._image_key_value[frame_type]]
504
+ return data[self.dataset_scanner.image_key_control == _image_type[frame_type]]
439
505
 
440
506
  def get_reduced_flats(self, method="median", force_reload=False, **reader_kwargs):
441
507
  dkrf_reader = NXDarksFlats(
@@ -458,25 +524,132 @@ class HDF5DatasetAnalyzer(DatasetAnalyzer):
458
524
  For example, if the dataset flats are located at indices [1, 2, ..., 99], then
459
525
  frame_slices("flats") will return [slice(0, 100)].
460
526
  """
461
- return indices_to_slices(
462
- np.where(self.dataset_scanner.image_key_control == self._image_key_value[frame_type])[0]
463
- )
527
+ return indices_to_slices(np.where(self.dataset_scanner.image_key_control == _image_type[frame_type])[0])
464
528
 
465
529
  def get_reader(self, **kwargs):
466
530
  return NXTomoReader(self.dataset_hdf5_url.file_path(), data_path=self.dataset_hdf5_url.data_path(), **kwargs)
467
531
 
532
+ @property
533
+ def scan_basename(self):
534
+ # os.path.splitext(os.path.basename(self.dataset_hdf5_url.file_path()))[0]
535
+ return self.dataset_scanner.get_dataset_basename()
468
536
 
469
- def analyze_dataset(dataset_path, extra_options=None, logger=None):
470
- if not (os.path.isdir(dataset_path)):
471
- if not (os.path.isfile(dataset_path)):
472
- raise ValueError("Error: %s no such file or directory" % dataset_path)
473
- if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))):
474
- raise ValueError("Error: expected a HDF5 file")
475
- dataset_analyzer_class = HDF5DatasetAnalyzer
476
- else: # directory -> assuming EDF
477
- dataset_analyzer_class = EDFDatasetAnalyzer
478
- dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger)
479
- return dataset_structure
537
+ @property
538
+ def scan_dirname(self):
539
+ # os.path.dirname(di.dataset_hdf5_url.file_path())
540
+ return self.dataset_scanner.path
541
+
542
+ def get_alignment_projections(self, image_sub_region=None):
543
+ """
544
+ Get the extra projections (if any) that are used as "reference projections" for alignment.
545
+ For certain scan, when completing a (half) turn, sometimes extra projections are acquired for alignment purpose.
546
+
547
+ Returns
548
+ -------
549
+ projs: numpy.ndarray
550
+ Array with shape (n_projections, n_y, n_x)
551
+ angles: numpy.ndarray
552
+ Corresponding angles in degrees
553
+ indices:
554
+ Indices of projections
555
+ """
556
+ sub_region = None
557
+ if image_sub_region is not None:
558
+ sub_region = (None,) + image_sub_region
559
+ reader = self.get_reader(image_key=ImageKey.ALIGNMENT.value, sub_region=sub_region)
560
+ projs = reader.load_data()
561
+ indices = reader.get_frames_indices()
562
+ angles = get_angle_at_index(self.all_angles, indices)
563
+ return projs, angles, indices
564
+
565
+ @property
566
+ def all_angles(self):
567
+ return np.array(self.dataset_scanner.rotation_angle)
568
+
569
+ def get_index_from_angle(self, angle, image_key=0, return_found_angle=False):
570
+ """
571
+ Return the index of the image taken at rotation angle 'angle'.
572
+ By default look at the projections, i.e image_key = 0
573
+ """
574
+ all_angles = self.all_angles
575
+ all_indices = np.arange(len(all_angles))
576
+ all_image_key = self.dataset_scanner.image_key_control
577
+
578
+ idx2 = np.where(all_image_key == image_key)[0]
579
+ angles = all_angles[idx2]
580
+ idx_angles_sorted = np.argsort(angles)
581
+ angles_sorted = angles[idx_angles_sorted]
582
+
583
+ pos = search_sorted(angles_sorted, angle)
584
+ # this gives a position in "idx2", but we need the position in "all_indices"
585
+ idx = all_indices[idx2[idx_angles_sorted[pos]]]
586
+ if return_found_angle:
587
+ return idx, angles_sorted[pos]
588
+ return idx
589
+
590
+ def get_image_at_angle(self, angle_deg, image_type="projection", sub_region=None, return_angle_and_index=False):
591
+ image_key = _image_type[image_type]
592
+ idx, angle_found = self.get_index_from_angle(angle_deg, image_key=image_key, return_found_angle=True)
593
+
594
+ # Option 1:
595
+ if sub_region is None:
596
+ sub_region = (None, None)
597
+ # Convert absolute index to index of image_key
598
+ idx2 = np.searchsorted(np.where(self.dataset_scanner.image_key_control == image_key)[0], idx)
599
+ sub_region = (slice(idx2, idx2 + 1),) + sub_region
600
+ reader = self.get_reader(image_key=image_key, sub_region=sub_region)
601
+ img = reader.load_data()[0]
602
+ if return_angle_and_index:
603
+ return img, angle_found, idx
604
+ return img
605
+
606
+ # Option 2:
607
+ # return self.get_frame(idx)
608
+ # something like:
609
+ # [fr for fr in self.dataset_scanner.frames if fr.image_key.value == 0 and fr.rotation_angle == 180 and fr._is_control_frame is False]
610
+
611
+ def get_frame(self, idx):
612
+ return get_data(self.dataset_scanner.frames[idx].url)
613
+
614
+ def get_frames_indices(self, frame_type):
615
+ return self._select_according_to_frame_type(np.arange(self.dataset_scanner.image_key_control.size), frame_type)
616
+
617
+ def index_to_proj_number(self, proj_index):
618
+ """
619
+ Return the projection *number*, from its frame *index*.
620
+
621
+ For example if there are 11 flats before projections,
622
+ then projections will have indices [11, 12, .....] (possibly not contiguous)
623
+ while their number is [0, 1, ..., ] (contiguous, starts from 0)
624
+ """
625
+ all_projs_indices = self.get_frames_indices("projection")
626
+ return search_sorted(all_projs_indices, proj_index)
627
+
628
+ def get_excluded_projections_indices(self, including_other_frames_types=True):
629
+ # Get indices of ALL projections (even excluded ones)
630
+ # the index accounts for flats/darks !
631
+ # Get indices of excluded projs (again, accounting for flats/darks)
632
+ ignored_projs_indices = self.dataset_scanner.get_ignored_projection_indices()
633
+ ignored_projs_indices = [
634
+ idx for idx in ignored_projs_indices if self.dataset_scanner.frames[idx].is_control is False
635
+ ]
636
+ if including_other_frames_types:
637
+ return ignored_projs_indices
638
+ # Get indices of excluded projs, now relative to the pure projections stack
639
+ ignored_projs_indices_rel = [
640
+ self.index_to_proj_number(ignored_proj_idx_abs) for ignored_proj_idx_abs in ignored_projs_indices
641
+ ]
642
+ return ignored_projs_indices_rel
643
+
644
+
645
+ def get_angle_at_index(all_angles, index):
646
+ """
647
+ Return the rotation angle corresponding to image index 'index'
648
+ """
649
+ if is_scalar(index):
650
+ return all_angles[index]
651
+ else:
652
+ return all_angles[np.array(index)]
480
653
 
481
654
 
482
655
  def get_radio_pair(dataset_info, radio_angles: tuple, return_indices=False):
@@ -522,3 +695,16 @@ def get_radio_pair(dataset_info, radio_angles: tuple, return_indices=False):
522
695
  return radios, radios_indices
523
696
  else:
524
697
  return radios
698
+
699
+
700
+ def analyze_dataset(dataset_path, extra_options=None, logger=None):
701
+ if not (os.path.isdir(dataset_path)):
702
+ if not (os.path.isfile(dataset_path)):
703
+ raise ValueError("Error: %s no such file or directory" % dataset_path)
704
+ if not (is_hdf5_extension(os.path.splitext(dataset_path)[-1].replace(".", ""))):
705
+ raise ValueError("Error: expected a HDF5 file")
706
+ dataset_analyzer_class = HDF5DatasetAnalyzer
707
+ else: # directory -> assuming EDF
708
+ dataset_analyzer_class = EDFDatasetAnalyzer
709
+ dataset_structure = dataset_analyzer_class(dataset_path, extra_options=extra_options, logger=logger)
710
+ return dataset_structure
nabu/resources/gpu.py CHANGED
@@ -131,7 +131,7 @@ def pick_gpus_auto(cuda_gpus, opencl_platforms, n_gpus):
131
131
  return (gpu1["device_id"] == gpu2["device_id"]) and (gpu1["name"] == gpu2["name"])
132
132
 
133
133
  def is_in_gpus(avail_gpus, query_gpu):
134
- for gpu in avail_gpus:
134
+ for gpu in avail_gpus: # noqa: SIM110
135
135
  if gpu_equal(gpu, query_gpu):
136
136
  return True
137
137
  return False
@@ -142,8 +142,8 @@ def pick_gpus_auto(cuda_gpus, opencl_platforms, n_gpus):
142
142
  chosen_gpus = list(cuda_gpus.values())
143
143
  if len(chosen_gpus) >= n_gpus:
144
144
  return chosen_gpus
145
- for platform, gpus in opencl_platforms.items():
146
- for gpu_id, gpu in gpus.items():
145
+ for platform, gpus in opencl_platforms.items(): # noqa: PERF102
146
+ for gpu_id, gpu in gpus.items(): # noqa: PERF102
147
147
  if not (is_in_gpus(chosen_gpus, gpu)):
148
148
  # TODO prioritize some OpenCL implementations ?
149
149
  chosen_gpus.append(gpu)
@@ -166,5 +166,5 @@ def pick_gpus_nvidia(cuda_gpus, n_gpus):
166
166
  gpus_cc_sorted = sorted(gpus_cc, key=lambda x: x[1], reverse=True)
167
167
  res = []
168
168
  for i in range(n_gpus):
169
- res.append(cuda_gpus[gpus_cc_sorted[i][0]])
169
+ res.append(cuda_gpus[gpus_cc_sorted[i][0]]) # noqa: PERF401
170
170
  return res
nabu/resources/logger.py CHANGED
@@ -2,7 +2,7 @@ import logging
2
2
  import logging.config
3
3
 
4
4
 
5
- class Logger(object):
5
+ class Logger:
6
6
  def __init__(self, loggername, level="DEBUG", logfile="logger.log", console=True):
7
7
  """
8
8
  Configure a Logger object.
@@ -30,7 +30,7 @@ class Logger(object):
30
30
 
31
31
  def _configure_logger(self):
32
32
  conf = self._get_default_config_dict()
33
- for handler in conf["handlers"].keys():
33
+ for handler in conf["handlers"]:
34
34
  conf["handlers"][handler]["level"] = self.level.upper()
35
35
  conf["loggers"][self.loggername]["level"] = self.level.upper()
36
36
  if not (self.console):
@@ -103,7 +103,7 @@ def LoggerOrPrint(logger):
103
103
  return logger
104
104
 
105
105
 
106
- class PrinterLogger(object):
106
+ class PrinterLogger:
107
107
  def __init__(self):
108
108
  methods = [
109
109
  "debug",
@@ -122,7 +122,7 @@ LogLevel = {
122
122
  "notset": logging.NOTSET,
123
123
  "debug": logging.DEBUG,
124
124
  "info": logging.INFO,
125
- "warn": logging.WARN,
125
+ "warn": logging.WARNING,
126
126
  "warning": logging.WARNING,
127
127
  "error": logging.ERROR,
128
128
  "critical": logging.CRITICAL,
@@ -6,6 +6,8 @@ from silx.io import get_data
6
6
  from tomoscan.framereducer.reducedframesinfos import ReducedFramesInfos
7
7
  from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
8
8
  from ..utils import check_supported, is_writeable
9
+ from ..preproc.flatfield import PCAFlatsDecomposer
10
+ from ..io.reader import NXDarksFlats
9
11
 
10
12
 
11
13
  def get_frame_possible_urls(dataset_info, user_dir, output_dir):
@@ -22,7 +24,7 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir):
22
24
  Output processing directory
23
25
  """
24
26
 
25
- frame_types = ["flats", "darks"]
27
+ frame_types = ["flats", "darks", "pcaflats"]
26
28
  h5scan = dataset_info.dataset_scanner # tomoscan object
27
29
 
28
30
  def make_dataurl(dirname, frame_type):
@@ -32,8 +34,10 @@ def get_frame_possible_urls(dataset_info, user_dir, output_dir):
32
34
 
33
35
  if frame_type == "flats":
34
36
  dataurl_default_template = h5scan.REDUCED_FLATS_DATAURLS[0]
35
- else:
37
+ elif frame_type == "darks":
36
38
  dataurl_default_template = h5scan.REDUCED_DARKS_DATAURLS[0]
39
+ elif frame_type == "pcaflats":
40
+ dataurl_default_template = h5scan.PCA_FLATS_DATAURLS[0]
37
41
 
38
42
  rel_file_path = dataurl_default_template.file_path().format(scan_prefix=h5scan.get_dataset_basename())
39
43
  return DataUrl(
@@ -67,7 +71,7 @@ def save_reduced_frames(dataset_info, reduced_frames_arrays, reduced_frames_urls
67
71
  darks_flats_dir_url = reduced_frames_urls.get("user", None)
68
72
  if darks_flats_dir_url is not None:
69
73
  output_url = darks_flats_dir_url
70
- elif is_writeable(os.path.dirname(reduced_frames_urls["dataset"]["flats"].file_path())):
74
+ elif is_writeable(os.path.abspath(os.path.dirname(reduced_frames_urls["dataset"]["flats"].file_path()))):
71
75
  output_url = reduced_frames_urls["dataset"]
72
76
  else:
73
77
  output_url = reduced_frames_urls["output"]
@@ -149,8 +153,94 @@ def data_url_exists(data_url):
149
153
  return group_exists
150
154
 
151
155
 
156
+ def _compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls):
157
+ if flatfield_mode == "pca":
158
+ dfreader = NXDarksFlats(dataset_info.location)
159
+ darks = np.concatenate([d for d in dfreader.get_raw_darks()], axis=0)
160
+ flats = np.concatenate([f for f in dfreader.get_raw_flats()], axis=0)
161
+ pcaflats_darks = PCAFlatsDecomposer(flats, darks)
162
+
163
+ # Get "where to write". tomoscan expects a DataUrl
164
+ pcaflats_dir_url = reduced_frames_urls.get("user", None)
165
+ if pcaflats_dir_url is not None:
166
+ output_url = pcaflats_dir_url
167
+ elif is_writeable(os.path.dirname(reduced_frames_urls["dataset"]["flats"].file_path())):
168
+ output_url = reduced_frames_urls["dataset"]
169
+ else:
170
+ output_url = reduced_frames_urls["output"]
171
+ pcaflats_darks.save_decomposition(
172
+ path=output_url["pcaflats"].file_path(), entry=output_url["pcaflats"].data_path().strip("/").split("/")[0]
173
+ )
174
+ dataset_info.logger.info("PCA flats computed and written at %s" % (output_url["pcaflats"].file_path()))
175
+
176
+ # Update dataset_info with pca flats and dark
177
+ dataset_info.darks = {0: pcaflats_darks.dark}
178
+ flats = {0: pcaflats_darks.mean}
179
+ for k in range(len(pcaflats_darks.components)):
180
+ flats.update({k + 1: pcaflats_darks.components[k]})
181
+ dataset_info.flats = flats
182
+ else:
183
+ try:
184
+ dataset_info.flats = dataset_info.get_reduced_flats()
185
+ dataset_info.darks = dataset_info.get_reduced_darks()
186
+ except FileNotFoundError:
187
+ msg = "Could not find any flats and/or darks"
188
+ raise FileNotFoundError(msg)
189
+ _, flats_info, darks_info = save_reduced_frames(
190
+ dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls
191
+ )
192
+ dataset_info.flats_srcurrent = flats_info.machine_electric_current
193
+
194
+
195
+ def _load_existing_flatfields(dataset_info, reduced_frames_urls, frames_types, where_to_load_from):
196
+ if "pcaflats" not in frames_types:
197
+ reduced_frames_with_info = {}
198
+ for frame_type in frames_types:
199
+ reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames(
200
+ dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type]
201
+ )
202
+ dataset_info.logger.info(
203
+ "Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path())
204
+ )
205
+ red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type]
206
+ setattr(
207
+ dataset_info,
208
+ frame_type,
209
+ {k: get_data(red_frames_dict[k]) for k in red_frames_dict},
210
+ )
211
+ if frame_type == "flats":
212
+ dataset_info.flats_srcurrent = red_frames_info.machine_electric_current
213
+ else:
214
+ df_path = reduced_frames_urls[where_to_load_from]["pcaflats"].file_path()
215
+ entry = reduced_frames_urls[where_to_load_from]["pcaflats"].data_path()
216
+
217
+ # Update dark
218
+ dark_url = DataUrl(f"silx://{df_path}?{entry}/dark")
219
+ dark = get_data(dark_url)
220
+ setattr(
221
+ dataset_info,
222
+ "dark",
223
+ {0: dark},
224
+ )
225
+ # Update flats with principal compenents
226
+ # Take mean as first comp., mask as second, flats thereafter
227
+ flats_url = DataUrl(f"silx://{df_path}?{entry}/p_components")
228
+ mean_url = DataUrl(f"silx://{df_path}?{entry}/p_mean")
229
+ flats = get_data(flats_url)
230
+ mean = get_data(mean_url)
231
+ flats = np.concatenate([mean[np.newaxis], flats], axis=0)
232
+ setattr(
233
+ dataset_info,
234
+ "flats",
235
+ {k: flats[k] for k in range(len(flats))},
236
+ )
237
+ dataset_info.logger.info("Loaded %s from %s" % ("PCA darks/flats", df_path))
238
+
239
+
152
240
  # pylint: disable=E1136
153
- def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=None, darks_flats_dir=None):
241
+ def update_dataset_info_flats_darks(
242
+ dataset_info, flatfield_mode, loading_mode="load_if_present", output_dir=None, darks_flats_dir=None
243
+ ):
154
244
  """
155
245
  Update a DatasetAnalyzer object with reduced flats/darks (hereafter "reduced frames").
156
246
 
@@ -170,23 +260,14 @@ def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=Non
170
260
  if flatfield_mode is False:
171
261
  return
172
262
 
173
- frames_types = ["darks", "flats"]
263
+ if flatfield_mode == "pca":
264
+ frames_types = ["pcaflats"]
265
+ else:
266
+ frames_types = ["darks", "flats"]
174
267
  reduced_frames_urls = get_frame_possible_urls(dataset_info, darks_flats_dir, output_dir)
175
268
 
176
- def _compute_and_save_reduced_frames():
177
- try:
178
- dataset_info.flats = dataset_info.get_reduced_flats()
179
- dataset_info.darks = dataset_info.get_reduced_darks()
180
- except FileNotFoundError:
181
- msg = "Could not find any flats and/or darks"
182
- raise FileNotFoundError(msg)
183
- _, flats_info, darks_info = save_reduced_frames(
184
- dataset_info, {"darks": dataset_info.darks, "flats": dataset_info.flats}, reduced_frames_urls
185
- )
186
- dataset_info.flats_srcurrent = flats_info.machine_electric_current
187
-
188
- if flatfield_mode == "force-compute":
189
- _compute_and_save_reduced_frames()
269
+ if loading_mode == "force-compute":
270
+ _compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls)
190
271
  return
191
272
 
192
273
  def _can_load_from(folder_type):
@@ -202,25 +283,10 @@ def update_dataset_info_flats_darks(dataset_info, flatfield_mode, output_dir=Non
202
283
  elif _can_load_from("output"):
203
284
  where_to_load_from = "output"
204
285
 
205
- if where_to_load_from == None and flatfield_mode == "force-load":
286
+ if where_to_load_from is None and flatfield_mode == "force-load":
206
287
  raise ValueError("Could not load darks/flats (using 'force-load')")
207
288
 
208
289
  if where_to_load_from is not None:
209
- reduced_frames_with_info = {}
210
- for frame_type in frames_types:
211
- reduced_frames_with_info[frame_type] = tomoscan_load_reduced_frames(
212
- dataset_info, frame_type, reduced_frames_urls[where_to_load_from][frame_type]
213
- )
214
- dataset_info.logger.info(
215
- "Loaded %s from %s" % (frame_type, reduced_frames_urls[where_to_load_from][frame_type].file_path())
216
- )
217
- red_frames_dict, red_frames_info = reduced_frames_with_info[frame_type]
218
- setattr(
219
- dataset_info,
220
- frame_type,
221
- {k: get_data(red_frames_dict[k]) for k in red_frames_dict.keys()},
222
- )
223
- if frame_type == "flats":
224
- dataset_info.flats_srcurrent = red_frames_info.machine_electric_current
290
+ _load_existing_flatfields(dataset_info, reduced_frames_urls, frames_types, where_to_load_from)
225
291
  else:
226
- _compute_and_save_reduced_frames()
292
+ _compute_and_save_reduced_frames(flatfield_mode, dataset_info, reduced_frames_urls)
@@ -0,0 +1,37 @@
1
+ import pytest
2
+ import numpy as np
3
+ from nabu.testutils import get_dummy_nxtomo_info
4
+ from nabu.resources.dataset_analyzer import analyze_dataset
5
+
6
+
7
+ @pytest.fixture(scope="class")
8
+ def bootstrap_nx(request):
9
+ cls = request.cls
10
+ cls.nx_fname, cls.data_desc, cls.image_key, cls.projs_vals, cls.darks_vals, cls.flats1_vals, cls.flats2_vals = (
11
+ get_dummy_nxtomo_info()
12
+ )
13
+
14
+
15
+ @pytest.mark.usefixtures("bootstrap_nx")
16
+ class TestNXDataset:
17
+
18
+ def test_exclude_projs_angular_range(self):
19
+ dataset_info_with_all_projs = analyze_dataset(self.nx_fname)
20
+
21
+ # Test exclude angular range - angles min and max in degrees
22
+ angular_ranges_to_test = [(0, 15), (5, 6), (50, 58.5)]
23
+ for angular_range in angular_ranges_to_test:
24
+ angle_min, angle_max = angular_range
25
+ dataset_info = analyze_dataset(
26
+ self.nx_fname,
27
+ extra_options={"exclude_projections": {"type": "angular_range", "range": [angle_min, angle_max]}},
28
+ )
29
+ excluded_projs_indices = dataset_info.get_excluded_projections_indices()
30
+ # Check that get_excluded_projections_indices() angles are correct
31
+ for excluded_proj_index in excluded_projs_indices:
32
+ frame_angle_deg = dataset_info.dataset_scanner.frames[excluded_proj_index].rotation_angle
33
+ assert angle_min <= frame_angle_deg and frame_angle_deg <= angle_max
34
+
35
+ assert set(dataset_info_with_all_projs.projections.keys()) - set(dataset_info.projections.keys()) == set(
36
+ excluded_projs_indices
37
+ )
@@ -0,0 +1,11 @@
1
+ from nabu.utils import list_match_queries
2
+
3
+
4
+ def test_list_match_queries():
5
+
6
+ # entry0000 .... entry0099
7
+ avail = ["entry%04d" % i for i in range(100)]
8
+ assert list_match_queries(avail, "entry0000") == ["entry0000"]
9
+ assert list_match_queries(avail, ["entry0001"]) == ["entry0001"]
10
+ assert list_match_queries(avail, ["entry000?"]) == ["entry%04d" % i for i in range(10)]
11
+ assert list_match_queries(avail, ["entry*"]) == avail
@@ -80,11 +80,11 @@ class TestNXFlatField:
80
80
  output_dir = self.params.get("output_dir", None)
81
81
  if output_dir is not None:
82
82
  output_dir = output_dir.format(tempdir=self.tempdir)
83
- update_dataset_info_flats_darks(dataset_info, True, output_dir=output_dir)
83
+ update_dataset_info_flats_darks(dataset_info, True, loading_mode="load_if_present", output_dir=output_dir)
84
84
  # After reduction (median/mean), the flats/darks are located in another file.
85
85
  # median(series_1) goes to entry/flats/idx1, mean(series_2) goes to entry/flats/idx2, etc.
86
- assert set(dataset_info.flats.keys()) == set(s.start for s in self.params["flats_pos"])
87
- assert set(dataset_info.darks.keys()) == set(s.start for s in self.params["darks_pos"])
86
+ assert set(dataset_info.flats.keys()) == set(s.start for s in self.params["flats_pos"]) # noqa: C401
87
+ assert set(dataset_info.darks.keys()) == set(s.start for s in self.params["darks_pos"]) # noqa: C401
88
88
 
89
89
  # Check that the computations were correct
90
90
  # Loads the entire volume in memory ! So keep the data volume small for the tests
@@ -97,8 +97,8 @@ class TestNXFlatField:
97
97
  expected_darks[s.start] = self._reduction_func["darks"](data_volume[s.start : s.stop], axis=0)
98
98
 
99
99
  flats = dataset_info.flats
100
- for idx in flats.keys():
100
+ for idx in flats:
101
101
  assert np.allclose(flats[idx], expected_flats[idx])
102
102
  darks = dataset_info.darks
103
- for idx in darks.keys():
103
+ for idx in darks:
104
104
  assert np.allclose(darks[idx], expected_darks[idx])