nabu 2025.1.0.dev14__py3-none-any.whl → 2025.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nabu/__init__.py +1 -1
  2. nabu/app/cast_volume.py +12 -1
  3. nabu/app/cli_configs.py +80 -3
  4. nabu/app/estimate_motion.py +54 -0
  5. nabu/app/multicor.py +2 -4
  6. nabu/app/pcaflats.py +116 -0
  7. nabu/app/reconstruct.py +1 -7
  8. nabu/app/reduce_dark_flat.py +5 -2
  9. nabu/estimation/cor.py +1 -1
  10. nabu/estimation/motion.py +557 -0
  11. nabu/estimation/tests/test_motion_estimation.py +471 -0
  12. nabu/estimation/tilt.py +1 -1
  13. nabu/estimation/translation.py +47 -1
  14. nabu/io/cast_volume.py +94 -13
  15. nabu/io/reader.py +32 -1
  16. nabu/io/tests/test_remove_volume.py +152 -0
  17. nabu/pipeline/config_validators.py +42 -43
  18. nabu/pipeline/estimators.py +255 -0
  19. nabu/pipeline/fullfield/chunked.py +67 -43
  20. nabu/pipeline/fullfield/chunked_cuda.py +5 -2
  21. nabu/pipeline/fullfield/nabu_config.py +17 -11
  22. nabu/pipeline/fullfield/processconfig.py +8 -2
  23. nabu/pipeline/fullfield/reconstruction.py +3 -0
  24. nabu/pipeline/params.py +12 -0
  25. nabu/pipeline/tests/test_estimators.py +240 -3
  26. nabu/preproc/ccd.py +53 -3
  27. nabu/preproc/flatfield.py +306 -1
  28. nabu/preproc/shift.py +3 -1
  29. nabu/preproc/tests/test_pcaflats.py +154 -0
  30. nabu/processing/rotation_cuda.py +3 -1
  31. nabu/processing/tests/test_rotation.py +4 -2
  32. nabu/reconstruction/fbp.py +7 -0
  33. nabu/reconstruction/fbp_base.py +31 -7
  34. nabu/reconstruction/fbp_opencl.py +8 -0
  35. nabu/reconstruction/filtering_opencl.py +2 -0
  36. nabu/reconstruction/mlem.py +47 -13
  37. nabu/reconstruction/tests/test_filtering.py +13 -2
  38. nabu/reconstruction/tests/test_mlem.py +91 -62
  39. nabu/resources/dataset_analyzer.py +144 -20
  40. nabu/resources/nxflatfield.py +101 -35
  41. nabu/resources/tests/test_nxflatfield.py +1 -1
  42. nabu/resources/utils.py +16 -10
  43. nabu/stitching/alignment.py +7 -7
  44. nabu/stitching/config.py +22 -20
  45. nabu/stitching/definitions.py +2 -2
  46. nabu/stitching/overlap.py +4 -4
  47. nabu/stitching/sample_normalization.py +5 -5
  48. nabu/stitching/stitcher/post_processing.py +5 -3
  49. nabu/stitching/stitcher/pre_processing.py +24 -20
  50. nabu/stitching/tests/test_config.py +3 -3
  51. nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
  52. nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
  53. nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
  54. nabu/stitching/utils/utils.py +7 -7
  55. nabu/testutils.py +1 -4
  56. nabu/utils.py +13 -0
  57. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/METADATA +3 -4
  58. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/RECORD +62 -57
  59. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/WHEEL +1 -1
  60. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/entry_points.txt +2 -1
  61. nabu/app/correct_rot.py +0 -62
  62. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/licenses/LICENSE +0 -0
  63. {nabu-2025.1.0.dev14.dist-info → nabu-2025.1.0rc1.dist-info}/top_level.txt +0 -0
nabu/io/reader.py CHANGED
@@ -4,6 +4,7 @@ from math import ceil
4
4
  from multiprocessing.pool import ThreadPool
5
5
  from posixpath import sep as posix_sep, join as posix_join
6
6
  import numpy as np
7
+ from h5py import Dataset
7
8
  from silx.io import get_data
8
9
  from silx.io.dictdump import h5todict
9
10
  from tomoscan.io import HDF5File
@@ -770,7 +771,7 @@ class NXDarksFlats:
770
771
  reduced_frames = [self._reduce_func[method](frames, axis=0) for frames in raw_frames]
771
772
  reader = getattr(self, "%s_reader" % what)
772
773
  if as_dict:
