nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/cast_volume.py +12 -1
  3. nabu/app/cli_configs.py +80 -3
  4. nabu/app/estimate_motion.py +54 -0
  5. nabu/app/multicor.py +2 -4
  6. nabu/app/pcaflats.py +116 -0
  7. nabu/app/reconstruct.py +1 -7
  8. nabu/app/reduce_dark_flat.py +5 -2
  9. nabu/estimation/cor.py +1 -1
  10. nabu/estimation/motion.py +557 -0
  11. nabu/estimation/tests/test_motion_estimation.py +471 -0
  12. nabu/estimation/tilt.py +1 -1
  13. nabu/estimation/translation.py +47 -1
  14. nabu/io/cast_volume.py +94 -13
  15. nabu/io/reader.py +32 -1
  16. nabu/io/tests/test_remove_volume.py +152 -0
  17. nabu/pipeline/config_validators.py +42 -43
  18. nabu/pipeline/estimators.py +255 -0
  19. nabu/pipeline/fullfield/chunked.py +67 -43
  20. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  21. nabu/pipeline/fullfield/nabu_config.py +17 -11
  22. nabu/pipeline/fullfield/processconfig.py +8 -2
  23. nabu/pipeline/fullfield/reconstruction.py +3 -0
  24. nabu/pipeline/params.py +12 -0
  25. nabu/pipeline/tests/test_estimators.py +240 -3
  26. nabu/preproc/ccd.py +53 -3
  27. nabu/preproc/flatfield.py +306 -1
  28. nabu/preproc/shift.py +3 -1
  29. nabu/preproc/tests/test_pcaflats.py +154 -0
  30. nabu/processing/rotation_cuda.py +3 -1
  31. nabu/processing/tests/test_rotation.py +4 -2
  32. nabu/reconstruction/fbp.py +7 -0
  33. nabu/reconstruction/fbp_base.py +31 -7
  34. nabu/reconstruction/fbp_opencl.py +8 -0
  35. nabu/reconstruction/filtering_opencl.py +2 -0
  36. nabu/reconstruction/mlem.py +47 -13
  37. nabu/reconstruction/tests/test_filtering.py +13 -2
  38. nabu/reconstruction/tests/test_mlem.py +91 -62
  39. nabu/resources/dataset_analyzer.py +144 -20
  40. nabu/resources/nxflatfield.py +101 -35
  41. nabu/resources/tests/test_nxflatfield.py +1 -1
  42. nabu/resources/utils.py +16 -10
  43. nabu/stitching/alignment.py +7 -7
  44. nabu/stitching/config.py +22 -20
  45. nabu/stitching/definitions.py +2 -2
  46. nabu/stitching/overlap.py +4 -4
  47. nabu/stitching/sample_normalization.py +5 -5
  48. nabu/stitching/stitcher/post_processing.py +5 -3
  49. nabu/stitching/stitcher/pre_processing.py +24 -20
  50. nabu/stitching/tests/test_config.py +3 -3
  51. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  52. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  53. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  54. nabu/stitching/utils/utils.py +7 -7
  55. nabu/testutils.py +1 -4
  56. nabu/utils.py +13 -0
  57. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/METADATA +3 -4
  58. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/RECORD +62 -57
  59. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/WHEEL +1 -1
  60. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/entry_points.txt +2 -1
  61. nabu/app/correct_rot.py +0 -62
  62. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/licenses/LICENSE +0 -0
  63. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ from ...resources.utils import extract_parameters
10
10
  from ...misc.binning import binning as image_binning
11
11
  from ...io.reader import EDFStackReader, HDF5Loader, NXTomoReader
12
12
  from ...preproc.ccd import Log, CCDFilter
13
- from ...preproc.flatfield import FlatField
13
+ from ...preproc.flatfield import FlatField, PCAFlatsNormalizer
14
14
  from ...preproc.distortion import DistortionCorrection
15
15
  from ...preproc.shift import VerticalShift
16
16
  from ...preproc.double_flatfield import DoubleFlatField
@@ -45,6 +45,7 @@ class ChunkedPipeline:
45
45
 
46
46
  backend = "numpy"
47
47
  FlatFieldClass = FlatField
48
+ PCAFlatFieldClass = PCAFlatsNormalizer
48
49
  DoubleFlatFieldClass = DoubleFlatField
49
50
  CCDCorrectionClass = CCDFilter
50
51
  PaganinPhaseRetrievalClass = PaganinPhaseRetrieval
