nabu 2023.2.1__py3-none-any.whl → 2024.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/conf.py +1 -1
- doc/doc_config.py +32 -0
- nabu/__init__.py +2 -1
- nabu/app/bootstrap_stitching.py +1 -1
- nabu/app/cli_configs.py +122 -2
- nabu/app/composite_cor.py +27 -2
- nabu/app/correct_rot.py +70 -0
- nabu/app/create_distortion_map_from_poly.py +42 -18
- nabu/app/diag_to_pix.py +358 -0
- nabu/app/diag_to_rot.py +449 -0
- nabu/app/generate_header.py +4 -3
- nabu/app/histogram.py +2 -2
- nabu/app/multicor.py +6 -1
- nabu/app/parse_reconstruction_log.py +151 -0
- nabu/app/prepare_weights_double.py +83 -22
- nabu/app/reconstruct.py +5 -1
- nabu/app/reconstruct_helical.py +7 -0
- nabu/app/reduce_dark_flat.py +6 -3
- nabu/app/rotate.py +4 -4
- nabu/app/stitching.py +16 -2
- nabu/app/tests/test_reduce_dark_flat.py +18 -2
- nabu/app/validator.py +4 -4
- nabu/cuda/convolution.py +8 -376
- nabu/cuda/fft.py +4 -0
- nabu/cuda/kernel.py +4 -4
- nabu/cuda/medfilt.py +5 -158
- nabu/cuda/padding.py +5 -71
- nabu/cuda/processing.py +23 -2
- nabu/cuda/src/ElementOp.cu +78 -0
- nabu/cuda/src/backproj.cu +28 -2
- nabu/cuda/src/fourier_wavelets.cu +2 -2
- nabu/cuda/src/normalization.cu +23 -0
- nabu/cuda/src/padding.cu +2 -2
- nabu/cuda/src/transpose.cu +16 -0
- nabu/cuda/utils.py +39 -0
- nabu/estimation/alignment.py +10 -1
- nabu/estimation/cor.py +808 -38
- nabu/estimation/cor_sino.py +7 -9
- nabu/estimation/tests/test_cor.py +85 -3
- nabu/io/reader.py +26 -18
- nabu/io/tests/test_cast_volume.py +3 -3
- nabu/io/tests/test_detector_distortion.py +3 -3
- nabu/io/tiffwriter_zmm.py +2 -2
- nabu/io/utils.py +14 -4
- nabu/io/writer.py +5 -3
- nabu/misc/fftshift.py +6 -0
- nabu/misc/histogram.py +5 -285
- nabu/misc/histogram_cuda.py +8 -104
- nabu/misc/kernel_base.py +3 -121
- nabu/misc/padding_base.py +5 -69
- nabu/misc/processing_base.py +3 -107
- nabu/misc/rotation.py +5 -62
- nabu/misc/rotation_cuda.py +5 -65
- nabu/misc/transpose.py +6 -0
- nabu/misc/unsharp.py +3 -78
- nabu/misc/unsharp_cuda.py +5 -52
- nabu/misc/unsharp_opencl.py +8 -85
- nabu/opencl/fft.py +6 -0
- nabu/opencl/kernel.py +21 -6
- nabu/opencl/padding.py +5 -72
- nabu/opencl/processing.py +27 -5
- nabu/opencl/src/backproj.cl +3 -3
- nabu/opencl/src/fftshift.cl +65 -12
- nabu/opencl/src/padding.cl +2 -2
- nabu/opencl/src/roll.cl +96 -0
- nabu/opencl/src/transpose.cl +16 -0
- nabu/pipeline/config_validators.py +63 -3
- nabu/pipeline/dataset_validator.py +2 -2
- nabu/pipeline/estimators.py +193 -35
- nabu/pipeline/fullfield/chunked.py +34 -17
- nabu/pipeline/fullfield/chunked_cuda.py +7 -5
- nabu/pipeline/fullfield/computations.py +48 -13
- nabu/pipeline/fullfield/nabu_config.py +13 -13
- nabu/pipeline/fullfield/processconfig.py +10 -5
- nabu/pipeline/fullfield/reconstruction.py +1 -2
- nabu/pipeline/helical/fbp.py +5 -0
- nabu/pipeline/helical/filtering.py +12 -9
- nabu/pipeline/helical/gridded_accumulator.py +179 -33
- nabu/pipeline/helical/helical_chunked_regridded.py +262 -151
- nabu/pipeline/helical/helical_chunked_regridded_cuda.py +4 -11
- nabu/pipeline/helical/helical_reconstruction.py +56 -18
- nabu/pipeline/helical/span_strategy.py +1 -1
- nabu/pipeline/helical/tests/test_accumulator.py +4 -0
- nabu/pipeline/params.py +23 -2
- nabu/pipeline/processconfig.py +3 -8
- nabu/pipeline/tests/test_chunk_reader.py +78 -0
- nabu/pipeline/tests/test_estimators.py +120 -2
- nabu/pipeline/utils.py +25 -0
- nabu/pipeline/writer.py +2 -0
- nabu/preproc/ccd_cuda.py +9 -7
- nabu/preproc/ctf.py +21 -26
- nabu/preproc/ctf_cuda.py +25 -25
- nabu/preproc/double_flatfield.py +14 -2
- nabu/preproc/double_flatfield_cuda.py +7 -11
- nabu/preproc/flatfield_cuda.py +23 -27
- nabu/preproc/phase.py +19 -24
- nabu/preproc/phase_cuda.py +21 -21
- nabu/preproc/shift_cuda.py +58 -28
- nabu/preproc/tests/test_ctf.py +5 -5
- nabu/preproc/tests/test_double_flatfield.py +2 -2
- nabu/preproc/tests/test_vshift.py +13 -2
- nabu/processing/__init__.py +0 -0
- nabu/processing/convolution_cuda.py +375 -0
- nabu/processing/fft_base.py +163 -0
- nabu/processing/fft_cuda.py +256 -0
- nabu/processing/fft_opencl.py +54 -0
- nabu/processing/fftshift.py +134 -0
- nabu/processing/histogram.py +286 -0
- nabu/processing/histogram_cuda.py +103 -0
- nabu/processing/kernel_base.py +126 -0
- nabu/processing/medfilt_cuda.py +159 -0
- nabu/processing/muladd.py +29 -0
- nabu/processing/muladd_cuda.py +68 -0
- nabu/processing/padding_base.py +71 -0
- nabu/processing/padding_cuda.py +75 -0
- nabu/processing/padding_opencl.py +77 -0
- nabu/processing/processing_base.py +123 -0
- nabu/processing/roll_opencl.py +64 -0
- nabu/processing/rotation.py +63 -0
- nabu/processing/rotation_cuda.py +66 -0
- nabu/processing/tests/__init__.py +0 -0
- nabu/processing/tests/test_fft.py +268 -0
- nabu/processing/tests/test_fftshift.py +71 -0
- nabu/{misc → processing}/tests/test_histogram.py +2 -4
- nabu/{cuda → processing}/tests/test_medfilt.py +1 -1
- nabu/processing/tests/test_muladd.py +54 -0
- nabu/{cuda → processing}/tests/test_padding.py +119 -75
- nabu/processing/tests/test_roll.py +63 -0
- nabu/{misc → processing}/tests/test_rotation.py +3 -2
- nabu/processing/tests/test_transpose.py +72 -0
- nabu/{misc → processing}/tests/test_unsharp.py +41 -8
- nabu/processing/transpose.py +126 -0
- nabu/processing/unsharp.py +79 -0
- nabu/processing/unsharp_cuda.py +53 -0
- nabu/processing/unsharp_opencl.py +75 -0
- nabu/reconstruction/fbp.py +34 -10
- nabu/reconstruction/fbp_base.py +35 -16
- nabu/reconstruction/fbp_opencl.py +7 -12
- nabu/reconstruction/filtering.py +2 -2
- nabu/reconstruction/filtering_cuda.py +13 -14
- nabu/reconstruction/filtering_opencl.py +3 -4
- nabu/reconstruction/projection.py +2 -0
- nabu/reconstruction/rings.py +158 -1
- nabu/reconstruction/rings_cuda.py +218 -58
- nabu/reconstruction/sinogram_cuda.py +16 -12
- nabu/reconstruction/tests/test_deringer.py +116 -14
- nabu/reconstruction/tests/test_fbp.py +22 -31
- nabu/reconstruction/tests/test_filtering.py +11 -2
- nabu/resources/dataset_analyzer.py +89 -26
- nabu/resources/nxflatfield.py +2 -2
- nabu/resources/tests/test_nxflatfield.py +1 -1
- nabu/resources/utils.py +9 -2
- nabu/stitching/alignment.py +184 -0
- nabu/stitching/config.py +241 -39
- nabu/stitching/definitions.py +6 -0
- nabu/stitching/frame_composition.py +4 -2
- nabu/stitching/overlap.py +99 -3
- nabu/stitching/sample_normalization.py +60 -0
- nabu/stitching/slurm_utils.py +10 -10
- nabu/stitching/tests/test_alignment.py +99 -0
- nabu/stitching/tests/test_config.py +16 -1
- nabu/stitching/tests/test_overlap.py +68 -2
- nabu/stitching/tests/test_sample_normalization.py +49 -0
- nabu/stitching/tests/test_slurm_utils.py +5 -5
- nabu/stitching/tests/test_utils.py +3 -33
- nabu/stitching/tests/test_z_stitching.py +391 -22
- nabu/stitching/utils.py +144 -202
- nabu/stitching/z_stitching.py +309 -126
- nabu/testutils.py +18 -0
- nabu/thirdparty/tomocupy_remove_stripe.py +586 -0
- nabu/utils.py +32 -6
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/LICENSE +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/METADATA +5 -5
- nabu-2024.1.0rc3.dist-info/RECORD +296 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/WHEEL +1 -1
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/entry_points.txt +5 -1
- nabu/conftest.py +0 -14
- nabu/opencl/fftshift.py +0 -92
- nabu/opencl/tests/test_fftshift.py +0 -55
- nabu/opencl/tests/test_padding.py +0 -84
- nabu-2023.2.1.dist-info/RECORD +0 -252
- /nabu/cuda/src/{fftshift.cu → dfi_fftshift.cu} +0 -0
- {nabu-2023.2.1.dist-info → nabu-2024.1.0rc3.dist-info}/top_level.txt +0 -0
nabu/stitching/z_stitching.py
CHANGED
@@ -37,30 +37,33 @@ from math import ceil
|
|
37
37
|
from contextlib import AbstractContextManager
|
38
38
|
import h5py
|
39
39
|
import logging
|
40
|
-
|
41
40
|
from scipy.ndimage import shift as shift_scipy
|
41
|
+
from functools import lru_cache as cache
|
42
42
|
|
43
43
|
from silx.io.utils import get_data
|
44
44
|
from silx.io.url import DataUrl
|
45
45
|
from silx.io.dictdump import dicttonx
|
46
46
|
|
47
|
+
from nxtomo.nxobject.nxdetector import ImageKey
|
48
|
+
from nxtomo.nxobject.nxtransformations import NXtransformations
|
49
|
+
from nxtomo.paths.nxtomo import get_paths as _get_nexus_paths
|
50
|
+
from nxtomo.utils.transformation import build_matrix, LRDetTransformation, UDDetTransformation
|
51
|
+
|
47
52
|
from tomoscan.io import HDF5File
|
48
53
|
from tomoscan.esrf.scan.utils import cwd_context
|
49
54
|
from tomoscan.identifier import BaseIdentifier
|
50
|
-
from tomoscan.esrf import
|
55
|
+
from tomoscan.esrf import NXtomoScan, EDFTomoScan
|
51
56
|
from tomoscan.volumebase import VolumeBase
|
52
57
|
from tomoscan.esrf.volume import HDF5Volume
|
53
58
|
from tomoscan.serie import Serie
|
54
|
-
from tomoscan.esrf.scan.hdf5scan import ImageKey
|
55
|
-
from tomoscan.nexus.paths.nxtomo import get_paths as _get_nexus_paths
|
56
59
|
from tomoscan.factory import Factory as TomoscanFactory
|
57
60
|
from tomoscan.utils.volume import concatenate as concatenate_volumes
|
58
61
|
from tomoscan.esrf.scan.utils import (
|
59
62
|
get_compacted_dataslices,
|
60
63
|
) # this version has a 'return_url_set' needed here. At one point they should be merged together
|
61
|
-
from
|
64
|
+
from pyunitsystem.metricsystem import MetricSystem
|
62
65
|
|
63
|
-
from
|
66
|
+
from nxtomo.application.nxtomo import NXtomo
|
64
67
|
from silx.io.dictdump import dicttonx
|
65
68
|
|
66
69
|
from nabu.io.utils import DatasetReader
|
@@ -76,16 +79,19 @@ from nabu.stitching.config import (
|
|
76
79
|
KEY_RESCALE_MAX_PERCENTILES,
|
77
80
|
KEY_THRESHOLD_FREQUENCY,
|
78
81
|
)
|
82
|
+
from nabu.stitching.alignment import align_horizontally, AlignmentAxis1
|
79
83
|
from nabu.utils import Progress
|
80
84
|
from nabu import version as nabu_version
|
81
85
|
from nabu.io.writer import get_datetime
|
82
86
|
from .overlap import (
|
83
87
|
ZStichOverlapKernel,
|
88
|
+
check_overlaps,
|
84
89
|
)
|
85
90
|
from .. import version as nabu_version
|
86
91
|
from nabu.io.writer import get_datetime
|
87
92
|
from nabu.misc.utils import rescale_data
|
88
|
-
|
93
|
+
from nabu.stitching.alignment import PaddedRawData
|
94
|
+
from nabu.stitching.sample_normalization import normalize_frame as normalize_frame_by_sample
|
89
95
|
|
90
96
|
_logger = logging.getLogger(__name__)
|
91
97
|
|
@@ -127,6 +133,9 @@ class ZStitcher:
|
|
127
133
|
self._axis_2_rel_shifts = []
|
128
134
|
# shift between upper and lower frames
|
129
135
|
|
136
|
+
self._stitching_width = None
|
137
|
+
# stitching width: larger volume width. Other volume will be pad
|
138
|
+
|
130
139
|
# z serie must be defined from daughter class
|
131
140
|
assert hasattr(self, "_z_serie")
|
132
141
|
|
@@ -184,6 +193,7 @@ class ZStitcher:
|
|
184
193
|
estimated_shifts_axis_0.insert(0, 0)
|
185
194
|
|
186
195
|
final_pos = {}
|
196
|
+
previous_shift = 0
|
187
197
|
for tomo_obj, pos_axis_0, pos_axis_2, final_shift_axis_0, estimated_shift_axis_0, final_shift_axis_2 in zip(
|
188
198
|
self.z_serie,
|
189
199
|
self.configuration.axis_0_pos_px,
|
@@ -194,10 +204,11 @@ class ZStitcher:
|
|
194
204
|
):
|
195
205
|
# warning estimated_shift is the estimatation from the overlap. So playes no role here
|
196
206
|
final_pos[tomo_obj.get_identifier().to_str()] = (
|
197
|
-
pos_axis_0
|
207
|
+
pos_axis_0 - (final_shift_axis_0 - estimated_shift_axis_0) + previous_shift,
|
198
208
|
None, # axis 1 is not handled for now
|
199
209
|
pos_axis_2 + final_shift_axis_2,
|
200
210
|
)
|
211
|
+
previous_shift += final_shift_axis_0 - estimated_shift_axis_0
|
201
212
|
return final_pos
|
202
213
|
|
203
214
|
def from_abs_pos_to_rel_pos(self, abs_position: tuple):
|
@@ -278,21 +289,17 @@ class ZStitcher:
|
|
278
289
|
else:
|
279
290
|
overlap_size = int(overlap_size)
|
280
291
|
|
281
|
-
|
292
|
+
self._stitching_width = max([get_obj_width(obj) for obj in self.z_serie])
|
293
|
+
|
294
|
+
for axis_0_shift in self._axis_0_rel_shifts:
|
282
295
|
if overlap_size == -1:
|
283
296
|
height = abs(axis_0_shift)
|
284
297
|
else:
|
285
298
|
height = overlap_size
|
286
|
-
if isinstance(obj, HDF5TomoScan):
|
287
|
-
frame_width = obj.dim_1
|
288
|
-
elif isinstance(obj, VolumeBase):
|
289
|
-
frame_width = obj.get_volume_shape()[2]
|
290
|
-
else:
|
291
|
-
raise TypeError(f"obj type ({type(obj)}) is not handled")
|
292
299
|
|
293
300
|
self._overlap_kernels.append(
|
294
301
|
ZStichOverlapKernel(
|
295
|
-
frame_width=
|
302
|
+
frame_width=self._stitching_width,
|
296
303
|
stitching_strategy=self.configuration.stitching_strategy,
|
297
304
|
overlap_size=height,
|
298
305
|
extra_params=self.configuration.stitching_kernels_extra_params,
|
@@ -385,6 +392,7 @@ class ZStitcher:
|
|
385
392
|
"""
|
386
393
|
rescale_frames if requested by the configuration
|
387
394
|
"""
|
395
|
+
_logger.info("apply rescale frames")
|
388
396
|
|
389
397
|
def cast_percentile(percentile) -> int:
|
390
398
|
if isinstance(percentile, str):
|
@@ -407,9 +415,27 @@ class ZStitcher:
|
|
407
415
|
|
408
416
|
return tuple([rescale(data) for data in frames])
|
409
417
|
|
418
|
+
def normalize_frame_by_sample(self, frames: tuple):
|
419
|
+
"""
|
420
|
+
normalize frame from a sample picked on the left or the right
|
421
|
+
"""
|
422
|
+
_logger.info("apply normalization by a sample")
|
423
|
+
return tuple(
|
424
|
+
[
|
425
|
+
normalize_frame_by_sample(
|
426
|
+
frame=frame,
|
427
|
+
side=self.configuration.normalization_by_sample.side,
|
428
|
+
method=self.configuration.normalization_by_sample.method,
|
429
|
+
margin_before_sample=self.configuration.normalization_by_sample.margin,
|
430
|
+
sample_width=self.configuration.normalization_by_sample.width,
|
431
|
+
)
|
432
|
+
for frame in frames
|
433
|
+
]
|
434
|
+
)
|
435
|
+
|
410
436
|
@staticmethod
|
411
437
|
def stitch_frames(
|
412
|
-
frames: tuple,
|
438
|
+
frames: Union[tuple, numpy.ndarray],
|
413
439
|
x_relative_shifts: tuple,
|
414
440
|
y_relative_shifts: tuple,
|
415
441
|
output_dtype: numpy.ndarray,
|
@@ -421,6 +447,9 @@ class ZStitcher:
|
|
421
447
|
shift_mode="nearest",
|
422
448
|
i_frame=None,
|
423
449
|
return_composition_cls=False,
|
450
|
+
alignment="center",
|
451
|
+
pad_mode="constant",
|
452
|
+
new_width: Optional[int] = None,
|
424
453
|
) -> numpy.ndarray:
|
425
454
|
"""
|
426
455
|
shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and
|
@@ -444,6 +473,21 @@ class ZStitcher:
|
|
444
473
|
f"expect to have the same number of y_relative_shifts ({len(y_relative_shifts)}) and y_overlap ({len(overlap_kernels)})"
|
445
474
|
)
|
446
475
|
|
476
|
+
relative_positions = [(0, 0)]
|
477
|
+
for y_rel_pos, x_rel_pos in zip(y_relative_shifts, x_relative_shifts):
|
478
|
+
relative_positions.append(
|
479
|
+
(
|
480
|
+
y_rel_pos + relative_positions[-1][0],
|
481
|
+
x_rel_pos + relative_positions[-1][1],
|
482
|
+
)
|
483
|
+
)
|
484
|
+
check_overlaps(
|
485
|
+
frames=tuple(frames),
|
486
|
+
positions=tuple(relative_positions),
|
487
|
+
axis=0,
|
488
|
+
raise_error=False,
|
489
|
+
)
|
490
|
+
|
447
491
|
def check_frame_is_2d(frame):
|
448
492
|
if frame.ndim != 2:
|
449
493
|
raise ValueError(f"2D frame expected when {frame.ndim}D provided")
|
@@ -462,6 +506,7 @@ class ZStitcher:
|
|
462
506
|
data.append(frame)
|
463
507
|
else:
|
464
508
|
raise TypeError(f"frames are expected to be DataUrl or 2D numpy array. Not {type(frame)}")
|
509
|
+
|
465
510
|
# step 1: shift each frames (except the first one)
|
466
511
|
x_shifted_data = [data[0]]
|
467
512
|
for frame, x_relative_shift in zip(data[1:], x_relative_shifts):
|
@@ -492,6 +537,9 @@ class ZStitcher:
|
|
492
537
|
check_inputs=check_inputs,
|
493
538
|
output_dtype=output_dtype,
|
494
539
|
return_composition_cls=return_composition_cls,
|
540
|
+
alignment=alignment,
|
541
|
+
pad_mode=pad_mode,
|
542
|
+
new_width=new_width,
|
495
543
|
)
|
496
544
|
if return_composition_cls:
|
497
545
|
stitched_frame, _ = res
|
@@ -507,6 +555,24 @@ class ZStitcher:
|
|
507
555
|
)
|
508
556
|
return res
|
509
557
|
|
558
|
+
@staticmethod
|
559
|
+
@cache(maxsize=None)
|
560
|
+
def _get_UD_flip_matrix():
|
561
|
+
return UDDetTransformation().as_matrix()
|
562
|
+
|
563
|
+
@staticmethod
|
564
|
+
@cache(maxsize=None)
|
565
|
+
def _get_LR_flip_matrix():
|
566
|
+
return LRDetTransformation().as_matrix()
|
567
|
+
|
568
|
+
@staticmethod
|
569
|
+
@cache(maxsize=None)
|
570
|
+
def _get_UD_AND_LR_flip_matrix():
|
571
|
+
return numpy.matmul(
|
572
|
+
ZStitcher._get_UD_flip_matrix(),
|
573
|
+
ZStitcher._get_LR_flip_matrix(),
|
574
|
+
)
|
575
|
+
|
510
576
|
|
511
577
|
class PreProcessZStitcher(ZStitcher):
|
512
578
|
def __init__(self, configuration, progress=None) -> None:
|
@@ -681,13 +747,6 @@ class PreProcessZStitcher(ZStitcher):
|
|
681
747
|
if isinstance(axis_pos_px, Iterable) and len(axis_pos_px) != (n_scans):
|
682
748
|
raise ValueError(f"{axis_name} expect {n_scans} shift defined. Get {len(axis_pos_px)}")
|
683
749
|
|
684
|
-
for scan in self.z_serie:
|
685
|
-
if scan.x_flipped is None or scan.y_flipped is None:
|
686
|
-
_logger.warning(
|
687
|
-
f"Found at least one scan with no frame flips information ({scan}). Will consider those are unflipped. Might end up with some inverted frame errors."
|
688
|
-
)
|
689
|
-
break
|
690
|
-
|
691
750
|
self._reading_orders = []
|
692
751
|
# the first scan will define the expected reading orderd, and expected flip.
|
693
752
|
# if all scan are flipped then we will keep it this way
|
@@ -697,8 +756,8 @@ class PreProcessZStitcher(ZStitcher):
|
|
697
756
|
for scan_0, scan_1 in zip(self.z_serie[0:-1], self.z_serie[1:]):
|
698
757
|
if len(scan_0.projections) != len(scan_1.projections):
|
699
758
|
raise ValueError(f"{scan_0} and {scan_1} have a different number of projections")
|
700
|
-
if isinstance(scan_0,
|
701
|
-
# check rotation (only of is an
|
759
|
+
if isinstance(scan_0, NXtomoScan) and isinstance(scan_1, NXtomoScan):
|
760
|
+
# check rotation (only of is an NXtomoScan)
|
702
761
|
scan_0_angles = numpy.asarray(scan_0.rotation_angle)
|
703
762
|
scan_0_projections_angles = scan_0_angles[
|
704
763
|
numpy.asarray(scan_0.image_key_control) == ImageKey.PROJECTION.value
|
@@ -749,8 +808,8 @@ class PreProcessZStitcher(ZStitcher):
|
|
749
808
|
)
|
750
809
|
|
751
810
|
for scan in self.z_serie:
|
752
|
-
# check x, y and z translation are constant (only if is an
|
753
|
-
if isinstance(scan,
|
811
|
+
# check x, y and z translation are constant (only if is an NXtomoScan)
|
812
|
+
if isinstance(scan, NXtomoScan):
|
754
813
|
if scan.x_translation is not None and not numpy.isclose(
|
755
814
|
min(scan.x_translation), max(scan.x_translation)
|
756
815
|
):
|
@@ -1002,23 +1061,43 @@ class PreProcessZStitcher(ZStitcher):
|
|
1002
1061
|
darks=scan.reduced_darks,
|
1003
1062
|
radios_indices=radio_indices,
|
1004
1063
|
radios_srcurrent=scan.electric_current[radio_indices] if has_reduced_metadata else None,
|
1005
|
-
flats_srcurrent=
|
1006
|
-
|
1007
|
-
|
1064
|
+
flats_srcurrent=(
|
1065
|
+
scan.reduced_flats_infos.machine_electric_current if has_reduced_metadata else None
|
1066
|
+
),
|
1008
1067
|
)
|
1009
1068
|
# note: we need to cast radios to float 32. Darks and flats are cast to anyway
|
1010
1069
|
data = ff_arrays.normalize_radios(raw_radios.astype(numpy.float32))
|
1011
1070
|
|
1012
|
-
|
1013
|
-
|
1071
|
+
transformations = list(scans[i_scan].get_detector_transformations(tuple()))
|
1072
|
+
if scan_flip_lr:
|
1073
|
+
transformations.append(LRDetTransformation())
|
1074
|
+
if scan_flip_ud:
|
1075
|
+
transformations.append(UDDetTransformation())
|
1076
|
+
|
1077
|
+
transformation_matrix_det_space = build_matrix(transformations)
|
1078
|
+
if transformation_matrix_det_space is None or numpy.allclose(
|
1079
|
+
transformation_matrix_det_space, numpy.identity(3)
|
1080
|
+
):
|
1081
|
+
flip_ud = False
|
1082
|
+
flip_lr = False
|
1083
|
+
elif numpy.array_equal(transformation_matrix_det_space, ZStitcher._get_UD_flip_matrix()):
|
1084
|
+
flip_ud = True
|
1085
|
+
flip_lr = False
|
1086
|
+
elif numpy.allclose(transformation_matrix_det_space, ZStitcher._get_LR_flip_matrix()):
|
1087
|
+
flip_ud = False
|
1088
|
+
flip_lr = True
|
1089
|
+
elif numpy.allclose(transformation_matrix_det_space, ZStitcher._get_UD_AND_LR_flip_matrix()):
|
1090
|
+
flip_ud = True
|
1091
|
+
flip_lr = True
|
1092
|
+
else:
|
1093
|
+
raise ValueError("case not handled... For now only handle up-down flip as left-right flip")
|
1094
|
+
|
1014
1095
|
for frame in data:
|
1015
|
-
f_frame = frame
|
1016
|
-
if flip_lr:
|
1017
|
-
f_frame = numpy.fliplr(f_frame)
|
1018
1096
|
if flip_ud:
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1097
|
+
frame = numpy.flipud(frame)
|
1098
|
+
if flip_lr:
|
1099
|
+
frame = numpy.fliplr(frame)
|
1100
|
+
all_scan_final_data[i_frame, i_scan] = frame
|
1022
1101
|
i_frame += 1
|
1023
1102
|
|
1024
1103
|
return all_scan_final_data
|
@@ -1032,7 +1111,9 @@ class PreProcessZStitcher(ZStitcher):
|
|
1032
1111
|
):
|
1033
1112
|
upper_scan_pos = upper_scan_axis_0_pos - upper_scan.dim_2 / 2
|
1034
1113
|
lower_scan_high_pos = lower_scan_axis_0_pos + lower_scan.dim_2 / 2
|
1035
|
-
|
1114
|
+
# simple test of overlap. More complete test are runned by check_overlaps later
|
1115
|
+
if lower_scan_high_pos <= upper_scan_pos:
|
1116
|
+
raise ValueError(f"no overlap found between {upper_scan} and {lower_scan}")
|
1036
1117
|
self._axis_0_estimated_shifts.append(
|
1037
1118
|
int(lower_scan_high_pos - upper_scan_pos) # overlap are expected to be int for now
|
1038
1119
|
)
|
@@ -1074,11 +1155,10 @@ class PreProcessZStitcher(ZStitcher):
|
|
1074
1155
|
nx_tomo.instrument.detector.y_pixel_size = self.z_serie[0].y_pixel_size
|
1075
1156
|
nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
|
1076
1157
|
nx_tomo.instrument.detector.tomo_n = n_proj
|
1077
|
-
# note: stitching process insure unflipping of frames
|
1078
|
-
nx_tomo.instrument.detector.
|
1079
|
-
nx_tomo.instrument.detector.y_flipped = False
|
1158
|
+
# note: stitching process insure unflipping of frames. So make sure transformations is defined as an empty set
|
1159
|
+
nx_tomo.instrument.detector.transformations = NXtransformations()
|
1080
1160
|
|
1081
|
-
if isinstance(self.z_serie[0],
|
1161
|
+
if isinstance(self.z_serie[0], NXtomoScan):
|
1082
1162
|
# note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
|
1083
1163
|
rotation_angle = numpy.asarray(self.z_serie[0].rotation_angle)
|
1084
1164
|
nx_tomo.sample.rotation_angle = rotation_angle[
|
@@ -1091,8 +1171,8 @@ class PreProcessZStitcher(ZStitcher):
|
|
1091
1171
|
else:
|
1092
1172
|
raise NotImplementedError(
|
1093
1173
|
f"scan type ({type(self.z_serie[0])} is not handled)",
|
1094
|
-
|
1095
|
-
isinstance(self.z_serie[0],
|
1174
|
+
NXtomoScan,
|
1175
|
+
isinstance(self.z_serie[0], NXtomoScan),
|
1096
1176
|
)
|
1097
1177
|
|
1098
1178
|
# do a sub selection of the rotation angle if a we are only computing a part of the slices
|
@@ -1110,7 +1190,7 @@ class PreProcessZStitcher(ZStitcher):
|
|
1110
1190
|
|
1111
1191
|
# handle sample
|
1112
1192
|
n_frames = n_proj
|
1113
|
-
if False not in [isinstance(scan,
|
1193
|
+
if False not in [isinstance(scan, NXtomoScan) for scan in self.z_serie]:
|
1114
1194
|
# we consider the new x, y and z position to be at the center of the one created
|
1115
1195
|
x_translation = [scan.x_translation for scan in self.z_serie if scan.x_translation is not None]
|
1116
1196
|
nx_tomo.sample.x_translation = [numpy.asarray(x_translation).mean()] * n_frames
|
@@ -1128,7 +1208,7 @@ class PreProcessZStitcher(ZStitcher):
|
|
1128
1208
|
numpy.asarray([scan.dim_2 for scan in self.z_serie]).sum()
|
1129
1209
|
- numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_shifts]).sum()
|
1130
1210
|
),
|
1131
|
-
self.
|
1211
|
+
self._stitching_width,
|
1132
1212
|
)
|
1133
1213
|
|
1134
1214
|
# get expected output dataset first (just in case output and input files are the same)
|
@@ -1151,11 +1231,15 @@ class PreProcessZStitcher(ZStitcher):
|
|
1151
1231
|
overwrite=self.configuration.overwrite_results,
|
1152
1232
|
)
|
1153
1233
|
|
1234
|
+
transformation_matrices = {
|
1235
|
+
scan.get_identifier()
|
1236
|
+
.to_str()
|
1237
|
+
.center(80, "-"): numpy.array2string(build_matrix(scan.get_detector_transformations(tuple())))
|
1238
|
+
for scan in self.z_serie
|
1239
|
+
}
|
1154
1240
|
_logger.info(
|
1155
|
-
|
1156
|
-
|
1157
|
-
_logger.info(
|
1158
|
-
f"scan y flipped are {','.join([str(scan.get_y_flipped(default=False)) for scan in self.z_serie])}"
|
1241
|
+
"scan detector transformation matrices are:\n"
|
1242
|
+
"\n".join(["/n".join(item) for item in transformation_matrices.items()])
|
1159
1243
|
)
|
1160
1244
|
|
1161
1245
|
_logger.info(
|
@@ -1214,6 +1298,9 @@ class PreProcessZStitcher(ZStitcher):
|
|
1214
1298
|
):
|
1215
1299
|
if self.configuration.rescale_frames:
|
1216
1300
|
data_frames = self.rescale_frames(data_frames)
|
1301
|
+
if self.configuration.normalization_by_sample.is_active():
|
1302
|
+
data_frames = self.normalize_frame_by_sample(data_frames)
|
1303
|
+
|
1217
1304
|
sf = ZStitcher.stitch_frames(
|
1218
1305
|
frames=data_frames,
|
1219
1306
|
x_relative_shifts=self._axis_2_rel_shifts,
|
@@ -1225,6 +1312,10 @@ class PreProcessZStitcher(ZStitcher):
|
|
1225
1312
|
dump_frame_fct=self._dump_frame,
|
1226
1313
|
return_composition_cls=store_composition if i_proj == 0 else False,
|
1227
1314
|
stitching_axis=0,
|
1315
|
+
pad_mode=self.configuration.pad_mode,
|
1316
|
+
alignment=self.configuration.alignment_axis_2,
|
1317
|
+
new_width=self._stitching_width,
|
1318
|
+
check_inputs=i_proj == 0, # on process check on the first iteration
|
1228
1319
|
)
|
1229
1320
|
if i_proj == 0 and store_composition:
|
1230
1321
|
_, self._frame_composition = sf
|
@@ -1277,6 +1368,7 @@ class PreProcessZStitcher(ZStitcher):
|
|
1277
1368
|
class PostProcessZStitcher(ZStitcher):
|
1278
1369
|
def __init__(self, configuration, progress: Progress = None) -> None:
|
1279
1370
|
self._input_volumes = configuration.input_volumes
|
1371
|
+
self.__output_data_type = None
|
1280
1372
|
|
1281
1373
|
self._z_serie = Serie("z-serie", iterable=self._input_volumes, use_identifiers=False)
|
1282
1374
|
super().__init__(configuration, progress)
|
@@ -1333,7 +1425,7 @@ class PostProcessZStitcher(ZStitcher):
|
|
1333
1425
|
if scan_location is not None:
|
1334
1426
|
# this work around (until most volume have position metadata) works only for Hdf5volume
|
1335
1427
|
with cwd_context(os.path.dirname(volume.file_path)):
|
1336
|
-
o_scan =
|
1428
|
+
o_scan = NXtomoScan(scan_location, scan_entry)
|
1337
1429
|
bb_acqui = o_scan.get_bounding_box(axis=None)
|
1338
1430
|
# for next step volume position will be required.
|
1339
1431
|
# if you can find it set it directly
|
@@ -1414,7 +1506,7 @@ class PostProcessZStitcher(ZStitcher):
|
|
1414
1506
|
# deduce from position given in configuration and pixel size
|
1415
1507
|
axis_N_pos_px = []
|
1416
1508
|
for volume, pos_in_mm in zip(self.z_serie, pos_as_mm):
|
1417
|
-
voxel_size_m = self.configuration.voxel_size or volume.
|
1509
|
+
voxel_size_m = self.configuration.voxel_size or volume.voxel_size
|
1418
1510
|
axis_N_pos_px.append((pos_in_mm / MetricSystem.MILLIMETER.value) / voxel_size_m[0])
|
1419
1511
|
return axis_N_pos_px
|
1420
1512
|
else:
|
@@ -1422,7 +1514,7 @@ class PostProcessZStitcher(ZStitcher):
|
|
1422
1514
|
axis_N_pos_px = []
|
1423
1515
|
base_position_m = self.z_serie[0].get_bounding_box(axis=axis).min
|
1424
1516
|
for volume in self.z_serie:
|
1425
|
-
voxel_size_m = self.configuration.voxel_size or volume.
|
1517
|
+
voxel_size_m = self.configuration.voxel_size or volume.voxel_size
|
1426
1518
|
volume_axis_bb = volume.get_bounding_box(axis=axis)
|
1427
1519
|
axis_N_mean_pos_m = (volume_axis_bb.max - volume_axis_bb.min) / 2 + volume_axis_bb.min
|
1428
1520
|
axis_N_mean_rel_pos_m = axis_N_mean_pos_m - base_position_m
|
@@ -1464,6 +1556,7 @@ class PostProcessZStitcher(ZStitcher):
|
|
1464
1556
|
slice_for_shift = self.configuration.slice_for_cross_correlation or "middle"
|
1465
1557
|
y_rel_shifts = self._axis_0_estimated_shifts
|
1466
1558
|
x_rel_shifts = self.from_abs_pos_to_rel_pos(self.configuration.axis_2_pos_px)
|
1559
|
+
dim_axis_1 = max([volume.get_volume_shape()[1] for volume in self.z_serie])
|
1467
1560
|
|
1468
1561
|
final_rel_shifts = []
|
1469
1562
|
for (
|
@@ -1488,6 +1581,8 @@ class PostProcessZStitcher(ZStitcher):
|
|
1488
1581
|
found_shift_y, found_shift_x = find_volumes_relative_shifts(
|
1489
1582
|
upper_volume=upper_volume,
|
1490
1583
|
lower_volume=lower_volume,
|
1584
|
+
dtype=self.get_output_data_type(),
|
1585
|
+
dim_axis_1=dim_axis_1,
|
1491
1586
|
slice_for_shift=slice_for_shift,
|
1492
1587
|
x_cross_correlation_function=x_cross_algo,
|
1493
1588
|
y_cross_correlation_function=y_cross_algo,
|
@@ -1496,6 +1591,8 @@ class PostProcessZStitcher(ZStitcher):
|
|
1496
1591
|
estimated_shifts=(y_rel_shift, x_rel_shift),
|
1497
1592
|
flip_ud_lower_frame=flip_ud_lower,
|
1498
1593
|
flip_ud_upper_frame=flip_ud_upper,
|
1594
|
+
alignment_axis_1=self.configuration.alignment_axis_1,
|
1595
|
+
alignment_axis_2=self.configuration.alignment_axis_2,
|
1499
1596
|
)
|
1500
1597
|
final_rel_shifts.append(
|
1501
1598
|
(found_shift_y, found_shift_x),
|
@@ -1504,10 +1601,10 @@ class PostProcessZStitcher(ZStitcher):
|
|
1504
1601
|
# set back values. Now position should start at 0
|
1505
1602
|
self._axis_0_rel_shifts = [final_shift[0] for final_shift in final_rel_shifts]
|
1506
1603
|
self._axis_2_rel_shifts = [final_shift[1] for final_shift in final_rel_shifts]
|
1507
|
-
_logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self.
|
1508
|
-
print(f"axis 2 relative shifts (x in radio ref) to be used will be {self.
|
1509
|
-
_logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self.
|
1510
|
-
print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self.
|
1604
|
+
_logger.info(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_2_rel_shifts}")
|
1605
|
+
print(f"axis 2 relative shifts (x in radio ref) to be used will be {self._axis_2_rel_shifts}")
|
1606
|
+
_logger.info(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_shifts}")
|
1607
|
+
print(f"axis 0 relative shifts (y in radio ref) y to be used will be {self._axis_0_rel_shifts}")
|
1511
1608
|
|
1512
1609
|
def _dump_stitching_configuration(self):
|
1513
1610
|
voxel_size = self._input_volumes[0].voxel_size
|
@@ -1583,22 +1680,27 @@ class PostProcessZStitcher(ZStitcher):
|
|
1583
1680
|
):
|
1584
1681
|
raise ValueError(f"expect {n_volumes} overlap defined. Get {len(self.configuration.axis_2_pos_mm)}")
|
1585
1682
|
|
1586
|
-
yz_shape = None
|
1587
|
-
for volume in self.configuration.input_volumes:
|
1588
|
-
assert isinstance(volume, VolumeBase)
|
1589
|
-
volume_shape = volume.get_volume_shape()
|
1590
|
-
if volume_shape is None:
|
1591
|
-
raise ValueError("Unable to load volume shape (probably no data found from {volume.get_identifier()}")
|
1592
|
-
if yz_shape is None:
|
1593
|
-
yz_shape = volume_shape[1:]
|
1594
|
-
elif yz_shape != volume_shape[1:]:
|
1595
|
-
raise ValueError("Input volumes have incoherent (yz) shapes. Unable to stitch it together")
|
1596
|
-
|
1597
1683
|
self._reading_orders = []
|
1598
1684
|
# the first scan will define the expected reading orderd, and expected flip.
|
1599
1685
|
# if all scan are flipped then we will keep it this way
|
1600
1686
|
self._reading_orders.append(1)
|
1601
1687
|
|
1688
|
+
def get_output_data_type(self):
|
1689
|
+
if self.__output_data_type is None:
|
1690
|
+
|
1691
|
+
def find_output_data_type():
|
1692
|
+
first_vol = self._input_volumes[0]
|
1693
|
+
if first_vol.data is not None:
|
1694
|
+
return first_vol.data.dtype
|
1695
|
+
elif isinstance(first_vol, HDF5Volume):
|
1696
|
+
with DatasetReader(first_vol.data_url) as vol_dataset:
|
1697
|
+
return vol_dataset.dtype
|
1698
|
+
else:
|
1699
|
+
return first_vol.load_data(store=False).dtype
|
1700
|
+
|
1701
|
+
self.__output_data_type = find_output_data_type()
|
1702
|
+
return self.__output_data_type
|
1703
|
+
|
1602
1704
|
def _create_stitched_volume(self, store_composition: bool):
|
1603
1705
|
overlap_kernels = self._overlap_kernels
|
1604
1706
|
self._slices_to_stitch, n_slices = self.configuration.settle_slices()
|
@@ -1614,20 +1716,10 @@ class PostProcessZStitcher(ZStitcher):
|
|
1614
1716
|
- numpy.asarray([abs(overlap) for overlap in self._axis_0_rel_shifts]).sum(),
|
1615
1717
|
),
|
1616
1718
|
n_slices,
|
1617
|
-
self.
|
1719
|
+
self._stitching_width,
|
1618
1720
|
)
|
1619
1721
|
|
1620
|
-
|
1621
|
-
first_vol = self._input_volumes[0]
|
1622
|
-
if first_vol.data is not None:
|
1623
|
-
return first_vol.data.dtype
|
1624
|
-
elif isinstance(first_vol, HDF5Volume):
|
1625
|
-
with DatasetReader(first_vol.data_url) as vol_dataset:
|
1626
|
-
return vol_dataset.dtype
|
1627
|
-
else:
|
1628
|
-
return first_vol.load_data(store=False).dtype
|
1629
|
-
|
1630
|
-
data_type = get_output_data_type()
|
1722
|
+
data_type = self.get_output_data_type()
|
1631
1723
|
|
1632
1724
|
if self.progress:
|
1633
1725
|
self.progress.set_max_advancement(final_volume_shape[1])
|
@@ -1641,7 +1733,10 @@ class PostProcessZStitcher(ZStitcher):
|
|
1641
1733
|
volume=final_volume, volume_shape=final_volume_shape, dtype=data_type
|
1642
1734
|
) as output_dataset:
|
1643
1735
|
# note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
|
1644
|
-
with PostProcessZStitcher._RawDatasetsContext(
|
1736
|
+
with PostProcessZStitcher._RawDatasetsContext(
|
1737
|
+
self._input_volumes,
|
1738
|
+
alignment_axis_1=self.configuration.alignment_axis_1,
|
1739
|
+
) as raw_datasets:
|
1645
1740
|
# note: raw_datasets can be numpy arrays or HDF5 dataset (in the case of HDF5Volume)
|
1646
1741
|
# to speed up we read by bunch of dataset. For numpy array this doesn't change anything
|
1647
1742
|
# but for HDF5 dataset this can speed up a lot the processing (depending on HDF5 dataset chuncks)
|
@@ -1659,6 +1754,9 @@ class PostProcessZStitcher(ZStitcher):
|
|
1659
1754
|
):
|
1660
1755
|
if self.configuration.rescale_frames:
|
1661
1756
|
data_frames = self.rescale_frames(data_frames)
|
1757
|
+
if self.configuration.normalization_by_sample.is_active():
|
1758
|
+
data_frames = self.normalize_frame_by_sample(data_frames)
|
1759
|
+
|
1662
1760
|
sf = ZStitcher.stitch_frames(
|
1663
1761
|
frames=data_frames,
|
1664
1762
|
x_relative_shifts=self._axis_2_rel_shifts,
|
@@ -1670,6 +1768,7 @@ class PostProcessZStitcher(ZStitcher):
|
|
1670
1768
|
output_dtype=data_type,
|
1671
1769
|
return_composition_cls=store_composition if y_index == 0 else False,
|
1672
1770
|
stitching_axis=0,
|
1771
|
+
check_inputs=y_index == 0, # on process check on the first iteration
|
1673
1772
|
)
|
1674
1773
|
if y_index == 0 and store_composition:
|
1675
1774
|
_, self._frame_composition = sf
|
@@ -1680,7 +1779,12 @@ class PostProcessZStitcher(ZStitcher):
|
|
1680
1779
|
|
1681
1780
|
@staticmethod
|
1682
1781
|
def _get_bunch_of_data(
|
1683
|
-
bunch_start: int,
|
1782
|
+
bunch_start: int,
|
1783
|
+
bunch_end: int,
|
1784
|
+
step: int,
|
1785
|
+
volumes: tuple,
|
1786
|
+
flip_lr_arr: bool,
|
1787
|
+
flip_ud_arr: bool,
|
1684
1788
|
):
|
1685
1789
|
"""
|
1686
1790
|
goal is to load contiguous frames as much as possible...
|
@@ -1768,7 +1872,7 @@ class PostProcessZStitcher(ZStitcher):
|
|
1768
1872
|
If the volume is of another type then it will be loaded in memory then used (more memory consuming)
|
1769
1873
|
"""
|
1770
1874
|
|
1771
|
-
def __init__(self, volumes: tuple) -> None:
|
1875
|
+
def __init__(self, volumes: tuple, alignment_axis_1) -> None:
|
1772
1876
|
super().__init__()
|
1773
1877
|
for volume in volumes:
|
1774
1878
|
if not isinstance(volume, VolumeBase):
|
@@ -1778,25 +1882,36 @@ class PostProcessZStitcher(ZStitcher):
|
|
1778
1882
|
|
1779
1883
|
self._volumes = volumes
|
1780
1884
|
self.__file_handlers = []
|
1885
|
+
self._alignment_axis_1 = alignment_axis_1
|
1886
|
+
|
1887
|
+
@property
|
1888
|
+
def alignment_axis_1(self):
|
1889
|
+
return self._alignment_axis_1
|
1781
1890
|
|
1782
1891
|
def __enter__(self):
|
1783
1892
|
# handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory
|
1784
1893
|
datasets = []
|
1894
|
+
shapes = {volume.get_volume_shape()[1] for volume in self._volumes}
|
1895
|
+
axis_1_dim = max(shapes)
|
1896
|
+
axis_1_need_padding = len(shapes) > 1
|
1897
|
+
|
1785
1898
|
try:
|
1786
1899
|
for volume in self._volumes:
|
1787
1900
|
if volume.data is not None:
|
1788
|
-
|
1901
|
+
data = volume.data
|
1789
1902
|
elif isinstance(volume, HDF5Volume):
|
1790
1903
|
file_handler = HDF5File(filename=volume.data_url.file_path(), mode="r")
|
1791
1904
|
dataset = file_handler[volume.data_url.data_path()]
|
1792
|
-
|
1905
|
+
data = dataset
|
1793
1906
|
self.__file_handlers.append(file_handler)
|
1794
1907
|
# for other file format: load the full dataset in memory
|
1795
1908
|
else:
|
1796
1909
|
data = volume.load_data(store=False)
|
1797
1910
|
if data is None:
|
1798
1911
|
raise ValueError(f"No data found for volume {volume.get_identifier()}")
|
1799
|
-
|
1912
|
+
if axis_1_need_padding:
|
1913
|
+
data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
|
1914
|
+
datasets.append(data)
|
1800
1915
|
except Exception as e:
|
1801
1916
|
# if some errors happen during loading HDF5
|
1802
1917
|
for file_handled in self.__file_handlers:
|
@@ -1811,16 +1926,36 @@ class PostProcessZStitcher(ZStitcher):
|
|
1811
1926
|
success = success and file_handler.close()
|
1812
1927
|
return success
|
1813
1928
|
|
1929
|
+
def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
|
1930
|
+
alignment = AlignmentAxis1.from_value(alignment)
|
1931
|
+
if alignment is AlignmentAxis1.BACK:
|
1932
|
+
axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
|
1933
|
+
elif alignment is AlignmentAxis1.CENTER:
|
1934
|
+
half_width = int((axis_1_dim - data.shape[1]) / 2)
|
1935
|
+
axis_1_pad_width = (half_width, axis_1_dim - data.shape[1] - half_width)
|
1936
|
+
elif alignment is AlignmentAxis1.FRONT:
|
1937
|
+
axis_1_pad_width = (0, axis_1_dim - data.shape[1])
|
1938
|
+
else:
|
1939
|
+
raise ValueError(f"alignment {alignment} is not handled")
|
1940
|
+
|
1941
|
+
return PaddedRawData(
|
1942
|
+
data=data,
|
1943
|
+
axis_1_pad_width=axis_1_pad_width,
|
1944
|
+
)
|
1945
|
+
|
1814
1946
|
|
1815
1947
|
def stitch_vertically_raw_frames(
|
1816
1948
|
frames: tuple,
|
1817
1949
|
key_lines: tuple,
|
1818
|
-
overlap_kernels:
|
1950
|
+
overlap_kernels: Union[ZStichOverlapKernel, tuple],
|
1819
1951
|
output_dtype: numpy.dtype = numpy.float32,
|
1820
1952
|
check_inputs=True,
|
1821
1953
|
raw_frames_compositions: Optional[ZFrameComposition] = None,
|
1822
1954
|
overlap_frames_compositions: Optional[ZFrameComposition] = None,
|
1823
1955
|
return_composition_cls=False,
|
1956
|
+
alignment="center",
|
1957
|
+
pad_mode="constant",
|
1958
|
+
new_width: Optional[int] = None,
|
1824
1959
|
) -> numpy.ndarray:
|
1825
1960
|
"""
|
1826
1961
|
stitches raw frames (already shifted and flat fielded !!!) together using
|
@@ -1865,8 +2000,7 @@ def stitch_vertically_raw_frames(
|
|
1865
2000
|
for frame_0, frame_1 in zip(frames[:-1], frames[1:]):
|
1866
2001
|
if not (frame_0.ndim == frame_1.ndim == 2):
|
1867
2002
|
raise ValueError("Frames are expected to be 2D")
|
1868
|
-
|
1869
|
-
raise ValueError("Both projections are expected to have the same width")
|
2003
|
+
|
1870
2004
|
for frame_0, frame_1, kernel in zip(frames[:-1], frames[1:], overlap_kernels):
|
1871
2005
|
if frame_0.shape[0] < kernel.overlap_size:
|
1872
2006
|
raise ValueError(
|
@@ -1886,6 +2020,20 @@ def stitch_vertically_raw_frames(
|
|
1886
2020
|
elif value < 0:
|
1887
2021
|
raise ValueError(f"key lines are expected to be positive values. Get {value} as key line value")
|
1888
2022
|
|
2023
|
+
if new_width is None:
|
2024
|
+
new_width = max([frame.shape[-1] for frame in frames])
|
2025
|
+
frames = tuple(
|
2026
|
+
[
|
2027
|
+
align_horizontally(
|
2028
|
+
data=frame,
|
2029
|
+
alignment=alignment,
|
2030
|
+
new_width=new_width,
|
2031
|
+
pad_mode=pad_mode,
|
2032
|
+
)
|
2033
|
+
for frame in frames
|
2034
|
+
]
|
2035
|
+
)
|
2036
|
+
|
1889
2037
|
# step 1: create numpy array that will contain stitching
|
1890
2038
|
# if raw composition doesn't exists create it
|
1891
2039
|
if raw_frames_compositions is None:
|
@@ -1899,7 +2047,7 @@ def stitch_vertically_raw_frames(
|
|
1899
2047
|
stitched_projection_shape = (
|
1900
2048
|
# here we only handle frames because shift are already done
|
1901
2049
|
int(new_frame_height),
|
1902
|
-
|
2050
|
+
new_width,
|
1903
2051
|
)
|
1904
2052
|
stitch_array = numpy.empty(stitched_projection_shape, dtype=output_dtype)
|
1905
2053
|
|
@@ -1966,18 +2114,67 @@ class StitchingPostProcAggregation:
|
|
1966
2114
|
|
1967
2115
|
This is the goal of this class.
|
1968
2116
|
Please be careful with API. This is already inheriting from a tomwer class
|
2117
|
+
|
2118
|
+
:param ZStitchingConfiguration stitching_config: configuration of the stitching configuration
|
2119
|
+
:param Optional[tuple] futures: futures that just runned
|
2120
|
+
:param Optional[tuple] existing_objs: futures that just runned
|
2121
|
+
:param
|
1969
2122
|
"""
|
1970
2123
|
|
1971
|
-
def __init__(
|
1972
|
-
|
1973
|
-
|
2124
|
+
def __init__(
|
2125
|
+
self,
|
2126
|
+
stitching_config: ZStitchingConfiguration,
|
2127
|
+
futures: Optional[tuple] = None,
|
2128
|
+
existing_objs_ids: Optional[tuple] = None,
|
2129
|
+
) -> None:
|
2130
|
+
if not isinstance(stitching_config, (ZStitchingConfiguration)):
|
2131
|
+
raise TypeError(f"stitching_config should be an instance of {ZStitchingConfiguration}")
|
2132
|
+
if not ((existing_objs_ids is None) ^ (futures is None)):
|
2133
|
+
raise ValueError("Either existing_objs or futures should be provided (can't provide both)")
|
1974
2134
|
self._futures = futures
|
1975
2135
|
self._stitching_config = stitching_config
|
2136
|
+
self._existing_objs_ids = existing_objs_ids
|
1976
2137
|
|
1977
2138
|
@property
|
1978
2139
|
def futures(self):
|
2140
|
+
# TODO: deprecate it ?
|
1979
2141
|
return self._futures
|
1980
2142
|
|
2143
|
+
def retrieve_tomo_objects(self) -> tuple():
|
2144
|
+
"""
|
2145
|
+
Return tomo objects to be stitched together. Either from future or from existing_objs
|
2146
|
+
"""
|
2147
|
+
if self._existing_objs_ids is not None:
|
2148
|
+
scan_ids = self._existing_objs_ids
|
2149
|
+
else:
|
2150
|
+
results = {}
|
2151
|
+
_logger.info(f"wait for slurm job to be completed")
|
2152
|
+
for obj_id, future in self.futures.items():
|
2153
|
+
results[obj_id] = future.result()
|
2154
|
+
|
2155
|
+
failed = tuple(
|
2156
|
+
filter(
|
2157
|
+
lambda x: x.exception() is not None,
|
2158
|
+
self.futures.values(),
|
2159
|
+
)
|
2160
|
+
)
|
2161
|
+
if len(failed) > 0:
|
2162
|
+
# if some job failed: unseless to do the concatenation
|
2163
|
+
exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed])
|
2164
|
+
raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}")
|
2165
|
+
|
2166
|
+
canceled = tuple(
|
2167
|
+
filter(
|
2168
|
+
lambda x: x.cancelled(),
|
2169
|
+
self.futures.values(),
|
2170
|
+
)
|
2171
|
+
)
|
2172
|
+
if len(canceled) > 0:
|
2173
|
+
# if some job canceled: unseless to do the concatenation
|
2174
|
+
raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}")
|
2175
|
+
scan_ids = results.keys()
|
2176
|
+
return [TomoscanFactory.create_tomo_object_from_identifier(scan_id) for scan_id in scan_ids]
|
2177
|
+
|
1981
2178
|
def dump_stiching_config_as_nx_process(self, file_path: str, data_path: str, overwrite: bool, process_name: str):
|
1982
2179
|
dict_to_dump = {
|
1983
2180
|
process_name: {
|
@@ -2005,40 +2202,14 @@ class StitchingPostProcAggregation:
|
|
2005
2202
|
"""
|
2006
2203
|
main function
|
2007
2204
|
"""
|
2008
|
-
# retrive results (and wait if some processing are not finished)
|
2009
|
-
results = {}
|
2010
|
-
_logger.info(f"wait for slurm job to be completed")
|
2011
|
-
for obj_id, future in self.futures.items():
|
2012
|
-
results[obj_id] = future.result()
|
2013
|
-
|
2014
|
-
failed = tuple(
|
2015
|
-
filter(
|
2016
|
-
lambda x: x.exception() is not None,
|
2017
|
-
self.futures.values(),
|
2018
|
-
)
|
2019
|
-
)
|
2020
|
-
if len(failed) > 0:
|
2021
|
-
# if some job failed: unseless to do the concatenation
|
2022
|
-
exceptions = " ; ".join([f"{job} : {job.exception()}" for job in failed])
|
2023
|
-
raise RuntimeError(f"some job failed. Won't do the concatenation. Exceptiosn are {exceptions}")
|
2024
|
-
|
2025
|
-
canceled = tuple(
|
2026
|
-
filter(
|
2027
|
-
lambda x: x.cancelled(),
|
2028
|
-
self.futures.values(),
|
2029
|
-
)
|
2030
|
-
)
|
2031
|
-
if len(canceled) > 0:
|
2032
|
-
# if some job canceled: unseless to do the concatenation
|
2033
|
-
raise RuntimeError(f"some job failed. Won't do the concatenation. Jobs are {' ; '.join(canceled)}")
|
2034
2205
|
|
2035
2206
|
# concatenate result
|
2036
2207
|
_logger.info("all job succeeded. Concatenate results")
|
2037
2208
|
if isinstance(self._stitching_config, PreProcessedZStitchingConfiguration):
|
2038
2209
|
# 1: case of a pre-processing stitching
|
2210
|
+
scans = self.retrieve_tomo_objects()
|
2039
2211
|
nx_tomos = []
|
2040
|
-
for
|
2041
|
-
scan = TomoscanFactory.create_tomo_object_from_identifier(result)
|
2212
|
+
for scan in scans:
|
2042
2213
|
nx_tomos.append(
|
2043
2214
|
NXtomo().load(
|
2044
2215
|
file_path=scan.master_file,
|
@@ -2069,9 +2240,7 @@ class StitchingPostProcAggregation:
|
|
2069
2240
|
|
2070
2241
|
elif isinstance(self.stitching_config, PostProcessedZStitchingConfiguration):
|
2071
2242
|
# 2: case of a post-processing stitching
|
2072
|
-
outputs_sub_volumes =
|
2073
|
-
TomoscanFactory.create_tomo_object_from_identifier(result) for result in results.keys()
|
2074
|
-
]
|
2243
|
+
outputs_sub_volumes = self.retrieve_tomo_objects()
|
2075
2244
|
concatenate_volumes(
|
2076
2245
|
output_volume=self.stitching_config.output_volume,
|
2077
2246
|
volumes=tuple(outputs_sub_volumes),
|
@@ -2092,3 +2261,17 @@ class StitchingPostProcAggregation:
|
|
2092
2261
|
process_name=process_name,
|
2093
2262
|
overwrite=self.stitching_config.overwrite_results,
|
2094
2263
|
)
|
2264
|
+
else:
|
2265
|
+
raise TypeError(f"stitching_config type ({type(self.stitching_config)}) not handled")
|
2266
|
+
|
2267
|
+
|
2268
|
+
def get_obj_width(obj: Union[NXtomoScan, VolumeBase]) -> int:
|
2269
|
+
"""
|
2270
|
+
return tomo object width
|
2271
|
+
"""
|
2272
|
+
if isinstance(obj, NXtomoScan):
|
2273
|
+
return obj.dim_1
|
2274
|
+
elif isinstance(obj, VolumeBase):
|
2275
|
+
return obj.get_volume_shape()[-1]
|
2276
|
+
else:
|
2277
|
+
raise TypeError(f"obj type ({type(obj)}) is not handled")
|