773
- return {k: v for k, v in zip([s.start for s in reader._image_key_slices], reduced_frames)} # noqa: C416
774
+ return {k: v for k, v in zip([s.start for s in reader._image_key_slices], reduced_frames)}
774
775
  return reduced_frames
775
776
 
776
777
  def get_raw_darks(self, force_reload=False, as_multiple_array=True):
@@ -987,6 +988,12 @@ def get_entry_from_h5_path(h5_path):
987
988
  return v[0] or v[1]
988
989
 
989
990
 
991
+ def list_hdf5_entries(fname):
992
+ with HDF5File(fname, "r") as f:
993
+ entries = list(f.keys())
994
+ return entries
995
+
996
+
990
997
  def check_virtual_sources_exist(fname, data_path):
991
998
  with HDF5File(fname, "r") as f:
992
999
  if data_path not in f:
@@ -1006,6 +1013,30 @@ def check_virtual_sources_exist(fname, data_path):
1006
1013
  return True
1007
1014
 
1008
1015
 
1016
+ def get_hdf5_file_all_virtual_sources(file_path_or_obj, return_only_filenames=False):
1017
+ result = []
1018
+
1019
+ def collect_vsources(name, obj):
1020
+ if isinstance(obj, Dataset) and obj.is_virtual:
1021
+ vs = obj.virtual_sources()
1022
+ if return_only_filenames:
1023
+ vs = [vs_.file_name for vs_ in vs]
1024
+ result.append({name: vs})
1025
+
1026
+ _self_opened_file = False
1027
+ if isinstance(file_path_or_obj, str):
1028
+ fdesc = HDF5File(file_path_or_obj, "r")
1029
+ _self_opened_file = True
1030
+ else:
1031
+ fdesc = file_path_or_obj
1032
+
1033
+ fdesc.visititems(collect_vsources)
1034
+
1035
+ if _self_opened_file:
1036
+ fdesc.close()
1037
+ return result
1038
+
1039
+
1009
1040
  def import_h5_to_dict(h5file, h5path, asarray=False):