@@ -393,50 +394,68 @@ class ChunkedPipeline:
393
394
 
394
395
  @use_options("flatfield", "flatfield")
395
396
  def _init_flatfield(self):
396
- self._ff_options = self.processing_options["flatfield"].copy()
397
-
398
- # This won't work when resuming from a step (i.e before FF), because we rely on H5Loader()
399
- # which re-compacts the data. When data is re-compacted, we have to know the original radios positions.
400
- # These positions can be saved in the "file_dump" metadata, but it is not loaded for now
401
- # (the process_config object is re-built from scratch every time)
402
- self._ff_options["projs_indices"] = self.chunk_reader.get_frames_indices()
403
-
404
- if self._ff_options.get("normalize_srcurrent", False):
405
- a_start_idx, a_end_idx = self.sub_region[0]
406
- subs = self.process_config.subsampling_factor
407
- self._ff_options["radios_srcurrent"] = self._ff_options["radios_srcurrent"][a_start_idx:a_end_idx:subs]
408
-
409
- distortion_correction = None
410
- if self._ff_options["do_flat_distortion"]:
411
- self.logger.info("Flats distortion correction will be applied")
412
- self.FlatFieldClass = FlatField # no GPU implementation available, force this backend
413
- estimation_kwargs = {}
414
- estimation_kwargs.update(self._ff_options["flat_distortion_params"])
415
- estimation_kwargs["logger"] = self.logger
416
- distortion_correction = DistortionCorrection(
417
- estimation_method="fft-correlation", estimation_kwargs=estimation_kwargs, correction_method="interpn"
418
- )
397
+ if self.processing_options["flatfield"]:
398
+ self._ff_options = self.processing_options["flatfield"].copy()
399
+
400
+ # This won't work when resuming from a step (i.e before FF), because we rely on H5Loader()
401
+ # which re-compacts the data. When data is re-compacted, we have to know the original radios positions.
402
+ # These positions can be saved in the "file_dump" metadata, but it is not loaded for now
403
+ # (the process_config object is re-built from scratch every time)
404
+ self._ff_options["projs_indices"] = self.chunk_reader.get_frames_indices()
405
+
406
+ if self._ff_options.get("normalize_srcurrent", False):
407
+ a_start_idx, a_end_idx = self.sub_region[0]
408
+ subs = self.process_config.subsampling_factor
409
+ self._ff_options["radios_srcurrent"] = self._ff_options["radios_srcurrent"][a_start_idx:a_end_idx:subs]
410
+
411
+ distortion_correction = None
412
+ if self._ff_options["do_flat_distortion"]:
413
+ self.logger.info("Flats distortion correction will be applied")
414
+ self.FlatFieldClass = FlatField # no GPU implementation available, force this backend
415
+ estimation_kwargs = {}
416
+ estimation_kwargs.update(self._ff_options["flat_distortion_params"])
417
+ estimation_kwargs["logger"] = self.logger
418
+ distortion_correction = DistortionCorrection(
419
+ estimation_method="fft-correlation",
420
+ estimation_kwargs=estimation_kwargs,
421
+ correction_method="interpn",
422
+ )
419
423
 
420
- # Reduced darks/flats are loaded, but we have to crop them on the current sub-region
421
- # and possibly do apply some pre-processing (binning, distortion correction, ...)
422
- darks_flats = load_darks_flats(
423
- self.dataset_info,
424
- self.sub_region[1:],
425
- processing_func=self._ff_processing_function,
426
- processing_func_args=self._ff_processing_function_args,
427
- )
424
+ if self.processing_options["flatfield"]["method"].lower() != "pca":
425
+ # Reduced darks/flats are loaded, but we have to crop them on the current sub-region
426
+ # and possibly do apply some pre-processing (binning, distortion correction, ...)
427
+ darks_flats = load_darks_flats(
428
+ self.dataset_info,
429
+ self.sub_region[1:],
430
+ processing_func=self._ff_processing_function,
431
+ processing_func_args=self._ff_processing_function_args,
432
+ )
428
433
 
