nabu 2024.1.9__py3-none-any.whl → 2024.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nabu/__init__.py +1 -1
- nabu/app/bootstrap.py +2 -3
- nabu/app/cast_volume.py +4 -2
- nabu/app/cli_configs.py +5 -0
- nabu/app/composite_cor.py +1 -1
- nabu/app/create_distortion_map_from_poly.py +5 -6
- nabu/app/diag_to_pix.py +7 -19
- nabu/app/diag_to_rot.py +14 -29
- nabu/app/double_flatfield.py +32 -44
- nabu/app/parse_reconstruction_log.py +3 -0
- nabu/app/reconstruct.py +53 -15
- nabu/app/reconstruct_helical.py +2 -2
- nabu/app/stitching.py +27 -13
- nabu/app/tests/test_reduce_dark_flat.py +4 -1
- nabu/cuda/kernel.py +11 -2
- nabu/cuda/processing.py +2 -2
- nabu/cuda/src/cone.cu +77 -0
- nabu/cuda/src/hierarchical_backproj.cu +271 -0
- nabu/cuda/utils.py +0 -6
- nabu/estimation/alignment.py +5 -19
- nabu/estimation/cor.py +173 -599
- nabu/estimation/cor_sino.py +356 -26
- nabu/estimation/focus.py +63 -11
- nabu/estimation/tests/test_cor.py +124 -58
- nabu/estimation/tests/test_focus.py +6 -6
- nabu/estimation/tilt.py +2 -1
- nabu/estimation/utils.py +5 -33
- nabu/io/__init__.py +1 -1
- nabu/io/cast_volume.py +1 -1
- nabu/io/reader.py +416 -21
- nabu/io/tests/test_readers.py +422 -0
- nabu/io/tests/test_writers.py +1 -102
- nabu/io/writer.py +4 -433
- nabu/opencl/kernel.py +14 -3
- nabu/opencl/processing.py +8 -0
- nabu/pipeline/config_validators.py +5 -2
- nabu/pipeline/datadump.py +12 -5
- nabu/pipeline/estimators.py +162 -188
- nabu/pipeline/fullfield/chunked.py +168 -92
- nabu/pipeline/fullfield/chunked_cuda.py +7 -3
- nabu/pipeline/fullfield/computations.py +2 -7
- nabu/pipeline/fullfield/dataset_validator.py +0 -4
- nabu/pipeline/fullfield/nabu_config.py +37 -13
- nabu/pipeline/fullfield/processconfig.py +22 -13
- nabu/pipeline/fullfield/reconstruction.py +13 -9
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -1
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +1 -1
- nabu/pipeline/params.py +21 -1
- nabu/pipeline/processconfig.py +1 -12
- nabu/pipeline/reader.py +146 -0
- nabu/pipeline/tests/test_estimators.py +44 -72
- nabu/pipeline/utils.py +4 -2
- nabu/pipeline/writer.py +10 -2
- nabu/preproc/ccd_cuda.py +1 -1
- nabu/preproc/ctf.py +14 -7
- nabu/preproc/ctf_cuda.py +2 -3
- nabu/preproc/double_flatfield.py +5 -12
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/flatfield.py +5 -1
- nabu/preproc/flatfield_cuda.py +5 -1
- nabu/preproc/phase.py +24 -73
- nabu/preproc/phase_cuda.py +5 -8
- nabu/preproc/tests/test_ctf.py +11 -7
- nabu/preproc/tests/test_flatfield.py +67 -122
- nabu/preproc/tests/test_paganin.py +54 -30
- nabu/processing/azim.py +206 -0
- nabu/processing/convolution_cuda.py +1 -1
- nabu/processing/fft_cuda.py +15 -17
- nabu/processing/histogram.py +2 -0
- nabu/processing/histogram_cuda.py +2 -1
- nabu/processing/kernel_base.py +3 -0
- nabu/processing/muladd_cuda.py +1 -0
- nabu/processing/padding_opencl.py +1 -1
- nabu/processing/roll_opencl.py +1 -0
- nabu/processing/rotation_cuda.py +2 -2
- nabu/processing/tests/test_fft.py +17 -10
- nabu/processing/unsharp_cuda.py +1 -1
- nabu/reconstruction/cone.py +104 -40
- nabu/reconstruction/fbp.py +3 -0
- nabu/reconstruction/fbp_base.py +7 -2
- nabu/reconstruction/filtering.py +20 -7
- nabu/reconstruction/filtering_cuda.py +7 -1
- nabu/reconstruction/hbp.py +424 -0
- nabu/reconstruction/mlem.py +99 -0
- nabu/reconstruction/reconstructor.py +2 -0
- nabu/reconstruction/rings_cuda.py +19 -19
- nabu/reconstruction/sinogram_cuda.py +1 -0
- nabu/reconstruction/sinogram_opencl.py +3 -1
- nabu/reconstruction/tests/test_cone.py +10 -5
- nabu/reconstruction/tests/test_deringer.py +7 -6
- nabu/reconstruction/tests/test_fbp.py +124 -10
- nabu/reconstruction/tests/test_filtering.py +13 -11
- nabu/reconstruction/tests/test_halftomo.py +30 -4
- nabu/reconstruction/tests/test_mlem.py +91 -0
- nabu/reconstruction/tests/test_reconstructor.py +8 -3
- nabu/resources/dataset_analyzer.py +142 -92
- nabu/resources/gpu.py +1 -0
- nabu/resources/nxflatfield.py +134 -125
- nabu/resources/templates/id16a_fluo.conf +42 -0
- nabu/resources/tests/test_extract.py +10 -0
- nabu/resources/tests/test_nxflatfield.py +2 -2
- nabu/stitching/alignment.py +80 -24
- nabu/stitching/config.py +105 -68
- nabu/stitching/definitions.py +1 -0
- nabu/stitching/frame_composition.py +68 -60
- nabu/stitching/overlap.py +91 -51
- nabu/stitching/single_axis_stitching.py +32 -0
- nabu/stitching/slurm_utils.py +6 -6
- nabu/stitching/stitcher/__init__.py +0 -0
- nabu/stitching/stitcher/base.py +124 -0
- nabu/stitching/stitcher/dumper/__init__.py +3 -0
- nabu/stitching/stitcher/dumper/base.py +94 -0
- nabu/stitching/stitcher/dumper/postprocessing.py +356 -0
- nabu/stitching/stitcher/dumper/preprocessing.py +60 -0
- nabu/stitching/stitcher/post_processing.py +555 -0
- nabu/stitching/stitcher/pre_processing.py +1068 -0
- nabu/stitching/stitcher/single_axis.py +484 -0
- nabu/stitching/stitcher/stitcher.py +0 -0
- nabu/stitching/stitcher/y_stitcher.py +13 -0
- nabu/stitching/stitcher/z_stitcher.py +45 -0
- nabu/stitching/stitcher_2D.py +278 -0
- nabu/stitching/tests/test_config.py +12 -37
- nabu/stitching/tests/test_frame_composition.py +33 -59
- nabu/stitching/tests/test_overlap.py +149 -7
- nabu/stitching/tests/test_utils.py +1 -1
- nabu/stitching/tests/test_y_preprocessing_stitching.py +132 -0
- nabu/stitching/tests/{test_z_stitching.py → test_z_postprocessing_stitching.py} +167 -561
- nabu/stitching/tests/test_z_preprocessing_stitching.py +431 -0
- nabu/stitching/utils/__init__.py +1 -0
- nabu/stitching/utils/post_processing.py +281 -0
- nabu/stitching/utils/tests/test_post-processing.py +21 -0
- nabu/stitching/{utils.py → utils/utils.py} +79 -52
- nabu/stitching/y_stitching.py +27 -0
- nabu/stitching/z_stitching.py +32 -2263
- nabu/testutils.py +1 -152
- nabu/thirdparty/tomocupy_remove_stripe.py +43 -9
- nabu/utils.py +158 -61
- {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/METADATA +10 -3
- {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/RECORD +144 -121
- nabu/io/tiffwriter_zmm.py +0 -99
- nabu/pipeline/fallback_utils.py +0 -149
- nabu/pipeline/helical/tests/test_accumulator.py +0 -158
- nabu/pipeline/helical/tests/test_pipeline_elements_full.py +0 -355
- nabu/pipeline/helical/tests/test_strategy.py +0 -61
- nabu/pipeline/helical/utils.py +0 -51
- nabu/pipeline/tests/test_chunk_reader.py +0 -74
- {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/LICENSE +0 -0
- {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/WHEEL +0 -0
- {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/entry_points.txt +0 -0
- {nabu-2024.1.9.dist-info → nabu-2024.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,281 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
from typing import Optional, Union
|
4
|
+
from nabu import version as nabu_version
|
5
|
+
from nabu.stitching.config import (
|
6
|
+
PreProcessedSingleAxisStitchingConfiguration,
|
7
|
+
PostProcessedSingleAxisStitchingConfiguration,
|
8
|
+
SingleAxisStitchingConfiguration,
|
9
|
+
)
|
10
|
+
from nabu.stitching.stitcher.single_axis import PROGRESS_BAR_STITCH_VOL_DESC
|
11
|
+
from nabu.io.writer import get_datetime
|
12
|
+
from tomoscan.factory import Factory as TomoscanFactory
|
13
|
+
from silx.io.dictdump import dicttonx
|
14
|
+
from nxtomo.application.nxtomo import NXtomo
|
15
|
+
from tomoscan.utils.volume import concatenate as concatenate_volumes
|
16
|
+
from tomoscan.esrf.volume import HDF5Volume
|
17
|
+
from contextlib import AbstractContextManager
|
18
|
+
from threading import Thread
|
19
|
+
from time import sleep
|
20
|
+
|
21
|
+
_logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class StitchingPostProcAggregation:
|
25
|
+
"""
|
26
|
+
for remote stitching each process will stitch a part of the volume or projections.
|
27
|
+
Then once all are finished we want to aggregate them all to a final volume or NXtomo.
|
28
|
+
|
29
|
+
This is the goal of this class.
|
30
|
+
Please be careful with API. This is already inheriting from a tomwer class
|
31
|
+
|
32
|
+
:param stitching_config: configuration of the stitching configuration
|
33
|
+
:param futures: futures that just run
|
34
|
+
:param existing_objs: futures that just run
|
35
|
+
:param progress_bars: tqdm progress bars for each jobs
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
stitching_config: SingleAxisStitchingConfiguration,
|
41
|
+
futures: Optional[tuple] = None,
|
42
|
+
existing_objs_ids: Optional[tuple] = None,
|
43
|
+
progress_bars: Optional[dict] = None,
|
44
|
+
) -> None:
|
45
|
+
if not isinstance(stitching_config, (SingleAxisStitchingConfiguration)):
|
46
|
+
raise TypeError(f"stitching_config should be an instance of {SingleAxisStitchingConfiguration}")
|
47
|
+
if not ((existing_objs_ids is None) ^ (futures is None)):
|
48
|
+
raise ValueError("Either existing_objs or futures should be provided (can't provide both)")
|
49
|
+
if progress_bars is not None and not isinstance(progress_bars, dict):
|
50
|
+
raise TypeError(f"'progress_bars' should be None or an instance of a dict. Got {type(progress_bars)}")
|
51
|
+
self._futures = futures
|
52
|
+
self._stitching_config = stitching_config
|
53
|
+
self._existing_objs_ids = existing_objs_ids
|
54
|
+
self._progress_bars = progress_bars or {}
|
55
|
+
|
56
|
+
@property
|
57
|
+
def futures(self):
|
58
|
+
return self._futures
|
59
|
+
|
60
|
+
@property
|
61
|
+
def progress_bars(self) -> dict:
|
62
|
+
return self._progress_bars
|
63
|
+
|
64
|
+
def retrieve_tomo_objects(self) -> tuple:
|
65
|
+
"""
|
66
|
+
Return tomo objects to be stitched together. Either from future or from existing_objs
|
67
|
+
"""
|
68
|
+
if self._existing_objs_ids is not None:
|
69
|
+
scan_ids = self._existing_objs_ids
|
70
|
+
else:
|
71
|
+
results = {}
|
72
|
+
_logger.info(
|
73
|
+
f"wait for slurm job to be completed. Advancement will be created once slurm job output file will be available"
|
74
|
+
)
|
75
|
+
for obj_id, future in self.futures.items():
|
76
|
+
results[obj_id] = future.result()
|
77
|
+
|
78
|
+
failed = tuple(
|
79
|
+
filter(
|
80
|
+
lambda x: x.exception() is not None,
|
81
|
+
self.futures.values(),
|
82
|
+
)
|
83
|
+
)
|
84
|
+
if len(failed) > 0:
|
85
|
+
# if some job failed: useless to do the concatenation
|
86
|
+
exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed])
|
87
|
+
raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}")
|
88
|
+
|
89
|
+
canceled = tuple(
|
90
|
+
filter(
|
91
|
+
lambda x: x.cancelled(),
|
92
|
+
self.futures.values(),
|
93
|
+
)
|
94
|
+
)
|
95
|
+
if len(canceled) > 0:
|
96
|
+
# if some job canceled: useless to do the concatenation
|
97
|
+
raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}")
|
98
|
+
scan_ids = results.keys()
|
99
|
+
return [TomoscanFactory.create_tomo_object_from_identifier(scan_id) for scan_id in scan_ids]
|
100
|
+
|
101
|
+
def dump_stitching_config_as_nx_process(self, file_path: str, data_path: str, overwrite: bool, process_name: str):
|
102
|
+
dict_to_dump = {
|
103
|
+
process_name: {
|
104
|
+
"config": self._stitching_config.to_dict(),
|
105
|
+
"program": "nabu-stitching",
|
106
|
+
"version": nabu_version,
|
107
|
+
"date": get_datetime(),
|
108
|
+
},
|
109
|
+
f"{process_name}@NX_class": "NXprocess",
|
110
|
+
}
|
111
|
+
|
112
|
+
dicttonx(
|
113
|
+
dict_to_dump,
|
114
|
+
h5file=file_path,
|
115
|
+
h5path=data_path,
|
116
|
+
update_mode="replace" if overwrite else "add",
|
117
|
+
mode="a",
|
118
|
+
)
|
119
|
+
|
120
|
+
@property
|
121
|
+
def stitching_config(self) -> SingleAxisStitchingConfiguration:
|
122
|
+
return self._stitching_config
|
123
|
+
|
124
|
+
def process(self) -> None:
|
125
|
+
"""
|
126
|
+
main function
|
127
|
+
"""
|
128
|
+
|
129
|
+
# concatenate result
|
130
|
+
_logger.info("all job succeeded. Concatenate results")
|
131
|
+
if isinstance(self._stitching_config, PreProcessedSingleAxisStitchingConfiguration):
|
132
|
+
# 1: case of a pre-processing stitching
|
133
|
+
with self.follow_progress():
|
134
|
+
scans = self.retrieve_tomo_objects()
|
135
|
+
nx_tomos = []
|
136
|
+
for scan in scans:
|
137
|
+
if not os.path.exists(scan.master_file):
|
138
|
+
raise RuntimeError(
|
139
|
+
f"output file not created ({scan.master_file}). Stitching failed. "
|
140
|
+
"Please check slurm .out files to have more information. Most likely the slurm configuration is invalid. "
|
141
|
+
"(partition name not existing...)"
|
142
|
+
)
|
143
|
+
nx_tomos.append(
|
144
|
+
NXtomo().load(
|
145
|
+
file_path=scan.master_file,
|
146
|
+
data_path=scan.entry,
|
147
|
+
)
|
148
|
+
)
|
149
|
+
final_nx_tomo = NXtomo.concatenate(nx_tomos)
|
150
|
+
final_nx_tomo.save(
|
151
|
+
file_path=self.stitching_config.output_file_path,
|
152
|
+
data_path=self.stitching_config.output_data_path,
|
153
|
+
overwrite=self.stitching_config.overwrite_results,
|
154
|
+
)
|
155
|
+
|
156
|
+
# dump NXprocess if possible
|
157
|
+
parts = self.stitching_config.output_data_path.split("/")
|
158
|
+
process_name = parts[-1] + "_stitching"
|
159
|
+
if len(parts) < 2:
|
160
|
+
data_path = "/"
|
161
|
+
else:
|
162
|
+
data_path = "/".join(parts[:-1])
|
163
|
+
|
164
|
+
self.dump_stitching_config_as_nx_process(
|
165
|
+
file_path=self.stitching_config.output_file_path,
|
166
|
+
data_path=data_path,
|
167
|
+
process_name=process_name,
|
168
|
+
overwrite=self.stitching_config.overwrite_results,
|
169
|
+
)
|
170
|
+
|
171
|
+
elif isinstance(self.stitching_config, PostProcessedSingleAxisStitchingConfiguration):
|
172
|
+
# 2: case of a post-processing stitching
|
173
|
+
with self.follow_progress():
|
174
|
+
outputs_sub_volumes = self.retrieve_tomo_objects()
|
175
|
+
concatenate_volumes(
|
176
|
+
output_volume=self.stitching_config.output_volume,
|
177
|
+
volumes=tuple(outputs_sub_volumes),
|
178
|
+
axis=1,
|
179
|
+
)
|
180
|
+
|
181
|
+
if isinstance(self.stitching_config.output_volume, HDF5Volume):
|
182
|
+
parts = self.stitching_config.output_volume.metadata_url.data_path().split("/")
|
183
|
+
process_name = parts[-1] + "_stitching"
|
184
|
+
if len(parts) < 2:
|
185
|
+
data_path = "/"
|
186
|
+
else:
|
187
|
+
data_path = "/".join(parts[:-1])
|
188
|
+
|
189
|
+
self.dump_stitching_config_as_nx_process(
|
190
|
+
file_path=self.stitching_config.output_volume.metadata_url.file_path(),
|
191
|
+
data_path=data_path,
|
192
|
+
process_name=process_name,
|
193
|
+
overwrite=self.stitching_config.overwrite_results,
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
raise TypeError(f"stitching_config type ({type(self.stitching_config)}) not handled")
|
197
|
+
|
198
|
+
def follow_progress(self) -> AbstractContextManager:
|
199
|
+
return SlurmStitchingFollowerContext(
|
200
|
+
output_files_to_progress_bars={
|
201
|
+
job._get_output_file_path(): progress_bar for (job, progress_bar) in self.progress_bars.items()
|
202
|
+
}
|
203
|
+
)
|
204
|
+
|
205
|
+
|
206
|
+
class SlurmStitchingFollowerContext(AbstractContextManager):
|
207
|
+
"""Util class to provide user feedback from stitching done on slurm"""
|
208
|
+
|
209
|
+
def __init__(self, output_files_to_progress_bars: dict):
|
210
|
+
self._update_thread = SlurmStitchingFollowerThread(file_to_progress_bar=output_files_to_progress_bars)
|
211
|
+
|
212
|
+
def __enter__(self) -> None:
|
213
|
+
self._update_thread.start()
|
214
|
+
|
215
|
+
def __exit__(self, *args, **kwargs):
|
216
|
+
self._update_thread.join(timeout=1.5)
|
217
|
+
for progress_bar in self._update_thread.file_to_progress_bar.values():
|
218
|
+
progress_bar.close() # close to clean display as leave == False
|
219
|
+
|
220
|
+
|
221
|
+
class SlurmStitchingFollowerThread(Thread):
|
222
|
+
"""
|
223
|
+
Thread to check progression of stitching slurm job(s)
|
224
|
+
Read slurm jobs .out file each 'delay time' and look for a tqdm line at the end.
|
225
|
+
If it exists then deduce progress from it.
|
226
|
+
|
227
|
+
file_to_progress_bar provide for each slurm .out file the progress bar to update
|
228
|
+
"""
|
229
|
+
|
230
|
+
def __init__(self, file_to_progress_bar: dict, delay_time: float = 0.5) -> None:
|
231
|
+
super().__init__()
|
232
|
+
self._stop_run = False
|
233
|
+
self._wait_time = delay_time
|
234
|
+
self._file_to_progress_bar = file_to_progress_bar
|
235
|
+
self._first_run = True
|
236
|
+
|
237
|
+
@property
|
238
|
+
def file_to_progress_bar(self) -> dict:
|
239
|
+
return self._file_to_progress_bar
|
240
|
+
|
241
|
+
def run(self) -> None:
|
242
|
+
while not self._stop_run:
|
243
|
+
for file_path, progress_bar in self._file_to_progress_bar.items():
|
244
|
+
if self._first_run:
|
245
|
+
# make sure each progress bar have been refreshed at least one
|
246
|
+
progress_bar.refresh()
|
247
|
+
|
248
|
+
if not os.path.exists(file_path):
|
249
|
+
continue
|
250
|
+
with open(file_path, "r") as f:
|
251
|
+
try:
|
252
|
+
last_line = f.readlines()[-1]
|
253
|
+
except IndexError:
|
254
|
+
continue
|
255
|
+
advancement = self.cast_progress_line_from_log(line=last_line)
|
256
|
+
if advancement is not None:
|
257
|
+
progress_bar.n = advancement
|
258
|
+
progress_bar.refresh()
|
259
|
+
|
260
|
+
self._first_run = False
|
261
|
+
|
262
|
+
sleep(self._wait_time)
|
263
|
+
|
264
|
+
def join(self, timeout: Union[float, None] = None) -> None:
|
265
|
+
self._stop_run = True
|
266
|
+
return super().join(timeout)
|
267
|
+
|
268
|
+
@staticmethod
|
269
|
+
def cast_progress_line_from_log(line: str) -> Optional[float]:
|
270
|
+
"""Try to retrieve from a line from log the advancement (in percentage)"""
|
271
|
+
if PROGRESS_BAR_STITCH_VOL_DESC not in line or "%" not in line:
|
272
|
+
return None
|
273
|
+
|
274
|
+
str_before_percentage = line.split("%")[0].split(" ")[-1]
|
275
|
+
try:
|
276
|
+
advancement = float(str_before_percentage)
|
277
|
+
except ValueError:
|
278
|
+
_logger.debug(f"Failed to retrieve advancement from log file. Value got is {str_before_percentage}")
|
279
|
+
return None
|
280
|
+
else:
|
281
|
+
return advancement
|
@@ -0,0 +1,21 @@
|
|
1
|
+
import pytest
|
2
|
+
|
3
|
+
from nabu.stitching.stitcher.single_axis import PROGRESS_BAR_STITCH_VOL_DESC
|
4
|
+
from nabu.stitching.utils.post_processing import SlurmStitchingFollowerThread
|
5
|
+
|
6
|
+
|
7
|
+
@pytest.mark.parametrize(
|
8
|
+
"test_case",
|
9
|
+
{
|
10
|
+
"dump configuration: 100%|": None,
|
11
|
+
f"stitching : 100%|": None,
|
12
|
+
f"{PROGRESS_BAR_STITCH_VOL_DESC}: 42%": 42.0,
|
13
|
+
f"{PROGRESS_BAR_STITCH_VOL_DESC}: 56% toto: 23%": 56.0,
|
14
|
+
"": None,
|
15
|
+
"my%": None,
|
16
|
+
}.items(),
|
17
|
+
)
|
18
|
+
def test_SlurmStitchingFollowerContext(test_case):
|
19
|
+
"""Test that the conversion from log lines created by tqdm can be read back"""
|
20
|
+
str_to_test, expected_result = test_case
|
21
|
+
assert SlurmStitchingFollowerThread.cast_progress_line_from_log(str_to_test) == expected_result
|
@@ -3,22 +3,18 @@ from typing import Optional, Union
|
|
3
3
|
import logging
|
4
4
|
import functools
|
5
5
|
import numpy
|
6
|
-
from scipy.ndimage import affine_transform
|
7
6
|
from tomoscan.scanbase import TomoScanBase
|
8
7
|
from tomoscan.volumebase import VolumeBase
|
9
|
-
from nxtomo.utils.transformation import build_matrix,
|
8
|
+
from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation
|
10
9
|
from silx.utils.enum import Enum as _Enum
|
11
10
|
from scipy.fft import rfftn as local_fftn
|
12
11
|
from scipy.fft import irfftn as local_ifftn
|
13
|
-
from
|
14
|
-
from
|
15
|
-
from
|
16
|
-
from .
|
17
|
-
from .
|
18
|
-
from
|
19
|
-
from ..estimation.alignment import AlignmentBase
|
20
|
-
from ..resources.dataset_analyzer import HDF5DatasetAnalyzer
|
21
|
-
from ..resources.nxflatfield import update_dataset_info_flats_darks
|
12
|
+
from ..overlap import OverlapStitchingStrategy, ImageStichOverlapKernel
|
13
|
+
from ..alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
|
14
|
+
from ...misc import fourier_filters
|
15
|
+
from ...estimation.alignment import AlignmentBase
|
16
|
+
from ...resources.dataset_analyzer import HDF5DatasetAnalyzer
|
17
|
+
from ...resources.nxflatfield import update_dataset_info_flats_darks
|
22
18
|
|
23
19
|
try:
|
24
20
|
import itk
|
@@ -66,39 +62,22 @@ class ShiftAlgorithm(_Enum):
|
|
66
62
|
return super().from_value(value=value)
|
67
63
|
|
68
64
|
|
69
|
-
def test_overlap_stitching_strategy(overlap_1, overlap_2, stitching_strategies):
|
70
|
-
"""
|
71
|
-
stitch the two ovrelap with all the requested strategies.
|
72
|
-
Return a dictionary with stitching strategy as key and a result dict as value.
|
73
|
-
result dict keys are: 'weights_overlap_1', 'weights_overlap_2', 'stiching'
|
74
|
-
"""
|
75
|
-
res = {}
|
76
|
-
for strategy in stitching_strategies:
|
77
|
-
s = OverlapStitchingStrategy.from_value(strategy)
|
78
|
-
stitcher = ZStichOverlapKernel(
|
79
|
-
stitching_strategy=s,
|
80
|
-
frame_width=overlap_1.shape[1],
|
81
|
-
)
|
82
|
-
stiched_overlap, w1, w2 = stitcher.stitch(overlap_1, overlap_2, check_input=True)
|
83
|
-
res[s.value] = {
|
84
|
-
"stitching": stiched_overlap,
|
85
|
-
"weights_overlap_1": w1,
|
86
|
-
"weights_overlap_2": w2,
|
87
|
-
}
|
88
|
-
return res
|
89
|
-
|
90
|
-
|
91
65
|
def find_frame_relative_shifts(
|
92
66
|
overlap_upper_frame: numpy.ndarray,
|
93
67
|
overlap_lower_frame: numpy.ndarray,
|
94
|
-
estimated_shifts,
|
68
|
+
estimated_shifts: tuple,
|
69
|
+
overlap_axis: int,
|
95
70
|
x_cross_correlation_function=None,
|
96
71
|
y_cross_correlation_function=None,
|
97
72
|
x_shifts_params: Optional[dict] = None,
|
98
73
|
y_shifts_params: Optional[dict] = None,
|
99
74
|
):
|
75
|
+
"""
|
76
|
+
:param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x
|
77
|
+
"""
|
78
|
+
if not overlap_axis in (0, 1):
|
79
|
+
raise ValueError(f"overlap_axis should be in (0, 1). Get {overlap_axis}")
|
100
80
|
from nabu.stitching.config import (
|
101
|
-
KEY_WINDOW_SIZE,
|
102
81
|
KEY_LOW_PASS_FILTER,
|
103
82
|
KEY_HIGH_PASS_FILTER,
|
104
83
|
) # avoid cyclic import
|
@@ -188,6 +167,7 @@ def find_frame_relative_shifts(
|
|
188
167
|
def find_volumes_relative_shifts(
|
189
168
|
upper_volume: VolumeBase,
|
190
169
|
lower_volume: VolumeBase,
|
170
|
+
overlap_axis: int,
|
191
171
|
estimated_shifts,
|
192
172
|
dim_axis_1: int,
|
193
173
|
dtype,
|
@@ -210,6 +190,15 @@ def find_volumes_relative_shifts(
|
|
210
190
|
|
211
191
|
if x_shifts_params is None:
|
212
192
|
x_shifts_params = {}
|
193
|
+
# convert from overlap_axis (3D acquisition space) to overlap_axis_proj_space.
|
194
|
+
if overlap_axis == 1:
|
195
|
+
raise NotImplementedError("finding projection shift along axis 1 is not handled for projections")
|
196
|
+
elif overlap_axis == 0:
|
197
|
+
overlap_axis_proj_space = 0
|
198
|
+
elif overlap_axis == 2:
|
199
|
+
overlap_axis_proj_space = 1
|
200
|
+
else:
|
201
|
+
raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {overlap_axis}")
|
213
202
|
|
214
203
|
alignment_axis_2 = AlignmentAxis2.from_value(alignment_axis_2)
|
215
204
|
alignment_axis_1 = AlignmentAxis1.from_value(alignment_axis_1)
|
@@ -299,6 +288,7 @@ def find_volumes_relative_shifts(
|
|
299
288
|
y_cross_correlation_function=y_cross_correlation_function,
|
300
289
|
x_shifts_params=x_shifts_params,
|
301
290
|
y_shifts_params=y_shifts_params,
|
291
|
+
overlap_axis=overlap_axis_proj_space,
|
302
292
|
)
|
303
293
|
|
304
294
|
|
@@ -308,7 +298,8 @@ from nabu.pipeline.estimators import estimate_cor
|
|
308
298
|
def find_projections_relative_shifts(
|
309
299
|
upper_scan: TomoScanBase,
|
310
300
|
lower_scan: TomoScanBase,
|
311
|
-
estimated_shifts,
|
301
|
+
estimated_shifts: tuple,
|
302
|
+
axis: int,
|
312
303
|
flip_ud_upper_frame: bool = False,
|
313
304
|
flip_ud_lower_frame: bool = False,
|
314
305
|
projection_for_shift: Union[int, str] = "middle",
|
@@ -326,13 +317,16 @@ def find_projections_relative_shifts(
|
|
326
317
|
|
327
318
|
:param TomoScanBase scan_0:
|
328
319
|
:param TomoScanBase scan_1:
|
329
|
-
:param
|
320
|
+
:param tuple estimated_shifts: 'a priori' shift estimation
|
321
|
+
:param int axis: axis on which the overlap / stitching is happening. In the 3D space (sample, detector referential)
|
322
|
+
:param bool flip_ud_upper_frame: is the upper frame flipped
|
323
|
+
:param bool flip_ud_lower_frame: is the lower frame flipped
|
330
324
|
:param Union[int,str] projection_for_shift: index fo the projection to use (in projection space or in scan space ?. For now in projection) or str. If str must be in (`middle`, `first`, `last`)
|
325
|
+
:param bool invert_order: are projections inverted between the two scans (case if rotation angle are inverted)
|
331
326
|
:param str x_cross_correlation_function: optional method to refine x shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft")
|
332
327
|
:param str y_cross_correlation_function: optional method to refine y shift from computing cross correlation. For now valid values are: ("skimage", "nabu-fft")
|
333
|
-
:param
|
334
|
-
:param
|
335
|
-
:param tuple estimated_shifts: 'a priori' shift estimation
|
328
|
+
:param x_shifts_params: parameters to find the shift over x
|
329
|
+
:param y_shifts_params: parameters to find the shift over y
|
336
330
|
:return: relative shift of scan_1 with scan_0 as reference: (y_shift, x_shift)
|
337
331
|
:rtype: tuple
|
338
332
|
|
@@ -342,8 +336,18 @@ def find_projections_relative_shifts(
|
|
342
336
|
x_shifts_params = {}
|
343
337
|
if y_shifts_params is None:
|
344
338
|
y_shifts_params = {}
|
345
|
-
|
346
|
-
|
339
|
+
|
340
|
+
# convert from overlap_axis (3D acquisition space) to overlap_axis_proj_space.
|
341
|
+
if axis == 1:
|
342
|
+
axis_proj_space = 1
|
343
|
+
elif axis == 0:
|
344
|
+
axis_proj_space = 0
|
345
|
+
elif axis == 2:
|
346
|
+
raise NotImplementedError(
|
347
|
+
"finding projection shift along axis 1 (x-ray direction) is not handled for projections"
|
348
|
+
)
|
349
|
+
else:
|
350
|
+
raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {axis}")
|
347
351
|
|
348
352
|
x_cross_correlation_function = ShiftAlgorithm.from_value(x_cross_correlation_function)
|
349
353
|
y_cross_correlation_function = ShiftAlgorithm.from_value(y_cross_correlation_function)
|
@@ -432,11 +436,11 @@ def find_projections_relative_shifts(
|
|
432
436
|
|
433
437
|
upper_scan_transformations = list(upper_scan.get_detector_transformations(tuple()))
|
434
438
|
if flip_ud_upper_frame:
|
435
|
-
upper_scan_transformations.append(
|
439
|
+
upper_scan_transformations.append(DetYFlipTransformation(flip=True))
|
436
440
|
upper_scan_trans_matrix = build_matrix(upper_scan_transformations)
|
437
441
|
lower_scan_transformations = list(lower_scan.get_detector_transformations(tuple()))
|
438
442
|
if flip_ud_lower_frame:
|
439
|
-
lower_scan_transformations.append(
|
443
|
+
lower_scan_transformations.append(DetYFlipTransformation(flip=True))
|
440
444
|
lower_scan_trans_matrix = build_matrix(lower_scan_transformations)
|
441
445
|
upper_proj = get_flat_fielded_proj(
|
442
446
|
upper_scan,
|
@@ -453,16 +457,30 @@ def find_projections_relative_shifts(
|
|
453
457
|
|
454
458
|
from nabu.stitching.config import KEY_WINDOW_SIZE # avoid cyclic import
|
455
459
|
|
456
|
-
|
457
|
-
|
458
|
-
end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, min(upper_proj.shape[0], lower_proj.shape[0]))
|
459
|
-
if start_overlap == 0:
|
460
|
-
overlap_upper_frame = upper_proj[-end_overlap:]
|
460
|
+
if axis_proj_space == 0:
|
461
|
+
w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
|
461
462
|
else:
|
462
|
-
|
463
|
-
|
463
|
+
w_window_size = int(x_shifts_params.get(KEY_WINDOW_SIZE, 400))
|
464
|
+
start_overlap = max(estimated_shifts[axis_proj_space] // 2 - w_window_size // 2, 0)
|
465
|
+
end_overlap = min(
|
466
|
+
estimated_shifts[axis_proj_space] // 2 + w_window_size // 2,
|
467
|
+
min(upper_proj.shape[axis_proj_space], lower_proj.shape[axis_proj_space]),
|
468
|
+
)
|
469
|
+
o_upper_sel = numpy.array(range(-end_overlap, -start_overlap))
|
470
|
+
overlap_upper_frame = numpy.take_along_axis(
|
471
|
+
upper_proj,
|
472
|
+
o_upper_sel[:, None] if axis_proj_space == 0 else o_upper_sel[None, :],
|
473
|
+
axis=axis_proj_space,
|
474
|
+
)
|
475
|
+
o_lower_sel = numpy.array(range(start_overlap, end_overlap))
|
476
|
+
overlap_lower_frame = numpy.take_along_axis(
|
477
|
+
lower_proj,
|
478
|
+
o_lower_sel[:, None] if axis_proj_space == 0 else o_upper_sel[None, :],
|
479
|
+
axis=axis_proj_space,
|
480
|
+
)
|
481
|
+
|
464
482
|
if not overlap_upper_frame.shape == overlap_lower_frame.shape:
|
465
|
-
raise ValueError(f"Fail to get
|
483
|
+
raise ValueError(f"Fail to get consistent overlap ({overlap_upper_frame.shape} vs {overlap_lower_frame.shape})")
|
466
484
|
|
467
485
|
return find_frame_relative_shifts(
|
468
486
|
overlap_upper_frame=overlap_upper_frame,
|
@@ -472,6 +490,7 @@ def find_projections_relative_shifts(
|
|
472
490
|
y_cross_correlation_function=y_cross_correlation_function,
|
473
491
|
x_shifts_params=x_shifts_params,
|
474
492
|
y_shifts_params=y_shifts_params,
|
493
|
+
overlap_axis=axis_proj_space,
|
475
494
|
)
|
476
495
|
|
477
496
|
|
@@ -561,3 +580,11 @@ def find_shift_with_itk(img1: numpy.ndarray, img2: numpy.ndarray) -> tuple:
|
|
561
580
|
translation_along_y = final_parameters.GetElement(1)
|
562
581
|
|
563
582
|
return numpy.round(translation_along_y), numpy.round(translation_along_x)
|
583
|
+
|
584
|
+
|
585
|
+
def from_slice_to_n_elements(slice_: Union[slice, tuple]):
|
586
|
+
"""Return the number of element in a slice or in a tuple"""
|
587
|
+
if isinstance(slice_, slice):
|
588
|
+
return (slice_.stop - slice_.start) / (slice_.step or 1)
|
589
|
+
else:
|
590
|
+
return len(slice_)
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from tomoscan.identifier import BaseIdentifier
|
2
|
+
from nabu.stitching.stitcher.y_stitcher import PreProcessingYStitcher as PreProcessYStitcher
|
3
|
+
from nabu.stitching.config import PreProcessedYStitchingConfiguration
|
4
|
+
|
5
|
+
|
6
|
+
def y_stitching(configuration: PreProcessedYStitchingConfiguration, progress=None) -> BaseIdentifier:
|
7
|
+
"""
|
8
|
+
Apply stitching from provided configuration.
|
9
|
+
Stitching will be applied along the first axis - 1 (aka y).
|
10
|
+
|
11
|
+
like:
|
12
|
+
axis 0
|
13
|
+
^
|
14
|
+
|
|
15
|
+
x-ray |
|
16
|
+
--------> ------> axis 2
|
17
|
+
/
|
18
|
+
/
|
19
|
+
axis 1
|
20
|
+
"""
|
21
|
+
if isinstance(configuration, PreProcessedYStitchingConfiguration):
|
22
|
+
stitcher = PreProcessYStitcher(configuration=configuration, progress=progress)
|
23
|
+
else:
|
24
|
+
raise TypeError(
|
25
|
+
f"configuration is expected to be in {(PreProcessedYStitchingConfiguration, )}. {type(configuration)} provided"
|
26
|
+
)
|
27
|
+
return stitcher.stitch()
|