1010
1041
  """
1011
1042
  Wrapper on top of silx.io.dictdump.dicttoh5 replacing "None" with None
@@ -0,0 +1,152 @@
1
+ from pathlib import Path
2
+ from os import path, mkdir, rename
3
+ import numpy as np
4
+ import pytest
5
+ from tomoscan.esrf.volume import (
6
+ EDFVolume,
7
+ HDF5Volume,
8
+ JP2KVolume,
9
+ MultiTIFFVolume,
10
+ TIFFVolume,
11
+ )
12
+ from tomoscan.esrf.volume.jp2kvolume import has_glymur
13
+ from nabu.io.writer import merge_hdf5_files
14
+ from nabu.io.cast_volume import remove_volume
15
+
16
+
17
+ def test_remove_single_frame_volume(tmpdir):
18
+ """
19
+ Test volume removal for tiff, jp2 and EDF
20
+ """
21
+ # Have to use a not-too-small size because of jp2k
22
+ data = np.arange(10 * 40 * 50, dtype="f").reshape((10, 40, 50))
23
+
24
+ volume_classes = [EDFVolume, TIFFVolume, JP2KVolume]
25
+ if not (has_glymur):
26
+ volume_classes.pop()
27
+ for volume_cls in volume_classes:
28
+ ext = volume_cls.DEFAULT_DATA_EXTENSION
29
+ folder = path.join(tmpdir, f"{ext}_vol")
30
+ volume_basename = f"{ext}_basename"
31
+
32
+ vol_writer = volume_cls(folder=folder, volume_basename=volume_basename, overwrite=True, start_index=0)
33
+ vol_writer.data = data
34
+ vol_writer.save()
35
+
36
+ vol_reader = volume_cls(folder=folder, volume_basename=volume_basename)
37
+ assert path.isdir(folder), f"Expected to find a folder f{folder}"
38
+ remove_volume(vol_reader)
39
+ assert not (path.isdir(folder)), f"Expected to have removed the folder f{folder}"
40
+
41
+ vol_writer.save()
42
+ vol_reader = volume_cls(folder=folder, volume_basename=volume_basename)
43
+ Path(path.join(folder, f"unexpected.{ext}")).touch()
44
+ with pytest.raises(RuntimeError) as exc:
45
+ remove_volume(vol_reader, check=True)
46
+ assert "Unexpected files present" in str(exc.value), "Expected check to find extraneous files"
47
+
48
+
49
+ def test_remove_multiframe_volume(tmpdir):
50
+ """
51
+ Test volume removal for "multiframe" formats (HDF5, tiff3D)
52
+ The HDF5 files considered in this test do not have virtual sources
53
+ """
54
+ data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
55
+
56
+ for ext, volume_cls in {"h5": HDF5Volume, "tiff": MultiTIFFVolume}.items():
57
+ file_path = path.join(tmpdir, f"{ext}_vol.{ext}")
58
+
59
+ init_kwargs = {"file_path": file_path}
60
+ if ext == "h5":
61
+ init_kwargs["data_path"] = "entry"
62
+ vol_writer = volume_cls(**init_kwargs)
63
+ vol_writer.data = data
64
+ vol_writer.save()
65
+
66
+ vol_reader = volume_cls(**init_kwargs)
67
+ assert path.isfile(file_path), f"Expected to find a {ext} volume at {file_path}"
68
+ remove_volume(vol_reader)
69
+ assert not (path.isfile(file_path)), f"Expected to have removed f{file_path}"
70
+
71
+
72
+ def test_remove_hdf5_multiple_entries(tmpdir):
73
+ data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
74
+ file_path = path.join(tmpdir, "h5_vol.h5")
75
+ vol_writer_1 = HDF5Volume(file_path=file_path, data_path="entry0000")
76
+ vol_writer_1.data = data
77
+ vol_writer_1.save()
78
+ vol_writer_2 = HDF5Volume(file_path=file_path, data_path="entry0001")
79
+ vol_writer_2.data = data + 10
80
+ vol_writer_2.save()
81
+ vol_reader = HDF5Volume(file_path=file_path, data_path="entry0000")
82
+ with pytest.raises(NotImplementedError) as exc:
83
+ remove_volume(vol_reader, check=True)
84
+ assert "Removing a HDF5 volume with more than one entry is not supported" in str(
85
+ exc.value
86
+ ), "Expected an error message"
87
+
88
+
89
+ def test_remove_nabu_hdf5_reconstruction(tmpdir):
90
+ """
91
+ Test removal of HDF5 reconstruction generated by nabu (i.e with virtual sources)
92
+ """
93
+
94
+ entry = "entry"
95
+ process_name = "reconstruction"
96
+
97
+ master_file_path = path.join(tmpdir, "sample_naburec.hdf5")
98
+ associated_dir = path.join(tmpdir, "sample_naburec")
99
+ if not (path.isdir(associated_dir)):
100
+ mkdir(associated_dir)
101
+
102
+ n_chunks = 5
103
+ local_files = []
104
+ for i in range(n_chunks):
105
+ fname = "sample_naburec_%06d.h5" % i
106
+ partial_rec_abspath = path.join(associated_dir, fname)
107
+ local_files.append(f"sample_naburec/{fname}")
108
+ # local_files.append(fname)
109
+ vol = HDF5Volume(partial_rec_abspath, data_path=f"{entry}/{process_name}")
110
+ vol.data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
111
+ vol.save()
112
+
113
+ h5_path = f"{entry}/{process_name}/results/data"
114
+
115
+ merge_hdf5_files(
116
+ local_files,
117
+ h5_path,
118
+ master_file_path,
119
+ process_name,
120
+ output_entry=entry,
121
+ output_filemode="a",
122
+ processing_index=0,
123
+ config=None,
124
+ base_dir=path.dirname(associated_dir),
125
+ axis=0,
126
+ overwrite=True,
127
+ )
128
+
129
+ assert path.isfile(master_file_path), f"Expected to find the master file at {master_file_path}"
130
+ assert path.isdir(associated_dir)
131
+ for local_file in local_files:
132
+ partial_rec_file = path.join(tmpdir, local_file)
133
+ assert path.isfile(partial_rec_file), f"Expected to find partial file number {i} at {partial_rec_file}"
134
+
135
+ # Check that the virtual links are handled properly
136
+ # sample_rec.hdf5 should reference sample_rec/sample_rec_{i}.h5
137
+ renamed_master_file_path = (
138
+ path.join(path.dirname(master_file_path), path.basename(master_file_path).split(".")[0]) + "_renamed" + ".h5"
139
+ )
140
+ rename(master_file_path, renamed_master_file_path)
141
+ h5_vol = HDF5Volume(file_path=renamed_master_file_path, data_path=f"{entry}/{process_name}")
142
+ with pytest.raises(ValueError) as exc:
143
+ remove_volume(h5_vol)
144
+ expected_error_message = f"The virtual sources in {renamed_master_file_path}:{process_name}/results/data reference the directory sample_naburec, but expected was sample_naburec_renamed"
145
+ assert str(exc.value) == expected_error_message
146
+
147
+ # Check removal in normal circumstances
148
+ rename(renamed_master_file_path, master_file_path)
149
+ h5_vol = HDF5Volume(file_path=master_file_path, data_path=f"{entry}/{process_name}")
150
+ remove_volume(h5_vol)
151
+ assert not (path.isfile(master_file_path)), f"Expected to find the master file at {master_file_path}"
152
+ assert not (path.isdir(associated_dir))
@@ -2,7 +2,7 @@
2
2
  import os
3
3
 
4
4
  path = os.path
5
- from ..utils import check_supported, is_writeable
5
+ from ..utils import check_supported, deprecation_warning, is_writeable
6
6
  from .params import * # noqa: F403
7
7
 
8
8
  """