429
- # FlatField parameter "radios_indices" must account for subsampling
430
- self.flatfield = self.FlatFieldClass(
431
- self.radios_shape,
432
- flats=darks_flats["flats"],
433
- darks=darks_flats["darks"],
434
- radios_indices=self._ff_options["projs_indices"],
435
- interpolation="linear",
436
- distortion_correction=distortion_correction,
437
- radios_srcurrent=self._ff_options["radios_srcurrent"],
438
- flats_srcurrent=self._ff_options["flats_srcurrent"],
439
- )
434
+ # FlatField parameter "radios_indices" must account for subsampling
435
+ self.flatfield = self.FlatFieldClass(
436
+ self.radios_shape,
437
+ flats=darks_flats["flats"],
438
+ darks=darks_flats["darks"],
439
+ radios_indices=self._ff_options["projs_indices"],
440
+ interpolation="linear",
441
+ distortion_correction=distortion_correction,
442
+ radios_srcurrent=self._ff_options["radios_srcurrent"],
443
+ flats_srcurrent=self._ff_options["flats_srcurrent"],
444
+ )
445
+ else:
446
+ flats = self.process_config.dataset_info.flats
447
+ darks = self.process_config.dataset_info.darks
448
+ if len(darks) != 1:
449
+ raise ValueError(f"There should be only one reduced dark. Found {len(darks)}.")
450
+ else:
451
+ dark_key = list(darks.keys())[0]
452
+ nb_pca_components = len(flats) - 1
453
+ img_subregion = tuple(slice(*sr) for sr in self.sub_region[1:])
454
+ self.flatfield = self.PCAFlatFieldClass(
455
+ np.array([flats[k][img_subregion] for k in range(1, nb_pca_components)]),
456
+ darks[dark_key][img_subregion],
457
+ flats[0][img_subregion], # Mean
458
+ )
440
459
 
441
460
  @use_options("double_flatfield", "double_flatfield")
442
461
  def _init_double_flatfield(self):
@@ -647,6 +666,11 @@ class ChunkedPipeline:
647
666
  "v_max_for_v_shifts": None,
648
667
  "v_min_for_u_shifts": 0,
649
668
  "v_max_for_u_shifts": None,
669
+ "scale_factor": 1.0 / options["voxel_size_cm"][0],
670
+ "clip_outer_circle": options["clip_outer_circle"],
671
+ "outer_circle_value": options["outer_circle_value"],
672
+ "filter_cutoff": options["fbp_filter_cutoff"],
673
+ "crop_filtered_data": options["crop_filtered_data"],
650
674
  },
651
675
  )
652
676
 
@@ -75,8 +75,11 @@ class CudaChunkedPipeline(ChunkedPipeline):
75
75
  # Decide when to transfer data to GPU. Normally it's right after reading the data,
76
76
  # But sometimes a part of the processing is done on CPU.
77
77
  self._when_to_transfer_radios_on_gpu = "read_data"
78
- if self.flatfield is not None and self.flatfield.distortion_correction is not None:
79
- self._when_to_transfer_radios_on_gpu = "flatfield"
78
+ if self.flatfield is not None:
79
+ use_flats_distortion = getattr(self.flatfield, "distortion_correction", None) is not None
80
+ use_pca_flats = self.processing_options["flatfield"]["method"].lower() == "pca"
81
+ if use_flats_distortion or use_pca_flats:
82
+ self._when_to_transfer_radios_on_gpu = "flatfield"
80
83
 
81
84
  def _init_cuda(self, cuda_options):
82
85
  if not (__has_pycuda__):
@@ -23,7 +23,7 @@ nabu_config = {
23
23
  },
24
24
  "darks_flats_dir": {
25
25
  "default": "",
26
- "help": "Path to a directory where XXX_flats.h5 and XXX_darks.h5 are to be found, where 'XXX' denotes the dataset basename. If these files are found, then reduced flats/darks will be loaded from them. Otherwise, reduced flats/darks will be saved to there once computed, either in the .nx directory, or in the output directory. Mind that the HDF5 entry corresponds to the one of the dataset.",
26
+ "help": "Path to a directory where XXX_flats.h5 and XXX_darks.h5 are to be found, where 'XXX' denotes the dataset basename. If these files are found, then reduced flats/darks will be loaded from them. Otherwise, reduced flats/darks will be saved there once computed, either in the .nx directory, or in the output directory. Mind that the HDF5 entry corresponds to the one of the dataset.",
27
27
  "validator": optional_directory_location_validator,
28
28
  "type": "optional",
29
29
  },
@@ -41,7 +41,7 @@ nabu_config = {
41
41
  },
