nabu 2025.1.0.dev13__py3-none-any.whl → 2025.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/cast_volume.py +12 -1
- nabu/app/cli_configs.py +81 -4
- nabu/app/estimate_motion.py +54 -0
- nabu/app/multicor.py +2 -4
- nabu/app/pcaflats.py +116 -0
- nabu/app/reconstruct.py +1 -7
- nabu/app/reduce_dark_flat.py +5 -2
- nabu/estimation/cor.py +1 -1
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tilt.py +1 -1
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +94 -13
- nabu/io/reader.py +32 -1
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/pipeline/config_validators.py +42 -43
- nabu/pipeline/estimators.py +255 -0
- nabu/pipeline/fullfield/chunked.py +67 -43
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/nabu_config.py +17 -11
- nabu/pipeline/fullfield/processconfig.py +8 -2
- nabu/pipeline/fullfield/reconstruction.py +3 -0
- nabu/pipeline/params.py +12 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/flatfield.py +306 -1
- nabu/preproc/shift.py +3 -1
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +31 -7
- nabu/reconstruction/fbp_opencl.py +8 -0
- nabu/reconstruction/filtering_opencl.py +2 -0
- nabu/reconstruction/mlem.py +51 -14
- nabu/reconstruction/tests/test_filtering.py +13 -2
- nabu/reconstruction/tests/test_mlem.py +91 -62
- nabu/resources/dataset_analyzer.py +144 -20
- nabu/resources/nxflatfield.py +101 -35
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +7 -7
- nabu/stitching/config.py +22 -20
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/stitcher/post_processing.py +5 -3
- nabu/stitching/stitcher/pre_processing.py +24 -20
- nabu/stitching/tests/test_config.py +3 -3
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +2 -2
- nabu/stitching/tests/test_z_preprocessing_stitching.py +23 -20
- nabu/stitching/utils/utils.py +7 -7
- nabu/testutils.py +1 -4
- nabu/utils.py +13 -0
- {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/METADATA +3 -4
- {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/RECORD +62 -57
- {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/WHEEL +1 -1
- {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -62
- {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/licenses/LICENSE +0 -0
- {nabu-2025.1.0.dev13.dist-info → nabu-2025.1.0rc1.dist-info}/top_level.txt +0 -0
nabu/io/reader.py
CHANGED
@@ -4,6 +4,7 @@ from math import ceil
|
|
4
4
|
from multiprocessing.pool import ThreadPool
|
5
5
|
from posixpath import sep as posix_sep, join as posix_join
|
6
6
|
import numpy as np
|
7
|
+
from h5py import Dataset
|
7
8
|
from silx.io import get_data
|
8
9
|
from silx.io.dictdump import h5todict
|
9
10
|
from tomoscan.io import HDF5File
|
@@ -770,7 +771,7 @@ class NXDarksFlats:
|
|
770
771
|
reduced_frames = [self._reduce_func[method](frames, axis=0) for frames in raw_frames]
|
771
772
|
reader = getattr(self, "%s_reader" % what)
|
772
773
|
if as_dict:
|
773
|
-
return {k: v for k, v in zip([s.start for s in reader._image_key_slices], reduced_frames)}
|
774
|
+
return {k: v for k, v in zip([s.start for s in reader._image_key_slices], reduced_frames)}
|
774
775
|
return reduced_frames
|
775
776
|
|
776
777
|
def get_raw_darks(self, force_reload=False, as_multiple_array=True):
|
@@ -987,6 +988,12 @@ def get_entry_from_h5_path(h5_path):
|
|
987
988
|
return v[0] or v[1]
|
988
989
|
|
989
990
|
|
991
|
+
def list_hdf5_entries(fname):
|
992
|
+
with HDF5File(fname, "r") as f:
|
993
|
+
entries = list(f.keys())
|
994
|
+
return entries
|
995
|
+
|
996
|
+
|
990
997
|
def check_virtual_sources_exist(fname, data_path):
|
991
998
|
with HDF5File(fname, "r") as f:
|
992
999
|
if data_path not in f:
|
@@ -1006,6 +1013,30 @@ def check_virtual_sources_exist(fname, data_path):
|
|
1006
1013
|
return True
|
1007
1014
|
|
1008
1015
|
|
1016
|
+
def get_hdf5_file_all_virtual_sources(file_path_or_obj, return_only_filenames=False):
|
1017
|
+
result = []
|
1018
|
+
|
1019
|
+
def collect_vsources(name, obj):
|
1020
|
+
if isinstance(obj, Dataset) and obj.is_virtual:
|
1021
|
+
vs = obj.virtual_sources()
|
1022
|
+
if return_only_filenames:
|
1023
|
+
vs = [vs_.file_name for vs_ in vs]
|
1024
|
+
result.append({name: vs})
|
1025
|
+
|
1026
|
+
_self_opened_file = False
|
1027
|
+
if isinstance(file_path_or_obj, str):
|
1028
|
+
fdesc = HDF5File(file_path_or_obj, "r")
|
1029
|
+
_self_opened_file = True
|
1030
|
+
else:
|
1031
|
+
fdesc = file_path_or_obj
|
1032
|
+
|
1033
|
+
fdesc.visititems(collect_vsources)
|
1034
|
+
|
1035
|
+
if _self_opened_file:
|
1036
|
+
fdesc.close()
|
1037
|
+
return result
|
1038
|
+
|
1039
|
+
|
1009
1040
|
def import_h5_to_dict(h5file, h5path, asarray=False):
|
1010
1041
|
"""
|
1011
1042
|
Wrapper on top of silx.io.dictdump.dicttoh5 replacing "None" with None
|
@@ -0,0 +1,152 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from os import path, mkdir, rename
|
3
|
+
import numpy as np
|
4
|
+
import pytest
|
5
|
+
from tomoscan.esrf.volume import (
|
6
|
+
EDFVolume,
|
7
|
+
HDF5Volume,
|
8
|
+
JP2KVolume,
|
9
|
+
MultiTIFFVolume,
|
10
|
+
TIFFVolume,
|
11
|
+
)
|
12
|
+
from tomoscan.esrf.volume.jp2kvolume import has_glymur
|
13
|
+
from nabu.io.writer import merge_hdf5_files
|
14
|
+
from nabu.io.cast_volume import remove_volume
|
15
|
+
|
16
|
+
|
17
|
+
def test_remove_single_frame_volume(tmpdir):
|
18
|
+
"""
|
19
|
+
Test volume removal for tiff, jp2 and EDF
|
20
|
+
"""
|
21
|
+
# Have to use a not-too-small size because of jp2k
|
22
|
+
data = np.arange(10 * 40 * 50, dtype="f").reshape((10, 40, 50))
|
23
|
+
|
24
|
+
volume_classes = [EDFVolume, TIFFVolume, JP2KVolume]
|
25
|
+
if not (has_glymur):
|
26
|
+
volume_classes.pop()
|
27
|
+
for volume_cls in volume_classes:
|
28
|
+
ext = volume_cls.DEFAULT_DATA_EXTENSION
|
29
|
+
folder = path.join(tmpdir, f"{ext}_vol")
|
30
|
+
volume_basename = f"{ext}_basename"
|
31
|
+
|
32
|
+
vol_writer = volume_cls(folder=folder, volume_basename=volume_basename, overwrite=True, start_index=0)
|
33
|
+
vol_writer.data = data
|
34
|
+
vol_writer.save()
|
35
|
+
|
36
|
+
vol_reader = volume_cls(folder=folder, volume_basename=volume_basename)
|
37
|
+
assert path.isdir(folder), f"Expected to find a folder f{folder}"
|
38
|
+
remove_volume(vol_reader)
|
39
|
+
assert not (path.isdir(folder)), f"Expected to have removed the folder f{folder}"
|
40
|
+
|
41
|
+
vol_writer.save()
|
42
|
+
vol_reader = volume_cls(folder=folder, volume_basename=volume_basename)
|
43
|
+
Path(path.join(folder, f"unexpected.{ext}")).touch()
|
44
|
+
with pytest.raises(RuntimeError) as exc:
|
45
|
+
remove_volume(vol_reader, check=True)
|
46
|
+
assert "Unexpected files present" in str(exc.value), "Expected check to find extraneous files"
|
47
|
+
|
48
|
+
|
49
|
+
def test_remove_multiframe_volume(tmpdir):
|
50
|
+
"""
|
51
|
+
Test volume removal for "multiframe" formats (HDF5, tiff3D)
|
52
|
+
The HDF5 files considered in this test do not have virtual sources
|
53
|
+
"""
|
54
|
+
data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
|
55
|
+
|
56
|
+
for ext, volume_cls in {"h5": HDF5Volume, "tiff": MultiTIFFVolume}.items():
|
57
|
+
file_path = path.join(tmpdir, f"{ext}_vol.{ext}")
|
58
|
+
|
59
|
+
init_kwargs = {"file_path": file_path}
|
60
|
+
if ext == "h5":
|
61
|
+
init_kwargs["data_path"] = "entry"
|
62
|
+
vol_writer = volume_cls(**init_kwargs)
|
63
|
+
vol_writer.data = data
|
64
|
+
vol_writer.save()
|
65
|
+
|
66
|
+
vol_reader = volume_cls(**init_kwargs)
|
67
|
+
assert path.isfile(file_path), f"Expected to find a {ext} volume at {file_path}"
|
68
|
+
remove_volume(vol_reader)
|
69
|
+
assert not (path.isfile(file_path)), f"Expected to have removed f{file_path}"
|
70
|
+
|
71
|
+
|
72
|
+
def test_remove_hdf5_multiple_entries(tmpdir):
|
73
|
+
data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
|
74
|
+
file_path = path.join(tmpdir, "h5_vol.h5")
|
75
|
+
vol_writer_1 = HDF5Volume(file_path=file_path, data_path="entry0000")
|
76
|
+
vol_writer_1.data = data
|
77
|
+
vol_writer_1.save()
|
78
|
+
vol_writer_2 = HDF5Volume(file_path=file_path, data_path="entry0001")
|
79
|
+
vol_writer_2.data = data + 10
|
80
|
+
vol_writer_2.save()
|
81
|
+
vol_reader = HDF5Volume(file_path=file_path, data_path="entry0000")
|
82
|
+
with pytest.raises(NotImplementedError) as exc:
|
83
|
+
remove_volume(vol_reader, check=True)
|
84
|
+
assert "Removing a HDF5 volume with more than one entry is not supported" in str(
|
85
|
+
exc.value
|
86
|
+
), "Expected an error message"
|
87
|
+
|
88
|
+
|
89
|
+
def test_remove_nabu_hdf5_reconstruction(tmpdir):
|
90
|
+
"""
|
91
|
+
Test removal of HDF5 reconstruction generated by nabu (i.e with virtual sources)
|
92
|
+
"""
|
93
|
+
|
94
|
+
entry = "entry"
|
95
|
+
process_name = "reconstruction"
|
96
|
+
|
97
|
+
master_file_path = path.join(tmpdir, "sample_naburec.hdf5")
|
98
|
+
associated_dir = path.join(tmpdir, "sample_naburec")
|
99
|
+
if not (path.isdir(associated_dir)):
|
100
|
+
mkdir(associated_dir)
|
101
|
+
|
102
|
+
n_chunks = 5
|
103
|
+
local_files = []
|
104
|
+
for i in range(n_chunks):
|
105
|
+
fname = "sample_naburec_%06d.h5" % i
|
106
|
+
partial_rec_abspath = path.join(associated_dir, fname)
|
107
|
+
local_files.append(f"sample_naburec/{fname}")
|
108
|
+
# local_files.append(fname)
|
109
|
+
vol = HDF5Volume(partial_rec_abspath, data_path=f"{entry}/{process_name}")
|
110
|
+
vol.data = np.arange(3 * 4 * 5, dtype="f").reshape((3, 4, 5))
|
111
|
+
vol.save()
|
112
|
+
|
113
|
+
h5_path = f"{entry}/{process_name}/results/data"
|
114
|
+
|
115
|
+
merge_hdf5_files(
|
116
|
+
local_files,
|
117
|
+
h5_path,
|
118
|
+
master_file_path,
|
119
|
+
process_name,
|
120
|
+
output_entry=entry,
|
121
|
+
output_filemode="a",
|
122
|
+
processing_index=0,
|
123
|
+
config=None,
|
124
|
+
base_dir=path.dirname(associated_dir),
|
125
|
+
axis=0,
|
126
|
+
overwrite=True,
|
127
|
+
)
|
128
|
+
|
129
|
+
assert path.isfile(master_file_path), f"Expected to find the master file at {master_file_path}"
|
130
|
+
assert path.isdir(associated_dir)
|
131
|
+
for local_file in local_files:
|
132
|
+
partial_rec_file = path.join(tmpdir, local_file)
|
133
|
+
assert path.isfile(partial_rec_file), f"Expected to find partial file number {i} at {partial_rec_file}"
|
134
|
+
|
135
|
+
# Check that the virtual links are handled properly
|
136
|
+
# sample_rec.hdf5 should reference sample_rec/sample_rec_{i}.h5
|
137
|
+
renamed_master_file_path = (
|
138
|
+
path.join(path.dirname(master_file_path), path.basename(master_file_path).split(".")[0]) + "_renamed" + ".h5"
|
139
|
+
)
|
140
|
+
rename(master_file_path, renamed_master_file_path)
|
141
|
+
h5_vol = HDF5Volume(file_path=renamed_master_file_path, data_path=f"{entry}/{process_name}")
|
142
|
+
with pytest.raises(ValueError) as exc:
|
143
|
+
remove_volume(h5_vol)
|
144
|
+
expected_error_message = f"The virtual sources in {renamed_master_file_path}:{process_name}/results/data reference the directory sample_naburec, but expected was sample_naburec_renamed"
|
145
|
+
assert str(exc.value) == expected_error_message
|
146
|
+
|
147
|
+
# Check removal in normal circumstances
|
148
|
+
rename(renamed_master_file_path, master_file_path)
|
149
|
+
h5_vol = HDF5Volume(file_path=master_file_path, data_path=f"{entry}/{process_name}")
|
150
|
+
remove_volume(h5_vol)
|
151
|
+
assert not (path.isfile(master_file_path)), f"Expected to find the master file at {master_file_path}"
|
152
|
+
assert not (path.isdir(associated_dir))
|
@@ -2,7 +2,7 @@
|
|
2
2
|
import os
|
3
3
|
|
4
4
|
path = os.path
|
5
|
-
from ..utils import check_supported, is_writeable
|
5
|
+
from ..utils import check_supported, deprecation_warning, is_writeable
|
6
6
|
from .params import * # noqa: F403
|
7
7
|
|
8
8
|
"""
|
@@ -97,12 +97,21 @@ def convert_to_bool_noerr(val):
|
|
97
97
|
return res
|
98
98
|
|
99
99
|
|
100
|
-
def name_range_checker(name,
|
100
|
+
def name_range_checker(name, available_names, descr):
|
101
|
+
"""
|
102
|
+
Check whether a parameter name is valid, against a list or dictionary of names.
|
103
|
+
"""
|
101
104
|
name = name.strip().lower()
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
105
|
+
valid = name in available_names
|
106
|
+
if isinstance(available_names, dict):
|
107
|
+
# handle replacements, eg. {"edge": "edges"}
|
108
|
+
name = available_names[name]
|
109
|
+
# we could use .keys() instead to be more permissive to the user
|
110
|
+
available_names_str = str(set(available_names.values()))
|
111
|
+
else:
|
112
|
+
# assuming list
|
113
|
+
available_names_str = str(available_names)
|
114
|
+
assert valid, "Invalid %s '%s'. Available are %s" % (descr, name, available_names_str)
|
106
115
|
return name
|
107
116
|
|
108
117
|
|
@@ -338,9 +347,7 @@ def cor_validator(val):
|
|
338
347
|
return val_float
|
339
348
|
if len(val.strip()) == 0:
|
340
349
|
return None
|
341
|
-
val = name_range_checker(
|
342
|
-
val.lower(), set(cor_methods.values()), "center of rotation estimation method", replacements=cor_methods
|
343
|
-
)
|
350
|
+
val = name_range_checker(val, cor_methods, "center of rotation estimation method")
|
344
351
|
return val
|
345
352
|
|
346
353
|
|
@@ -351,9 +358,7 @@ def tilt_validator(val):
|
|
351
358
|
return val_float
|
352
359
|
if len(val.strip()) == 0:
|
353
360
|
return None
|
354
|
-
val = name_range_checker(
|
355
|
-
val.lower(), set(tilt_methods.values()), "automatic detector tilt estimation method", replacements=tilt_methods
|
356
|
-
)
|
361
|
+
val = name_range_checker(val, tilt_methods, "automatic detector tilt estimation method")
|
357
362
|
return val
|
358
363
|
|
359
364
|
|
@@ -395,53 +400,55 @@ def cor_slice_validator(val):
|
|
395
400
|
|
396
401
|
|
397
402
|
@validator
|
398
|
-
def
|
399
|
-
|
403
|
+
def flatfield_validator(val):
|
404
|
+
ret = name_range_checker(val, flatfield_modes, "flatfield mode")
|
405
|
+
if ret in ["force-load", "force-compute"]:
|
406
|
+
deprecation_warning(
|
407
|
+
f"Using 'flatfield = {ret}' is deprecated since version 2025.1.0. Please use the parameter 'flatfield_loading_mode'",
|
408
|
+
)
|
409
|
+
return ret
|
410
|
+
|
411
|
+
|
412
|
+
@validator
|
413
|
+
def flatfield_loading_mode_validator(val):
|
414
|
+
return name_range_checker(val, flatfield_loading_mode, "flatfield mode")
|
400
415
|
|
401
416
|
|
402
417
|
@validator
|
403
418
|
def phase_method_validator(val):
|
404
|
-
return name_range_checker(
|
405
|
-
val, set(phase_retrieval_methods.values()), "phase retrieval method", replacements=phase_retrieval_methods
|
406
|
-
)
|
419
|
+
return name_range_checker(val, phase_retrieval_methods, "phase retrieval method")
|
407
420
|
|
408
421
|
|
409
422
|
@validator
|
410
423
|
def detector_distortion_correction_validator(val):
|
411
424
|
return name_range_checker(
|
412
425
|
val,
|
413
|
-
|
426
|
+
detector_distortion_correction_methods,
|
414
427
|
"detector_distortion_correction_methods",
|
415
|
-
replacements=detector_distortion_correction_methods,
|
416
428
|
)
|
417
429
|
|
418
430
|
|
419
431
|
@validator
|
420
432
|
def unsharp_method_validator(val):
|
421
|
-
return name_range_checker(
|
422
|
-
val, set(unsharp_methods.values()), "unsharp mask method", replacements=phase_retrieval_methods
|
423
|
-
)
|
433
|
+
return name_range_checker(val, unsharp_methods, "unsharp mask method")
|
424
434
|
|
425
435
|
|
426
436
|
@validator
|
427
437
|
def padding_mode_validator(val):
|
428
|
-
return name_range_checker(val,
|
438
|
+
return name_range_checker(val, padding_modes, "padding mode")
|
429
439
|
|
430
440
|
|
431
441
|
@validator
|
432
442
|
def reconstruction_method_validator(val):
|
433
|
-
return name_range_checker(
|
434
|
-
val, set(reconstruction_methods.values()), "reconstruction method", replacements=reconstruction_methods
|
435
|
-
)
|
443
|
+
return name_range_checker(val, reconstruction_methods, "reconstruction method")
|
436
444
|
|
437
445
|
|
438
446
|
@validator
|
439
447
|
def fbp_filter_name_validator(val):
|
440
448
|
return name_range_checker(
|
441
449
|
val,
|
442
|
-
|
450
|
+
fbp_filters,
|
443
451
|
"FBP filter",
|
444
|
-
replacements=fbp_filters,
|
445
452
|
)
|
446
453
|
|
447
454
|
|
@@ -449,29 +456,24 @@ def fbp_filter_name_validator(val):
|
|
449
456
|
def reconstruction_implementation_validator(val):
|
450
457
|
return name_range_checker(
|
451
458
|
val,
|
452
|
-
|
459
|
+
reco_implementations,
|
453
460
|
"Reconstruction method implementation",
|
454
|
-
replacements=reco_implementations,
|
455
461
|
)
|
456
462
|
|
457
463
|
|
458
464
|
@validator
|
459
465
|
def optimization_algorithm_name_validator(val):
|
460
|
-
return name_range_checker(
|
461
|
-
val, set(optim_algorithms.values()), "optimization algorithm name", replacements=iterative_methods
|
462
|
-
)
|
466
|
+
return name_range_checker(val, optim_algorithms, "optimization algorithm name")
|
463
467
|
|
464
468
|
|
465
469
|
@validator
|
466
470
|
def output_file_format_validator(val):
|
467
|
-
return name_range_checker(val,
|
471
|
+
return name_range_checker(val, files_formats, "output file format")
|
468
472
|
|
469
473
|
|
470
474
|
@validator
|
471
475
|
def distribution_method_validator(val):
|
472
|
-
val = name_range_checker(
|
473
|
-
val, set(distribution_methods.values()), "workload distribution method", replacements=distribution_methods
|
474
|
-
)
|
476
|
+
val = name_range_checker(val, distribution_methods, "workload distribution method")
|
475
477
|
# TEMP.
|
476
478
|
if val != "local":
|
477
479
|
raise NotImplementedError("Computation method '%s' is not implemented yet" % val)
|
@@ -481,9 +483,7 @@ def distribution_method_validator(val):
|
|
481
483
|
|
482
484
|
@validator
|
483
485
|
def sino_normalization_validator(val):
|
484
|
-
val = name_range_checker(
|
485
|
-
val, set(sino_normalizations.values()), "sinogram normalization method", replacements=sino_normalizations
|
486
|
-
)
|
486
|
+
val = name_range_checker(val, sino_normalizations, "sinogram normalization method")
|
487
487
|
return val
|
488
488
|
|
489
489
|
|
@@ -491,9 +491,8 @@ def sino_normalization_validator(val):
|
|
491
491
|
def sino_deringer_methods(val):
|
492
492
|
val = name_range_checker(
|
493
493
|
val,
|
494
|
-
|
494
|
+
rings_methods,
|
495
495
|
"sinogram rings artefacts correction method",
|
496
|
-
replacements=rings_methods,
|
497
496
|
)
|
498
497
|
return val
|
499
498
|
|
@@ -556,7 +555,7 @@ def nonempty_string_validator(val):
|
|
556
555
|
|
557
556
|
@validator
|
558
557
|
def logging_validator(val):
|
559
|
-
return name_range_checker(val,
|
558
|
+
return name_range_checker(val, log_levels, "logging level")
|
560
559
|
|
561
560
|
|
562
561
|
@validator
|
nabu/pipeline/estimators.py
CHANGED
@@ -9,6 +9,7 @@ import scipy.fft # pylint: disable=E0611
|
|
9
9
|
from silx.io import get_data
|
10
10
|
import math
|
11
11
|
from scipy import ndimage as nd
|
12
|
+
|
12
13
|
from ..preproc.flatfield import FlatField
|
13
14
|
from ..estimation.cor import (
|
14
15
|
CenterOfRotation,
|
@@ -17,8 +18,10 @@ from ..estimation.cor import (
|
|
17
18
|
CenterOfRotationGrowingWindow,
|
18
19
|
CenterOfRotationOctaveAccurate,
|
19
20
|
)
|
21
|
+
from .. import version as nabu_version
|
20
22
|
from ..estimation.cor_sino import SinoCorInterface, CenterOfRotationFourierAngles, CenterOfRotationVo
|
21
23
|
from ..estimation.tilt import CameraTilt
|
24
|
+
from ..estimation.motion import MotionEstimation
|
22
25
|
from ..estimation.utils import is_fullturn_scan
|
23
26
|
from ..resources.logger import LoggerOrPrint
|
24
27
|
from ..resources.utils import extract_parameters
|
@@ -989,3 +992,255 @@ class DetectorTiltEstimator:
|
|
989
992
|
|
990
993
|
# alias
|
991
994
|
TiltFinder = DetectorTiltEstimator
|
995
|
+
|
996
|
+
|
997
|
+
def estimate_translations(dataset_info, do_flatfield=True): ...
|
998
|
+
|
999
|
+
|
1000
|
+
class TranslationsEstimator:
|
1001
|
+
|
1002
|
+
_default_extra_options = {
|
1003
|
+
"window_size": 200,
|
1004
|
+
}
|
1005
|
+
|
1006
|
+
def __init__(
|
1007
|
+
self,
|
1008
|
+
dataset_info,
|
1009
|
+
do_flatfield=True,
|
1010
|
+
rot_center=None,
|
1011
|
+
halftomo_side=None,
|
1012
|
+
angular_subsampling=10,
|
1013
|
+
deg_xy=2,
|
1014
|
+
deg_z=2,
|
1015
|
+
shifts_estimator="phase_cross_correlation",
|
1016
|
+
extra_options=None,
|
1017
|
+
):
|
1018
|
+
self._configure_extra_options(extra_options)
|
1019
|
+
self.logger = LoggerOrPrint(dataset_info.logger)
|
1020
|
+
self.dataset_info = dataset_info
|
1021
|
+
self.angular_subsampling = angular_subsampling
|
1022
|
+
self.do_360 = self.dataset_info.is_360
|
1023
|
+
self.do_flatfield = do_flatfield
|
1024
|
+
self.radios = None
|
1025
|
+
self._deg_xy = deg_xy
|
1026
|
+
self._deg_z = deg_z
|
1027
|
+
self._shifts_estimator = shifts_estimator
|
1028
|
+
self._shifts_estimator_kwargs = {}
|
1029
|
+
self._cor = rot_center
|
1030
|
+
self._configure_halftomo(halftomo_side)
|
1031
|
+
self._estimate_cor = self._cor is None
|
1032
|
+
self.sample_shifts_xy = None
|
1033
|
+
self.sample_shifts_z = None
|
1034
|
+
|
1035
|
+
def _configure_extra_options(self, extra_options):
|
1036
|
+
self.extra_options = self._default_extra_options.copy()
|
1037
|
+
self.extra_options.update(extra_options or {})
|
1038
|
+
|
1039
|
+
def _configure_halftomo(self, halftomo_side):
|
1040
|
+
if halftomo_side is False:
|
1041
|
+
# Force disable halftomo
|
1042
|
+
self.halftomo_side = False
|
1043
|
+
return
|
1044
|
+
self._start_x = None
|
1045
|
+
self._end_x = None
|
1046
|
+
if (halftomo_side is not None) and not (self.do_360):
|
1047
|
+
raise ValueError(
|
1048
|
+
"Expected 360° dataset for half-tomography, but this dataset does not look like a 360° dataset"
|
1049
|
+
)
|
1050
|
+
if halftomo_side is None:
|
1051
|
+
if self.dataset_info.is_halftomo:
|
1052
|
+
halftomo_side = "right"
|
1053
|
+
else:
|
1054
|
+
self.halftomo_side = False
|
1055
|
+
return
|
1056
|
+
self.halftomo_side = halftomo_side
|
1057
|
+
window_size = self.extra_options["window_size"]
|
1058
|
+
if self._cor is not None:
|
1059
|
+
# In this case we look for shifts around the CoR
|
1060
|
+
self._start_x = int(self._cor - window_size / 2)
|
1061
|
+
self._end_x = int(self._cor + window_size / 2)
|
1062
|
+
elif halftomo_side == "right":
|
1063
|
+
self._start_x = -window_size
|
1064
|
+
self._end_x = None
|
1065
|
+
elif halftomo_side == "left":
|
1066
|
+
self._start_x = 0
|
1067
|
+
self._end_x = window_size
|
1068
|
+
elif is_scalar(halftomo_side):
|
1069
|
+
# Expect approximate location of CoR, relative to left-most column
|
1070
|
+
self._start_x = int(halftomo_side - window_size / 2)
|
1071
|
+
self._end_x = int(halftomo_side + window_size / 2)
|
1072
|
+
else:
|
1073
|
+
raise ValueError(
|
1074
|
+
f"Expected 'halftomo_side' to be either 'left', 'right', or an integer (got {halftomo_side})"
|
1075
|
+
)
|
1076
|
+
self.logger.debug(f"[MotionEstimation] Half-tomo looking at [{self._start_x}:{self._end_x}]")
|
1077
|
+
# For half-tomo, skimage.registration.phase_cross_correlation might look a bit too far away
|
1078
|
+
if (
|
1079
|
+
self._shifts_estimator == "phase_cross_correlation"
|
1080
|
+
and self._shifts_estimator_kwargs.get("overlap_ratio", 0.3) >= 0.3
|
1081
|
+
):
|
1082
|
+
self._shifts_estimator_kwargs.update({"overlap_ratio": 0.2})
|
1083
|
+
#
|
1084
|
+
|
1085
|
+
def _load_data(self):
|
1086
|
+
self.logger.debug("[MotionEstimation] reading data")
|
1087
|
+
if self.do_360:
|
1088
|
+
"""
|
1089
|
+
In this case we compare pair of opposite projections.
|
1090
|
+
If rotation angles are arbitrary, we should do something like
|
1091
|
+
for angle in dataset_info.rotation_angles:
|
1092
|
+
img, angle_deg, idx = dataset_info.get_image_at_angle(
|
1093
|
+
np.degrees(angle)+180, return_angle_and_index=True
|
1094
|
+
)
|
1095
|
+
Most of the time (always ?), the dataset was acquired with a circular trajectory,
|
1096
|
+
so we can use angles:
|
1097
|
+
dataset_info.rotation_angles[::self.angular_subsampling]
|
1098
|
+
which amounts to reading one radio out of "angular_subsampling"
|
1099
|
+
"""
|
1100
|
+
|
1101
|
+
# TODO account for more general rotation angles. The following will only work for circular trajectory and ordered angles
|
1102
|
+
self._reader = self.dataset_info.get_reader(
|
1103
|
+
sub_region=(slice(None, None, self.angular_subsampling), slice(None), slice(None))
|
1104
|
+
)
|
1105
|
+
self.radios = self._reader.load_data()
|
1106
|
+
self.angles = self.dataset_info.rotation_angles[:: self.angular_subsampling]
|
1107
|
+
self._radios_idx = self._reader.get_frames_indices()
|
1108
|
+
self.logger.debug("[MotionEstimation] This is a 360° scan, will use pairs of opposite projections")
|
1109
|
+
else:
|
1110
|
+
"""
|
1111
|
+
In this case we use the "return projections", i.e special projections acquired at several angles
|
1112
|
+
(eg. [180, 90, 0]) before ending the scan
|
1113
|
+
"""
|
1114
|
+
return_projs, return_angles_deg, return_idx = self.dataset_info.get_alignment_projections()
|
1115
|
+
self._angles_return = np.radians(return_angles_deg)
|
1116
|
+
self._radios_return = return_projs
|
1117
|
+
self._radios_idx_return = return_idx
|
1118
|
+
|
1119
|
+
projs = []
|
1120
|
+
angles_rad = []
|
1121
|
+
projs_idx = []
|
1122
|
+
for angle_deg in return_angles_deg:
|
1123
|
+
proj, rot_angle_deg, proj_idx = self.dataset_info.get_image_at_angle(
|
1124
|
+
angle_deg, image_type="projection", return_angle_and_index=True
|
1125
|
+
)
|
1126
|
+
projs.append(proj)
|
1127
|
+
angles_rad.append(np.radians(rot_angle_deg))
|
1128
|
+
projs_idx.append(proj_idx)
|
1129
|
+
self._radios_outwards = np.array(projs)
|
1130
|
+
self._angles_outward = np.array(angles_rad)
|
1131
|
+
self._radios_idx_outwards = np.array(projs_idx)
|
1132
|
+
self.logger.debug("[MotionEstimation] This is a 180° scan, will use 'return projections'")
|
1133
|
+
|
1134
|
+
def _apply_flatfield(self):
|
1135
|
+
if not (self.do_flatfield):
|
1136
|
+
return
|
1137
|
+
self.logger.debug("[MotionEstimation] flatfield")
|
1138
|
+
if self.do_360:
|
1139
|
+
self._flatfield = FlatField(
|
1140
|
+
self.radios.shape,
|
1141
|
+
flats=self.dataset_info.flats,
|
1142
|
+
darks=self.dataset_info.darks,
|
1143
|
+
radios_indices=self._radios_idx,
|
1144
|
+
)
|
1145
|
+
self._flatfield.normalize_radios(self.radios)
|
1146
|
+
else:
|
1147
|
+
# 180 + return projs
|
1148
|
+
self._flatfield_outwards = FlatField(
|
1149
|
+
self._radios_outwards.shape,
|
1150
|
+
flats=self.dataset_info.flats,
|
1151
|
+
darks=self.dataset_info.darks,
|
1152
|
+
radios_indices=self._radios_idx_outwards,
|
1153
|
+
)
|
1154
|
+
self._flatfield_outwards.normalize_radios(self._radios_outwards)
|
1155
|
+
self._flatfield_return = FlatField(
|
1156
|
+
self._radios_return.shape,
|
1157
|
+
flats=self.dataset_info.flats,
|
1158
|
+
darks=self.dataset_info.darks,
|
1159
|
+
radios_indices=self._radios_idx_return,
|
1160
|
+
)
|
1161
|
+
self._flatfield_outwards.normalize_radios(self._radios_return)
|
1162
|
+
|
1163
|
+
def estimate_motion(self):
|
1164
|
+
self._load_data()
|
1165
|
+
self._apply_flatfield()
|
1166
|
+
|
1167
|
+
n_projs_tot = self.dataset_info.n_angles
|
1168
|
+
if self.do_360:
|
1169
|
+
n_a = self.radios.shape[0]
|
1170
|
+
# See notes above - this works only for circular trajectory / ordered angles
|
1171
|
+
projs_stack1 = self.radios[: n_a // 2]
|
1172
|
+
projs_stack2 = self.radios[n_a // 2 :]
|
1173
|
+
angles1 = self.angles[: n_a // 2]
|
1174
|
+
angles2 = self.angles[n_a // 2 :]
|
1175
|
+
indices1 = (self._radios_idx - self._radios_idx[0])[: n_a // 2]
|
1176
|
+
indices2 = (self._radios_idx - self._radios_idx[0])[n_a // 2 :]
|
1177
|
+
else:
|
1178
|
+
projs_stack1 = self._radios_outwards
|
1179
|
+
projs_stack2 = self._radios_return
|
1180
|
+
angles1 = self._angles_outward
|
1181
|
+
angles2 = self._angles_return
|
1182
|
+
indices1 = self._radios_idx_outwards - self._radios_idx_outwards.min()
|
1183
|
+
indices2 = self._radios_idx_return - self._radios_idx_outwards.min()
|
1184
|
+
|
1185
|
+
if self._start_x is not None:
|
1186
|
+
# Compute Motion Estimation on subset of images (eg. for half-tomo)
|
1187
|
+
projs_stack1 = projs_stack1[..., self._start_x : self._end_x]
|
1188
|
+
projs_stack2 = projs_stack2[..., self._start_x : self._end_x]
|
1189
|
+
|
1190
|
+
self.motion_estimator = MotionEstimation(
|
1191
|
+
projs_stack1,
|
1192
|
+
projs_stack2,
|
1193
|
+
angles1,
|
1194
|
+
angles2,
|
1195
|
+
indices1,
|
1196
|
+
indices2,
|
1197
|
+
n_projs_tot,
|
1198
|
+
shifts_estimator=self._shifts_estimator,
|
1199
|
+
shifts_estimator_kwargs=self._shifts_estimator_kwargs,
|
1200
|
+
)
|
1201
|
+
|
1202
|
+
self.logger.debug("[MotionEstimation] estimating shifts")
|
1203
|
+
|
1204
|
+
estimated_shifts_v = self.motion_estimator.estimate_vertical_motion(degree=self._deg_z)
|
1205
|
+
estimated_shifts_h, cor = self.motion_estimator.estimate_horizontal_motion(degree=self._deg_xy, cor=self._cor)
|
1206
|
+
if self._start_x is not None:
|
1207
|
+
cor += (self._start_x % self.radios.shape[-1]) + (projs_stack1.shape[-1] - 1) / 2.0
|
1208
|
+
|
1209
|
+
self.sample_shifts_xy = estimated_shifts_h
|
1210
|
+
self.sample_shifts_z = estimated_shifts_v
|
1211
|
+
if self._cor is None:
|
1212
|
+
self.logger.info(
|
1213
|
+
"[MotionEstimation] Estimated center of rotation (relative to middle of detector): %.2f" % cor
|
1214
|
+
)
|
1215
|
+
return estimated_shifts_h, estimated_shifts_v, cor
|
1216
|
+
|
1217
|
+
def generate_translations_movements_file(self, filename, fmt="%.3f", only=None):
|
1218
|
+
if self.sample_shifts_xy is None:
|
1219
|
+
raise RuntimeError("Need to run estimate_motion() first")
|
1220
|
+
|
1221
|
+
angles = self.dataset_info.rotation_angles
|
1222
|
+
cor = self._cor or 0
|
1223
|
+
txy_est_all_angles = self.motion_estimator.apply_fit_horiz(angles=angles)
|
1224
|
+
tz_est_all_angles = self.motion_estimator.apply_fit_vertic(angles=angles)
|
1225
|
+
estimated_shifts_vu_all_angles = self.motion_estimator.convert_sample_motion_to_detector_shifts(
|
1226
|
+
txy_est_all_angles, tz_est_all_angles, angles, cor=cor
|
1227
|
+
)
|
1228
|
+
estimated_shifts_vu_all_angles[:, 1] -= cor
|
1229
|
+
correct_shifts_uv = -estimated_shifts_vu_all_angles[:, ::-1]
|
1230
|
+
|
1231
|
+
if only is not None:
|
1232
|
+
if only == "horizontal":
|
1233
|
+
correct_shifts_uv[:, 1] = 0
|
1234
|
+
elif only == "vertical":
|
1235
|
+
correct_shifts_uv[:, 0] = 0
|
1236
|
+
else:
|
1237
|
+
raise ValueError("Expected 'only' to be either None, 'horizontal' or 'vertical'")
|
1238
|
+
|
1239
|
+
header = f"Generated by nabu {nabu_version} : {str(self)}"
|
1240
|
+
np.savetxt(filename, correct_shifts_uv, fmt=fmt, header=header)
|
1241
|
+
|
1242
|
+
def __str__(self):
|
1243
|
+
ret = f"{self.__class__.__name__}(do_flatfield={self.do_flatfield}, rot_center={self._cor}, angular_subsampling={self.angular_subsampling})"
|
1244
|
+
if self.sample_shifts_xy is not None:
|
1245
|
+
ret += f", shifts_estimator={self.motion_estimator.shifts_estimator}"
|
1246
|
+
return ret
|