@@ -97,12 +97,21 @@ def convert_to_bool_noerr(val):
97
97
  return res
98
98
 
99
99
 
100
- def name_range_checker(name, valid_names, descr, replacements=None):
100
+ def name_range_checker(name, available_names, descr):
101
+ """
102
+ Check whether a parameter name is valid, against a list or dictionary of names.
103
+ """
101
104
  name = name.strip().lower()
102
- if replacements is not None and name in replacements:
103
- name = replacements[name]
104
- valid = name in valid_names
105
- assert valid, "Invalid %s '%s'. Available are %s" % (descr, name, str(valid_names))
105
+ valid = name in available_names
106
+ if isinstance(available_names, dict):
107
+ # handle replacements, eg. {"edge": "edges"}
108
+ name = available_names[name]
109
+ # we could use .keys() instead to be more permissive to the user
110
+ available_names_str = str(set(available_names.values()))
111
+ else:
112
+ # assuming list
113
+ available_names_str = str(available_names)
114
+ assert valid, "Invalid %s '%s'. Available are %s" % (descr, name, available_names_str)
106
115
  return name
107
116
 
108
117
 
@@ -338,9 +347,7 @@ def cor_validator(val):
338
347
  return val_float
339
348
  if len(val.strip()) == 0:
340
349
  return None
341
- val = name_range_checker(
342
- val.lower(), set(cor_methods.values()), "center of rotation estimation method", replacements=cor_methods
343
- )
350
+ val = name_range_checker(val, cor_methods, "center of rotation estimation method")
344
351
  return val
345
352
 
346
353
 
@@ -351,9 +358,7 @@ def tilt_validator(val):
351
358
  return val_float
352
359
  if len(val.strip()) == 0:
353
360
  return None
354
- val = name_range_checker(
355
- val.lower(), set(tilt_methods.values()), "automatic detector tilt estimation method", replacements=tilt_methods
356
- )
361
+ val = name_range_checker(val, tilt_methods, "automatic detector tilt estimation method")
357
362
  return val
358
363
 
359
364
 
@@ -395,53 +400,55 @@ def cor_slice_validator(val):
395
400
 
396
401
 
397
402
  @validator
398
- def flatfield_enabled_validator(val):
399
- return name_range_checker(val, set(flatfield_modes.values()), "flatfield mode", replacements=flatfield_modes)
403
+ def flatfield_validator(val):
404
+ ret = name_range_checker(val, flatfield_modes, "flatfield mode")
405
+ if ret in ["force-load", "force-compute"]:
406
+ deprecation_warning(
407
+ f"Using 'flatfield = {ret}' is deprecated since version 2025.1.0. Please use the parameter 'flatfield_loading_mode'",
408
+ )
409
+ return ret
410
+
411
+
412
+ @validator
413
+ def flatfield_loading_mode_validator(val):
414
+ return name_range_checker(val, flatfield_loading_mode, "flatfield mode")
400
415
 
401
416
 
402
417
  @validator
403
418
  def phase_method_validator(val):
404
- return name_range_checker(
405
- val, set(phase_retrieval_methods.values()), "phase retrieval method", replacements=phase_retrieval_methods
406
- )
419
+ return name_range_checker(val, phase_retrieval_methods, "phase retrieval method")
407
420
 
408
421
 
409
422
  @validator
410
423
  def detector_distortion_correction_validator(val):
411
424
  return name_range_checker(
412
425
  val,
413
- set(detector_distortion_correction_methods.values()),
426
+ detector_distortion_correction_methods,
414
427
  "detector_distortion_correction_methods",
415
- replacements=detector_distortion_correction_methods,
416
428
  )
