nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +36 -5
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.14.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,8 @@ import numpy
|
|
|
2
2
|
import logging
|
|
3
3
|
import h5py
|
|
4
4
|
import os
|
|
5
|
-
|
|
5
|
+
import pint
|
|
6
|
+
from collections.abc import Iterable
|
|
6
7
|
from silx.io.url import DataUrl
|
|
7
8
|
from silx.io.utils import get_data
|
|
8
9
|
from datetime import datetime
|
|
@@ -26,8 +27,8 @@ from nabu.stitching.config import (
|
|
|
26
27
|
from nabu.stitching.utils import find_projections_relative_shifts
|
|
27
28
|
from functools import lru_cache as cache
|
|
28
29
|
from .single_axis import SingleAxisStitcher
|
|
29
|
-
from pyunitsystem.metricsystem import MetricSystem
|
|
30
30
|
|
|
31
|
+
_ureg = pint.get_application_registry()
|
|
31
32
|
|
|
32
33
|
_logger = logging.getLogger(__name__)
|
|
33
34
|
|
|
@@ -40,7 +41,6 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
40
41
|
"""
|
|
41
42
|
|
|
42
43
|
def __init__(self, configuration, progress=None) -> None:
|
|
43
|
-
""" """
|
|
44
44
|
if not isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration):
|
|
45
45
|
raise TypeError(
|
|
46
46
|
f"configuration is expected to be an instance of {PreProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead"
|
|
@@ -239,18 +239,18 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
239
239
|
if not scan_0.field_of_view == scan_1.field_of_view:
|
|
240
240
|
raise ValueError(f"{scan_0} and {scan_1} have different field of view")
|
|
241
241
|
# check distance
|
|
242
|
-
if scan_0.
|
|
242
|
+
if scan_0.sample_detector_distance is None:
|
|
243
243
|
_logger.warning(f"no distance found for {scan_0}")
|
|
244
|
-
elif not numpy.isclose(scan_0.
|
|
244
|
+
elif not numpy.isclose(scan_0.sample_detector_distance, scan_1.sample_detector_distance, rtol=10e-3):
|
|
245
245
|
raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance")
|
|
246
246
|
# check pixel size
|
|
247
|
-
if not numpy.isclose(scan_0.
|
|
247
|
+
if not numpy.isclose(scan_0.sample_x_pixel_size, scan_1.sample_x_pixel_size):
|
|
248
248
|
raise ValueError(
|
|
249
|
-
f"{scan_0} and {scan_1} have different x pixel size. {scan_0.
|
|
249
|
+
f"{scan_0} and {scan_1} have different x pixel size. {scan_0.sample_x_pixel_size} vs {scan_1.sample_x_pixel_size}"
|
|
250
250
|
)
|
|
251
|
-
if not numpy.isclose(scan_0.
|
|
251
|
+
if not numpy.isclose(scan_0.sample_y_pixel_size, scan_1.sample_y_pixel_size):
|
|
252
252
|
raise ValueError(
|
|
253
|
-
f"{scan_0} and {scan_1} have different y pixel size. {scan_0.
|
|
253
|
+
f"{scan_0} and {scan_1} have different y pixel size. {scan_0.sample_y_pixel_size} vs {scan_1.sample_y_pixel_size}"
|
|
254
254
|
)
|
|
255
255
|
|
|
256
256
|
for scan in self.series:
|
|
@@ -292,7 +292,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
292
292
|
axis_N_pos_px = []
|
|
293
293
|
for scan, pos_in_mm in zip(self.series, pos_as_mm):
|
|
294
294
|
pixel_size_m = self.configuration.pixel_size or scan.pixel_size
|
|
295
|
-
axis_N_pos_px.append((pos_in_mm
|
|
295
|
+
axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / pixel_size_m)
|
|
296
296
|
return axis_N_pos_px
|
|
297
297
|
else:
|
|
298
298
|
# deduce from motor position and pixel size
|
|
@@ -472,7 +472,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
472
472
|
"""
|
|
473
473
|
nx_tomo = NXtomo()
|
|
474
474
|
|
|
475
|
-
nx_tomo.energy = self.series[0].energy
|
|
475
|
+
nx_tomo.energy = self.series[0].energy * _ureg.keV
|
|
476
476
|
start_times = list(filter(None, [scan.start_time for scan in self.series]))
|
|
477
477
|
end_times = list(filter(None, [scan.end_time for scan in self.series]))
|
|
478
478
|
|
|
@@ -496,9 +496,9 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
496
496
|
|
|
497
497
|
# handle detector (without frames)
|
|
498
498
|
nx_tomo.instrument.detector.field_of_view = self.series[0].field_of_view
|
|
499
|
-
nx_tomo.instrument.detector.distance = self.series[0].
|
|
500
|
-
nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size
|
|
501
|
-
nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size
|
|
499
|
+
nx_tomo.instrument.detector.distance = self.series[0].sample_detector_distance * _ureg.meter
|
|
500
|
+
nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size * _ureg.meter
|
|
501
|
+
nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size * _ureg.meter
|
|
502
502
|
nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
|
|
503
503
|
nx_tomo.instrument.detector.tomo_n = n_proj
|
|
504
504
|
# note: stitching process insure un-flipping of frames. So make sure transformations is defined as an empty set
|
|
@@ -506,13 +506,13 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
506
506
|
|
|
507
507
|
if isinstance(self.series[0], NXtomoScan):
|
|
508
508
|
# note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
|
|
509
|
-
rotation_angle = numpy.asarray(self.series[0].rotation_angle)
|
|
509
|
+
rotation_angle = numpy.asarray(self.series[0].rotation_angle) * _ureg.degree
|
|
510
510
|
nx_tomo.sample.rotation_angle = rotation_angle[
|
|
511
511
|
numpy.asarray(self.series[0].image_key_control) == ImageKey.PROJECTION.value
|
|
512
512
|
]
|
|
513
513
|
elif isinstance(self.series[0], EDFTomoScan):
|
|
514
|
-
nx_tomo.sample.rotation_angle =
|
|
515
|
-
start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n
|
|
514
|
+
nx_tomo.sample.rotation_angle = (
|
|
515
|
+
numpy.linspace(start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n) * _ureg.degree
|
|
516
516
|
)
|
|
517
517
|
else:
|
|
518
518
|
raise NotImplementedError(
|
|
@@ -526,12 +526,15 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
526
526
|
if isinstance(slices, slice):
|
|
527
527
|
return array[slices.start : slices.stop : 1]
|
|
528
528
|
elif isinstance(slices, Iterable):
|
|
529
|
-
return
|
|
529
|
+
return [array[index] for index in slices]
|
|
530
530
|
else:
|
|
531
|
-
raise RuntimeError("slices must be instance of a slice or of an iterable")
|
|
531
|
+
raise RuntimeError("slices must be instance of a slice or of an iterable") # noqa: TRY004
|
|
532
532
|
|
|
533
|
-
nx_tomo.sample.rotation_angle =
|
|
534
|
-
|
|
533
|
+
nx_tomo.sample.rotation_angle = (
|
|
534
|
+
apply_slices_selection(
|
|
535
|
+
array=nx_tomo.sample.rotation_angle.to_base_units().magnitude, slices=self._slices_to_stitch
|
|
536
|
+
)
|
|
537
|
+
* _ureg.degree
|
|
535
538
|
)
|
|
536
539
|
|
|
537
540
|
# handle sample
|
|
@@ -560,7 +563,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
560
563
|
# note: if at least one has missing values the numpy.Array(x_translation) with create an error as well
|
|
561
564
|
x_translation = [0.0] * n_proj
|
|
562
565
|
_logger.warning("Unable to fin input nxtomo x_translation values. Set it to 0.0")
|
|
563
|
-
nx_tomo.sample.x_translation = x_translation
|
|
566
|
+
nx_tomo.sample.x_translation = x_translation * _ureg.meter
|
|
564
567
|
|
|
565
568
|
y_translation = [
|
|
566
569
|
get_sample_translation_for_projs(scan, "y_translation")
|
|
@@ -575,7 +578,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
575
578
|
else:
|
|
576
579
|
y_translation = [0.0] * n_proj
|
|
577
580
|
_logger.warning("Unable to fin input nxtomo y_translation values. Set it to 0.0")
|
|
578
|
-
nx_tomo.sample.y_translation = y_translation
|
|
581
|
+
nx_tomo.sample.y_translation = y_translation * _ureg.meter
|
|
579
582
|
z_translation = [
|
|
580
583
|
get_sample_translation_for_projs(scan, "z_translation")
|
|
581
584
|
for scan in self.series
|
|
@@ -589,7 +592,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
589
592
|
else:
|
|
590
593
|
z_translation = [0.0] * n_proj
|
|
591
594
|
_logger.warning("Unable to fin input nxtomo z_translation values. Set it to 0.0")
|
|
592
|
-
nx_tomo.sample.z_translation = z_translation
|
|
595
|
+
nx_tomo.sample.z_translation = z_translation * _ureg.meter
|
|
593
596
|
|
|
594
597
|
nx_tomo.sample.name = self.series[0].sample_name
|
|
595
598
|
|
|
@@ -794,7 +797,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
794
797
|
):
|
|
795
798
|
i_frame = 0
|
|
796
799
|
_, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
|
|
797
|
-
for
|
|
800
|
+
for url in set_of_compacted_slices.values():
|
|
798
801
|
scan = scans[i_scan]
|
|
799
802
|
url = DataUrl(
|
|
800
803
|
file_path=url.file_path(),
|
|
@@ -886,6 +889,9 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
886
889
|
"""
|
|
887
890
|
make sure reduced dark and flats are existing otherwise compute them
|
|
888
891
|
"""
|
|
892
|
+
# TODO
|
|
893
|
+
# ruff: noqa: SIM105, S110
|
|
894
|
+
# --
|
|
889
895
|
for scan in self.series:
|
|
890
896
|
try:
|
|
891
897
|
reduced_darks, darks_infos = scan.load_reduced_darks(return_info=True)
|
|
@@ -896,7 +902,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
896
902
|
try:
|
|
897
903
|
# if we don't have write in the folder containing the .nx for example
|
|
898
904
|
scan.save_reduced_darks(reduced_darks, darks_infos=darks_infos)
|
|
899
|
-
except Exception
|
|
905
|
+
except Exception:
|
|
900
906
|
pass
|
|
901
907
|
scan.set_reduced_darks(reduced_darks, darks_infos=darks_infos)
|
|
902
908
|
|
|
@@ -909,7 +915,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
909
915
|
try:
|
|
910
916
|
# if we don't have write in the folder containing the .nx for example
|
|
911
917
|
scan.save_reduced_flats(reduced_flats, flats_infos=flats_infos)
|
|
912
|
-
except Exception
|
|
918
|
+
except Exception:
|
|
913
919
|
pass
|
|
914
920
|
scan.set_reduced_flats(reduced_flats, flats_infos=flats_infos)
|
|
915
921
|
|
|
@@ -979,7 +985,7 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
979
985
|
):
|
|
980
986
|
i_frame = 0
|
|
981
987
|
_, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
|
|
982
|
-
for
|
|
988
|
+
for url in set_of_compacted_slices.values():
|
|
983
989
|
scan = scans[i_scan]
|
|
984
990
|
url = DataUrl(
|
|
985
991
|
file_path=url.file_path(),
|
|
@@ -1000,12 +1006,12 @@ class PreProcessingStitching(SingleAxisStitcher):
|
|
|
1000
1006
|
|
|
1001
1007
|
missing = []
|
|
1002
1008
|
if len(scan.reduced_flats) == 0:
|
|
1003
|
-
missing
|
|
1009
|
+
missing.append("flats")
|
|
1004
1010
|
if len(scan.reduced_darks) == 0:
|
|
1005
|
-
missing
|
|
1011
|
+
missing.append("darks")
|
|
1006
1012
|
|
|
1007
1013
|
if len(missing) > 0:
|
|
1008
|
-
_logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
|
|
1014
|
+
_logger.warning(f"missing {' and '.join(missing)}. Unable to do flat field correction")
|
|
1009
1015
|
ff_arrays = None
|
|
1010
1016
|
data = raw_radios
|
|
1011
1017
|
else:
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
import h5py
|
|
2
1
|
import numpy
|
|
3
2
|
import logging
|
|
4
|
-
from
|
|
5
|
-
from
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
from collections.abc import Iterable
|
|
6
5
|
from tomoscan.series import Series
|
|
7
6
|
from tomoscan.identifier import BaseIdentifier
|
|
8
7
|
from nabu.stitching.stitcher.base import _StitcherBase, get_obj_constant_side_length
|
nabu/stitching/stitcher_2D.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
# ruff: noqa: N999
|
|
1
2
|
import numpy
|
|
2
3
|
from math import ceil
|
|
3
4
|
from typing import Union, Optional
|
|
@@ -19,7 +20,7 @@ def stitch_raw_frames(
|
|
|
19
20
|
pad_mode="constant",
|
|
20
21
|
new_unstitched_axis_size: Optional[int] = None,
|
|
21
22
|
) -> numpy.ndarray:
|
|
22
|
-
"""
|
|
23
|
+
r"""
|
|
23
24
|
stitches raw frames (already shifted and flat fielded !!!) together using
|
|
24
25
|
raw stitching (no pixel interpolation, y_overlap_in_px is expected to be a int).
|
|
25
26
|
Sttiching depends on the kernel used.
|
|
@@ -10,7 +10,7 @@ from nabu.stitching.overlap import OverlapStitchingStrategy
|
|
|
10
10
|
from nabu.stitching import config as stiching_config
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
_stitching_types = list(stiching_config.StitchingType
|
|
13
|
+
_stitching_types = [st.value for st in list(stiching_config.StitchingType)]
|
|
14
14
|
_stitching_types.append(None)
|
|
15
15
|
|
|
16
16
|
|
|
@@ -37,18 +37,16 @@ def test_stitching_config(stitching_type, option_level):
|
|
|
37
37
|
|
|
38
38
|
assert "stitching" in config
|
|
39
39
|
assert "type" in config["stitching"]
|
|
40
|
-
stitching_type = stiching_config.StitchingType
|
|
40
|
+
stitching_type = stiching_config.StitchingType(config["stitching"]["type"])
|
|
41
41
|
if stitching_type is stiching_config.StitchingType.Z_POSTPROC:
|
|
42
42
|
assert isinstance(
|
|
43
43
|
stiching_config.dict_to_config_obj(config),
|
|
44
44
|
stiching_config.PostProcessedSingleAxisStitchingConfiguration,
|
|
45
45
|
)
|
|
46
|
-
elif
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
)
|
|
51
|
-
elif stitching_type is stiching_config.StitchingType.Y_PREPROC:
|
|
46
|
+
elif (
|
|
47
|
+
stitching_type is stiching_config.StitchingType.Z_PREPROC
|
|
48
|
+
or stitching_type is stiching_config.StitchingType.Y_PREPROC
|
|
49
|
+
):
|
|
52
50
|
assert isinstance(
|
|
53
51
|
stiching_config.dict_to_config_obj(config),
|
|
54
52
|
stiching_config.PreProcessedSingleAxisStitchingConfiguration,
|
|
@@ -84,7 +82,7 @@ def test_stitching_config(stitching_type, option_level):
|
|
|
84
82
|
assert isinstance(config_class_instance.to_dict(), dict)
|
|
85
83
|
|
|
86
84
|
|
|
87
|
-
@pytest.mark.parametrize("stitching_strategy", OverlapStitchingStrategy
|
|
85
|
+
@pytest.mark.parametrize("stitching_strategy", [oss for oss in OverlapStitchingStrategy])
|
|
88
86
|
@pytest.mark.parametrize("overwrite_results", (True, "False", 0, "1"))
|
|
89
87
|
@pytest.mark.parametrize(
|
|
90
88
|
"axis_shifts",
|
|
@@ -92,7 +90,7 @@ def test_stitching_config(stitching_type, option_level):
|
|
|
92
90
|
"",
|
|
93
91
|
None,
|
|
94
92
|
"None",
|
|
95
|
-
"",
|
|
93
|
+
"", # noqa: PT014
|
|
96
94
|
"skimage",
|
|
97
95
|
"nabu-fft",
|
|
98
96
|
),
|
|
@@ -176,7 +174,8 @@ def test_PreProcessedZStitchingConfiguration(
|
|
|
176
174
|
|
|
177
175
|
from_dict = stiching_config.PreProcessedZStitchingConfiguration.from_dict(pre_process_config.to_dict())
|
|
178
176
|
# workaround for scans because a new object is created each time
|
|
179
|
-
|
|
177
|
+
# ???
|
|
178
|
+
pre_process_config.settle_inputs # noqa: B018
|
|
180
179
|
assert len(from_dict.input_scans) == len(pre_process_config.input_scans)
|
|
181
180
|
from_dict.input_scans = None
|
|
182
181
|
pre_process_config.input_scans = None
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import numpy
|
|
3
2
|
import pytest
|
|
4
3
|
from tomoscan.esrf import NXtomoScan
|
|
5
4
|
from tomoscan.esrf.volume import HDF5Volume
|
|
@@ -14,7 +13,7 @@ from nabu.stitching.slurm_utils import (
|
|
|
14
13
|
from tomoscan.esrf.mock import MockNXtomo
|
|
15
14
|
|
|
16
15
|
try:
|
|
17
|
-
import sluurp
|
|
16
|
+
import sluurp # noqa: F401
|
|
18
17
|
except ImportError:
|
|
19
18
|
has_sluurp = False
|
|
20
19
|
else:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import pytest
|
|
3
3
|
import numpy
|
|
4
|
+
import pint
|
|
4
5
|
from tqdm import tqdm
|
|
5
6
|
|
|
6
7
|
from nabu.stitching.y_stitching import y_stitching
|
|
@@ -9,6 +10,8 @@ from nxtomo.application.nxtomo import NXtomo
|
|
|
9
10
|
from nxtomo.nxobject.nxdetector import ImageKey
|
|
10
11
|
from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
|
|
11
12
|
|
|
13
|
+
_ureg = pint.UnitRegistry()
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
|
|
14
17
|
r"""
|
|
@@ -52,10 +55,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
|
|
|
52
55
|
|
|
53
56
|
n_projs = 3
|
|
54
57
|
nx_tomo = NXtomo()
|
|
55
|
-
nx_tomo.sample.x_translation = [0] * (n_projs + 2)
|
|
56
|
-
nx_tomo.sample.y_translation = [frame_y_position] * (n_projs + 2)
|
|
57
|
-
nx_tomo.sample.z_translation = [0] * (n_projs + 2)
|
|
58
|
-
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False)
|
|
58
|
+
nx_tomo.sample.x_translation = ([0] * (n_projs + 2)) * _ureg.meter
|
|
59
|
+
nx_tomo.sample.y_translation = ([frame_y_position] * (n_projs + 2)) * _ureg.meter
|
|
60
|
+
nx_tomo.sample.z_translation = ([0] * (n_projs + 2)) * _ureg.meter
|
|
61
|
+
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False) * _ureg.degree
|
|
59
62
|
nx_tomo.instrument.detector.image_key_control = (
|
|
60
63
|
ImageKey.DARK_FIELD,
|
|
61
64
|
ImageKey.FLAT_FIELD,
|
|
@@ -63,10 +66,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
|
|
|
63
66
|
ImageKey.PROJECTION,
|
|
64
67
|
ImageKey.PROJECTION,
|
|
65
68
|
)
|
|
66
|
-
nx_tomo.instrument.detector.x_pixel_size = 1.0
|
|
67
|
-
nx_tomo.instrument.detector.y_pixel_size = 1.0
|
|
68
|
-
nx_tomo.instrument.detector.distance = 2.3
|
|
69
|
-
nx_tomo.energy = 19.2
|
|
69
|
+
nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
|
|
70
|
+
nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
|
|
71
|
+
nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
|
|
72
|
+
nx_tomo.energy = 19.2 * _ureg.keV
|
|
70
73
|
nx_tomo.instrument.detector.data = numpy.stack(
|
|
71
74
|
(
|
|
72
75
|
my_dark_data,
|
|
@@ -395,7 +395,7 @@ def test_vol_z_stitching_with_alignment_axis_2(tmp_path, alignment_axis_2):
|
|
|
395
395
|
axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
|
|
396
396
|
slice_for_cross_correlation="middle",
|
|
397
397
|
voxel_size=None,
|
|
398
|
-
alignment_axis_2=AlignmentAxis2
|
|
398
|
+
alignment_axis_2=AlignmentAxis2(alignment_axis_2),
|
|
399
399
|
)
|
|
400
400
|
|
|
401
401
|
stitcher = PostProcessZStitcher(z_stich_config, progress=None)
|
|
@@ -512,7 +512,7 @@ def test_vol_z_stitching_with_alignment_axis_1(tmp_path, alignment_axis_1):
|
|
|
512
512
|
axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
|
|
513
513
|
slice_for_cross_correlation="middle",
|
|
514
514
|
voxel_size=None,
|
|
515
|
-
alignment_axis_1=AlignmentAxis1
|
|
515
|
+
alignment_axis_1=AlignmentAxis1(alignment_axis_1),
|
|
516
516
|
)
|
|
517
517
|
|
|
518
518
|
stitcher = PostProcessZStitcher(z_stich_config, progress=None)
|
|
@@ -770,6 +770,6 @@ def test_data_duplication(tmp_path, data_duplication):
|
|
|
770
770
|
if not data_duplication:
|
|
771
771
|
# make sure an error is raised if we try to ask for no data duplication and if we get some flips
|
|
772
772
|
z_stich_config.flip_ud = (False, True, False)
|
|
773
|
-
with pytest.raises(ValueError):
|
|
773
|
+
with pytest.raises(ValueError): # noqa: PT012
|
|
774
774
|
stitcher = PostProcessZStitcherNoDD(z_stich_config, progress=None)
|
|
775
775
|
stitcher.stitch()
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import pint
|
|
2
3
|
from silx.image.phantomgenerator import PhantomGenerator
|
|
3
4
|
from scipy.ndimage import shift as scipy_shift
|
|
4
5
|
import numpy
|
|
5
6
|
import pytest
|
|
6
7
|
from nabu.stitching.config import PreProcessedZStitchingConfiguration
|
|
7
8
|
from nabu.stitching.config import KEY_IMG_REG_METHOD
|
|
8
|
-
from nabu.stitching.overlap import
|
|
9
|
+
from nabu.stitching.overlap import OverlapStitchingStrategy
|
|
9
10
|
from nabu.stitching.z_stitching import (
|
|
10
11
|
PreProcessZStitcher,
|
|
11
12
|
)
|
|
12
|
-
from nabu.stitching.stitcher_2D import
|
|
13
|
+
from nabu.stitching.stitcher_2D import get_overlap_areas
|
|
13
14
|
from nxtomo.nxobject.nxdetector import ImageKey
|
|
14
15
|
from nxtomo.utils.transformation import DetYFlipTransformation, DetZFlipTransformation
|
|
15
16
|
from nxtomo.application.nxtomo import NXtomo
|
|
@@ -17,6 +18,8 @@ from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
|
|
|
17
18
|
from nabu.stitching.utils import ShiftAlgorithm
|
|
18
19
|
import h5py
|
|
19
20
|
|
|
21
|
+
_ureg = pint.get_application_registry()
|
|
22
|
+
|
|
20
23
|
|
|
21
24
|
_stitching_configurations = (
|
|
22
25
|
# simple case where shifts are provided
|
|
@@ -82,13 +85,13 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration):
|
|
|
82
85
|
scans = []
|
|
83
86
|
for (i_frame, frame), z_pos in zip(enumerate(frames), z_position):
|
|
84
87
|
nx_tomo = NXtomo()
|
|
85
|
-
nx_tomo.sample.z_translation = [z_pos] * n_proj
|
|
86
|
-
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
|
|
88
|
+
nx_tomo.sample.z_translation = ([z_pos] * n_proj) * _ureg.meter
|
|
89
|
+
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
|
|
87
90
|
nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
|
|
88
|
-
nx_tomo.instrument.detector.x_pixel_size = 1.0
|
|
89
|
-
nx_tomo.instrument.detector.y_pixel_size = 1.0
|
|
90
|
-
nx_tomo.instrument.detector.distance = 2.3
|
|
91
|
-
nx_tomo.energy = 19.2
|
|
91
|
+
nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
|
|
92
|
+
nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
|
|
93
|
+
nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
|
|
94
|
+
nx_tomo.energy = 19.2 * _ureg.keV
|
|
92
95
|
nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
|
|
93
96
|
|
|
94
97
|
file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
|
|
@@ -160,8 +163,8 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration):
|
|
|
160
163
|
)
|
|
161
164
|
|
|
162
165
|
# check also other metadata are here
|
|
163
|
-
assert created_nx_tomo.instrument.detector.distance
|
|
164
|
-
assert created_nx_tomo.energy
|
|
166
|
+
assert created_nx_tomo.instrument.detector.distance == 2.3 * _ureg.meter
|
|
167
|
+
assert created_nx_tomo.energy == 19.2 * _ureg.keV
|
|
165
168
|
numpy.testing.assert_array_equal(
|
|
166
169
|
created_nx_tomo.instrument.detector.image_key_control,
|
|
167
170
|
numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj),
|
|
@@ -228,13 +231,13 @@ def build_nxtomos(output_dir) -> tuple:
|
|
|
228
231
|
scans = []
|
|
229
232
|
for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions):
|
|
230
233
|
nx_tomo = NXtomo()
|
|
231
|
-
nx_tomo.sample.z_translation = [z_pos] * n_projs
|
|
232
|
-
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False)
|
|
234
|
+
nx_tomo.sample.z_translation = [z_pos] * n_projs * _ureg.meter
|
|
235
|
+
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False) * _ureg.degree
|
|
233
236
|
nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs
|
|
234
|
-
nx_tomo.instrument.detector.x_pixel_size = 1.0
|
|
235
|
-
nx_tomo.instrument.detector.y_pixel_size = 1.0
|
|
236
|
-
nx_tomo.instrument.detector.distance = 2.3
|
|
237
|
-
nx_tomo.energy = 19.2
|
|
237
|
+
nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
|
|
238
|
+
nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
|
|
239
|
+
nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
|
|
240
|
+
nx_tomo.energy = 19.2 * _ureg.keV
|
|
238
241
|
nx_tomo.instrument.detector.data = frame
|
|
239
242
|
|
|
240
243
|
file_path = os.path.join(output_dir, f"nxtomo_{i_frame}.nx")
|
|
@@ -301,11 +304,11 @@ def test_DistributePreProcessZStitcher(tmp_path, configuration_dist):
|
|
|
301
304
|
)
|
|
302
305
|
|
|
303
306
|
if complete:
|
|
304
|
-
len(final_nx_tomo.instrument.detector.data) ==
|
|
307
|
+
assert len(final_nx_tomo.instrument.detector.data) == 100
|
|
305
308
|
# test middle
|
|
306
309
|
numpy.testing.assert_array_almost_equal(raw_data[1], final_nx_tomo.instrument.detector.data[1, :, :])
|
|
307
310
|
else:
|
|
308
|
-
len(final_nx_tomo.instrument.detector.data) == 3
|
|
311
|
+
assert len(final_nx_tomo.instrument.detector.data) == 3
|
|
309
312
|
# test middle
|
|
310
313
|
numpy.testing.assert_array_almost_equal(raw_data[49], final_nx_tomo.instrument.detector.data[1, :, :])
|
|
311
314
|
# in the case of first, middle and last frames
|
|
@@ -373,15 +376,15 @@ def test_frame_flip(tmp_path):
|
|
|
373
376
|
scans = []
|
|
374
377
|
for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips):
|
|
375
378
|
nx_tomo = NXtomo()
|
|
376
|
-
nx_tomo.sample.z_translation = [z_pos] * n_proj
|
|
377
|
-
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
|
|
379
|
+
nx_tomo.sample.z_translation = [z_pos] * n_proj * _ureg.meter
|
|
380
|
+
nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
|
|
378
381
|
nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
|
|
379
|
-
nx_tomo.instrument.detector.x_pixel_size = 1.0
|
|
380
|
-
nx_tomo.instrument.detector.y_pixel_size = 1.0
|
|
381
|
-
nx_tomo.instrument.detector.distance = 2.3
|
|
382
|
+
nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
|
|
383
|
+
nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
|
|
384
|
+
nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
|
|
382
385
|
nx_tomo.instrument.detector.transformations.add_transformation(DetZFlipTransformation(flip=x_flip))
|
|
383
386
|
nx_tomo.instrument.detector.transformations.add_transformation(DetYFlipTransformation(flip=y_flip))
|
|
384
|
-
nx_tomo.energy = 19.2
|
|
387
|
+
nx_tomo.energy = 19.2 * _ureg.keV
|
|
385
388
|
nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
|
|
386
389
|
|
|
387
390
|
file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
|
|
File without changes
|
nabu/stitching/utils/utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from enum import Enum
|
|
1
2
|
from packaging.version import parse as parse_version
|
|
2
3
|
from typing import Optional, Union
|
|
3
4
|
import logging
|
|
@@ -6,10 +7,8 @@ import numpy
|
|
|
6
7
|
from tomoscan.scanbase import TomoScanBase
|
|
7
8
|
from tomoscan.volumebase import VolumeBase
|
|
8
9
|
from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation
|
|
9
|
-
from silx.utils.enum import Enum as _Enum
|
|
10
10
|
from scipy.fft import rfftn as local_fftn
|
|
11
11
|
from scipy.fft import irfftn as local_ifftn
|
|
12
|
-
from ..overlap import OverlapStitchingStrategy, ImageStichOverlapKernel
|
|
13
12
|
from ..alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
|
|
14
13
|
from ...misc import fourier_filters
|
|
15
14
|
from ...estimation.alignment import AlignmentBase
|
|
@@ -37,7 +36,7 @@ else:
|
|
|
37
36
|
__has_sk_phase_correlation__ = True
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
class ShiftAlgorithm(
|
|
39
|
+
class ShiftAlgorithm(Enum):
|
|
41
40
|
"""All generic shift search algorithm"""
|
|
42
41
|
|
|
43
42
|
NABU_FFT = "nabu-fft"
|
|
@@ -59,7 +58,7 @@ class ShiftAlgorithm(_Enum):
|
|
|
59
58
|
if value in ("", None):
|
|
60
59
|
return ShiftAlgorithm.NONE
|
|
61
60
|
else:
|
|
62
|
-
return super().
|
|
61
|
+
return super().__new__(cls, value)
|
|
63
62
|
|
|
64
63
|
|
|
65
64
|
def find_frame_relative_shifts(
|
|
@@ -73,9 +72,9 @@ def find_frame_relative_shifts(
|
|
|
73
72
|
y_shifts_params: Optional[dict] = None,
|
|
74
73
|
):
|
|
75
74
|
"""
|
|
76
|
-
:param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x
|
|
75
|
+
:param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x.
|
|
77
76
|
"""
|
|
78
|
-
if not
|
|
77
|
+
if overlap_axis not in (0, 1):
|
|
79
78
|
raise ValueError(f"overlap_axis should be in (0, 1). Get {overlap_axis}")
|
|
80
79
|
from nabu.stitching.config import (
|
|
81
80
|
KEY_LOW_PASS_FILTER,
|
|
@@ -146,7 +145,7 @@ def find_frame_relative_shifts(
|
|
|
146
145
|
}
|
|
147
146
|
|
|
148
147
|
res_algo = {}
|
|
149
|
-
for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)):
|
|
148
|
+
for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)): # noqa: C405
|
|
150
149
|
if shift_alg not in shift_methods:
|
|
151
150
|
raise ValueError(f"requested image alignment function not handled ({shift_alg})")
|
|
152
151
|
try:
|
|
@@ -200,8 +199,8 @@ def find_volumes_relative_shifts(
|
|
|
200
199
|
else:
|
|
201
200
|
raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {overlap_axis}")
|
|
202
201
|
|
|
203
|
-
alignment_axis_2 = AlignmentAxis2
|
|
204
|
-
alignment_axis_1 = AlignmentAxis1
|
|
202
|
+
alignment_axis_2 = AlignmentAxis2(alignment_axis_2)
|
|
203
|
+
alignment_axis_1 = AlignmentAxis1(alignment_axis_1)
|
|
205
204
|
assert dim_axis_1 > 0, "dim_axis_1 <= 0"
|
|
206
205
|
|
|
207
206
|
if isinstance(slice_for_shift, str):
|
|
@@ -249,7 +248,7 @@ def find_volumes_relative_shifts(
|
|
|
249
248
|
|
|
250
249
|
w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
|
|
251
250
|
start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
|
|
252
|
-
end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2,
|
|
251
|
+
end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, upper_frame.shape[0], lower_frame.shape[0])
|
|
253
252
|
|
|
254
253
|
if start_overlap == 0:
|
|
255
254
|
overlap_upper_frame = upper_frame[-end_overlap:]
|
|
@@ -385,12 +384,10 @@ def find_projections_relative_shifts(
|
|
|
385
384
|
cor_options=cor_options,
|
|
386
385
|
)
|
|
387
386
|
|
|
388
|
-
estimated_shifts =
|
|
389
|
-
[
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
]
|
|
393
|
-
)
|
|
387
|
+
estimated_shifts = [
|
|
388
|
+
estimated_shifts[0],
|
|
389
|
+
(lower_scan_pos - upper_scan_pos),
|
|
390
|
+
]
|
|
394
391
|
x_cross_correlation_function = ShiftAlgorithm.NONE
|
|
395
392
|
|
|
396
393
|
# } else we will compute shift from the flat projections
|
|
@@ -464,7 +461,8 @@ def find_projections_relative_shifts(
|
|
|
464
461
|
start_overlap = max(estimated_shifts[axis_proj_space] // 2 - w_window_size // 2, 0)
|
|
465
462
|
end_overlap = min(
|
|
466
463
|
estimated_shifts[axis_proj_space] // 2 + w_window_size // 2,
|
|
467
|
-
|
|
464
|
+
upper_proj.shape[axis_proj_space],
|
|
465
|
+
lower_proj.shape[axis_proj_space],
|
|
468
466
|
)
|
|
469
467
|
o_upper_sel = numpy.array(range(-end_overlap, -start_overlap))
|
|
470
468
|
overlap_upper_frame = numpy.take_along_axis(
|
|
@@ -502,7 +500,7 @@ def find_shift_correlate(img1, img2, padding_mode="reflect"):
|
|
|
502
500
|
padding_mode,
|
|
503
501
|
)
|
|
504
502
|
|
|
505
|
-
img_shape =
|
|
503
|
+
img_shape = cc.shape # Because cc.shape can differ from img_2.shape (e.g. in case of odd nb of cols)
|
|
506
504
|
cc_vs = numpy.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
|
|
507
505
|
cc_hs = numpy.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
|
|
508
506
|
|