nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/bootstrap.py +2 -3
- nabu/app/cast_volume.py +4 -2
- nabu/app/cli_configs.py +5 -0
- nabu/app/composite_cor.py +1 -1
- nabu/app/create_distortion_map_from_poly.py +5 -6
- nabu/app/diag_to_pix.py +7 -19
- nabu/app/diag_to_rot.py +14 -29
- nabu/app/double_flatfield.py +32 -44
- nabu/app/parse_reconstruction_log.py +3 -0
- nabu/app/reconstruct.py +53 -15
- nabu/app/reconstruct_helical.py +2 -2
- nabu/app/stitching.py +27 -13
- nabu/app/tests/__init__.py +0 -0
- nabu/app/tests/test_reduce_dark_flat.py +4 -1
- nabu/cuda/kernel.py +11 -2
- nabu/cuda/processing.py +2 -2
- nabu/cuda/src/cone.cu +77 -0
- nabu/cuda/src/hierarchical_backproj.cu +271 -0
- nabu/cuda/utils.py +0 -6
- nabu/estimation/alignment.py +5 -19
- nabu/estimation/cor.py +173 -599
- nabu/estimation/cor_sino.py +356 -26
- nabu/estimation/focus.py +63 -11
- nabu/estimation/tests/test_cor.py +124 -58
- nabu/estimation/tests/test_focus.py +6 -6
- nabu/estimation/tilt.py +2 -1
- nabu/estimation/utils.py +5 -33
- nabu/io/__init__.py +1 -1
- nabu/io/cast_volume.py +1 -1
- nabu/io/reader.py +416 -21
- nabu/io/tests/test_readers.py +422 -0
- nabu/io/tests/test_writers.py +1 -102
- nabu/io/writer.py +4 -433
- nabu/opencl/kernel.py +14 -3
- nabu/opencl/processing.py +8 -0
- nabu/pipeline/config_validators.py +5 -2
- nabu/pipeline/datadump.py +12 -5
- nabu/pipeline/estimators.py +162 -188
- nabu/pipeline/fullfield/chunked.py +168 -92
- nabu/pipeline/fullfield/chunked_cuda.py +7 -3
- nabu/pipeline/fullfield/computations.py +2 -7
- nabu/pipeline/fullfield/dataset_validator.py +0 -4
- nabu/pipeline/fullfield/nabu_config.py +37 -13
- nabu/pipeline/fullfield/processconfig.py +22 -13
- nabu/pipeline/fullfield/reconstruction.py +13 -9
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +1 -1
- nabu/pipeline/params.py +21 -1
- nabu/pipeline/processconfig.py +1 -12
- nabu/pipeline/reader.py +146 -0
- nabu/pipeline/tests/test_estimators.py +44 -72
- nabu/pipeline/utils.py +4 -2
- nabu/pipeline/writer.py +10 -2
- nabu/preproc/ccd_cuda.py +1 -1
- nabu/preproc/ctf.py +14 -7
- nabu/preproc/ctf_cuda.py +2 -3
- nabu/preproc/double_flatfield.py +5 -12
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/flatfield.py +5 -1
- nabu/preproc/flatfield_cuda.py +5 -1
- nabu/preproc/phase.py +24 -73
- nabu/preproc/phase_cuda.py +5 -8
- nabu/preproc/tests/test_ctf.py +11 -7
- nabu/preproc/tests/test_flatfield.py +67 -122
- nabu/preproc/tests/test_paganin.py +54 -30
- nabu/processing/azim.py +206 -0
- nabu/processing/convolution_cuda.py +1 -1
- nabu/processing/fft_cuda.py +15 -17
- nabu/processing/histogram.py +2 -0
- nabu/processing/histogram_cuda.py +2 -1
- nabu/processing/kernel_base.py +3 -0
- nabu/processing/muladd_cuda.py +1 -0
- nabu/processing/padding_opencl.py +1 -1
- nabu/processing/roll_opencl.py +1 -0
- nabu/processing/rotation_cuda.py +2 -2
- nabu/processing/tests/test_fft.py +17 -10
- nabu/processing/unsharp_cuda.py +1 -1
- nabu/reconstruction/cone.py +104 -40
- nabu/reconstruction/fbp.py +3 -0
- nabu/reconstruction/fbp_base.py +7 -2
- nabu/reconstruction/filtering.py +20 -7
- nabu/reconstruction/filtering_cuda.py +7 -1
- nabu/reconstruction/hbp.py +424 -0
- nabu/reconstruction/mlem.py +99 -0
- nabu/reconstruction/reconstructor.py +2 -0
- nabu/reconstruction/rings_cuda.py +19 -19
- nabu/reconstruction/sinogram_cuda.py +1 -0
- nabu/reconstruction/sinogram_opencl.py +3 -1
- nabu/reconstruction/tests/test_cone.py +10 -5
- nabu/reconstruction/tests/test_deringer.py +7 -6
- nabu/reconstruction/tests/test_fbp.py +124 -10
- nabu/reconstruction/tests/test_filtering.py +13 -11
- nabu/reconstruction/tests/test_halftomo.py +30 -4
- nabu/reconstruction/tests/test_mlem.py +91 -0
- nabu/reconstruction/tests/test_reconstructor.py +8 -3
- nabu/resources/dataset_analyzer.py +142 -92
- nabu/resources/gpu.py +1 -0
- nabu/resources/nxflatfield.py +134 -125
- nabu/resources/templates/id16a_fluo.conf +42 -0
- nabu/resources/tests/test_extract.py +10 -0
- nabu/resources/tests/test_nxflatfield.py +2 -2
- nabu/stitching/alignment.py +80 -24
- nabu/stitching/config.py +105 -68
- nabu/stitching/definitions.py +1 -0
- nabu/stitching/frame_composition.py +68 -60
- nabu/stitching/overlap.py +91 -51
- nabu/stitching/single_axis_stitching.py +32 -0
- nabu/stitching/slurm_utils.py +6 -6
- nabu/stitching/stitcher/__init__.py +0 -0
- nabu/stitching/stitcher/base.py +124 -0
- nabu/stitching/stitcher/dumper/__init__.py +3 -0
- nabu/stitching/stitcher/dumper/base.py +94 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
- nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
- nabu/stitching/stitcher/post_processing.py +555 -0
- nabu/stitching/stitcher/pre_processing.py +1068 -0
- nabu/stitching/stitcher/single_axis.py +484 -0
- nabu/stitching/stitcher/stitcher.py +0 -0
- nabu/stitching/stitcher/y_stitcher.py +13 -0
- nabu/stitching/stitcher/z_stitcher.py +45 -0
- nabu/stitching/stitcher_2D.py +278 -0
- nabu/stitching/tests/test_config.py +12 -37
- nabu/stitching/tests/test_frame_composition.py +33 -59
- nabu/stitching/tests/test_overlap.py +149 -7
- nabu/stitching/tests/test_utils.py +1 -1
- nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
- nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
- nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
- nabu/stitching/utils/__init__.py +1 -0
- nabu/stitching/utils/post_processing.py +281 -0
- nabu/stitching/utils/tests/test_post-processing.py +21 -0
- nabu/stitching/{utils.py → utils/utils.py} +79 -52
- nabu/stitching/y_stitching.py +27 -0
- nabu/stitching/z_stitching.py +32 -2281
- nabu/testutils.py +1 -152
- nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
- nabu/utils.py +158 -61
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
- nabu/io/tiffwriter_zmm.py +0 -99
- nabu/pipeline/fallback_utils.py +0 -149
- nabu/pipeline/helical/tests/test_accumulator.py +0 -158
- nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
- nabu/pipeline/helical/tests/test_strategy.py +0 -61
- nabu/pipeline/helical/utils.py +0 -51
- nabu/pipeline/tests/test_chunk_reader.py +0 -74
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1068 @@
|
|
1
|
+
import numpy
|
2
|
+
import logging
|
3
|
+
import h5py
|
4
|
+
import os
|
5
|
+
from typing import Iterable
|
6
|
+
from silx.io.url import DataUrl
|
7
|
+
from silx.io.utils import get_data
|
8
|
+
from datetime import datetime
|
9
|
+
|
10
|
+
from nxtomo.nxobject.nxdetector import ImageKey
|
11
|
+
from nxtomo.application.nxtomo import NXtomo
|
12
|
+
from nxtomo.nxobject.nxtransformations import NXtransformations
|
13
|
+
from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation, DetZFlipTransformation
|
14
|
+
from nxtomo.paths.nxtomo import get_paths as _get_nexus_paths
|
15
|
+
|
16
|
+
from tomoscan.io import HDF5File
|
17
|
+
from tomoscan.series import Series
|
18
|
+
from tomoscan.esrf import NXtomoScan, EDFTomoScan
|
19
|
+
from tomoscan.esrf.scan.utils import (
|
20
|
+
get_compacted_dataslices,
|
21
|
+
) # this version has a 'return_url_set' needed here. At one point they should be merged together
|
22
|
+
from nabu.stitching.config import (
|
23
|
+
PreProcessedSingleAxisStitchingConfiguration,
|
24
|
+
KEY_IMG_REG_METHOD,
|
25
|
+
)
|
26
|
+
from nabu.stitching.utils import find_projections_relative_shifts
|
27
|
+
from functools import lru_cache as cache
|
28
|
+
from .single_axis import SingleAxisStitcher
|
29
|
+
from pyunitsystem.metricsystem import MetricSystem
|
30
|
+
|
31
|
+
|
32
|
+
_logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
class PreProcessingStitching(SingleAxisStitcher):
|
36
|
+
"""
|
37
|
+
loader to be used when save data during pre-processing stitching (on projections). Output is expected to be an NXtomo
|
38
|
+
|
39
|
+
warning: axis are provided according to the `acquisition space <https://tomo.gitlab-pages.esrf.fr/bliss-tomo/master/modelization.html>`_
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, configuration, progress=None) -> None:
|
43
|
+
""" """
|
44
|
+
if not isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration):
|
45
|
+
raise TypeError(
|
46
|
+
f"configuration is expected to be an instance of {PreProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead"
|
47
|
+
)
|
48
|
+
super().__init__(configuration, progress=progress)
|
49
|
+
self._series = Series("series", iterable=configuration.input_scans, use_identifiers=False)
|
50
|
+
self._reading_orders = []
|
51
|
+
# TODO: rename flips to axis_0_flips, axis_1_flips, axis_2_flips...
|
52
|
+
self._x_flips = []
|
53
|
+
self._y_flips = []
|
54
|
+
self._z_flips = []
|
55
|
+
|
56
|
+
# 'expend' auto shift request if only set once for all
|
57
|
+
if numpy.isscalar(self.configuration.axis_0_pos_px):
|
58
|
+
self.configuration.axis_0_pos_px = [
|
59
|
+
self.configuration.axis_0_pos_px,
|
60
|
+
] * (len(self.series) - 1)
|
61
|
+
if numpy.isscalar(self.configuration.axis_1_pos_px):
|
62
|
+
self.configuration.axis_1_pos_px = [
|
63
|
+
self.configuration.axis_1_pos_px,
|
64
|
+
] * (len(self.series) - 1)
|
65
|
+
if numpy.isscalar(self.configuration.axis_1_pos_px):
|
66
|
+
self.configuration.axis_1_pos_px = [
|
67
|
+
self.configuration.axis_1_pos_px,
|
68
|
+
] * (len(self.series) - 1)
|
69
|
+
|
70
|
+
if self.configuration.axis_0_params is None:
|
71
|
+
self.configuration.axis_0_params = {}
|
72
|
+
if self.configuration.axis_1_params is None:
|
73
|
+
self.configuration.axis_1_params = {}
|
74
|
+
if self.configuration.axis_2_params is None:
|
75
|
+
self.configuration.axis_2_params = {}
|
76
|
+
|
77
|
+
def pre_processing_computation(self):
|
78
|
+
self.compute_reduced_flats_and_darks()
|
79
|
+
|
80
|
+
@property
|
81
|
+
def stitching_axis_in_frame_space(self):
|
82
|
+
if self.axis == 0:
|
83
|
+
return 0
|
84
|
+
elif self.axis == 1:
|
85
|
+
return 1
|
86
|
+
elif self.axis == 2:
|
87
|
+
raise NotImplementedError(
|
88
|
+
"pre-processing stitching along axis 2 is not handled. This would require to do interpolation between frame along the rotation angle. Just not possible"
|
89
|
+
)
|
90
|
+
else:
|
91
|
+
raise NotImplementedError(f"stitching axis must be in (0, 1, 2). Get {self.axis}")
|
92
|
+
|
93
|
+
@property
|
94
|
+
def x_flips(self) -> list:
|
95
|
+
return self._x_flips
|
96
|
+
|
97
|
+
@property
|
98
|
+
def y_flips(self) -> list:
|
99
|
+
return self._y_flips
|
100
|
+
|
101
|
+
def order_input_tomo_objects(self):
|
102
|
+
|
103
|
+
def get_min_bound(scan):
|
104
|
+
return scan.get_bounding_box(axis=self.axis).min
|
105
|
+
|
106
|
+
# order scans along the stitched axis
|
107
|
+
if self.axis == 0:
|
108
|
+
position_along_stitched_axis = self.configuration.axis_0_pos_px
|
109
|
+
elif self.axis == 1:
|
110
|
+
position_along_stitched_axis = self.configuration.axis_1_pos_px
|
111
|
+
else:
|
112
|
+
raise ValueError(
|
113
|
+
"stitching cannot be done along axis 2 for pre-processing. This would require to interpolate frame between different rotation angle"
|
114
|
+
)
|
115
|
+
# if axis 0 position is provided then use directly it
|
116
|
+
if position_along_stitched_axis is not None and len(position_along_stitched_axis) > 0:
|
117
|
+
order = numpy.argsort(position_along_stitched_axis)[::-1]
|
118
|
+
sorted_series = Series(
|
119
|
+
self.series.name,
|
120
|
+
numpy.take_along_axis(numpy.array(self.series[:]), order, axis=0),
|
121
|
+
use_identifiers=False,
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
# else use bounding box
|
125
|
+
sorted_series = Series(
|
126
|
+
self.series.name,
|
127
|
+
sorted(self.series[:], key=get_min_bound, reverse=True),
|
128
|
+
use_identifiers=False,
|
129
|
+
)
|
130
|
+
if sorted_series != self.series:
|
131
|
+
if sorted_series[:] != self.series[::-1]:
|
132
|
+
raise ValueError(
|
133
|
+
f"Unable to get comprehensive input. Axis {self.axis} (decreasing) ordering is not respected."
|
134
|
+
)
|
135
|
+
else:
|
136
|
+
_logger.warning(
|
137
|
+
f"decreasing order haven't been respected. Need to reorder {self.serie_label} ({[str(scan) for scan in sorted_series[:]]}). Will also reorder overlap height, stitching height and invert shifts"
|
138
|
+
)
|
139
|
+
if self.configuration.axis_0_pos_mm is not None:
|
140
|
+
self.configuration.axis_0_pos_mm = self.configuration.axis_0_pos_mm[::-1]
|
141
|
+
if self.configuration.axis_0_pos_px is not None:
|
142
|
+
self.configuration.axis_0_pos_px = self.configuration.axis_0_pos_px[::-1]
|
143
|
+
if self.configuration.axis_1_pos_mm is not None:
|
144
|
+
self.configuration.axis_1_pos_mm = self.configuration.axis_1_pos_mm[::-1]
|
145
|
+
if self.configuration.axis_1_pos_px is not None:
|
146
|
+
self.configuration.axis_1_pos_px = self.configuration.axis_1_pos_px[::-1]
|
147
|
+
if self.configuration.axis_2_pos_mm is not None:
|
148
|
+
self.configuration.axis_2_pos_mm = self.configuration.axis_2_pos_mm[::-1]
|
149
|
+
if self.configuration.axis_2_pos_px is not None:
|
150
|
+
self.configuration.axis_2_pos_px = self.configuration.axis_2_pos_px[::-1]
|
151
|
+
if not numpy.isscalar(self._configuration.flip_ud):
|
152
|
+
self._configuration.flip_ud = self._configuration.flip_ud[::-1]
|
153
|
+
if not numpy.isscalar(self._configuration.flip_lr):
|
154
|
+
self._configuration.flip_ud = self._configuration.flip_lr[::-1]
|
155
|
+
|
156
|
+
self._series = sorted_series
|
157
|
+
|
158
|
+
def check_inputs(self):
|
159
|
+
"""
|
160
|
+
insure input data is coherent
|
161
|
+
"""
|
162
|
+
n_scans = len(self.series)
|
163
|
+
if n_scans == 0:
|
164
|
+
raise ValueError("no scan to stich together")
|
165
|
+
|
166
|
+
for scan in self.series:
|
167
|
+
from tomoscan.scanbase import TomoScanBase
|
168
|
+
|
169
|
+
if not isinstance(scan, TomoScanBase):
|
170
|
+
raise TypeError(f"z-preproc stitching expects instances of {TomoScanBase}. {type(scan)} provided.")
|
171
|
+
|
172
|
+
# check output file path and data path are provided
|
173
|
+
if self.configuration.output_file_path in (None, ""):
|
174
|
+
raise ValueError("output_file_path should be provided to the configuration")
|
175
|
+
if self.configuration.output_data_path in (None, ""):
|
176
|
+
raise ValueError("output_data_path should be provided to the configuration")
|
177
|
+
|
178
|
+
# check number of shift provided
|
179
|
+
for axis_pos_px, axis_name in zip(
|
180
|
+
(
|
181
|
+
self.configuration.axis_0_pos_px,
|
182
|
+
self.configuration.axis_1_pos_px,
|
183
|
+
self.configuration.axis_1_pos_px,
|
184
|
+
self.configuration.axis_0_pos_mm,
|
185
|
+
self.configuration.axis_1_pos_mm,
|
186
|
+
self.configuration.axis_2_pos_mm,
|
187
|
+
),
|
188
|
+
(
|
189
|
+
"axis_0_pos_px",
|
190
|
+
"axis_1_pos_px",
|
191
|
+
"axis_2_pos_px",
|
192
|
+
"axis_0_pos_mm",
|
193
|
+
"axis_1_pos_mm",
|
194
|
+
"axis_2_pos_mm",
|
195
|
+
),
|
196
|
+
):
|
197
|
+
if isinstance(axis_pos_px, Iterable) and len(axis_pos_px) != (n_scans):
|
198
|
+
raise ValueError(f"{axis_name} expect {n_scans} shift defined. Get {len(axis_pos_px)}")
|
199
|
+
|
200
|
+
self._reading_orders = []
|
201
|
+
# the first scan will define the expected reading orderd, and expected flip.
|
202
|
+
# if all scan are flipped then we will keep it this way
|
203
|
+
self._reading_orders.append(1)
|
204
|
+
|
205
|
+
# check scans are coherent (nb projections, rotation angle, energy...)
|
206
|
+
for scan_0, scan_1 in zip(self.series[0:-1], self.series[1:]):
|
207
|
+
if len(scan_0.projections) != len(scan_1.projections):
|
208
|
+
raise ValueError(f"{scan_0} and {scan_1} have a different number of projections")
|
209
|
+
if isinstance(scan_0, NXtomoScan) and isinstance(scan_1, NXtomoScan):
|
210
|
+
# check rotation (only of is an NXtomoScan)
|
211
|
+
scan_0_angles = numpy.asarray(scan_0.rotation_angle)
|
212
|
+
scan_0_projections_angles = scan_0_angles[
|
213
|
+
numpy.asarray(scan_0.image_key_control) == ImageKey.PROJECTION.value
|
214
|
+
]
|
215
|
+
scan_1_angles = numpy.asarray(scan_1.rotation_angle)
|
216
|
+
scan_1_projections_angles = scan_1_angles[
|
217
|
+
numpy.asarray(scan_1.image_key_control) == ImageKey.PROJECTION.value
|
218
|
+
]
|
219
|
+
if not numpy.allclose(scan_0_projections_angles, scan_1_projections_angles, atol=10e-1):
|
220
|
+
if numpy.allclose(
|
221
|
+
scan_0_projections_angles,
|
222
|
+
scan_1_projections_angles[::-1],
|
223
|
+
atol=10e-1,
|
224
|
+
):
|
225
|
+
reading_order = -1 * self._reading_orders[-1]
|
226
|
+
else:
|
227
|
+
raise ValueError(f"Angles from {scan_0} and {scan_1} are different")
|
228
|
+
else:
|
229
|
+
reading_order = 1 * self._reading_orders[-1]
|
230
|
+
self._reading_orders.append(reading_order)
|
231
|
+
# check energy
|
232
|
+
if scan_0.energy is None:
|
233
|
+
_logger.warning(f"no energy found for {scan_0}")
|
234
|
+
elif not numpy.isclose(scan_0.energy, scan_1.energy, rtol=1e-03):
|
235
|
+
_logger.warning(
|
236
|
+
f"different energy found between {scan_0} ({scan_0.energy}) and {scan_1} ({scan_1.energy})"
|
237
|
+
)
|
238
|
+
# check FOV
|
239
|
+
if not scan_0.field_of_view == scan_1.field_of_view:
|
240
|
+
raise ValueError(f"{scan_0} and {scan_1} have different field of view")
|
241
|
+
# check distance
|
242
|
+
if scan_0.distance is None:
|
243
|
+
_logger.warning(f"no distance found for {scan_0}")
|
244
|
+
elif not numpy.isclose(scan_0.distance, scan_1.distance, rtol=10e-3):
|
245
|
+
raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance")
|
246
|
+
# check pixel size
|
247
|
+
if not numpy.isclose(scan_0.x_pixel_size, scan_1.x_pixel_size):
|
248
|
+
raise ValueError(
|
249
|
+
f"{scan_0} and {scan_1} have different x pixel size. {scan_0.x_pixel_size} vs {scan_1.x_pixel_size}"
|
250
|
+
)
|
251
|
+
if not numpy.isclose(scan_0.y_pixel_size, scan_1.y_pixel_size):
|
252
|
+
raise ValueError(
|
253
|
+
f"{scan_0} and {scan_1} have different y pixel size. {scan_0.y_pixel_size} vs {scan_1.y_pixel_size}"
|
254
|
+
)
|
255
|
+
|
256
|
+
for scan in self.series:
|
257
|
+
# check x, y and z translation are constant (only if is an NXtomoScan)
|
258
|
+
if isinstance(scan, NXtomoScan):
|
259
|
+
if scan.x_translation is not None and not numpy.isclose(
|
260
|
+
min(scan.x_translation), max(scan.x_translation)
|
261
|
+
):
|
262
|
+
_logger.warning(
|
263
|
+
"x translations appears to be evolving over time. Might end up with wrong stitching"
|
264
|
+
)
|
265
|
+
if scan.y_translation is not None and not numpy.isclose(
|
266
|
+
min(scan.y_translation), max(scan.y_translation)
|
267
|
+
):
|
268
|
+
_logger.warning(
|
269
|
+
"y translations appears to be evolving over time. Might end up with wrong stitching"
|
270
|
+
)
|
271
|
+
if scan.z_translation is not None and not numpy.isclose(
|
272
|
+
min(scan.z_translation), max(scan.z_translation)
|
273
|
+
):
|
274
|
+
_logger.warning(
|
275
|
+
"z translations appears to be evolving over time. Might end up with wrong stitching"
|
276
|
+
)
|
277
|
+
|
278
|
+
def _compute_positions_as_px(self):
|
279
|
+
"""insure we have or we can deduce an estimated position as pixel"""
|
280
|
+
|
281
|
+
def get_position_as_px_on_axis(axis, pos_as_px, pos_as_mm):
|
282
|
+
if pos_as_px is not None:
|
283
|
+
if pos_as_mm is not None:
|
284
|
+
raise ValueError(
|
285
|
+
f"position of axis {axis} is provided twice: as mm and as px. Please provide one only ({pos_as_mm} vs {pos_as_px})"
|
286
|
+
)
|
287
|
+
else:
|
288
|
+
return pos_as_px
|
289
|
+
|
290
|
+
elif pos_as_mm is not None:
|
291
|
+
# deduce from position given in configuration and pixel size
|
292
|
+
axis_N_pos_px = []
|
293
|
+
for scan, pos_in_mm in zip(self.series, pos_as_mm):
|
294
|
+
pixel_size_m = self.configuration.pixel_size or scan.pixel_size
|
295
|
+
axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / pixel_size_m)
|
296
|
+
return axis_N_pos_px
|
297
|
+
else:
|
298
|
+
# deduce from motor position and pixel size
|
299
|
+
axis_N_pos_px = []
|
300
|
+
base_position_m = self.series[0].get_bounding_box(axis=axis).min
|
301
|
+
for scan in self.series:
|
302
|
+
pixel_size_m = self.configuration.pixel_size or scan.pixel_size
|
303
|
+
scan_axis_bb = scan.get_bounding_box(axis=axis)
|
304
|
+
axis_N_mean_pos_m = (scan_axis_bb.max - scan_axis_bb.min) / 2 + scan_axis_bb.min
|
305
|
+
axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
|
306
|
+
axis_N_pos_px.append(int(axis_N_mean_rel_pos_m / pixel_size_m))
|
307
|
+
return axis_N_pos_px
|
308
|
+
|
309
|
+
for axis, property_px_name, property_mm_name in zip(
|
310
|
+
(0, 1, 2),
|
311
|
+
(
|
312
|
+
"axis_0_pos_px",
|
313
|
+
"axis_1_pos_px",
|
314
|
+
"axis_2_pos_px",
|
315
|
+
),
|
316
|
+
(
|
317
|
+
"axis_0_pos_mm",
|
318
|
+
"axis_1_pos_mm",
|
319
|
+
"axis_2_pos_mm",
|
320
|
+
),
|
321
|
+
):
|
322
|
+
assert hasattr(
|
323
|
+
self.configuration, property_px_name
|
324
|
+
), f"configuration API changed. should have {property_px_name}"
|
325
|
+
assert hasattr(
|
326
|
+
self.configuration, property_mm_name
|
327
|
+
), f"configuration API changed. should have {property_px_name}"
|
328
|
+
try:
|
329
|
+
new_px_position = get_position_as_px_on_axis(
|
330
|
+
axis=axis,
|
331
|
+
pos_as_px=getattr(self.configuration, property_px_name),
|
332
|
+
pos_as_mm=getattr(self.configuration, property_mm_name),
|
333
|
+
)
|
334
|
+
except ValueError:
|
335
|
+
# when unable to find the position
|
336
|
+
if axis == self.axis:
|
337
|
+
# if we cannot find position over the stitching axis then raise an error: unable to process without
|
338
|
+
raise
|
339
|
+
else:
|
340
|
+
_logger.warning(f"Unable to find position over axis {axis}. Set them to zero")
|
341
|
+
setattr(
|
342
|
+
self.configuration,
|
343
|
+
property_px_name,
|
344
|
+
numpy.array([0] * len(self.series)),
|
345
|
+
)
|
346
|
+
else:
|
347
|
+
setattr(
|
348
|
+
self.configuration,
|
349
|
+
property_px_name,
|
350
|
+
new_px_position,
|
351
|
+
)
|
352
|
+
|
353
|
+
# clear position in mm as the one we will used are the px one
|
354
|
+
self.configuration.axis_0_pos_mm = None
|
355
|
+
self.configuration.axis_1_pos_mm = None
|
356
|
+
self.configuration.axis_2_pos_mm = None
|
357
|
+
|
358
|
+
# add some log
|
359
|
+
if self.configuration.axis_2_pos_mm is not None or self.configuration.axis_2_pos_px is not None:
|
360
|
+
_logger.warning("axis 2 position is not handled by the stitcher. Will be ignored")
|
361
|
+
axis_0_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_0_pos_px])
|
362
|
+
axis_1_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_1_pos_px])
|
363
|
+
axis_2_pos = ", ".join([f"{pos}px" for pos in self.configuration.axis_2_pos_px])
|
364
|
+
_logger.info(f"axis 0 position to be used: " + axis_0_pos)
|
365
|
+
_logger.info(f"axis 1 position to be used: " + axis_1_pos)
|
366
|
+
_logger.info(f"axis 2 position to be used: " + axis_2_pos)
|
367
|
+
_logger.info(f"stitching will be applied along axis: {self.axis}")
|
368
|
+
|
369
|
+
def compute_estimated_shifts(self):
|
370
|
+
if self.axis == 0:
|
371
|
+
# if we want to stitch over axis 0 (aka z)
|
372
|
+
axis_0_pos_px = self.configuration.axis_0_pos_px
|
373
|
+
self._axis_0_rel_ini_shifts = []
|
374
|
+
# compute overlap along axis 0
|
375
|
+
for upper_scan, lower_scan, upper_scan_axis_0_pos, lower_scan_axis_0_pos in zip(
|
376
|
+
self.series[:-1], self.series[1:], axis_0_pos_px[:-1], axis_0_pos_px[1:]
|
377
|
+
):
|
378
|
+
upper_scan_pos = upper_scan_axis_0_pos - upper_scan.dim_2 / 2
|
379
|
+
lower_scan_high_pos = lower_scan_axis_0_pos + lower_scan.dim_2 / 2
|
380
|
+
# simple test of overlap. More complete test are run by check_overlaps later
|
381
|
+
if lower_scan_high_pos <= upper_scan_pos:
|
382
|
+
raise ValueError(f"no overlap found between {upper_scan} and {lower_scan}")
|
383
|
+
self._axis_0_rel_ini_shifts.append(
|
384
|
+
int(lower_scan_high_pos - upper_scan_pos) # overlap are expected to be int for now
|
385
|
+
)
|
386
|
+
self._axis_1_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_1_pos_px)
|
387
|
+
self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1)
|
388
|
+
elif self.axis == 1:
|
389
|
+
# if we want to stitch over axis 1 (aka Y in acquisition reference - which is x in frame reference)
|
390
|
+
axis_1_pos_px = self.configuration.axis_1_pos_px
|
391
|
+
self._axis_1_rel_ini_shifts = []
|
392
|
+
# compute overlap along axis 0
|
393
|
+
for left_scan, right_scan, left_scan_axis_1_pos, right_scan_axis_1_pos in zip(
|
394
|
+
self.series[:-1], self.series[1:], axis_1_pos_px[:-1], axis_1_pos_px[1:]
|
395
|
+
):
|
396
|
+
left_scan_pos = left_scan_axis_1_pos - left_scan.dim_1 / 2
|
397
|
+
right_scan_high_pos = right_scan_axis_1_pos + right_scan.dim_1 / 2
|
398
|
+
# simple test of overlap. More complete test are run by check_overlaps later
|
399
|
+
if right_scan_high_pos <= left_scan_pos:
|
400
|
+
raise ValueError(f"no overlap found between {left_scan} and {right_scan}")
|
401
|
+
self._axis_1_rel_ini_shifts.append(
|
402
|
+
int(right_scan_high_pos - left_scan_pos) # overlap are expected to be int for now
|
403
|
+
)
|
404
|
+
self._axis_0_rel_ini_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_0_pos_px)
|
405
|
+
self._axis_2_rel_ini_shifts = [0.0] * (len(self.series) - 1)
|
406
|
+
else:
|
407
|
+
raise NotImplementedError("stitching only forseen for axis 0 and 1 for now")
|
408
|
+
|
409
|
+
def _compute_shifts(self):
|
410
|
+
"""
|
411
|
+
compute all shift requested (set to 'auto' in the configuration)
|
412
|
+
|
413
|
+
"""
|
414
|
+
n_scans = len(self.configuration.input_scans)
|
415
|
+
if n_scans == 0:
|
416
|
+
raise ValueError("no scan to stich provided")
|
417
|
+
|
418
|
+
projection_for_shift = self.configuration.slice_for_cross_correlation or "middle"
|
419
|
+
if self.axis not in (0, 1):
|
420
|
+
raise NotImplementedError("only stitching over axis 0 and 2 are handled for pre-processing stitching")
|
421
|
+
|
422
|
+
final_rel_shifts = []
|
423
|
+
for (
|
424
|
+
scan_0,
|
425
|
+
scan_1,
|
426
|
+
order_s0,
|
427
|
+
order_s1,
|
428
|
+
x_rel_shift,
|
429
|
+
y_rel_shift,
|
430
|
+
) in zip(
|
431
|
+
self.series[:-1],
|
432
|
+
self.series[1:],
|
433
|
+
self.reading_orders[:-1],
|
434
|
+
self.reading_orders[1:],
|
435
|
+
self._axis_1_rel_ini_shifts,
|
436
|
+
self._axis_0_rel_ini_shifts,
|
437
|
+
):
|
438
|
+
x_cross_algo = self.configuration.axis_1_params.get(KEY_IMG_REG_METHOD, None)
|
439
|
+
y_cross_algo = self.configuration.axis_0_params.get(KEY_IMG_REG_METHOD, None)
|
440
|
+
|
441
|
+
# compute relative shift
|
442
|
+
found_shift_y, found_shift_x = find_projections_relative_shifts(
|
443
|
+
upper_scan=scan_0,
|
444
|
+
lower_scan=scan_1,
|
445
|
+
projection_for_shift=projection_for_shift,
|
446
|
+
x_cross_correlation_function=x_cross_algo,
|
447
|
+
y_cross_correlation_function=y_cross_algo,
|
448
|
+
x_shifts_params=self.configuration.axis_1_params, # image x map acquisition axis 1 (Y)
|
449
|
+
y_shifts_params=self.configuration.axis_0_params, # image y map acquisition axis 0 (Z)
|
450
|
+
invert_order=order_s1 != order_s0,
|
451
|
+
estimated_shifts=(y_rel_shift, x_rel_shift),
|
452
|
+
axis=self.axis,
|
453
|
+
)
|
454
|
+
final_rel_shifts.append(
|
455
|
+
(found_shift_y, found_shift_x),
|
456
|
+
)
|
457
|
+
|
458
|
+
# set back values. Now position should start at 0
|
459
|
+
self._axis_0_rel_final_shifts = [final_shift[0] for final_shift in final_rel_shifts]
|
460
|
+
self._axis_1_rel_final_shifts = [final_shift[1] for final_shift in final_rel_shifts]
|
461
|
+
self._axis_2_rel_final_shifts = [0.0] * len(final_rel_shifts)
|
462
|
+
_logger.info(f"axis 1 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_final_shifts}")
|
463
|
+
print(f"axis 1 relative shifts (x in radio ref) to be used will be {self._axis_0_rel_final_shifts}")
|
464
|
+
_logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_1_rel_final_shifts}")
|
465
|
+
print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_1_rel_final_shifts}")
|
466
|
+
|
467
|
+
def _create_nx_tomo(self, store_composition: bool = False):
|
468
|
+
"""
|
469
|
+
create final NXtomo with stitched frames.
|
470
|
+
Policy: save all projections flat fielded. So this NXtomo will only contain projections (no dark and no flat).
|
471
|
+
But nabu will be able to reconstruct it with field `flatfield` set to False
|
472
|
+
"""
|
473
|
+
nx_tomo = NXtomo()
|
474
|
+
|
475
|
+
nx_tomo.energy = self.series[0].energy
|
476
|
+
start_times = list(filter(None, [scan.start_time for scan in self.series]))
|
477
|
+
end_times = list(filter(None, [scan.end_time for scan in self.series]))
|
478
|
+
|
479
|
+
if len(start_times) > 0:
|
480
|
+
nx_tomo.start_time = (
|
481
|
+
numpy.asarray([numpy.datetime64(start_time) for start_time in start_times]).min().astype(datetime)
|
482
|
+
)
|
483
|
+
else:
|
484
|
+
_logger.warning("Unable to find any start_time from input")
|
485
|
+
if len(end_times) > 0:
|
486
|
+
nx_tomo.end_time = (
|
487
|
+
numpy.asarray([numpy.datetime64(end_time) for end_time in end_times]).max().astype(datetime)
|
488
|
+
)
|
489
|
+
else:
|
490
|
+
_logger.warning("Unable to find any end_time from input")
|
491
|
+
|
492
|
+
title = ";".join([scan.sequence_name or "" for scan in self.series])
|
493
|
+
nx_tomo.title = f"stitch done from {title}"
|
494
|
+
|
495
|
+
self._slices_to_stitch, n_proj = self.configuration.settle_slices()
|
496
|
+
|
497
|
+
# handle detector (without frames)
|
498
|
+
nx_tomo.instrument.detector.field_of_view = self.series[0].field_of_view
|
499
|
+
nx_tomo.instrument.detector.distance = self.series[0].distance
|
500
|
+
nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size
|
501
|
+
nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size
|
502
|
+
nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
|
503
|
+
nx_tomo.instrument.detector.tomo_n = n_proj
|
504
|
+
# note: stitching process insure un-flipping of frames. So make sure transformations is defined as an empty set
|
505
|
+
nx_tomo.instrument.detector.transformations = NXtransformations()
|
506
|
+
|
507
|
+
if isinstance(self.series[0], NXtomoScan):
|
508
|
+
# note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
|
509
|
+
rotation_angle = numpy.asarray(self.series[0].rotation_angle)
|
510
|
+
nx_tomo.sample.rotation_angle = rotation_angle[
|
511
|
+
numpy.asarray(self.series[0].image_key_control) == ImageKey.PROJECTION.value
|
512
|
+
]
|
513
|
+
elif isinstance(self.series[0], EDFTomoScan):
|
514
|
+
nx_tomo.sample.rotation_angle = numpy.linspace(
|
515
|
+
start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n
|
516
|
+
)
|
517
|
+
else:
|
518
|
+
raise NotImplementedError(
|
519
|
+
f"scan type ({type(self.series[0])} is not handled)",
|
520
|
+
NXtomoScan,
|
521
|
+
isinstance(self.series[0], NXtomoScan),
|
522
|
+
)
|
523
|
+
|
524
|
+
# do a sub selection of the rotation angle if a we are only computing a part of the slices
|
525
|
+
def apply_slices_selection(array, slices, allow_empty: bool = False):
|
526
|
+
if isinstance(slices, slice):
|
527
|
+
return array[slices.start : slices.stop : 1]
|
528
|
+
elif isinstance(slices, Iterable):
|
529
|
+
return list([array[index] for index in slices])
|
530
|
+
else:
|
531
|
+
raise RuntimeError("slices must be instance of a slice or of an iterable")
|
532
|
+
|
533
|
+
nx_tomo.sample.rotation_angle = apply_slices_selection(
|
534
|
+
array=nx_tomo.sample.rotation_angle, slices=self._slices_to_stitch
|
535
|
+
)
|
536
|
+
|
537
|
+
# handle sample
|
538
|
+
if False not in [isinstance(scan, NXtomoScan) for scan in self.series]:
|
539
|
+
|
540
|
+
def get_sample_translation_for_projs(scan: NXtomoScan, attr):
|
541
|
+
values = numpy.array(getattr(scan, attr))
|
542
|
+
mask = scan.image_key_control == ImageKey.PROJECTION.value
|
543
|
+
return values[mask]
|
544
|
+
|
545
|
+
# we consider the new x, y and z position to be at the center of the one created
|
546
|
+
x_translation = [
|
547
|
+
get_sample_translation_for_projs(scan, "x_translation")
|
548
|
+
for scan in self.series
|
549
|
+
if scan.x_translation is not None
|
550
|
+
]
|
551
|
+
if len(x_translation) > 0:
|
552
|
+
# if there is some metadata about {x|y|z} translations
|
553
|
+
# we want to take the mean of each frame for each projections
|
554
|
+
x_translation = apply_slices_selection(
|
555
|
+
numpy.array(x_translation).mean(axis=0),
|
556
|
+
slices=self._slices_to_stitch,
|
557
|
+
)
|
558
|
+
else:
|
559
|
+
# if no NXtomo has information about x_translation.
|
560
|
+
# note: if at least one has missing values the numpy.Array(x_translation) with create an error as well
|
561
|
+
x_translation = [0.0] * n_proj
|
562
|
+
_logger.warning("Unable to fin input nxtomo x_translation values. Set it to 0.0")
|
563
|
+
nx_tomo.sample.x_translation = x_translation
|
564
|
+
|
565
|
+
y_translation = [
|
566
|
+
get_sample_translation_for_projs(scan, "y_translation")
|
567
|
+
for scan in self.series
|
568
|
+
if scan.y_translation is not None
|
569
|
+
]
|
570
|
+
if len(y_translation) > 0:
|
571
|
+
y_translation = apply_slices_selection(
|
572
|
+
numpy.array(y_translation).mean(axis=0),
|
573
|
+
slices=self._slices_to_stitch,
|
574
|
+
)
|
575
|
+
else:
|
576
|
+
y_translation = [0.0] * n_proj
|
577
|
+
_logger.warning("Unable to fin input nxtomo y_translation values. Set it to 0.0")
|
578
|
+
nx_tomo.sample.y_translation = y_translation
|
579
|
+
z_translation = [
|
580
|
+
get_sample_translation_for_projs(scan, "z_translation")
|
581
|
+
for scan in self.series
|
582
|
+
if scan.z_translation is not None
|
583
|
+
]
|
584
|
+
if len(z_translation) > 0:
|
585
|
+
z_translation = apply_slices_selection(
|
586
|
+
numpy.array(z_translation).mean(axis=0),
|
587
|
+
slices=self._slices_to_stitch,
|
588
|
+
)
|
589
|
+
else:
|
590
|
+
z_translation = [0.0] * n_proj
|
591
|
+
_logger.warning("Unable to fin input nxtomo z_translation values. Set it to 0.0")
|
592
|
+
nx_tomo.sample.z_translation = z_translation
|
593
|
+
|
594
|
+
nx_tomo.sample.name = self.series[0].sample_name
|
595
|
+
|
596
|
+
# compute stitched frame shape
|
597
|
+
if self.axis == 0:
|
598
|
+
stitched_frame_shape = (
|
599
|
+
n_proj,
|
600
|
+
(
|
601
|
+
numpy.asarray([scan.dim_2 for scan in self.series]).sum()
|
602
|
+
- numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_final_shifts]).sum()
|
603
|
+
),
|
604
|
+
self._stitching_constant_length,
|
605
|
+
)
|
606
|
+
elif self.axis == 1:
|
607
|
+
stitched_frame_shape = (
|
608
|
+
n_proj,
|
609
|
+
self._stitching_constant_length,
|
610
|
+
(
|
611
|
+
numpy.asarray([scan.dim_1 for scan in self.series]).sum()
|
612
|
+
- numpy.asarray([abs(overlap) for overlap in self._axis_1_rel_final_shifts]).sum()
|
613
|
+
),
|
614
|
+
)
|
615
|
+
else:
|
616
|
+
raise NotImplementedError("stitching on pre-processing along axis 2 (x-ray direction) is not handled")
|
617
|
+
|
618
|
+
if stitched_frame_shape[0] < 1 or stitched_frame_shape[1] < 1 or stitched_frame_shape[2] < 1:
|
619
|
+
raise RuntimeError(f"Error in stitched frame shape calculation. {stitched_frame_shape} found.")
|
620
|
+
# get expected output dataset first (just in case output and input files are the same)
|
621
|
+
first_proj_idx = sorted(self.series[0].projections.keys())[0]
|
622
|
+
first_proj_url = self.series[0].projections[first_proj_idx]
|
623
|
+
if h5py.is_hdf5(first_proj_url.file_path()):
|
624
|
+
first_proj_url = DataUrl(
|
625
|
+
file_path=first_proj_url.file_path(),
|
626
|
+
data_path=first_proj_url.data_path(),
|
627
|
+
scheme="h5py",
|
628
|
+
)
|
629
|
+
|
630
|
+
# first save the NXtomo entry without the frame
|
631
|
+
# dicttonx will fail if the folder does not exists
|
632
|
+
dir_name = os.path.dirname(self.configuration.output_file_path)
|
633
|
+
if dir_name not in (None, ""):
|
634
|
+
os.makedirs(dir_name, exist_ok=True)
|
635
|
+
nx_tomo.save(
|
636
|
+
file_path=self.configuration.output_file_path,
|
637
|
+
data_path=self.configuration.output_data_path,
|
638
|
+
nexus_path_version=self.configuration.output_nexus_version,
|
639
|
+
overwrite=self.configuration.overwrite_results,
|
640
|
+
)
|
641
|
+
|
642
|
+
transformation_matrices = {
|
643
|
+
scan.get_identifier()
|
644
|
+
.to_str()
|
645
|
+
.center(80, "-"): numpy.array2string(build_matrix(scan.get_detector_transformations(tuple())))
|
646
|
+
for scan in self.series
|
647
|
+
}
|
648
|
+
_logger.info(
|
649
|
+
"scan detector transformation matrices are:\n"
|
650
|
+
"\n".join(["/n".join(item) for item in transformation_matrices.items()])
|
651
|
+
)
|
652
|
+
|
653
|
+
_logger.info(
|
654
|
+
f"reading order is {self.reading_orders}",
|
655
|
+
)
|
656
|
+
|
657
|
+
def get_output_data_type():
|
658
|
+
return numpy.float32 # because we will apply flat field correction on it and they are not raw data
|
659
|
+
|
660
|
+
output_dtype = get_output_data_type()
|
661
|
+
# append frames ("instrument/detector/data" dataset)
|
662
|
+
with HDF5File(self.configuration.output_file_path, mode="a") as h5f:
|
663
|
+
# note: nx_tomo.save already handles the possible overwrite conflict by removing
|
664
|
+
# self.configuration.output_file_path or raising an error
|
665
|
+
|
666
|
+
stitched_frame_path = "/".join(
|
667
|
+
[
|
668
|
+
self.configuration.output_data_path,
|
669
|
+
_get_nexus_paths(self.configuration.output_nexus_version).PROJ_PATH,
|
670
|
+
]
|
671
|
+
)
|
672
|
+
self.dumper.output_dataset = h5f.create_dataset(
|
673
|
+
name=stitched_frame_path,
|
674
|
+
shape=stitched_frame_shape,
|
675
|
+
dtype=output_dtype,
|
676
|
+
)
|
677
|
+
# TODO: we could also create in several time and create a virtual dataset from it.
|
678
|
+
scans_projections_indexes = []
|
679
|
+
for scan, reverse in zip(self.series, self.reading_orders):
|
680
|
+
scans_projections_indexes.append(sorted(scan.projections.keys(), reverse=(reverse == -1)))
|
681
|
+
if self.progress:
|
682
|
+
self.progress.total = self.get_n_slices_to_stitch()
|
683
|
+
|
684
|
+
if isinstance(self._slices_to_stitch, slice):
|
685
|
+
step = self._slices_to_stitch.step or 1
|
686
|
+
else:
|
687
|
+
step = 1
|
688
|
+
i_proj = 0
|
689
|
+
for bunch_start, bunch_end in self._data_bunch_iterator(slices=self._slices_to_stitch, bunch_size=50):
|
690
|
+
for data_frames in self._get_bunch_of_data(
|
691
|
+
bunch_start,
|
692
|
+
bunch_end,
|
693
|
+
step=step,
|
694
|
+
scans=self.series,
|
695
|
+
scans_projections_indexes=scans_projections_indexes,
|
696
|
+
flip_ud_arr=self.configuration.flip_ud,
|
697
|
+
flip_lr_arr=self.configuration.flip_lr,
|
698
|
+
reading_orders=self.reading_orders,
|
699
|
+
):
|
700
|
+
if self.configuration.rescale_frames:
|
701
|
+
data_frames = self.rescale_frames(data_frames)
|
702
|
+
if self.configuration.normalization_by_sample.is_active():
|
703
|
+
data_frames = self.normalize_frame_by_sample(data_frames)
|
704
|
+
|
705
|
+
sf = SingleAxisStitcher.stitch_frames(
|
706
|
+
frames=data_frames,
|
707
|
+
axis=self.axis,
|
708
|
+
x_relative_shifts=self._axis_1_rel_final_shifts,
|
709
|
+
y_relative_shifts=self._axis_0_rel_final_shifts,
|
710
|
+
overlap_kernels=self._overlap_kernels,
|
711
|
+
i_frame=i_proj,
|
712
|
+
output_dtype=output_dtype,
|
713
|
+
dumper=self.dumper,
|
714
|
+
return_composition_cls=store_composition if i_proj == 0 else False,
|
715
|
+
stitching_axis=self.axis,
|
716
|
+
pad_mode=self.configuration.pad_mode,
|
717
|
+
alignment=self.configuration.alignment_axis_2,
|
718
|
+
new_width=self._stitching_constant_length,
|
719
|
+
check_inputs=i_proj == 0, # on process check on the first iteration
|
720
|
+
)
|
721
|
+
if i_proj == 0 and store_composition:
|
722
|
+
_, self._frame_composition = sf
|
723
|
+
if self.progress is not None:
|
724
|
+
self.progress.update()
|
725
|
+
|
726
|
+
i_proj += 1
|
727
|
+
|
728
|
+
# create link to this dataset that can be missing
|
729
|
+
# "data/data" link
|
730
|
+
if "data" in h5f[self.configuration.output_data_path]:
|
731
|
+
data_group = h5f[self.configuration.output_data_path]["data"]
|
732
|
+
if not stitched_frame_path.startswith("/"):
|
733
|
+
stitched_frame_path = "/" + stitched_frame_path
|
734
|
+
data_group["data"] = h5py.SoftLink(stitched_frame_path)
|
735
|
+
if "default" not in h5f[self.configuration.output_data_path].attrs:
|
736
|
+
h5f[self.configuration.output_data_path].attrs["default"] = "data"
|
737
|
+
for attr_name, attr_value in zip(
|
738
|
+
("NX_class", "SILX_style/axis_scale_types", "signal"),
|
739
|
+
("NXdata", ["linear", "linear"], "data"),
|
740
|
+
):
|
741
|
+
if attr_name not in data_group.attrs:
|
742
|
+
data_group.attrs[attr_name] = attr_value
|
743
|
+
|
744
|
+
return nx_tomo
|
745
|
+
|
746
|
+
def _create_stitching(self, store_composition):
|
747
|
+
self._create_nx_tomo(store_composition=store_composition)
|
748
|
+
|
749
|
+
@staticmethod
|
750
|
+
def get_bunch_of_data(
|
751
|
+
bunch_start: int,
|
752
|
+
bunch_end: int,
|
753
|
+
step: int,
|
754
|
+
scans: tuple,
|
755
|
+
scans_projections_indexes: tuple,
|
756
|
+
reading_orders: tuple,
|
757
|
+
flip_lr_arr: tuple,
|
758
|
+
flip_ud_arr: tuple,
|
759
|
+
):
|
760
|
+
"""
|
761
|
+
goal is to load contiguous projections as much as possible...
|
762
|
+
|
763
|
+
:param int bunch_start: begining of the bunch
|
764
|
+
:param int bunch_end: end of the bunch
|
765
|
+
:param int scans: ordered scan for which we want to get data
|
766
|
+
:param scans_projections_indexes: tuple with scans and scan projection indexes to be loaded
|
767
|
+
:param tuple flip_lr_arr: extra information from the user to left-right flip frames
|
768
|
+
:param tuple flip_ud_arr: extra information from the user to up-down flip frames
|
769
|
+
:return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together
|
770
|
+
"""
|
771
|
+
assert len(scans) == len(scans_projections_indexes)
|
772
|
+
assert isinstance(flip_lr_arr, tuple)
|
773
|
+
assert isinstance(flip_ud_arr, tuple)
|
774
|
+
assert isinstance(step, int)
|
775
|
+
scans_proj_urls = []
|
776
|
+
# for each scan store the real indices and the data url
|
777
|
+
|
778
|
+
for scan, scan_projection_indexes in zip(scans, scans_projections_indexes):
|
779
|
+
scan_proj_urls = {}
|
780
|
+
# for each scan get the list of url to be loaded
|
781
|
+
for i_proj in range(bunch_start, bunch_end):
|
782
|
+
if i_proj % step != 0:
|
783
|
+
continue
|
784
|
+
proj_index_in_full_scan = scan_projection_indexes[i_proj]
|
785
|
+
scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan]
|
786
|
+
scans_proj_urls.append(scan_proj_urls)
|
787
|
+
|
788
|
+
# then load data
|
789
|
+
all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object)
|
790
|
+
from nabu.preproc.flatfield import FlatFieldArrays
|
791
|
+
|
792
|
+
for i_scan, (scan_urls, scan_flip_lr, scan_flip_ud, reading_order) in enumerate(
|
793
|
+
zip(scans_proj_urls, flip_lr_arr, flip_ud_arr, reading_orders)
|
794
|
+
):
|
795
|
+
i_frame = 0
|
796
|
+
_, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
|
797
|
+
for _, url in set_of_compacted_slices.items():
|
798
|
+
scan = scans[i_scan]
|
799
|
+
url = DataUrl(
|
800
|
+
file_path=url.file_path(),
|
801
|
+
data_path=url.data_path(),
|
802
|
+
scheme="silx",
|
803
|
+
data_slice=url.data_slice(),
|
804
|
+
)
|
805
|
+
raw_radios = get_data(url)[::reading_order]
|
806
|
+
radio_indices = url.data_slice()
|
807
|
+
if isinstance(radio_indices, slice):
|
808
|
+
step = radio_indices.step if radio_indices is not None else 1
|
809
|
+
radio_indices = numpy.arange(
|
810
|
+
start=radio_indices.start,
|
811
|
+
stop=radio_indices.stop,
|
812
|
+
step=step,
|
813
|
+
dtype=numpy.int16,
|
814
|
+
)
|
815
|
+
|
816
|
+
missing = []
|
817
|
+
if len(scan.reduced_flats) == 0:
|
818
|
+
missing = "flats"
|
819
|
+
if len(scan.reduced_darks) == 0:
|
820
|
+
missing = "darks"
|
821
|
+
|
822
|
+
if len(missing) > 0:
|
823
|
+
_logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
|
824
|
+
ff_arrays = None
|
825
|
+
data = raw_radios
|
826
|
+
else:
|
827
|
+
has_reduced_metadata = (
|
828
|
+
scan.reduced_flats_infos is not None
|
829
|
+
and len(scan.reduced_flats_infos.machine_electric_current) > 0
|
830
|
+
and scan.reduced_darks_infos is not None
|
831
|
+
and len(scan.reduced_darks_infos.machine_electric_current) > 0
|
832
|
+
)
|
833
|
+
if not has_reduced_metadata:
|
834
|
+
_logger.warning("no metadata about current found. Won't normalize according to machine current")
|
835
|
+
|
836
|
+
ff_arrays = FlatFieldArrays(
|
837
|
+
radios_shape=(len(radio_indices), scan.dim_2, scan.dim_1),
|
838
|
+
flats=scan.reduced_flats,
|
839
|
+
darks=scan.reduced_darks,
|
840
|
+
radios_indices=radio_indices,
|
841
|
+
radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None,
|
842
|
+
flats_srcurrent=(
|
843
|
+
scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None
|
844
|
+
),
|
845
|
+
)
|
846
|
+
# note: we need to cast radios to float 32. Darks and flats are cast to anyway
|
847
|
+
data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32))
|
848
|
+
|
849
|
+
transformations = list(scans[i_scan].get_detector_transformations(tuple()))
|
850
|
+
if scan_flip_lr:
|
851
|
+
transformations.append(DetZFlipTransformation(flip=True))
|
852
|
+
if scan_flip_ud:
|
853
|
+
transformations.append(DetYFlipTransformation(flip=True))
|
854
|
+
|
855
|
+
transformation_matrix_det_space = build_matrix(transformations)
|
856
|
+
if transformation_matrix_det_space is None or numpy.allclose(
|
857
|
+
transformation_matrix_det_space, numpy.identity(3)
|
858
|
+
):
|
859
|
+
flip_ud = False
|
860
|
+
flip_lr = False
|
861
|
+
elif numpy.array_equal(transformation_matrix_det_space, PreProcessingStitching._get_UD_flip_matrix()):
|
862
|
+
flip_ud = True
|
863
|
+
flip_lr = False
|
864
|
+
elif numpy.allclose(transformation_matrix_det_space, PreProcessingStitching._get_LR_flip_matrix()):
|
865
|
+
flip_ud = False
|
866
|
+
flip_lr = True
|
867
|
+
elif numpy.allclose(
|
868
|
+
transformation_matrix_det_space, PreProcessingStitching._get_UD_AND_LR_flip_matrix()
|
869
|
+
):
|
870
|
+
flip_ud = True
|
871
|
+
flip_lr = True
|
872
|
+
else:
|
873
|
+
raise ValueError("case not handled... For now only handle up-down flip as left-right flip")
|
874
|
+
|
875
|
+
for frame in data:
|
876
|
+
if flip_ud:
|
877
|
+
frame = numpy.flipud(frame)
|
878
|
+
if flip_lr:
|
879
|
+
frame = numpy.fliplr(frame)
|
880
|
+
all_scan_final_data[i_frame, i_scan] = frame
|
881
|
+
i_frame += 1
|
882
|
+
|
883
|
+
return all_scan_final_data
|
884
|
+
|
885
|
+
def compute_reduced_flats_and_darks(self):
|
886
|
+
"""
|
887
|
+
make sure reduced dark and flats are existing otherwise compute them
|
888
|
+
"""
|
889
|
+
for scan in self.series:
|
890
|
+
try:
|
891
|
+
reduced_darks, darks_infos = scan.load_reduced_darks(return_info=True)
|
892
|
+
except:
|
893
|
+
_logger.info("no reduced dark found. Try to compute them.")
|
894
|
+
if reduced_darks in (None, {}):
|
895
|
+
reduced_darks, darks_infos = scan.compute_reduced_darks(return_info=True)
|
896
|
+
try:
|
897
|
+
# if we don't have write in the folder containing the .nx for example
|
898
|
+
scan.save_reduced_darks(reduced_darks, darks_infos=darks_infos)
|
899
|
+
except Exception as e:
|
900
|
+
pass
|
901
|
+
scan.set_reduced_darks(reduced_darks, darks_infos=darks_infos)
|
902
|
+
|
903
|
+
try:
|
904
|
+
reduced_flats, flats_infos = scan.load_reduced_flats(return_info=True)
|
905
|
+
except:
|
906
|
+
_logger.info("no reduced flats found. Try to compute them.")
|
907
|
+
if reduced_flats in (None, {}):
|
908
|
+
reduced_flats, flats_infos = scan.compute_reduced_flats(return_info=True)
|
909
|
+
try:
|
910
|
+
# if we don't have write in the folder containing the .nx for example
|
911
|
+
scan.save_reduced_flats(reduced_flats, flats_infos=flats_infos)
|
912
|
+
except Exception as e:
|
913
|
+
pass
|
914
|
+
scan.set_reduced_flats(reduced_flats, flats_infos=flats_infos)
|
915
|
+
|
916
|
+
@staticmethod
|
917
|
+
@cache(maxsize=None)
|
918
|
+
def _get_UD_flip_matrix():
|
919
|
+
return DetYFlipTransformation(flip=True).as_matrix()
|
920
|
+
|
921
|
+
@staticmethod
|
922
|
+
@cache(maxsize=None)
|
923
|
+
def _get_LR_flip_matrix():
|
924
|
+
return DetZFlipTransformation(flip=True).as_matrix()
|
925
|
+
|
926
|
+
@staticmethod
|
927
|
+
@cache(maxsize=None)
|
928
|
+
def _get_UD_AND_LR_flip_matrix():
|
929
|
+
return numpy.matmul(
|
930
|
+
PreProcessingStitching._get_UD_flip_matrix(),
|
931
|
+
PreProcessingStitching._get_LR_flip_matrix(),
|
932
|
+
)
|
933
|
+
|
934
|
+
@staticmethod
|
935
|
+
def _get_bunch_of_data(
|
936
|
+
bunch_start: int,
|
937
|
+
bunch_end: int,
|
938
|
+
step: int,
|
939
|
+
scans: tuple,
|
940
|
+
scans_projections_indexes: tuple,
|
941
|
+
reading_orders: tuple,
|
942
|
+
flip_lr_arr: tuple,
|
943
|
+
flip_ud_arr: tuple,
|
944
|
+
):
|
945
|
+
"""
|
946
|
+
goal is to load contiguous projections as much as possible...
|
947
|
+
|
948
|
+
:param int bunch_start: begining of the bunch
|
949
|
+
:param int bunch_end: end of the bunch
|
950
|
+
:param int scans: ordered scan for which we want to get data
|
951
|
+
:param scans_projections_indexes: tuple with scans and scan projection indexes to be loaded
|
952
|
+
:param tuple flip_lr_arr: extra information from the user to left-right flip frames
|
953
|
+
:param tuple flip_ud_arr: extra information from the user to up-down flip frames
|
954
|
+
:return: list of list. For each frame we want to stitch contains the (flat fielded) frames to stich together
|
955
|
+
"""
|
956
|
+
assert len(scans) == len(scans_projections_indexes)
|
957
|
+
assert isinstance(flip_lr_arr, tuple)
|
958
|
+
assert isinstance(flip_ud_arr, tuple)
|
959
|
+
assert isinstance(step, int)
|
960
|
+
scans_proj_urls = []
|
961
|
+
# for each scan store the real indices and the data url
|
962
|
+
|
963
|
+
for scan, scan_projection_indexes in zip(scans, scans_projections_indexes):
|
964
|
+
scan_proj_urls = {}
|
965
|
+
# for each scan get the list of url to be loaded
|
966
|
+
for i_proj in range(bunch_start, bunch_end):
|
967
|
+
if i_proj % step != 0:
|
968
|
+
continue
|
969
|
+
proj_index_in_full_scan = scan_projection_indexes[i_proj]
|
970
|
+
scan_proj_urls[proj_index_in_full_scan] = scan.projections[proj_index_in_full_scan]
|
971
|
+
scans_proj_urls.append(scan_proj_urls)
|
972
|
+
|
973
|
+
# then load data
|
974
|
+
all_scan_final_data = numpy.empty((bunch_end - bunch_start, len(scans)), dtype=object)
|
975
|
+
from nabu.preproc.flatfield import FlatFieldArrays
|
976
|
+
|
977
|
+
for i_scan, (scan_urls, scan_flip_lr, scan_flip_ud, reading_order) in enumerate(
|
978
|
+
zip(scans_proj_urls, flip_lr_arr, flip_ud_arr, reading_orders)
|
979
|
+
):
|
980
|
+
i_frame = 0
|
981
|
+
_, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
|
982
|
+
for _, url in set_of_compacted_slices.items():
|
983
|
+
scan = scans[i_scan]
|
984
|
+
url = DataUrl(
|
985
|
+
file_path=url.file_path(),
|
986
|
+
data_path=url.data_path(),
|
987
|
+
scheme="silx",
|
988
|
+
data_slice=url.data_slice(),
|
989
|
+
)
|
990
|
+
raw_radios = get_data(url)[::reading_order]
|
991
|
+
radio_indices = url.data_slice()
|
992
|
+
if isinstance(radio_indices, slice):
|
993
|
+
step = radio_indices.step if radio_indices is not None else 1
|
994
|
+
radio_indices = numpy.arange(
|
995
|
+
start=radio_indices.start,
|
996
|
+
stop=radio_indices.stop,
|
997
|
+
step=step,
|
998
|
+
dtype=numpy.int16,
|
999
|
+
)
|
1000
|
+
|
1001
|
+
missing = []
|
1002
|
+
if len(scan.reduced_flats) == 0:
|
1003
|
+
missing = "flats"
|
1004
|
+
if len(scan.reduced_darks) == 0:
|
1005
|
+
missing = "darks"
|
1006
|
+
|
1007
|
+
if len(missing) > 0:
|
1008
|
+
_logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
|
1009
|
+
ff_arrays = None
|
1010
|
+
data = raw_radios
|
1011
|
+
else:
|
1012
|
+
has_reduced_metadata = (
|
1013
|
+
scan.reduced_flats_infos is not None
|
1014
|
+
and len(scan.reduced_flats_infos.machine_electric_current) > 0
|
1015
|
+
and scan.reduced_darks_infos is not None
|
1016
|
+
and len(scan.reduced_darks_infos.machine_electric_current) > 0
|
1017
|
+
)
|
1018
|
+
if not has_reduced_metadata:
|
1019
|
+
_logger.warning("no metadata about current found. Won't normalize according to machine current")
|
1020
|
+
|
1021
|
+
ff_arrays = FlatFieldArrays(
|
1022
|
+
radios_shape=(len(radio_indices), scan.dim_2, scan.dim_1),
|
1023
|
+
flats=scan.reduced_flats,
|
1024
|
+
darks=scan.reduced_darks,
|
1025
|
+
radios_indices=radio_indices,
|
1026
|
+
radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None,
|
1027
|
+
flats_srcurrent=(
|
1028
|
+
scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None
|
1029
|
+
),
|
1030
|
+
)
|
1031
|
+
# note: we need to cast radios to float 32. Darks and flats are cast to anyway
|
1032
|
+
data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32))
|
1033
|
+
|
1034
|
+
transformations = list(scans[i_scan].get_detector_transformations(tuple()))
|
1035
|
+
if scan_flip_lr:
|
1036
|
+
transformations.append(DetZFlipTransformation(flip=True))
|
1037
|
+
if scan_flip_ud:
|
1038
|
+
transformations.append(DetYFlipTransformation(flip=True))
|
1039
|
+
|
1040
|
+
transformation_matrix_det_space = build_matrix(transformations)
|
1041
|
+
if transformation_matrix_det_space is None or numpy.allclose(
|
1042
|
+
transformation_matrix_det_space, numpy.identity(3)
|
1043
|
+
):
|
1044
|
+
flip_ud = False
|
1045
|
+
flip_lr = False
|
1046
|
+
elif numpy.array_equal(transformation_matrix_det_space, PreProcessingStitching._get_UD_flip_matrix()):
|
1047
|
+
flip_ud = True
|
1048
|
+
flip_lr = False
|
1049
|
+
elif numpy.allclose(transformation_matrix_det_space, PreProcessingStitching._get_LR_flip_matrix()):
|
1050
|
+
flip_ud = False
|
1051
|
+
flip_lr = True
|
1052
|
+
elif numpy.allclose(
|
1053
|
+
transformation_matrix_det_space, PreProcessingStitching._get_UD_AND_LR_flip_matrix()
|
1054
|
+
):
|
1055
|
+
flip_ud = True
|
1056
|
+
flip_lr = True
|
1057
|
+
else:
|
1058
|
+
raise ValueError("case not handled... For now only handle up-down flip as left-right flip")
|
1059
|
+
|
1060
|
+
for frame in data:
|
1061
|
+
if flip_ud:
|
1062
|
+
frame = numpy.flipud(frame)
|
1063
|
+
if flip_lr:
|
1064
|
+
frame = numpy.fliplr(frame)
|
1065
|
+
all_scan_final_data[i_frame, i_scan] = frame
|
1066
|
+
i_frame += 1
|
1067
|
+
|
1068
|
+
return all_scan_final_data
|