417
429
 
418
430
 
419
431
  @validator
420
432
  def unsharp_method_validator(val):
421
- return name_range_checker(
422
- val, set(unsharp_methods.values()), "unsharp mask method", replacements=phase_retrieval_methods
423
- )
433
+ return name_range_checker(val, unsharp_methods, "unsharp mask method")
424
434
 
425
435
 
426
436
  @validator
427
437
  def padding_mode_validator(val):
428
- return name_range_checker(val, set(padding_modes.values()), "padding mode", replacements=padding_modes)
438
+ return name_range_checker(val, padding_modes, "padding mode")
429
439
 
430
440
 
431
441
  @validator
432
442
  def reconstruction_method_validator(val):
433
- return name_range_checker(
434
- val, set(reconstruction_methods.values()), "reconstruction method", replacements=reconstruction_methods
435
- )
443
+ return name_range_checker(val, reconstruction_methods, "reconstruction method")
436
444
 
437
445
 
438
446
  @validator
439
447
  def fbp_filter_name_validator(val):
440
448
  return name_range_checker(
441
449
  val,
442
- set(fbp_filters.values()),
450
+ fbp_filters,
443
451
  "FBP filter",
444
- replacements=fbp_filters,
445
452
  )
446
453
 
447
454
 
@@ -449,29 +456,24 @@ def fbp_filter_name_validator(val):
449
456
  def reconstruction_implementation_validator(val):
450
457
  return name_range_checker(
451
458
  val,
452
- set(reco_implementations.values()),
459
+ reco_implementations,
453
460
  "Reconstruction method implementation",
454
- replacements=reco_implementations,
455
461
  )
456
462
 
457
463
 
458
464
  @validator
459
465
  def optimization_algorithm_name_validator(val):
460
- return name_range_checker(
461
- val, set(optim_algorithms.values()), "optimization algorithm name", replacements=iterative_methods
462
- )
466
+ return name_range_checker(val, optim_algorithms, "optimization algorithm name")
463
467
 
464
468
 
465
469
  @validator
466
470
  def output_file_format_validator(val):
467
- return name_range_checker(val, set(files_formats.values()), "output file format", replacements=files_formats)
471
+ return name_range_checker(val, files_formats, "output file format")
468
472
 
469
473
 
470
474
  @validator
471
475
  def distribution_method_validator(val):
472
- val = name_range_checker(
473
- val, set(distribution_methods.values()), "workload distribution method", replacements=distribution_methods
474
- )
476
+ val = name_range_checker(val, distribution_methods, "workload distribution method")
475
477
  # TEMP.
476
478
  if val != "local":
477
479
  raise NotImplementedError("Computation method '%s' is not implemented yet" % val)
@@ -481,9 +483,7 @@ def distribution_method_validator(val):
481
483
 
482
484
  @validator
483
485
  def sino_normalization_validator(val):
484
- val = name_range_checker(
485
- val, set(sino_normalizations.values()), "sinogram normalization method", replacements=sino_normalizations
486
- )
486
+ val = name_range_checker(val, sino_normalizations, "sinogram normalization method")
487
487
  return val
488
488
 
489
489
 
@@ -491,9 +491,8 @@ def sino_normalization_validator(val):
491
491
  def sino_deringer_methods(val):
492
492
  val = name_range_checker(
493
493
  val,
494
- set(rings_methods.values()),
494
+ rings_methods,
495
495
  "sinogram rings artefacts correction method",
496
- replacements=rings_methods,
497
496
  )
498
497
  return val
499
498
 
@@ -556,7 +555,7 @@ def nonempty_string_validator(val):
556
555
 
557
556
  @validator
558
557
  def logging_validator(val):
559
- return name_range_checker(val, set(log_levels.values()), "logging level", replacements=log_levels)
558
+ return name_range_checker(val, log_levels, "logging level")
560
559
 
561
560
 
562
561
  @validator
@@ -9,6 +9,7 @@ import scipy.fft # pylint: disable=E0611
9
9
  from silx.io import get_data
10
10
  import math
11
11
  from scipy import ndimage as nd
12
+
12
13
  from ..preproc.flatfield import FlatField