42
42
  "projections_subsampling": {
43
43
  "default": "1",
44
- "help": "Projections subsampling factor: take one projection out of 'projection_subsampling'. The format can be an integer (take 1 projection out of N), or N:M (take 1 projection out of N, start with the projection number M)\nFor example: 2 (or 2:0) to reconstruct from even projections, 2:1 to reconstruct from odd projections.",
44
+ "help": "Projections subsampling factor: take one projection out of 'projections_subsampling'. The format can be an integer (take 1 projection out of N), or N:M (take 1 projection out of N, start with the projection number M)\nFor example: 2 (or 2:0) to reconstruct from even projections, 2:1 to reconstruct from odd projections.",
45
45
  "validator": projections_subsampling_validator,
46
46
  "type": "advanced",
47
47
  },
@@ -61,13 +61,19 @@ nabu_config = {
61
61
  "preproc": {
62
62
  "flatfield": {
63
63
  "default": "1",
64
- "help": "How to perform flat-field normalization. The parameter value can be:\n - 1 or True: enabled.\n - 0 or False: disabled\n - forced or force-load: perform flatfield regardless of the dataset by attempting to load darks/flats\n - force-compute: perform flatfield, ignore all .h5 files containing already computed darks/flats.",
65
- "validator": flatfield_enabled_validator,
64
+ "help": "How to perform flat-field normalization. The parameter value can be:\n - 1 or True: enabled.\n - 0 or False: disabled\n - pca: perform a normalization via Principal Component Analysis decomposition PCA-flat-field normalization",
65
+ "validator": flatfield_validator,
66
66
  "type": "required",
67
67
  },
68
+ "flatfield_loading_mode": {
69
+ "default": "load_if_present",
70
+ "help": "How to load/compute flat-field. This parameter can be:\n - load_if_present (default) or empty string: Use the existing flatfield files, if existing.\n - force-load: perform flatfield regardless of the dataset by attempting to load darks/flats\n - force-compute: perform flatfield, ignore all .h5 files containing already computed darks/flats.",
71
+ "validator": flatfield_loading_mode_validator,
72
+ "type": "optional",
73
+ },
68
74
  "flat_distortion_correction_enabled": {
69
75
  "default": "0",
70
- "help": "Whether to correct for flat distortion. If activated, each radio is correlated with its corresponding flat, in order to determine and correct the flat distortion.",
76
+ "help": "Whether to correct for flat distortion. If activated, each radiograph is correlated with its corresponding flat, in order to determine and correct the flat distortion.",
71
77
  "validator": boolean_validator,
72
78
  "type": "advanced",
73
79
  },
@@ -113,7 +119,7 @@ nabu_config = {
113
119
  "double_flatfield": {
114
120
  "default": "0",
115
121
  "help": "Whether to perform 'double flat-field' filtering (this can help to remove rings artefacts). Possible values:\n - 1 or True: enabled.\n - 0 or False: disabled\n - force-load: use an existing DFF file regardless of the dataset\n - force-compute: re-compute the DFF, ignore all existing .h5 files containing already computed DFF",
116
- "validator": flatfield_enabled_validator,
122
+ "validator": flatfield_validator,
117
123
  "type": "optional",
118
124
  },
119
125
  "dff_sigma": {
@@ -172,7 +178,7 @@ nabu_config = {
172
178
  },
173
179
  "rotate_projections_center": {
174
180
  "default": "",
175
- "help": "Center of rotation when 'tilt_correction' is non-empty. By default the center of rotation is the middle of each radio, i.e ((Nx-1)/2.0, (Ny-1)/2.0).",
181
+ "help": "Center of rotation when 'tilt_correction' is non-empty. By default the center of rotation is the middle of each radiograph, i.e ((Nx-1)/2.0, (Ny-1)/2.0).",
176
182
  "validator": optional_tuple_of_floats_validator,
177
183
  "type": "advanced",
178
184
  },
@@ -272,7 +278,7 @@ nabu_config = {
272
278
  },
273
279
  "cor_slice": {
274
280
  "default": "",
275
- "help": "Which slice to use for estimating the Center of Rotation (CoR). This parameter can be an integer or 'top', 'middle', 'bottom'.\nIf provided, the CoR will be estimated from the correspondig sinogram, and 'cor_options' can contain the parameter 'subsampling'.",
281
+ "help": "Which slice to use for estimating the Center of Rotation (CoR). This parameter can be an integer or 'top', 'middle', 'bottom'.\nIf provided, the CoR will be estimated from the corresponding sinogram, and 'cor_options' can contain the parameter 'subsampling'.",
276
282
  "validator": cor_slice_validator,
277
283
  "type": "advanced",
278
284
  },
@@ -479,7 +485,7 @@ nabu_config = {
479
485
  },
480
486
  "postproc": {
481
487
  "output_histogram": {
482
- "default": "0",
488
+ "default": "1",
483
489
  "help": "Whether to compute a histogram of the volume.",
484
490
  "validator": boolean_validator,
485
491
  "type": "optional",
@@ -544,7 +550,7 @@ nabu_config = {
544
550
  "pipeline": {
545
551
  "save_steps": {
546
552
  "default": "",
547
- "help": "Save intermediate results. This is a list of comma-separated processing steps, for ex: flatfield, phase, sinogram.\nEach step generates a HDF5 file in the form name_file_prefix.hdf5 (ex. 'sinogram_file_prefix.hdf5')",
553
+ "help": "Save intermediate results. This is a list of comma-separated processing steps, for ex: flatfield, phase, sinogram.\nEach step generates a HDF5 file in the form name_file_prefix.hdf5 (e.g. 'sinogram_file_prefix.hdf5')",
548
554
  "validator": optional_string_validator,
549
555
  "type": "optional",
550
556
  },
@@ -556,7 +562,7 @@ nabu_config = {
556
562
  },
557
563
  "steps_file": {
558
564
  "default": "",
559
- "help": "File where the intermediate processing steps are written. By default it is empty, and intermediate processing steps are written in the same directory as the reconstructions, with a file prefix, ex. sinogram_mydataset.hdf5.",
565
+ "help": "File where the intermediate processing steps are written. By default it is empty, and intermediate processing steps are written in the same directory as the reconstructions, with a file prefix, e.g. sinogram_mydataset.hdf5.",
560
566
  "validator": optional_output_file_path_validator,
561
567
  "type": "advanced",
562
568
  },
@@ -75,16 +75,19 @@ class ProcessConfig(ProcessConfigBase):
75
75
  Update the 'dataset_info' (DatasetAnalyzer class instance) data structure with options from user configuration.
76
76
  """
77
77
  self.logger.debug("Updating dataset information with user configuration")
78
- if self.dataset_info.kind == "nx":
78
+ if self.dataset_info.kind == "nx" and self.nabu_config["preproc"]["flatfield"]:
79
79
  update_dataset_info_flats_darks(
80
80
  self.dataset_info,
81
81
  self.nabu_config["preproc"]["flatfield"],
82
+ loading_mode=self.nabu_config["preproc"]["flatfield_loading_mode"],
82
83
  output_dir=self.nabu_config["output"]["location"],
83
84
  darks_flats_dir=self.nabu_config["dataset"]["darks_flats_dir"],
84
85
  )
85
86
  elif self.dataset_info.kind == "edf":
86
87
  self.dataset_info.flats = self.dataset_info.get_reduced_flats()
87
88
  self.dataset_info.darks = self.dataset_info.get_reduced_darks()
89
+ else:
90
+ raise TypeError("Unknown dataset format")
88
91
  self.rec_params = self.nabu_config["reconstruction"]
89
92
 
90
93
  subsampling_factor, subsampling_start = self.nabu_config["dataset"]["projections_subsampling"]
@@ -425,8 +428,10 @@ class ProcessConfig(ProcessConfigBase):
425
428
  # Flat-field
426
429
  #
427
430
  if nabu_config["preproc"]["flatfield"]:
431
+ ff_method = "pca" if nabu_config["preproc"]["flatfield"] == "pca" else "default"
428
432
  tasks.append("flatfield")
429
433
  options["flatfield"] = {
434
+ "method": ff_method,
430
435
  # Data reader handles binning/subsampling by itself,
431
436
  # but FlatField needs "real" indices (after binning/subsampling)
432
437
  "projs_indices": self.projs_indices(subsampling=False),
@@ -434,7 +439,7 @@ class ProcessConfig(ProcessConfigBase):
434
439
  "do_flat_distortion": nabu_config["preproc"]["flat_distortion_correction_enabled"],
435
440
  "flat_distortion_params": extract_parameters(nabu_config["preproc"]["flat_distortion_params"]),
436
441
  }
437
- normalize_srcurrent = nabu_config["preproc"]["normalize_srcurrent"]
442
+ normalize_srcurrent = nabu_config["preproc"]["normalize_srcurrent"] and ff_method == "default"
438
443
  radios_srcurrent = None
439
444
  flats_srcurrent = None
440
445
  if normalize_srcurrent:
@@ -458,6 +463,7 @@ class ProcessConfig(ProcessConfigBase):
458
463
  if len(dataset_info.darks) > 1:
459
464
  self.logger.warning("Cannot do flat-field with more than one reduced dark. Taking the first one.")
460
465
  dataset_info.darks = dataset_info.darks[sorted(dataset_info.darks.keys())[0]]
466
+
461
467
  #
462
468
  # Spikes filter
463
469
  #
@@ -261,6 +261,9 @@ class FullFieldReconstructor:
261
261
  if (self.process_config.dataset_info.detector_tilt or 0) > 15:
262
262
  force_grouped_mode = True
263
263
  msg = "Radios rotation with a large angle needs to process full radios"
264
+ if self.process_config.processing_options.get("flatfield", {}).get("method", "default") == "pca":
265
+ force_grouped_mode = True
266
+ msg = "PCA-Flatfield normalization needs to process full radios"
264
267
  if self.process_config.resume_from_step == "sinogram" and force_grouped_mode:
265
268
  self.logger.warning("Cannot use grouped-radios processing when resuming from sinogram")
266
269
  force_grouped_mode = False
nabu/pipeline/params.py CHANGED
@@ -3,6 +3,17 @@ flatfield_modes = {
3
3
  "1": True,
4
4
  "false": False,
5
5
  "0": False,
6
+ # These three should be removed after a while (moved to 'flatfield_loading_mode')
7
+ "forced": "force-load",
8
+ "force-load": "force-load",
9
+ "force-compute": "force-compute",
10
+ #
11
+ "pca": "pca",
12
+ }
13
+
14
+ flatfield_loading_mode = {
15
+ "": "load_if_present",
16
+ "load_if_present": "load_if_present",
6
17
  "forced": "force-load",
7
18
  "force-load": "force-load",
8
19
  "force-compute": "force-compute",
@@ -77,6 +88,7 @@ iterative_methods = {
77
88
  optim_algorithms = {
78
89
  "chambolle": "chambolle-pock",
79
90
  "chambollepock": "chambolle-pock",
91
+ "chambolle-pock": "chambolle-pock",
80
92
  "fista": "fista",
81
93
  }
82
94
 
@@ -1,14 +1,23 @@
1
1
  import os
2
+ from tempfile import TemporaryDirectory
2
3
  import pytest
3
4
  import numpy as np
4
- from nabu.testutils import utilstest, __do_long_tests__
5
- from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer, analyze_dataset
5
+ from pint import get_application_registry
6
+ from nxtomo import NXtomo
7
+ from nabu.testutils import utilstest, __do_long_tests__, get_data
8
+ from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer, analyze_dataset, ImageKey
6
9
  from nabu.resources.nxflatfield import update_dataset_info_flats_darks
7
10
  from nabu.resources.utils import extract_parameters
8
- from nabu.pipeline.estimators import CompositeCOREstimator
11
+ from nabu.pipeline.estimators import CompositeCOREstimator, TranslationsEstimator
9
12
  from nabu.pipeline.config import parse_nabu_config_file
10
13
  from nabu.pipeline.estimators import SinoCORFinder, CORFinder
11
14
 
15
+ from nabu.estimation.tests.test_motion_estimation import (
16
+ check_motion_estimation,
17
+ project_volume,
18
+ _create_translations_vector,
19
+ )
20
+
12
21
 
13
22
  #
14
23
  # Test CoR estimation with "composite-coarse-to-fine" (aka "near" in the legacy system vocable)
@@ -119,3 +128,231 @@ class TestCorNearPos:
119
128
  cor = finder.find_cor()
120
129
  message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}."
121
130
  assert np.isclose(self.true_cor + 0.5, cor, atol=self.abs_tol), message
131
+
132
+
133
+ def _add_fake_flats_and_dark_to_data(data, n_darks=10, n_flats=21, dark_val=1, flat_val=3):
134
+ img_shape = data.shape[1:]
135
+ # Use constant darks/flats, to avoid "reduction" (mean/median) issues
136
+ fake_darks = np.ones((n_darks,) + img_shape, dtype=np.uint16) * dark_val
137
+ fake_flats = np.ones((n_flats,) + img_shape, dtype=np.uint16) * flat_val
138
+ return data * (fake_flats[0, 0, 0] - fake_darks[0, 0, 0]) + fake_darks[0, 0, 0], fake_darks, fake_flats
139
+
140
+
141
+ def _generate_nx_for_180_dataset(volume, output_file_path, n_darks=10, n_flats=21):
142
+
143
+ n_angles = 250
144
+ cor = -10
145
+
146
+ alpha_x = 4
147
+ beta_x = 3
148
+ alpha_y = -5
149
+ beta_y = 10
150
+ beta_z = 0
151
+ orig_det_dist = 0
152
+
153
+ angles0 = np.linspace(0, np.pi, n_angles, False)
154
+ return_angles = np.deg2rad([180.0, 135.0, 90.0, 45.0, 0.0])
155
+ angles = np.hstack([angles0, return_angles]).ravel()
156
+ a = np.arange(angles0.size + return_angles.size) / angles0.size
157
+
158
+ tx = _create_translations_vector(a, alpha_x, beta_x)
159
+ ty = _create_translations_vector(a, alpha_y, beta_y)
160
+ tz = _create_translations_vector(a, 0, beta_z)
161
+
162
+ sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
163
+ data = np.moveaxis(sinos, 1, 0)
164
+
165
+ sample_motion_xy = np.stack([-tx, ty], axis=1)
166
+ sample_motion_z = -tz
167
+ angles_deg = np.degrees(angles0)
168
+ return_angles_deg = np.degrees(return_angles)
169
+ n_return_radios = len(return_angles_deg)
170
+ n_radios = data.shape[0] - n_return_radios
171
+
172
+ ureg = get_application_registry()
173
+ fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
174
+
175
+ nxtomo = NXtomo()
176
+ nxtomo.instrument.detector.data = np.concatenate(
177
+ [
178
+ darks,
179
+ flats,
180
+ fake_raw_data, # radios + return radios (in float32 !)
181
+ ]
182
+ )
183
+ image_key_control = np.concatenate(
184
+ [
185
+ [ImageKey.DARK_FIELD.value] * n_darks,
186
+ [ImageKey.FLAT_FIELD.value] * n_flats,
187
+ [ImageKey.PROJECTION.value] * n_radios,
188
+ [ImageKey.ALIGNMENT.value] * n_return_radios,
189
+ ]
190
+ )
191
+ nxtomo.instrument.detector.image_key_control = image_key_control
192
+
193
+ rotation_angle = np.concatenate(
194
+ [np.zeros(n_darks, dtype="f"), np.zeros(n_flats, dtype="f"), angles_deg, return_angles_deg]
195
+ )
196
+ nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
197
+ nxtomo.instrument.detector.field_of_view = "Full"
198
+ nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
199
+ nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
200
+
201
+ return sample_motion_xy, sample_motion_z, cor
202
+
203
+
204
+ def _generate_nx_for_360_dataset(volume, output_file_path, n_darks=10, n_flats=21):
205
+
206
+ n_angles = 250
207
+ cor = -5.5
208
+
209
+ alpha_x = -2
210
+ beta_x = 7.0
211
+ alpha_y = -2
212
+ beta_y = 3
213
+ beta_z = 100
214
+ orig_det_dist = 0
215
+
216
+ angles = np.linspace(0, 2 * np.pi, n_angles, False)
217
+ a = np.linspace(0, 1, angles.size, endpoint=False) # theta/theta_max
218
+
219
+ tx = _create_translations_vector(a, alpha_x, beta_x)
220
+ ty = _create_translations_vector(a, alpha_y, beta_y)
221
+ tz = _create_translations_vector(a, 0, beta_z)
222
+
223
+ sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
224
+ data = np.moveaxis(sinos, 1, 0)
225
+
226
+ sample_motion_xy = np.stack([-tx, ty], axis=1)
227
+ sample_motion_z = -tz
228
+ angles_deg = np.degrees(angles)
229
+
230
+ ureg = get_application_registry()
231
+
232
+ fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
233
+
234
+ nxtomo = NXtomo()
235
+ nxtomo.instrument.detector.data = np.concatenate([darks, flats, fake_raw_data]) # in float32 !
236
+
237
+ image_key_control = np.concatenate(
238
+ [
239
+ [ImageKey.DARK_FIELD.value] * n_darks,
240
+ [ImageKey.FLAT_FIELD.value] * n_flats,
241
+ [ImageKey.PROJECTION.value] * data.shape[0],
242
+ ]
243
+ )
244
+ nxtomo.instrument.detector.image_key_control = image_key_control
245
+
246
+ rotation_angle = np.concatenate(
247
+ [
248
+ np.zeros(n_darks, dtype="f"),
249
+ np.zeros(n_flats, dtype="f"),
250
+ angles_deg,
251
+ ]
252
+ )
253
+ nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
254
+ nxtomo.instrument.detector.field_of_view = "Full"
255
+ nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
256
+ nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
257
+
258
+ return sample_motion_xy, sample_motion_z, cor
259
+
260
+
261
+ @pytest.fixture(scope="class")
262
+ def setup_test_motion_estimator(request):
263
+ cls = request.cls
264
+ cls.volume = get_data("motion/mri_volume_subsampled.npy")
265
+
266
+
267
+ @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1")
268
+ @pytest.mark.usefixtures("setup_test_motion_estimator")
269
+ class TestMotionEstimator:
270
+
271
+ def _setup(self, tmpdir):
272
+ # pytest uses some weird data structure for "tmpdir"
273
+ if not (isinstance(tmpdir, str)):
274
+ tmpdir = str(tmpdir)
275
+ #
276
+ if getattr(self, "volume", None) is None:
277
+ self.volume = get_data("motion/mri_volume_subsampled.npy")
278
+
279
+ def test_estimate_motion_360_dataset(self, tmpdir, verbose=False):
280
+ self._setup(tmpdir)
281
+ nx_file_path = os.path.join(tmpdir, "mri_projected_360_motion.nx")
282
+ sample_motion_xy, sample_motion_z, cor = _generate_nx_for_360_dataset(self.volume, nx_file_path)
283
+
284
+ dataset_info = analyze_dataset(nx_file_path)
285
+
286
+ translations_estimator = TranslationsEstimator(
287
+ dataset_info, do_flatfield=True, rot_center=cor, angular_subsampling=5, deg_xy=2, deg_z=2
288
+ )
289
+ estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
290
+
291
+ s = translations_estimator.angular_subsampling
292
+ if verbose:
293
+ translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
294
+ translations_estimator.motion_estimator.plot_movements(
295
+ cor=cor,
296
+ angles_rad=dataset_info.rotation_angles[::s],
297
+ gt_xy=sample_motion_xy[::s, :],
298
+ gt_z=sample_motion_z[::s],
299
+ )
300
+ check_motion_estimation(
301
+ translations_estimator.motion_estimator,
302
+ dataset_info.rotation_angles[::s],
303
+ cor,
304
+ sample_motion_xy[::s, :],
305
+ sample_motion_z[::s],
306
+ fit_error_shifts_tol_vu=(0.2, 0.2),
307
+ fit_error_det_tol_vu=(1e-5, 5e-2),
308
+ fit_error_tol_xyz=(0.05, 0.05, 0.05),
309
+ fit_error_det_all_angles_tol_vu=(1e-5, 0.05),
310
+ )
311
+
312
+ def test_estimate_motion_180_dataset(self, tmpdir, verbose=False):
313
+ self._setup(tmpdir)
314
+ nx_file_path = os.path.join(tmpdir, "mri_projected_180_motion.nx")
315
+
316
+ sample_motion_xy, sample_motion_z, cor = _generate_nx_for_180_dataset(self.volume, nx_file_path)
317
+
318
+ dataset_info = analyze_dataset(nx_file_path)
319
+
320
+ translations_estimator = TranslationsEstimator(
321
+ dataset_info,
322
+ do_flatfield=True,
323
+ rot_center=cor,
324
+ angular_subsampling=2,
325
+ deg_xy=2,
326
+ deg_z=2,
327
+ shifts_estimator="DetectorTranslationAlongBeam",
328
+ )
329
+ estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
330
+
331
+ if verbose:
332
+ translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
333
+ translations_estimator.motion_estimator.plot_movements(
334
+ cor=cor,
335
+ angles_rad=dataset_info.rotation_angles,
336
+ gt_xy=sample_motion_xy[: dataset_info.n_angles],
337
+ gt_z=sample_motion_z[: dataset_info.n_angles],
338
+ )
339
+
340
+ check_motion_estimation(
341
+ translations_estimator.motion_estimator,
342
+ dataset_info.rotation_angles,
343
+ cor,
344
+ sample_motion_xy,
345
+ sample_motion_z,
346
+ fit_error_shifts_tol_vu=(0.02, 0.1),
347
+ fit_error_det_tol_vu=(1e-2, 0.5),
348
+ fit_error_tol_xyz=(0.5, 2, 1e-2),
349
+ fit_error_det_all_angles_tol_vu=(1e-2, 2),
350
+ )
351
+
352
+
353
+ if __name__ == "__main__":
354
+
355
+ T = TestMotionEstimator()
356
+ with TemporaryDirectory(suffix="_motion", prefix="nabu_testdata") as tmpdir:
357
+ T.test_estimate_motion_360_dataset(tmpdir, verbose=True)
358
+ T.test_estimate_motion_180_dataset(tmpdir, verbose=True)