nabu 2024.1.10__py3-none-any.whl → 2024.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/bootstrap.py +2 -3
- nabu/app/cast_volume.py +4 -2
- nabu/app/cli_configs.py +5 -0
- nabu/app/composite_cor.py +1 -1
- nabu/app/create_distortion_map_from_poly.py +5 -6
- nabu/app/diag_to_pix.py +7 -19
- nabu/app/diag_to_rot.py +14 -29
- nabu/app/double_flatfield.py +32 -44
- nabu/app/parse_reconstruction_log.py +3 -0
- nabu/app/reconstruct.py +53 -15
- nabu/app/reconstruct_helical.py +2 -2
- nabu/app/stitching.py +27 -13
- nabu/app/tests/__init__.py +0 -0
- nabu/app/tests/test_reduce_dark_flat.py +4 -1
- nabu/cuda/kernel.py +11 -2
- nabu/cuda/processing.py +2 -2
- nabu/cuda/src/cone.cu +77 -0
- nabu/cuda/src/hierarchical_backproj.cu +271 -0
- nabu/cuda/utils.py +0 -6
- nabu/estimation/alignment.py +5 -19
- nabu/estimation/cor.py +173 -599
- nabu/estimation/cor_sino.py +356 -26
- nabu/estimation/focus.py +63 -11
- nabu/estimation/tests/test_cor.py +124 -58
- nabu/estimation/tests/test_focus.py +6 -6
- nabu/estimation/tilt.py +2 -1
- nabu/estimation/utils.py +5 -33
- nabu/io/__init__.py +1 -1
- nabu/io/cast_volume.py +1 -1
- nabu/io/reader.py +416 -21
- nabu/io/tests/test_readers.py +422 -0
- nabu/io/tests/test_writers.py +1 -102
- nabu/io/writer.py +4 -433
- nabu/opencl/kernel.py +14 -3
- nabu/opencl/processing.py +8 -0
- nabu/pipeline/config_validators.py +5 -2
- nabu/pipeline/datadump.py +12 -5
- nabu/pipeline/estimators.py +162 -188
- nabu/pipeline/fullfield/chunked.py +168 -92
- nabu/pipeline/fullfield/chunked_cuda.py +7 -3
- nabu/pipeline/fullfield/computations.py +2 -7
- nabu/pipeline/fullfield/dataset_validator.py +0 -4
- nabu/pipeline/fullfield/nabu_config.py +37 -13
- nabu/pipeline/fullfield/processconfig.py +22 -13
- nabu/pipeline/fullfield/reconstruction.py +13 -9
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +1 -1
- nabu/pipeline/params.py +21 -1
- nabu/pipeline/processconfig.py +1 -12
- nabu/pipeline/reader.py +146 -0
- nabu/pipeline/tests/test_estimators.py +44 -72
- nabu/pipeline/utils.py +4 -2
- nabu/pipeline/writer.py +10 -2
- nabu/preproc/ccd_cuda.py +1 -1
- nabu/preproc/ctf.py +14 -7
- nabu/preproc/ctf_cuda.py +2 -3
- nabu/preproc/double_flatfield.py +5 -12
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/flatfield.py +5 -1
- nabu/preproc/flatfield_cuda.py +5 -1
- nabu/preproc/phase.py +24 -73
- nabu/preproc/phase_cuda.py +5 -8
- nabu/preproc/tests/test_ctf.py +11 -7
- nabu/preproc/tests/test_flatfield.py +67 -122
- nabu/preproc/tests/test_paganin.py +54 -30
- nabu/processing/azim.py +206 -0
- nabu/processing/convolution_cuda.py +1 -1
- nabu/processing/fft_cuda.py +15 -17
- nabu/processing/histogram.py +2 -0
- nabu/processing/histogram_cuda.py +2 -1
- nabu/processing/kernel_base.py +3 -0
- nabu/processing/muladd_cuda.py +1 -0
- nabu/processing/padding_opencl.py +1 -1
- nabu/processing/roll_opencl.py +1 -0
- nabu/processing/rotation_cuda.py +2 -2
- nabu/processing/tests/test_fft.py +17 -10
- nabu/processing/unsharp_cuda.py +1 -1
- nabu/reconstruction/cone.py +104 -40
- nabu/reconstruction/fbp.py +3 -0
- nabu/reconstruction/fbp_base.py +7 -2
- nabu/reconstruction/filtering.py +20 -7
- nabu/reconstruction/filtering_cuda.py +7 -1
- nabu/reconstruction/hbp.py +424 -0
- nabu/reconstruction/mlem.py +99 -0
- nabu/reconstruction/reconstructor.py +2 -0
- nabu/reconstruction/rings_cuda.py +19 -19
- nabu/reconstruction/sinogram_cuda.py +1 -0
- nabu/reconstruction/sinogram_opencl.py +3 -1
- nabu/reconstruction/tests/test_cone.py +10 -5
- nabu/reconstruction/tests/test_deringer.py +7 -6
- nabu/reconstruction/tests/test_fbp.py +124 -10
- nabu/reconstruction/tests/test_filtering.py +13 -11
- nabu/reconstruction/tests/test_halftomo.py +30 -4
- nabu/reconstruction/tests/test_mlem.py +91 -0
- nabu/reconstruction/tests/test_reconstructor.py +8 -3
- nabu/resources/dataset_analyzer.py +142 -92
- nabu/resources/gpu.py +1 -0
- nabu/resources/nxflatfield.py +134 -125
- nabu/resources/templates/id16a_fluo.conf +42 -0
- nabu/resources/tests/test_extract.py +10 -0
- nabu/resources/tests/test_nxflatfield.py +2 -2
- nabu/stitching/alignment.py +80 -24
- nabu/stitching/config.py +105 -68
- nabu/stitching/definitions.py +1 -0
- nabu/stitching/frame_composition.py +68 -60
- nabu/stitching/overlap.py +91 -51
- nabu/stitching/single_axis_stitching.py +32 -0
- nabu/stitching/slurm_utils.py +6 -6
- nabu/stitching/stitcher/__init__.py +0 -0
- nabu/stitching/stitcher/base.py +124 -0
- nabu/stitching/stitcher/dumper/__init__.py +3 -0
- nabu/stitching/stitcher/dumper/base.py +94 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
- nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
- nabu/stitching/stitcher/post_processing.py +555 -0
- nabu/stitching/stitcher/pre_processing.py +1068 -0
- nabu/stitching/stitcher/single_axis.py +484 -0
- nabu/stitching/stitcher/stitcher.py +0 -0
- nabu/stitching/stitcher/y_stitcher.py +13 -0
- nabu/stitching/stitcher/z_stitcher.py +45 -0
- nabu/stitching/stitcher_2D.py +278 -0
- nabu/stitching/tests/test_config.py +12 -37
- nabu/stitching/tests/test_frame_composition.py +33 -59
- nabu/stitching/tests/test_overlap.py +149 -7
- nabu/stitching/tests/test_utils.py +1 -1
- nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
- nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
- nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
- nabu/stitching/utils/__init__.py +1 -0
- nabu/stitching/utils/post_processing.py +281 -0
- nabu/stitching/utils/tests/test_post-processing.py +21 -0
- nabu/stitching/{utils.py → utils/utils.py} +79 -52
- nabu/stitching/y_stitching.py +27 -0
- nabu/stitching/z_stitching.py +32 -2281
- nabu/testutils.py +1 -152
- nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
- nabu/utils.py +158 -61
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/METADATA +24 -17
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/RECORD +145 -121
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/WHEEL +1 -1
- nabu/io/tiffwriter_zmm.py +0 -99
- nabu/pipeline/fallback_utils.py +0 -149
- nabu/pipeline/helical/tests/test_accumulator.py +0 -158
- nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
- nabu/pipeline/helical/tests/test_strategy.py +0 -61
- nabu/pipeline/helical/utils.py +0 -51
- nabu/pipeline/tests/test_chunk_reader.py +0 -74
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
- {nabu-2024.1.10.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,484 @@
|
|
1
|
+
import h5py
|
2
|
+
import numpy
|
3
|
+
import logging
|
4
|
+
from math import ceil
|
5
|
+
from typing import Optional, Iterable, Union
|
6
|
+
from tomoscan.series import Series
|
7
|
+
from tomoscan.identifier import BaseIdentifier
|
8
|
+
from nabu.stitching.stitcher.base import _StitcherBase, get_obj_constant_side_length
|
9
|
+
from nabu.stitching.stitcher_2D import stitch_raw_frames
|
10
|
+
from nabu.stitching.utils.utils import ShiftAlgorithm, from_slice_to_n_elements
|
11
|
+
from nabu.stitching.overlap import (
|
12
|
+
check_overlaps,
|
13
|
+
ImageStichOverlapKernel,
|
14
|
+
)
|
15
|
+
from nabu.stitching.config import (
|
16
|
+
SingleAxisStitchingConfiguration,
|
17
|
+
KEY_RESCALE_MIN_PERCENTILES,
|
18
|
+
KEY_RESCALE_MAX_PERCENTILES,
|
19
|
+
)
|
20
|
+
from nabu.misc.utils import rescale_data
|
21
|
+
from nabu.stitching.sample_normalization import normalize_frame as normalize_frame_by_sample
|
22
|
+
from nabu.stitching.stitcher.dumper.base import DumperBase
|
23
|
+
from silx.io.utils import get_data
|
24
|
+
from silx.io.url import DataUrl
|
25
|
+
from scipy.ndimage import shift as shift_scipy
|
26
|
+
|
27
|
+
|
28
|
+
_logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
PROGRESS_BAR_STITCH_VOL_DESC = "stitch volumes"
|
32
|
+
# description of the progress bar used when stitching volume.
|
33
|
+
# Needed to retrieve advancement from file when stitching remotely
|
34
|
+
|
35
|
+
|
36
|
+
class _SingleAxisMetaClass(type):
|
37
|
+
"""
|
38
|
+
Metaclass for single axis stitcher in order to aggregate dumper class and axis
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __new__(mcls, name, bases, attrs, axis=None, dumper_cls=None):
|
42
|
+
mcls = super().__new__(mcls, name, bases, attrs)
|
43
|
+
mcls._axis = axis
|
44
|
+
mcls._dumperCls = dumper_cls
|
45
|
+
return mcls
|
46
|
+
|
47
|
+
|
48
|
+
class SingleAxisStitcher(_StitcherBase, metaclass=_SingleAxisMetaClass):
|
49
|
+
"""
|
50
|
+
Any single-axis base class
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self, configuration, *args, **kwargs) -> None:
|
54
|
+
super().__init__(configuration, *args, **kwargs)
|
55
|
+
if self._dumperCls is not None:
|
56
|
+
self._dumper = self._dumperCls(configuration=configuration)
|
57
|
+
else:
|
58
|
+
self._dumper = None
|
59
|
+
|
60
|
+
# initial shifts
|
61
|
+
self._axis_0_rel_ini_shifts = []
|
62
|
+
"""Shift between two juxtapose objects along axis 0 found from position metadata or given by the user"""
|
63
|
+
self._axis_1_rel_ini_shifts = []
|
64
|
+
"""Shift between two juxtapose objects along axis 1 found from position metadata or given by the user"""
|
65
|
+
self._axis_2_rel_ini_shifts = []
|
66
|
+
"""Shift between two juxtapose objects along axis 2 found from position metadata or given by the user"""
|
67
|
+
|
68
|
+
# shifts to add once refine
|
69
|
+
self._axis_0_rel_final_shifts = []
|
70
|
+
"""Shift over axis 0 found once refined by the cross correlation algorithm"""
|
71
|
+
self._axis_1_rel_final_shifts = []
|
72
|
+
"""Shift over axis 1 found once refined by the cross correlation algorithm"""
|
73
|
+
self._axis_2_rel_final_shifts = []
|
74
|
+
"""Shift over axis 2 found once refined by the cross correlation algorithm"""
|
75
|
+
|
76
|
+
self._slices_to_stitch = None
|
77
|
+
# slices to be stitched. Obtained from calling Configuration.settle_slices
|
78
|
+
|
79
|
+
self._stitching_constant_length = None
|
80
|
+
# stitching width: larger volume width. Other volume will be pad
|
81
|
+
|
82
|
+
def shifts_is_scalar(shifts):
|
83
|
+
return isinstance(shifts, ShiftAlgorithm) or numpy.isscalar(shifts)
|
84
|
+
|
85
|
+
# 'expend' shift algorithm
|
86
|
+
if shifts_is_scalar(self.configuration.axis_0_pos_px):
|
87
|
+
self.configuration.axis_0_pos_px = [
|
88
|
+
self.configuration.axis_0_pos_px,
|
89
|
+
] * (len(self.series) - 1)
|
90
|
+
if shifts_is_scalar(self.configuration.axis_1_pos_px):
|
91
|
+
self.configuration.axis_1_pos_px = [
|
92
|
+
self.configuration.axis_1_pos_px,
|
93
|
+
] * (len(self.series) - 1)
|
94
|
+
if shifts_is_scalar(self.configuration.axis_1_pos_px):
|
95
|
+
self.configuration.axis_1_pos_px = [
|
96
|
+
self.configuration.axis_1_pos_px,
|
97
|
+
] * (len(self.series) - 1)
|
98
|
+
if numpy.isscalar(self.configuration.axis_0_params):
|
99
|
+
self.configuration.axis_0_params = [
|
100
|
+
self.configuration.axis_0_params,
|
101
|
+
] * (len(self.series) - 1)
|
102
|
+
if numpy.isscalar(self.configuration.axis_1_params):
|
103
|
+
self.configuration.axis_1_params = [
|
104
|
+
self.configuration.axis_1_params,
|
105
|
+
] * (len(self.series) - 1)
|
106
|
+
if numpy.isscalar(self.configuration.axis_2_params):
|
107
|
+
self.configuration.axis_2_params = [
|
108
|
+
self.configuration.axis_2_params,
|
109
|
+
] * (len(self.series) - 1)
|
110
|
+
|
111
|
+
@property
|
112
|
+
def axis(self) -> int:
|
113
|
+
return self._axis
|
114
|
+
|
115
|
+
@property
|
116
|
+
def dumper(self):
|
117
|
+
return self._dumper
|
118
|
+
|
119
|
+
@property
|
120
|
+
def stitching_axis_in_frame_space(self):
|
121
|
+
"""
|
122
|
+
stitching is operated in 2D (frame) space. So the axis in frame space is different than the one in 3D ebs-tomo space (https://tomo.gitlab-pages.esrf.fr/bliss-tomo/master/modelization.html)
|
123
|
+
"""
|
124
|
+
raise NotImplementedError("Base class")
|
125
|
+
|
126
|
+
def stitch(self, store_composition: bool = True) -> BaseIdentifier:
|
127
|
+
if self.progress is not None:
|
128
|
+
self.progress.set_description("order scans")
|
129
|
+
self.order_input_tomo_objects()
|
130
|
+
if self.progress is not None:
|
131
|
+
self.progress.set_description("check inputs")
|
132
|
+
self.check_inputs()
|
133
|
+
self.settle_flips()
|
134
|
+
|
135
|
+
if self.progress is not None:
|
136
|
+
self.progress.set_description("compute shifts")
|
137
|
+
self._compute_positions_as_px()
|
138
|
+
self.pre_processing_computation()
|
139
|
+
|
140
|
+
self.compute_estimated_shifts()
|
141
|
+
self._compute_shifts()
|
142
|
+
self._createOverlapKernels()
|
143
|
+
if self.progress is not None:
|
144
|
+
self.progress.set_description(PROGRESS_BAR_STITCH_VOL_DESC)
|
145
|
+
|
146
|
+
self._create_stitching(store_composition=store_composition)
|
147
|
+
if self.progress is not None:
|
148
|
+
self.progress.set_description("dump configuration")
|
149
|
+
self.dumper.save_configuration()
|
150
|
+
return self.dumper.output_identifier
|
151
|
+
|
152
|
+
@property
|
153
|
+
def serie_label(self) -> str:
|
154
|
+
"""return serie name for logs"""
|
155
|
+
return "single axis serie"
|
156
|
+
|
157
|
+
def get_n_slices_to_stitch(self):
|
158
|
+
"""Return the number of slice to be stitched"""
|
159
|
+
if self._slices_to_stitch is None:
|
160
|
+
raise RuntimeError("Slices needs to be settled first")
|
161
|
+
return from_slice_to_n_elements(self._slices_to_stitch)
|
162
|
+
|
163
|
+
def get_final_axis_positions_in_px(self) -> dict:
|
164
|
+
"""
|
165
|
+
compute the final position (**in pixel**) from the initial position of the first object and the final relative shift computed (1)
|
166
|
+
(1): the final relative shift is obtained from the initial shift (from motor position of provided by the user) + the refinement shift from cross correlation algorithm
|
167
|
+
:return: dict with tomo object identifier (str) as key and a tuple of position in pixel (axis_0_pos, axis_1_pos, axis_2_pos)
|
168
|
+
"""
|
169
|
+
pos_0_shift = numpy.concatenate(
|
170
|
+
(
|
171
|
+
numpy.atleast_1d(0.0),
|
172
|
+
numpy.array(self._axis_0_rel_final_shifts) - numpy.array(self._axis_0_rel_ini_shifts),
|
173
|
+
)
|
174
|
+
)
|
175
|
+
pos_0_cum_shift = numpy.cumsum(pos_0_shift)
|
176
|
+
final_pos_axis_0 = self.configuration.axis_0_pos_px + pos_0_cum_shift
|
177
|
+
|
178
|
+
pos_1_shift = numpy.concatenate(
|
179
|
+
(
|
180
|
+
numpy.atleast_1d(0.0),
|
181
|
+
numpy.array(self._axis_1_rel_final_shifts) - numpy.array(self._axis_1_rel_ini_shifts),
|
182
|
+
)
|
183
|
+
)
|
184
|
+
pos_1_cum_shift = numpy.cumsum(pos_1_shift)
|
185
|
+
final_pos_axis_1 = self.configuration.axis_1_pos_px + pos_1_cum_shift
|
186
|
+
|
187
|
+
pos_2_shift = numpy.concatenate(
|
188
|
+
(
|
189
|
+
numpy.atleast_1d(0.0),
|
190
|
+
numpy.array(self._axis_2_rel_final_shifts) - numpy.array(self._axis_2_rel_ini_shifts),
|
191
|
+
)
|
192
|
+
)
|
193
|
+
pos_2_cum_shift = numpy.cumsum(pos_2_shift)
|
194
|
+
final_pos_axis_2 = self.configuration.axis_2_pos_px + pos_2_cum_shift
|
195
|
+
|
196
|
+
assert len(final_pos_axis_0) == len(final_pos_axis_1)
|
197
|
+
assert len(final_pos_axis_0) == len(final_pos_axis_2)
|
198
|
+
assert len(final_pos_axis_0) == len(self.series)
|
199
|
+
|
200
|
+
return {
|
201
|
+
tomo_obj.get_identifier().to_str(): (pos_0, pos_1, pos_2)
|
202
|
+
for tomo_obj, (pos_0, pos_1, pos_2) in zip(
|
203
|
+
self.series, zip(final_pos_axis_0, final_pos_axis_1, final_pos_axis_2)
|
204
|
+
)
|
205
|
+
}
|
206
|
+
|
207
|
+
def settle_flips(self):
|
208
|
+
"""
|
209
|
+
User can provide some information on existing flips at frame level.
|
210
|
+
The goal of this step is to get one flip_lr and on flip_ud value per scan or volume
|
211
|
+
"""
|
212
|
+
if numpy.isscalar(self.configuration.flip_lr):
|
213
|
+
self.configuration.flip_lr = tuple([self.configuration.flip_lr] * len(self.series))
|
214
|
+
else:
|
215
|
+
if not len(self.configuration.flip_lr) == len(self.series):
|
216
|
+
raise ValueError("flip_lr expects a scalar value or one value per element to stitch")
|
217
|
+
self.configuration.flip_lr = tuple(self.configuration.flip_lr)
|
218
|
+
for elmt in self.configuration.flip_lr:
|
219
|
+
if not isinstance(elmt, bool):
|
220
|
+
raise TypeError
|
221
|
+
|
222
|
+
if numpy.isscalar(self.configuration.flip_ud):
|
223
|
+
self.configuration.flip_ud = tuple([self.configuration.flip_ud] * len(self.series))
|
224
|
+
else:
|
225
|
+
if not len(self.configuration.flip_ud) == len(self.series):
|
226
|
+
raise ValueError("flip_ud expects a scalar value or one value per element to stitch")
|
227
|
+
self.configuration.flip_ud = tuple(self.configuration.flip_ud)
|
228
|
+
for elmt in self.configuration.flip_ud:
|
229
|
+
if not isinstance(elmt, bool):
|
230
|
+
raise TypeError
|
231
|
+
|
232
|
+
def _createOverlapKernels(self):
|
233
|
+
"""
|
234
|
+
after this stage the overlap kernels must be created and with the final overlap size
|
235
|
+
"""
|
236
|
+
if self.axis == 0:
|
237
|
+
stitched_axis_rel_shifts = self._axis_0_rel_final_shifts
|
238
|
+
stitched_axis_params = self.configuration.axis_0_params
|
239
|
+
elif self.axis == 1:
|
240
|
+
stitched_axis_rel_shifts = self._axis_1_rel_final_shifts
|
241
|
+
stitched_axis_params = self.configuration.axis_1_params
|
242
|
+
elif self.axis == 2:
|
243
|
+
stitched_axis_rel_shifts = self._axis_2_rel_final_shifts
|
244
|
+
stitched_axis_params = self.configuration.axis_2_params
|
245
|
+
else:
|
246
|
+
raise NotImplementedError
|
247
|
+
|
248
|
+
if stitched_axis_rel_shifts is None or len(stitched_axis_rel_shifts) == 0:
|
249
|
+
raise RuntimeError(
|
250
|
+
f"axis {self.axis} shifts have not been defined yet. Please define them before calling this function"
|
251
|
+
)
|
252
|
+
|
253
|
+
overlap_size = stitched_axis_params.get("overlap_size", None)
|
254
|
+
if overlap_size in (None, "None", ""):
|
255
|
+
overlap_size = -1
|
256
|
+
else:
|
257
|
+
overlap_size = int(overlap_size)
|
258
|
+
|
259
|
+
self._stitching_constant_length = max(
|
260
|
+
[get_obj_constant_side_length(obj, axis=self.axis) for obj in self.series]
|
261
|
+
)
|
262
|
+
|
263
|
+
for stitched_axis_shift in stitched_axis_rel_shifts:
|
264
|
+
if overlap_size == -1:
|
265
|
+
height = abs(stitched_axis_shift)
|
266
|
+
else:
|
267
|
+
height = overlap_size
|
268
|
+
|
269
|
+
self._overlap_kernels.append(
|
270
|
+
ImageStichOverlapKernel(
|
271
|
+
stitching_axis=self.stitching_axis_in_frame_space,
|
272
|
+
frame_unstitched_axis_size=self._stitching_constant_length,
|
273
|
+
stitching_strategy=self.configuration.stitching_strategy,
|
274
|
+
overlap_size=height,
|
275
|
+
extra_params=self.configuration.stitching_kernels_extra_params,
|
276
|
+
)
|
277
|
+
)
|
278
|
+
|
279
|
+
@property
|
280
|
+
def series(self) -> Series:
|
281
|
+
return self._series
|
282
|
+
|
283
|
+
@property
|
284
|
+
def configuration(self) -> SingleAxisStitchingConfiguration:
|
285
|
+
return self._configuration
|
286
|
+
|
287
|
+
@property
|
288
|
+
def progress(self):
|
289
|
+
return self._progress
|
290
|
+
|
291
|
+
@staticmethod
|
292
|
+
def _data_bunch_iterator(slices, bunch_size):
|
293
|
+
"""util to get indices by bunch until we reach n_frames"""
|
294
|
+
if isinstance(slices, slice):
|
295
|
+
# note: slice step is handled at a different level
|
296
|
+
start = end = slices.start
|
297
|
+
|
298
|
+
while True:
|
299
|
+
start, end = end, min((end + bunch_size), slices.stop)
|
300
|
+
yield (start, end)
|
301
|
+
if end >= slices.stop:
|
302
|
+
break
|
303
|
+
# in the case of non-contiguous frames
|
304
|
+
elif isinstance(slices, Iterable):
|
305
|
+
for s in slices:
|
306
|
+
yield (s, s + 1)
|
307
|
+
else:
|
308
|
+
raise TypeError(f"slices is provided as {type(slices)}. When Iterable or slice is expected")
|
309
|
+
|
310
|
+
def rescale_frames(self, frames: tuple):
|
311
|
+
"""
|
312
|
+
rescale_frames if requested by the configuration
|
313
|
+
"""
|
314
|
+
_logger.info("apply rescale frames")
|
315
|
+
|
316
|
+
def cast_percentile(percentile) -> int:
|
317
|
+
if isinstance(percentile, str):
|
318
|
+
percentile.replace(" ", "").rstrip("%")
|
319
|
+
return int(percentile)
|
320
|
+
|
321
|
+
rescale_min_percentile = cast_percentile(self.configuration.rescale_params.get(KEY_RESCALE_MIN_PERCENTILES, 0))
|
322
|
+
rescale_max_percentile = cast_percentile(
|
323
|
+
self.configuration.rescale_params.get(KEY_RESCALE_MAX_PERCENTILES, 100)
|
324
|
+
)
|
325
|
+
|
326
|
+
new_min = numpy.percentile(frames[0], rescale_min_percentile)
|
327
|
+
new_max = numpy.percentile(frames[0], rescale_max_percentile)
|
328
|
+
|
329
|
+
def rescale(data):
|
330
|
+
# FIXME: takes time because browse several time the dataset, twice for percentiles and twices to get min and max when calling rescale_data...
|
331
|
+
data_min = numpy.percentile(data, rescale_min_percentile)
|
332
|
+
data_max = numpy.percentile(data, rescale_max_percentile)
|
333
|
+
return rescale_data(data, new_min=new_min, new_max=new_max, data_min=data_min, data_max=data_max)
|
334
|
+
|
335
|
+
return tuple([rescale(data) for data in frames])
|
336
|
+
|
337
|
+
def normalize_frame_by_sample(self, frames: tuple):
|
338
|
+
"""
|
339
|
+
normalize frame from a sample picked on the left or the right
|
340
|
+
"""
|
341
|
+
_logger.info("apply normalization by a sample")
|
342
|
+
return tuple(
|
343
|
+
[
|
344
|
+
normalize_frame_by_sample(
|
345
|
+
frame=frame,
|
346
|
+
side=self.configuration.normalization_by_sample.side,
|
347
|
+
method=self.configuration.normalization_by_sample.method,
|
348
|
+
margin_before_sample=self.configuration.normalization_by_sample.margin,
|
349
|
+
sample_width=self.configuration.normalization_by_sample.width,
|
350
|
+
)
|
351
|
+
for frame in frames
|
352
|
+
]
|
353
|
+
)
|
354
|
+
|
355
|
+
@staticmethod
|
356
|
+
def stitch_frames(
|
357
|
+
frames: Union[tuple, numpy.ndarray],
|
358
|
+
axis,
|
359
|
+
x_relative_shifts: tuple,
|
360
|
+
y_relative_shifts: tuple,
|
361
|
+
output_dtype: numpy.ndarray,
|
362
|
+
stitching_axis: int,
|
363
|
+
overlap_kernels: tuple,
|
364
|
+
dumper: DumperBase = None,
|
365
|
+
check_inputs=True,
|
366
|
+
shift_mode="nearest",
|
367
|
+
i_frame=None,
|
368
|
+
return_composition_cls=False,
|
369
|
+
alignment="center",
|
370
|
+
pad_mode="constant",
|
371
|
+
new_width: Optional[int] = None,
|
372
|
+
) -> numpy.ndarray:
|
373
|
+
"""
|
374
|
+
shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and
|
375
|
+
save them to output_dataset.
|
376
|
+
|
377
|
+
:param tuple frames: element must be a DataUrl or a 2D numpy array
|
378
|
+
:param stitching_regions_hdf5_dataset:
|
379
|
+
"""
|
380
|
+
if check_inputs:
|
381
|
+
if len(frames) < 2:
|
382
|
+
raise ValueError(f"Not enought frames provided for stitching ({len(frames)} provided)")
|
383
|
+
if len(frames) != len(x_relative_shifts) + 1:
|
384
|
+
raise ValueError(
|
385
|
+
f"Incoherent number of shift provided ({len(x_relative_shifts)}) compare to number of frame ({len(frames)}). len(frames) - 1 expected"
|
386
|
+
)
|
387
|
+
if len(x_relative_shifts) != len(overlap_kernels):
|
388
|
+
raise ValueError(
|
389
|
+
f"expect to have the same number of x_relative_shifts ({len(x_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
|
390
|
+
)
|
391
|
+
if len(y_relative_shifts) != len(overlap_kernels):
|
392
|
+
raise ValueError(
|
393
|
+
f"expect to have the same number of y_relative_shifts ({len(y_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
|
394
|
+
)
|
395
|
+
|
396
|
+
relative_positions = [(0, 0, 0)]
|
397
|
+
for y_rel_pos, x_rel_pos in zip(y_relative_shifts, x_relative_shifts):
|
398
|
+
relative_positions.append(
|
399
|
+
(
|
400
|
+
y_rel_pos + relative_positions[-1][0],
|
401
|
+
0, # position over axis 1 (aka y) is not handled yet
|
402
|
+
x_rel_pos + relative_positions[-1][2],
|
403
|
+
)
|
404
|
+
)
|
405
|
+
check_overlaps(
|
406
|
+
frames=tuple(frames),
|
407
|
+
positions=tuple(relative_positions),
|
408
|
+
axis=axis,
|
409
|
+
raise_error=False,
|
410
|
+
)
|
411
|
+
|
412
|
+
def check_frame_is_2d(frame):
|
413
|
+
if frame.ndim != 2:
|
414
|
+
raise ValueError(f"2D frame expected when {frame.ndim}D provided")
|
415
|
+
|
416
|
+
# step_0 load data if from url
|
417
|
+
data = []
|
418
|
+
for frame in frames:
|
419
|
+
if isinstance(frame, DataUrl):
|
420
|
+
data_frame = get_data(frame)
|
421
|
+
if check_inputs:
|
422
|
+
check_frame_is_2d(data_frame)
|
423
|
+
data.append(data_frame)
|
424
|
+
elif isinstance(frame, numpy.ndarray):
|
425
|
+
if check_inputs:
|
426
|
+
check_frame_is_2d(frame)
|
427
|
+
data.append(frame)
|
428
|
+
else:
|
429
|
+
raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}")
|
430
|
+
|
431
|
+
# step 1: shift each frames (except the first one)
|
432
|
+
if stitching_axis == 0:
|
433
|
+
relative_shift_along_stitched_axis = y_relative_shifts
|
434
|
+
relative_shift_along_unstitched_axis = x_relative_shifts
|
435
|
+
elif stitching_axis == 1:
|
436
|
+
relative_shift_along_stitched_axis = x_relative_shifts
|
437
|
+
relative_shift_along_unstitched_axis = y_relative_shifts
|
438
|
+
else:
|
439
|
+
raise NotImplementedError("")
|
440
|
+
|
441
|
+
shifted_data = [data[0]]
|
442
|
+
for frame, relative_shift in zip(data[1:], relative_shift_along_unstitched_axis):
|
443
|
+
# note: for now we only shift data in x. the y shift is handled in the FrameComposition
|
444
|
+
relative_shift = numpy.asarray(relative_shift).astype(numpy.int8)
|
445
|
+
if relative_shift == 0:
|
446
|
+
shifted_frame = frame
|
447
|
+
else:
|
448
|
+
# TO speed up: should use the Fourier transform
|
449
|
+
shifted_frame = shift_scipy(
|
450
|
+
frame,
|
451
|
+
mode=shift_mode,
|
452
|
+
shift=[0, -relative_shift] if stitching_axis == 0 else [-relative_shift, 0],
|
453
|
+
order=1,
|
454
|
+
)
|
455
|
+
shifted_data.append(shifted_frame)
|
456
|
+
|
457
|
+
# step 2: create stitched frame
|
458
|
+
stitched_frame, composition_cls = stitch_raw_frames(
|
459
|
+
frames=shifted_data,
|
460
|
+
key_lines=(
|
461
|
+
[
|
462
|
+
(int(frame.shape[stitching_axis] - abs(relative_shift / 2)), int(abs(relative_shift / 2)))
|
463
|
+
for relative_shift, frame in zip(relative_shift_along_stitched_axis, frames)
|
464
|
+
]
|
465
|
+
),
|
466
|
+
overlap_kernels=overlap_kernels,
|
467
|
+
check_inputs=check_inputs,
|
468
|
+
output_dtype=output_dtype,
|
469
|
+
return_composition_cls=True,
|
470
|
+
alignment=alignment,
|
471
|
+
pad_mode=pad_mode,
|
472
|
+
new_unstitched_axis_size=new_width,
|
473
|
+
)
|
474
|
+
dumper.save_stitched_frame(
|
475
|
+
stitched_frame=stitched_frame,
|
476
|
+
composition_cls=composition_cls,
|
477
|
+
i_frame=i_frame,
|
478
|
+
axis=1,
|
479
|
+
)
|
480
|
+
|
481
|
+
if return_composition_cls:
|
482
|
+
return stitched_frame, composition_cls
|
483
|
+
else:
|
484
|
+
return stitched_frame
|
File without changes
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from nabu.stitching.stitcher.pre_processing import PreProcessingStitching
|
2
|
+
from .dumper import PreProcessingStitchingDumper
|
3
|
+
|
4
|
+
|
5
|
+
class PreProcessingYStitcher(
|
6
|
+
PreProcessingStitching,
|
7
|
+
dumper_cls=PreProcessingStitchingDumper,
|
8
|
+
axis=1,
|
9
|
+
):
|
10
|
+
|
11
|
+
@property
|
12
|
+
def serie_label(self) -> str:
|
13
|
+
return "y-serie"
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from nabu.stitching.stitcher.pre_processing import PreProcessingStitching
|
2
|
+
from nabu.stitching.stitcher.post_processing import PostProcessingStitching
|
3
|
+
from .dumper import PreProcessingStitchingDumper, PostProcessingStitchingDumperNoDD, PostProcessingStitchingDumper
|
4
|
+
from nabu.stitching.stitcher.single_axis import _SingleAxisMetaClass
|
5
|
+
|
6
|
+
|
7
|
+
class PreProcessingZStitcher(
|
8
|
+
PreProcessingStitching,
|
9
|
+
dumper_cls=PreProcessingStitchingDumper,
|
10
|
+
axis=0,
|
11
|
+
):
|
12
|
+
|
13
|
+
def check_inputs(self):
|
14
|
+
"""
|
15
|
+
insure input data is coherent
|
16
|
+
"""
|
17
|
+
super().check_inputs()
|
18
|
+
|
19
|
+
for scan_0, scan_1 in zip(self.series[0:-1], self.series[1:]):
|
20
|
+
if scan_0.dim_1 != scan_1.dim_1:
|
21
|
+
raise ValueError(
|
22
|
+
f"projections width are expected to be the same. Not the case for {scan_0} ({scan_0.dim_1} and {scan_1} ({scan_1.dim_1}))"
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class PostProcessingZStitcher(
|
27
|
+
PostProcessingStitching,
|
28
|
+
metaclass=_SingleAxisMetaClass,
|
29
|
+
dumper_cls=PostProcessingStitchingDumper,
|
30
|
+
axis=0,
|
31
|
+
):
|
32
|
+
@property
|
33
|
+
def serie_label(self) -> str:
|
34
|
+
return "z-serie"
|
35
|
+
|
36
|
+
|
37
|
+
class PostProcessingZStitcherNoDD(
|
38
|
+
PostProcessingStitching,
|
39
|
+
metaclass=_SingleAxisMetaClass,
|
40
|
+
dumper_cls=PostProcessingStitchingDumperNoDD,
|
41
|
+
axis=0,
|
42
|
+
):
|
43
|
+
@property
|
44
|
+
def serie_label(self) -> str:
|
45
|
+
return "z-serie"
|