13
14
  from ..estimation.cor import (
14
15
  CenterOfRotation,
@@ -17,8 +18,10 @@ from ..estimation.cor import (
17
18
  CenterOfRotationGrowingWindow,
18
19
  CenterOfRotationOctaveAccurate,
19
20
  )
21
+ from .. import version as nabu_version
20
22
  from ..estimation.cor_sino import SinoCorInterface, CenterOfRotationFourierAngles, CenterOfRotationVo
21
23
  from ..estimation.tilt import CameraTilt
24
+ from ..estimation.motion import MotionEstimation
22
25
  from ..estimation.utils import is_fullturn_scan
23
26
  from ..resources.logger import LoggerOrPrint
24
27
  from ..resources.utils import extract_parameters
@@ -989,3 +992,255 @@ class DetectorTiltEstimator:
989
992
 
990
993
  # alias
991
994
  TiltFinder = DetectorTiltEstimator
995
+
996
+
997
+ def estimate_translations(dataset_info, do_flatfield=True): ...
998
+
999
+
1000
+ class TranslationsEstimator:
1001
+
1002
+ _default_extra_options = {
1003
+ "window_size": 200,
1004
+ }
1005
+
1006
+ def __init__(
1007
+ self,
1008
+ dataset_info,
1009
+ do_flatfield=True,
1010
+ rot_center=None,
1011
+ halftomo_side=None,
1012
+ angular_subsampling=10,
1013
+ deg_xy=2,
1014
+ deg_z=2,
1015
+ shifts_estimator="phase_cross_correlation",
1016
+ extra_options=None,
1017
+ ):
1018
+ self._configure_extra_options(extra_options)
1019
+ self.logger = LoggerOrPrint(dataset_info.logger)
1020
+ self.dataset_info = dataset_info
1021
+ self.angular_subsampling = angular_subsampling
1022
+ self.do_360 = self.dataset_info.is_360
1023
+ self.do_flatfield = do_flatfield
1024
+ self.radios = None
1025
+ self._deg_xy = deg_xy
1026
+ self._deg_z = deg_z
1027
+ self._shifts_estimator = shifts_estimator
1028
+ self._shifts_estimator_kwargs = {}
1029
+ self._cor = rot_center
1030
+ self._configure_halftomo(halftomo_side)
1031
+ self._estimate_cor = self._cor is None
1032
+ self.sample_shifts_xy = None
1033
+ self.sample_shifts_z = None
1034
+
1035
+ def _configure_extra_options(self, extra_options):
1036
+ self.extra_options = self._default_extra_options.copy()
1037
+ self.extra_options.update(extra_options or {})
1038
+
1039
+ def _configure_halftomo(self, halftomo_side):
1040
+ if halftomo_side is False:
1041
+ # Force disable halftomo
1042
+ self.halftomo_side = False
1043
+ return
1044
+ self._start_x = None
1045
+ self._end_x = None
1046
+ if (halftomo_side is not None) and not (self.do_360):
1047
+ raise ValueError(
1048
+ "Expected 360° dataset for half-tomography, but this dataset does not look like a 360° dataset"
1049
+ )
1050
+ if halftomo_side is None:
1051
+ if self.dataset_info.is_halftomo:
1052
+ halftomo_side = "right"
1053
+ else:
1054
+ self.halftomo_side = False
1055
+ return
1056
+ self.halftomo_side = halftomo_side
1057
+ window_size = self.extra_options["window_size"]
1058
+ if self._cor is not None:
1059
+ # In this case we look for shifts around the CoR
1060
+ self._start_x = int(self._cor - window_size / 2)
1061
+ self._end_x = int(self._cor + window_size / 2)
1062
+ elif halftomo_side == "right":
1063
+ self._start_x = -window_size
1064
+ self._end_x = None
1065
+ elif halftomo_side == "left":
1066
+ self._start_x = 0
1067
+ self._end_x = window_size
1068
+ elif is_scalar(halftomo_side):
1069
+ # Expect approximate location of CoR, relative to left-most column
1070
+ self._start_x = int(halftomo_side - window_size / 2)
1071
+ self._end_x = int(halftomo_side + window_size / 2)
1072
+ else:
1073
+ raise ValueError(
1074
+ f"Expected 'halftomo_side' to be either 'left', 'right', or an integer (got {halftomo_side})"
1075
+ )
1076
+ self.logger.debug(f"[MotionEstimation] Half-tomo looking at [{self._start_x}:{self._end_x}]")
1077
+ # For half-tomo, skimage.registration.phase_cross_correlation might look a bit too far away
1078
+ if (
1079
+ self._shifts_estimator == "phase_cross_correlation"
1080
+ and self._shifts_estimator_kwargs.get("overlap_ratio", 0.3) >= 0.3
1081
+ ):
1082
+ self._shifts_estimator_kwargs.update({"overlap_ratio": 0.2})
1083
+ #
1084
+
1085
+ def _load_data(self):
1086
+ self.logger.debug("[MotionEstimation] reading data")
1087
+ if self.do_360:
1088
+ """
1089
+ In this case we compare pair of opposite projections.
1090
+ If rotation angles are arbitrary, we should do something like
1091
+ for angle in dataset_info.rotation_angles:
1092
+ img, angle_deg, idx = dataset_info.get_image_at_angle(
1093
+ np.degrees(angle)+180, return_angle_and_index=True
1094
+ )
1095
+ Most of the time (always ?), the dataset was acquired with a circular trajectory,
1096
+ so we can use angles:
1097
+ dataset_info.rotation_angles[::self.angular_subsampling]
1098
+ which amounts to reading one radio out of "angular_subsampling"
1099
+ """
1100
+
1101
+ # TODO account for more general rotation angles. The following will only work for circular trajectory and ordered angles
1102
+ self._reader = self.dataset_info.get_reader(
1103
+ sub_region=(slice(None, None, self.angular_subsampling), slice(None), slice(None))
1104
+ )
1105
+ self.radios = self._reader.load_data()
1106
+ self.angles = self.dataset_info.rotation_angles[:: self.angular_subsampling]
1107
+ self._radios_idx = self._reader.get_frames_indices()
1108
+ self.logger.debug("[MotionEstimation] This is a 360° scan, will use pairs of opposite projections")
1109
+ else:
1110
+ """
1111
+ In this case we use the "return projections", i.e special projections acquired at several angles
1112
+ (eg. [180, 90, 0]) before ending the scan
1113
+ """
1114
+ return_projs, return_angles_deg, return_idx = self.dataset_info.get_alignment_projections()
1115
+ self._angles_return = np.radians(return_angles_deg)
1116
+ self._radios_return = return_projs
1117
+ self._radios_idx_return = return_idx
1118
+
1119
+ projs = []
1120
+ angles_rad = []
1121
+ projs_idx = []
1122
+ for angle_deg in return_angles_deg:
1123
+ proj, rot_angle_deg, proj_idx = self.dataset_info.get_image_at_angle(
1124
+ angle_deg, image_type="projection", return_angle_and_index=True
1125
+ )
1126
+ projs.append(proj)
1127
+ angles_rad.append(np.radians(rot_angle_deg))
1128
+ projs_idx.append(proj_idx)
1129
+ self._radios_outwards = np.array(projs)
1130
+ self._angles_outward = np.array(angles_rad)
1131
+ self._radios_idx_outwards = np.array(projs_idx)
1132
+ self.logger.debug("[MotionEstimation] This is a 180° scan, will use 'return projections'")
1133
+
1134
+ def _apply_flatfield(self):
1135
+ if not (self.do_flatfield):
1136
+ return
1137
+ self.logger.debug("[MotionEstimation] flatfield")
1138
+ if self.do_360:
1139
+ self._flatfield = FlatField(
1140
+ self.radios.shape,
1141
+ flats=self.dataset_info.flats,
1142
+ darks=self.dataset_info.darks,
1143
+ radios_indices=self._radios_idx,
1144
+ )
1145
+ self._flatfield.normalize_radios(self.radios)
1146
+ else:
1147
+ # 180 + return projs
1148
+ self._flatfield_outwards = FlatField(
1149
+ self._radios_outwards.shape,
1150
+ flats=self.dataset_info.flats,
1151
+ darks=self.dataset_info.darks,
1152
+ radios_indices=self._radios_idx_outwards,
1153
+ )
1154
+ self._flatfield_outwards.normalize_radios(self._radios_outwards)
1155
+ self._flatfield_return = FlatField(
1156
+ self._radios_return.shape,
1157
+ flats=self.dataset_info.flats,
1158
+ darks=self.dataset_info.darks,
1159
+ radios_indices=self._radios_idx_return,
1160
+ )
1161
+ self._flatfield_outwards.normalize_radios(self._radios_return)
1162
+
1163
+ def estimate_motion(self):
1164
+ self._load_data()
1165
+ self._apply_flatfield()
1166
+
1167
+ n_projs_tot = self.dataset_info.n_angles
1168
+ if self.do_360:
1169
+ n_a = self.radios.shape[0]
1170
+ # See notes above - this works only for circular trajectory / ordered angles
1171
+ projs_stack1 = self.radios[: n_a // 2]
1172
+ projs_stack2 = self.radios[n_a // 2 :]
1173
+ angles1 = self.angles[: n_a // 2]
1174
+ angles2 = self.angles[n_a // 2 :]
1175
+ indices1 = (self._radios_idx - self._radios_idx[0])[: n_a // 2]
1176
+ indices2 = (self._radios_idx - self._radios_idx[0])[n_a // 2 :]
1177
+ else:
1178
+ projs_stack1 = self._radios_outwards
1179
+ projs_stack2 = self._radios_return
1180
+ angles1 = self._angles_outward
1181
+ angles2 = self._angles_return
1182
+ indices1 = self._radios_idx_outwards - self._radios_idx_outwards.min()
1183
+ indices2 = self._radios_idx_return - self._radios_idx_outwards.min()
1184
+
1185
+ if self._start_x is not None:
1186
+ # Compute Motion Estimation on subset of images (eg. for half-tomo)
1187
+ projs_stack1 = projs_stack1[..., self._start_x : self._end_x]
1188
+ projs_stack2 = projs_stack2[..., self._start_x : self._end_x]
1189
+
1190
+ self.motion_estimator = MotionEstimation(
1191
+ projs_stack1,
1192
+ projs_stack2,
1193
+ angles1,
1194
+ angles2,
1195
+ indices1,
1196
+ indices2,
1197
+ n_projs_tot,
1198
+ shifts_estimator=self._shifts_estimator,
1199
+ shifts_estimator_kwargs=self._shifts_estimator_kwargs,
1200
+ )
1201
+
1202
+ self.logger.debug("[MotionEstimation] estimating shifts")
1203
+
1204
+ estimated_shifts_v = self.motion_estimator.estimate_vertical_motion(degree=self._deg_z)
1205
+ estimated_shifts_h, cor = self.motion_estimator.estimate_horizontal_motion(degree=self._deg_xy, cor=self._cor)
1206
+ if self._start_x is not None:
1207
+ cor += (self._start_x % self.radios.shape[-1]) + (projs_stack1.shape[-1] - 1) / 2.0
1208
+
1209
+ self.sample_shifts_xy = estimated_shifts_h
1210
+ self.sample_shifts_z = estimated_shifts_v
1211
+ if self._cor is None:
1212
+ self.logger.info(
1213
+ "[MotionEstimation] Estimated center of rotation (relative to middle of detector): %.2f" % cor
1214
+ )
1215
+ return estimated_shifts_h, estimated_shifts_v, cor
1216
+
1217
+ def generate_translations_movements_file(self, filename, fmt="%.3f", only=None):
1218
+ if self.sample_shifts_xy is None:
1219
+ raise RuntimeError("Need to run estimate_motion() first")
1220
+
1221
+ angles = self.dataset_info.rotation_angles
1222
+ cor = self._cor or 0
1223
+ txy_est_all_angles = self.motion_estimator.apply_fit_horiz(angles=angles)
1224
+ tz_est_all_angles = self.motion_estimator.apply_fit_vertic(angles=angles)
1225
+ estimated_shifts_vu_all_angles = self.motion_estimator.convert_sample_motion_to_detector_shifts(
1226
+ txy_est_all_angles, tz_est_all_angles, angles, cor=cor
1227
+ )
1228
+ estimated_shifts_vu_all_angles[:, 1] -= cor
1229
+ correct_shifts_uv = -estimated_shifts_vu_all_angles[:, ::-1]
1230
+
1231
+ if only is not None:
1232
+ if only == "horizontal":
1233
+ correct_shifts_uv[:, 1] = 0
1234
+ elif only == "vertical":
1235
+ correct_shifts_uv[:, 0] = 0
1236
+ else:
1237
+ raise ValueError("Expected 'only' to be either None, 'horizontal' or 'vertical'")
1238
+
1239
+ header = f"Generated by nabu {nabu_version} : {str(self)}"
1240
+ np.savetxt(filename, correct_shifts_uv, fmt=fmt, header=header)
1241
+
1242
+ def __str__(self):
1243
+ ret = f"{self.__class__.__name__}(do_flatfield={self.do_flatfield}, rot_center={self._cor}, angular_subsampling={self.angular_subsampling})"
1244
+ if self.sample_shifts_xy is not None:
1245
+ ret += f", shifts_estimator={self.motion_estimator.shifts_estimator}"
1246
+ return ret