nabu 2024.2.13__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +14 -0
- nabu/reconstruction/fbp_base.py +40 -8
- nabu/reconstruction/fbp_opencl.py +8 -0
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.13.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
| @@ -2,7 +2,8 @@ import numpy | |
| 2 2 | 
             
            import logging
         | 
| 3 3 | 
             
            import h5py
         | 
| 4 4 | 
             
            import os
         | 
| 5 | 
            -
             | 
| 5 | 
            +
            import pint
         | 
| 6 | 
            +
            from collections.abc import Iterable
         | 
| 6 7 | 
             
            from silx.io.url import DataUrl
         | 
| 7 8 | 
             
            from silx.io.utils import get_data
         | 
| 8 9 | 
             
            from datetime import datetime
         | 
| @@ -26,8 +27,8 @@ from nabu.stitching.config import ( | |
| 26 27 | 
             
            from nabu.stitching.utils import find_projections_relative_shifts
         | 
| 27 28 | 
             
            from functools import lru_cache as cache
         | 
| 28 29 | 
             
            from .single_axis import SingleAxisStitcher
         | 
| 29 | 
            -
            from pyunitsystem.metricsystem import MetricSystem
         | 
| 30 30 |  | 
| 31 | 
            +
            _ureg = pint.get_application_registry()
         | 
| 31 32 |  | 
| 32 33 | 
             
            _logger = logging.getLogger(__name__)
         | 
| 33 34 |  | 
| @@ -40,7 +41,6 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 40 41 | 
             
                """
         | 
| 41 42 |  | 
| 42 43 | 
             
                def __init__(self, configuration, progress=None) -> None:
         | 
| 43 | 
            -
                    """ """
         | 
| 44 44 | 
             
                    if not isinstance(configuration, PreProcessedSingleAxisStitchingConfiguration):
         | 
| 45 45 | 
             
                        raise TypeError(
         | 
| 46 46 | 
             
                            f"configuration is expected to be an instance of {PreProcessedSingleAxisStitchingConfiguration}. Get {type(configuration)} instead"
         | 
| @@ -239,18 +239,18 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 239 239 | 
             
                        if not scan_0.field_of_view == scan_1.field_of_view:
         | 
| 240 240 | 
             
                            raise ValueError(f"{scan_0} and {scan_1} have different field of view")
         | 
| 241 241 | 
             
                        # check distance
         | 
| 242 | 
            -
                        if scan_0. | 
| 242 | 
            +
                        if scan_0.sample_detector_distance is None:
         | 
| 243 243 | 
             
                            _logger.warning(f"no distance found for {scan_0}")
         | 
| 244 | 
            -
                        elif not numpy.isclose(scan_0. | 
| 244 | 
            +
                        elif not numpy.isclose(scan_0.sample_detector_distance, scan_1.sample_detector_distance, rtol=10e-3):
         | 
| 245 245 | 
             
                            raise ValueError(f"{scan_0} and {scan_1} have different sample / detector distance")
         | 
| 246 246 | 
             
                        # check pixel size
         | 
| 247 | 
            -
                        if not numpy.isclose(scan_0. | 
| 247 | 
            +
                        if not numpy.isclose(scan_0.sample_x_pixel_size, scan_1.sample_x_pixel_size):
         | 
| 248 248 | 
             
                            raise ValueError(
         | 
| 249 | 
            -
                                f"{scan_0} and {scan_1} have different x pixel size. {scan_0. | 
| 249 | 
            +
                                f"{scan_0} and {scan_1} have different x pixel size. {scan_0.sample_x_pixel_size} vs {scan_1.sample_x_pixel_size}"
         | 
| 250 250 | 
             
                            )
         | 
| 251 | 
            -
                        if not numpy.isclose(scan_0. | 
| 251 | 
            +
                        if not numpy.isclose(scan_0.sample_y_pixel_size, scan_1.sample_y_pixel_size):
         | 
| 252 252 | 
             
                            raise ValueError(
         | 
| 253 | 
            -
                                f"{scan_0} and {scan_1} have different y pixel size. {scan_0. | 
| 253 | 
            +
                                f"{scan_0} and {scan_1} have different y pixel size. {scan_0.sample_y_pixel_size} vs {scan_1.sample_y_pixel_size}"
         | 
| 254 254 | 
             
                            )
         | 
| 255 255 |  | 
| 256 256 | 
             
                    for scan in self.series:
         | 
| @@ -292,7 +292,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 292 292 | 
             
                            axis_N_pos_px = []
         | 
| 293 293 | 
             
                            for scan, pos_in_mm in zip(self.series, pos_as_mm):
         | 
| 294 294 | 
             
                                pixel_size_m = self.configuration.pixel_size or scan.pixel_size
         | 
| 295 | 
            -
                                axis_N_pos_px.append((pos_in_mm  | 
| 295 | 
            +
                                axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / pixel_size_m)
         | 
| 296 296 | 
             
                            return axis_N_pos_px
         | 
| 297 297 | 
             
                        else:
         | 
| 298 298 | 
             
                            # deduce from motor position and pixel size
         | 
| @@ -472,7 +472,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 472 472 | 
             
                    """
         | 
| 473 473 | 
             
                    nx_tomo = NXtomo()
         | 
| 474 474 |  | 
| 475 | 
            -
                    nx_tomo.energy = self.series[0].energy
         | 
| 475 | 
            +
                    nx_tomo.energy = self.series[0].energy * _ureg.keV
         | 
| 476 476 | 
             
                    start_times = list(filter(None, [scan.start_time for scan in self.series]))
         | 
| 477 477 | 
             
                    end_times = list(filter(None, [scan.end_time for scan in self.series]))
         | 
| 478 478 |  | 
| @@ -496,9 +496,9 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 496 496 |  | 
| 497 497 | 
             
                    # handle detector (without frames)
         | 
| 498 498 | 
             
                    nx_tomo.instrument.detector.field_of_view = self.series[0].field_of_view
         | 
| 499 | 
            -
                    nx_tomo.instrument.detector.distance = self.series[0]. | 
| 500 | 
            -
                    nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size
         | 
| 501 | 
            -
                    nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size
         | 
| 499 | 
            +
                    nx_tomo.instrument.detector.distance = self.series[0].sample_detector_distance * _ureg.meter
         | 
| 500 | 
            +
                    nx_tomo.instrument.detector.x_pixel_size = self.series[0].x_pixel_size * _ureg.meter
         | 
| 501 | 
            +
                    nx_tomo.instrument.detector.y_pixel_size = self.series[0].y_pixel_size * _ureg.meter
         | 
| 502 502 | 
             
                    nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
         | 
| 503 503 | 
             
                    nx_tomo.instrument.detector.tomo_n = n_proj
         | 
| 504 504 | 
             
                    # note: stitching process insure un-flipping of frames. So make sure transformations is defined as an empty set
         | 
| @@ -506,13 +506,13 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 506 506 |  | 
| 507 507 | 
             
                    if isinstance(self.series[0], NXtomoScan):
         | 
| 508 508 | 
             
                        # note: first scan is always the reference as order to read data (so no rotation_angle inversion here)
         | 
| 509 | 
            -
                        rotation_angle = numpy.asarray(self.series[0].rotation_angle)
         | 
| 509 | 
            +
                        rotation_angle = numpy.asarray(self.series[0].rotation_angle) * _ureg.degree
         | 
| 510 510 | 
             
                        nx_tomo.sample.rotation_angle = rotation_angle[
         | 
| 511 511 | 
             
                            numpy.asarray(self.series[0].image_key_control) == ImageKey.PROJECTION.value
         | 
| 512 512 | 
             
                        ]
         | 
| 513 513 | 
             
                    elif isinstance(self.series[0], EDFTomoScan):
         | 
| 514 | 
            -
                        nx_tomo.sample.rotation_angle =  | 
| 515 | 
            -
                            start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n
         | 
| 514 | 
            +
                        nx_tomo.sample.rotation_angle = (
         | 
| 515 | 
            +
                            numpy.linspace(start=0, stop=self.series[0].scan_range, num=self.series[0].tomo_n) * _ureg.degree
         | 
| 516 516 | 
             
                        )
         | 
| 517 517 | 
             
                    else:
         | 
| 518 518 | 
             
                        raise NotImplementedError(
         | 
| @@ -526,12 +526,15 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 526 526 | 
             
                        if isinstance(slices, slice):
         | 
| 527 527 | 
             
                            return array[slices.start : slices.stop : 1]
         | 
| 528 528 | 
             
                        elif isinstance(slices, Iterable):
         | 
| 529 | 
            -
                            return  | 
| 529 | 
            +
                            return [array[index] for index in slices]
         | 
| 530 530 | 
             
                        else:
         | 
| 531 | 
            -
                            raise RuntimeError("slices must be instance of a slice or of an iterable")
         | 
| 531 | 
            +
                            raise RuntimeError("slices must be instance of a slice or of an iterable")  # noqa: TRY004
         | 
| 532 532 |  | 
| 533 | 
            -
                    nx_tomo.sample.rotation_angle =  | 
| 534 | 
            -
                         | 
| 533 | 
            +
                    nx_tomo.sample.rotation_angle = (
         | 
| 534 | 
            +
                        apply_slices_selection(
         | 
| 535 | 
            +
                            array=nx_tomo.sample.rotation_angle.to_base_units().magnitude, slices=self._slices_to_stitch
         | 
| 536 | 
            +
                        )
         | 
| 537 | 
            +
                        * _ureg.degree
         | 
| 535 538 | 
             
                    )
         | 
| 536 539 |  | 
| 537 540 | 
             
                    # handle sample
         | 
| @@ -560,7 +563,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 560 563 | 
             
                            # note: if at least one has missing values the numpy.Array(x_translation) with create an error as well
         | 
| 561 564 | 
             
                            x_translation = [0.0] * n_proj
         | 
| 562 565 | 
             
                            _logger.warning("Unable to fin input nxtomo x_translation values. Set it to 0.0")
         | 
| 563 | 
            -
                        nx_tomo.sample.x_translation = x_translation
         | 
| 566 | 
            +
                        nx_tomo.sample.x_translation = x_translation * _ureg.meter
         | 
| 564 567 |  | 
| 565 568 | 
             
                        y_translation = [
         | 
| 566 569 | 
             
                            get_sample_translation_for_projs(scan, "y_translation")
         | 
| @@ -575,7 +578,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 575 578 | 
             
                        else:
         | 
| 576 579 | 
             
                            y_translation = [0.0] * n_proj
         | 
| 577 580 | 
             
                            _logger.warning("Unable to fin input nxtomo y_translation values. Set it to 0.0")
         | 
| 578 | 
            -
                        nx_tomo.sample.y_translation = y_translation
         | 
| 581 | 
            +
                        nx_tomo.sample.y_translation = y_translation * _ureg.meter
         | 
| 579 582 | 
             
                        z_translation = [
         | 
| 580 583 | 
             
                            get_sample_translation_for_projs(scan, "z_translation")
         | 
| 581 584 | 
             
                            for scan in self.series
         | 
| @@ -589,7 +592,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 589 592 | 
             
                        else:
         | 
| 590 593 | 
             
                            z_translation = [0.0] * n_proj
         | 
| 591 594 | 
             
                            _logger.warning("Unable to fin input nxtomo z_translation values. Set it to 0.0")
         | 
| 592 | 
            -
                        nx_tomo.sample.z_translation = z_translation
         | 
| 595 | 
            +
                        nx_tomo.sample.z_translation = z_translation * _ureg.meter
         | 
| 593 596 |  | 
| 594 597 | 
             
                        nx_tomo.sample.name = self.series[0].sample_name
         | 
| 595 598 |  | 
| @@ -794,7 +797,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 794 797 | 
             
                    ):
         | 
| 795 798 | 
             
                        i_frame = 0
         | 
| 796 799 | 
             
                        _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
         | 
| 797 | 
            -
                        for  | 
| 800 | 
            +
                        for url in set_of_compacted_slices.values():
         | 
| 798 801 | 
             
                            scan = scans[i_scan]
         | 
| 799 802 | 
             
                            url = DataUrl(
         | 
| 800 803 | 
             
                                file_path=url.file_path(),
         | 
| @@ -886,6 +889,9 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 886 889 | 
             
                    """
         | 
| 887 890 | 
             
                    make sure reduced dark and flats are existing otherwise compute them
         | 
| 888 891 | 
             
                    """
         | 
| 892 | 
            +
                    # TODO
         | 
| 893 | 
            +
                    # ruff: noqa: SIM105, S110
         | 
| 894 | 
            +
                    # --
         | 
| 889 895 | 
             
                    for scan in self.series:
         | 
| 890 896 | 
             
                        try:
         | 
| 891 897 | 
             
                            reduced_darks, darks_infos = scan.load_reduced_darks(return_info=True)
         | 
| @@ -896,7 +902,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 896 902 | 
             
                            try:
         | 
| 897 903 | 
             
                                # if we don't have write in the folder containing the .nx for example
         | 
| 898 904 | 
             
                                scan.save_reduced_darks(reduced_darks, darks_infos=darks_infos)
         | 
| 899 | 
            -
                            except Exception | 
| 905 | 
            +
                            except Exception:
         | 
| 900 906 | 
             
                                pass
         | 
| 901 907 | 
             
                        scan.set_reduced_darks(reduced_darks, darks_infos=darks_infos)
         | 
| 902 908 |  | 
| @@ -909,7 +915,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 909 915 | 
             
                            try:
         | 
| 910 916 | 
             
                                # if we don't have write in the folder containing the .nx for example
         | 
| 911 917 | 
             
                                scan.save_reduced_flats(reduced_flats, flats_infos=flats_infos)
         | 
| 912 | 
            -
                            except Exception | 
| 918 | 
            +
                            except Exception:
         | 
| 913 919 | 
             
                                pass
         | 
| 914 920 | 
             
                        scan.set_reduced_flats(reduced_flats, flats_infos=flats_infos)
         | 
| 915 921 |  | 
| @@ -979,7 +985,7 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 979 985 | 
             
                    ):
         | 
| 980 986 | 
             
                        i_frame = 0
         | 
| 981 987 | 
             
                        _, set_of_compacted_slices = get_compacted_dataslices(scan_urls, return_url_set=True)
         | 
| 982 | 
            -
                        for  | 
| 988 | 
            +
                        for url in set_of_compacted_slices.values():
         | 
| 983 989 | 
             
                            scan = scans[i_scan]
         | 
| 984 990 | 
             
                            url = DataUrl(
         | 
| 985 991 | 
             
                                file_path=url.file_path(),
         | 
| @@ -1000,12 +1006,12 @@ class PreProcessingStitching(SingleAxisStitcher): | |
| 1000 1006 |  | 
| 1001 1007 | 
             
                            missing = []
         | 
| 1002 1008 | 
             
                            if len(scan.reduced_flats) == 0:
         | 
| 1003 | 
            -
                                missing | 
| 1009 | 
            +
                                missing.append("flats")
         | 
| 1004 1010 | 
             
                            if len(scan.reduced_darks) == 0:
         | 
| 1005 | 
            -
                                missing | 
| 1011 | 
            +
                                missing.append("darks")
         | 
| 1006 1012 |  | 
| 1007 1013 | 
             
                            if len(missing) > 0:
         | 
| 1008 | 
            -
                                _logger.warning(f"missing {'and'.join(missing)}. Unable to do flat field correction")
         | 
| 1014 | 
            +
                                _logger.warning(f"missing {' and '.join(missing)}. Unable to do flat field correction")
         | 
| 1009 1015 | 
             
                                ff_arrays = None
         | 
| 1010 1016 | 
             
                                data = raw_radios
         | 
| 1011 1017 | 
             
                            else:
         | 
| @@ -1,8 +1,7 @@ | |
| 1 | 
            -
            import h5py
         | 
| 2 1 | 
             
            import numpy
         | 
| 3 2 | 
             
            import logging
         | 
| 4 | 
            -
            from  | 
| 5 | 
            -
            from  | 
| 3 | 
            +
            from typing import Optional, Union
         | 
| 4 | 
            +
            from collections.abc import Iterable
         | 
| 6 5 | 
             
            from tomoscan.series import Series
         | 
| 7 6 | 
             
            from tomoscan.identifier import BaseIdentifier
         | 
| 8 7 | 
             
            from nabu.stitching.stitcher.base import _StitcherBase, get_obj_constant_side_length
         | 
    
        nabu/stitching/stitcher_2D.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            # ruff: noqa: N999
         | 
| 1 2 | 
             
            import numpy
         | 
| 2 3 | 
             
            from math import ceil
         | 
| 3 4 | 
             
            from typing import Union, Optional
         | 
| @@ -19,7 +20,7 @@ def stitch_raw_frames( | |
| 19 20 | 
             
                pad_mode="constant",
         | 
| 20 21 | 
             
                new_unstitched_axis_size: Optional[int] = None,
         | 
| 21 22 | 
             
            ) -> numpy.ndarray:
         | 
| 22 | 
            -
                """
         | 
| 23 | 
            +
                r"""
         | 
| 23 24 | 
             
                stitches raw frames (already shifted and flat fielded !!!) together using
         | 
| 24 25 | 
             
                raw stitching (no pixel interpolation, y_overlap_in_px is expected to be a int).
         | 
| 25 26 | 
             
                Sttiching depends on the kernel used.
         | 
| @@ -10,7 +10,7 @@ from nabu.stitching.overlap import OverlapStitchingStrategy | |
| 10 10 | 
             
            from nabu.stitching import config as stiching_config
         | 
| 11 11 |  | 
| 12 12 |  | 
| 13 | 
            -
            _stitching_types = list(stiching_config.StitchingType | 
| 13 | 
            +
            _stitching_types = [st.value for st in list(stiching_config.StitchingType)]
         | 
| 14 14 | 
             
            _stitching_types.append(None)
         | 
| 15 15 |  | 
| 16 16 |  | 
| @@ -37,18 +37,16 @@ def test_stitching_config(stitching_type, option_level): | |
| 37 37 |  | 
| 38 38 | 
             
                    assert "stitching" in config
         | 
| 39 39 | 
             
                    assert "type" in config["stitching"]
         | 
| 40 | 
            -
                    stitching_type = stiching_config.StitchingType | 
| 40 | 
            +
                    stitching_type = stiching_config.StitchingType(config["stitching"]["type"])
         | 
| 41 41 | 
             
                    if stitching_type is stiching_config.StitchingType.Z_POSTPROC:
         | 
| 42 42 | 
             
                        assert isinstance(
         | 
| 43 43 | 
             
                            stiching_config.dict_to_config_obj(config),
         | 
| 44 44 | 
             
                            stiching_config.PostProcessedSingleAxisStitchingConfiguration,
         | 
| 45 45 | 
             
                        )
         | 
| 46 | 
            -
                    elif  | 
| 47 | 
            -
                         | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
                        )
         | 
| 51 | 
            -
                    elif stitching_type is stiching_config.StitchingType.Y_PREPROC:
         | 
| 46 | 
            +
                    elif (
         | 
| 47 | 
            +
                        stitching_type is stiching_config.StitchingType.Z_PREPROC
         | 
| 48 | 
            +
                        or stitching_type is stiching_config.StitchingType.Y_PREPROC
         | 
| 49 | 
            +
                    ):
         | 
| 52 50 | 
             
                        assert isinstance(
         | 
| 53 51 | 
             
                            stiching_config.dict_to_config_obj(config),
         | 
| 54 52 | 
             
                            stiching_config.PreProcessedSingleAxisStitchingConfiguration,
         | 
| @@ -84,7 +82,7 @@ def test_stitching_config(stitching_type, option_level): | |
| 84 82 | 
             
                    assert isinstance(config_class_instance.to_dict(), dict)
         | 
| 85 83 |  | 
| 86 84 |  | 
| 87 | 
            -
            @pytest.mark.parametrize("stitching_strategy", OverlapStitchingStrategy | 
| 85 | 
            +
            @pytest.mark.parametrize("stitching_strategy", [oss for oss in OverlapStitchingStrategy])
         | 
| 88 86 | 
             
            @pytest.mark.parametrize("overwrite_results", (True, "False", 0, "1"))
         | 
| 89 87 | 
             
            @pytest.mark.parametrize(
         | 
| 90 88 | 
             
                "axis_shifts",
         | 
| @@ -92,7 +90,7 @@ def test_stitching_config(stitching_type, option_level): | |
| 92 90 | 
             
                    "",
         | 
| 93 91 | 
             
                    None,
         | 
| 94 92 | 
             
                    "None",
         | 
| 95 | 
            -
                    "",
         | 
| 93 | 
            +
                    "",  # noqa: PT014
         | 
| 96 94 | 
             
                    "skimage",
         | 
| 97 95 | 
             
                    "nabu-fft",
         | 
| 98 96 | 
             
                ),
         | 
| @@ -176,7 +174,8 @@ def test_PreProcessedZStitchingConfiguration( | |
| 176 174 |  | 
| 177 175 | 
             
                from_dict = stiching_config.PreProcessedZStitchingConfiguration.from_dict(pre_process_config.to_dict())
         | 
| 178 176 | 
             
                # workaround for scans because a new object is created each time
         | 
| 179 | 
            -
                 | 
| 177 | 
            +
                # ???
         | 
| 178 | 
            +
                pre_process_config.settle_inputs  # noqa: B018
         | 
| 180 179 | 
             
                assert len(from_dict.input_scans) == len(pre_process_config.input_scans)
         | 
| 181 180 | 
             
                from_dict.input_scans = None
         | 
| 182 181 | 
             
                pre_process_config.input_scans = None
         | 
| @@ -1,5 +1,4 @@ | |
| 1 1 | 
             
            import os
         | 
| 2 | 
            -
            import numpy
         | 
| 3 2 | 
             
            import pytest
         | 
| 4 3 | 
             
            from tomoscan.esrf import NXtomoScan
         | 
| 5 4 | 
             
            from tomoscan.esrf.volume import HDF5Volume
         | 
| @@ -14,7 +13,7 @@ from nabu.stitching.slurm_utils import ( | |
| 14 13 | 
             
            from tomoscan.esrf.mock import MockNXtomo
         | 
| 15 14 |  | 
| 16 15 | 
             
            try:
         | 
| 17 | 
            -
                import sluurp
         | 
| 16 | 
            +
                import sluurp  # noqa: F401
         | 
| 18 17 | 
             
            except ImportError:
         | 
| 19 18 | 
             
                has_sluurp = False
         | 
| 20 19 | 
             
            else:
         | 
| @@ -1,6 +1,7 @@ | |
| 1 1 | 
             
            import os
         | 
| 2 2 | 
             
            import pytest
         | 
| 3 3 | 
             
            import numpy
         | 
| 4 | 
            +
            import pint
         | 
| 4 5 | 
             
            from tqdm import tqdm
         | 
| 5 6 |  | 
| 6 7 | 
             
            from nabu.stitching.y_stitching import y_stitching
         | 
| @@ -9,6 +10,8 @@ from nxtomo.application.nxtomo import NXtomo | |
| 9 10 | 
             
            from nxtomo.nxobject.nxdetector import ImageKey
         | 
| 10 11 | 
             
            from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
         | 
| 11 12 |  | 
| 13 | 
            +
            _ureg = pint.UnitRegistry()
         | 
| 14 | 
            +
             | 
| 12 15 |  | 
| 13 16 | 
             
            def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple:
         | 
| 14 17 | 
             
                r"""
         | 
| @@ -52,10 +55,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple: | |
| 52 55 |  | 
| 53 56 | 
             
                    n_projs = 3
         | 
| 54 57 | 
             
                    nx_tomo = NXtomo()
         | 
| 55 | 
            -
                    nx_tomo.sample.x_translation = [0] * (n_projs + 2)
         | 
| 56 | 
            -
                    nx_tomo.sample.y_translation = [frame_y_position] * (n_projs + 2)
         | 
| 57 | 
            -
                    nx_tomo.sample.z_translation = [0] * (n_projs + 2)
         | 
| 58 | 
            -
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False)
         | 
| 58 | 
            +
                    nx_tomo.sample.x_translation = ([0] * (n_projs + 2)) * _ureg.meter
         | 
| 59 | 
            +
                    nx_tomo.sample.y_translation = ([frame_y_position] * (n_projs + 2)) * _ureg.meter
         | 
| 60 | 
            +
                    nx_tomo.sample.z_translation = ([0] * (n_projs + 2)) * _ureg.meter
         | 
| 61 | 
            +
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=(n_projs + 2), endpoint=False) * _ureg.degree
         | 
| 59 62 | 
             
                    nx_tomo.instrument.detector.image_key_control = (
         | 
| 60 63 | 
             
                        ImageKey.DARK_FIELD,
         | 
| 61 64 | 
             
                        ImageKey.FLAT_FIELD,
         | 
| @@ -63,10 +66,10 @@ def build_nxtomos(output_dir, flip_lr, flip_ud) -> tuple: | |
| 63 66 | 
             
                        ImageKey.PROJECTION,
         | 
| 64 67 | 
             
                        ImageKey.PROJECTION,
         | 
| 65 68 | 
             
                    )
         | 
| 66 | 
            -
                    nx_tomo.instrument.detector.x_pixel_size = 1.0
         | 
| 67 | 
            -
                    nx_tomo.instrument.detector.y_pixel_size = 1.0
         | 
| 68 | 
            -
                    nx_tomo.instrument.detector.distance = 2.3
         | 
| 69 | 
            -
                    nx_tomo.energy = 19.2
         | 
| 69 | 
            +
                    nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
         | 
| 70 | 
            +
                    nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
         | 
| 71 | 
            +
                    nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
         | 
| 72 | 
            +
                    nx_tomo.energy = 19.2 * _ureg.keV
         | 
| 70 73 | 
             
                    nx_tomo.instrument.detector.data = numpy.stack(
         | 
| 71 74 | 
             
                        (
         | 
| 72 75 | 
             
                            my_dark_data,
         | 
| @@ -395,7 +395,7 @@ def test_vol_z_stitching_with_alignment_axis_2(tmp_path, alignment_axis_2): | |
| 395 395 | 
             
                    axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
         | 
| 396 396 | 
             
                    slice_for_cross_correlation="middle",
         | 
| 397 397 | 
             
                    voxel_size=None,
         | 
| 398 | 
            -
                    alignment_axis_2=AlignmentAxis2 | 
| 398 | 
            +
                    alignment_axis_2=AlignmentAxis2(alignment_axis_2),
         | 
| 399 399 | 
             
                )
         | 
| 400 400 |  | 
| 401 401 | 
             
                stitcher = PostProcessZStitcher(z_stich_config, progress=None)
         | 
| @@ -512,7 +512,7 @@ def test_vol_z_stitching_with_alignment_axis_1(tmp_path, alignment_axis_1): | |
| 512 512 | 
             
                    axis_2_params={"img_reg_method": ShiftAlgorithm.NONE},
         | 
| 513 513 | 
             
                    slice_for_cross_correlation="middle",
         | 
| 514 514 | 
             
                    voxel_size=None,
         | 
| 515 | 
            -
                    alignment_axis_1=AlignmentAxis1 | 
| 515 | 
            +
                    alignment_axis_1=AlignmentAxis1(alignment_axis_1),
         | 
| 516 516 | 
             
                )
         | 
| 517 517 |  | 
| 518 518 | 
             
                stitcher = PostProcessZStitcher(z_stich_config, progress=None)
         | 
| @@ -770,6 +770,6 @@ def test_data_duplication(tmp_path, data_duplication): | |
| 770 770 | 
             
                if not data_duplication:
         | 
| 771 771 | 
             
                    # make sure an error is raised if we try to ask for no data duplication and if we get some flips
         | 
| 772 772 | 
             
                    z_stich_config.flip_ud = (False, True, False)
         | 
| 773 | 
            -
                    with pytest.raises(ValueError):
         | 
| 773 | 
            +
                    with pytest.raises(ValueError):  # noqa: PT012
         | 
| 774 774 | 
             
                        stitcher = PostProcessZStitcherNoDD(z_stich_config, progress=None)
         | 
| 775 775 | 
             
                        stitcher.stitch()
         | 
| @@ -1,15 +1,16 @@ | |
| 1 1 | 
             
            import os
         | 
| 2 | 
            +
            import pint
         | 
| 2 3 | 
             
            from silx.image.phantomgenerator import PhantomGenerator
         | 
| 3 4 | 
             
            from scipy.ndimage import shift as scipy_shift
         | 
| 4 5 | 
             
            import numpy
         | 
| 5 6 | 
             
            import pytest
         | 
| 6 7 | 
             
            from nabu.stitching.config import PreProcessedZStitchingConfiguration
         | 
| 7 8 | 
             
            from nabu.stitching.config import KEY_IMG_REG_METHOD
         | 
| 8 | 
            -
            from nabu.stitching.overlap import  | 
| 9 | 
            +
            from nabu.stitching.overlap import OverlapStitchingStrategy
         | 
| 9 10 | 
             
            from nabu.stitching.z_stitching import (
         | 
| 10 11 | 
             
                PreProcessZStitcher,
         | 
| 11 12 | 
             
            )
         | 
| 12 | 
            -
            from nabu.stitching.stitcher_2D import  | 
| 13 | 
            +
            from nabu.stitching.stitcher_2D import get_overlap_areas
         | 
| 13 14 | 
             
            from nxtomo.nxobject.nxdetector import ImageKey
         | 
| 14 15 | 
             
            from nxtomo.utils.transformation import DetYFlipTransformation, DetZFlipTransformation
         | 
| 15 16 | 
             
            from nxtomo.application.nxtomo import NXtomo
         | 
| @@ -17,6 +18,8 @@ from tomoscan.esrf.scan.nxtomoscan import NXtomoScan | |
| 17 18 | 
             
            from nabu.stitching.utils import ShiftAlgorithm
         | 
| 18 19 | 
             
            import h5py
         | 
| 19 20 |  | 
| 21 | 
            +
            _ureg = pint.get_application_registry()
         | 
| 22 | 
            +
             | 
| 20 23 |  | 
| 21 24 | 
             
            _stitching_configurations = (
         | 
| 22 25 | 
             
                # simple case where shifts are provided
         | 
| @@ -82,13 +85,13 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration): | |
| 82 85 | 
             
                scans = []
         | 
| 83 86 | 
             
                for (i_frame, frame), z_pos in zip(enumerate(frames), z_position):
         | 
| 84 87 | 
             
                    nx_tomo = NXtomo()
         | 
| 85 | 
            -
                    nx_tomo.sample.z_translation = [z_pos] * n_proj
         | 
| 86 | 
            -
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
         | 
| 88 | 
            +
                    nx_tomo.sample.z_translation = ([z_pos] * n_proj) * _ureg.meter
         | 
| 89 | 
            +
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
         | 
| 87 90 | 
             
                    nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
         | 
| 88 | 
            -
                    nx_tomo.instrument.detector.x_pixel_size = 1.0
         | 
| 89 | 
            -
                    nx_tomo.instrument.detector.y_pixel_size = 1.0
         | 
| 90 | 
            -
                    nx_tomo.instrument.detector.distance = 2.3
         | 
| 91 | 
            -
                    nx_tomo.energy = 19.2
         | 
| 91 | 
            +
                    nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
         | 
| 92 | 
            +
                    nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
         | 
| 93 | 
            +
                    nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
         | 
| 94 | 
            +
                    nx_tomo.energy = 19.2 * _ureg.keV
         | 
| 92 95 | 
             
                    nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
         | 
| 93 96 |  | 
| 94 97 | 
             
                    file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
         | 
| @@ -160,8 +163,8 @@ def test_PreProcessZStitcher(tmp_path, dtype, configuration): | |
| 160 163 | 
             
                    )
         | 
| 161 164 |  | 
| 162 165 | 
             
                # check also other metadata are here
         | 
| 163 | 
            -
                assert created_nx_tomo.instrument.detector.distance | 
| 164 | 
            -
                assert created_nx_tomo.energy | 
| 166 | 
            +
                assert created_nx_tomo.instrument.detector.distance == 2.3 * _ureg.meter
         | 
| 167 | 
            +
                assert created_nx_tomo.energy == 19.2 * _ureg.keV
         | 
| 165 168 | 
             
                numpy.testing.assert_array_equal(
         | 
| 166 169 | 
             
                    created_nx_tomo.instrument.detector.image_key_control,
         | 
| 167 170 | 
             
                    numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj),
         | 
| @@ -228,13 +231,13 @@ def build_nxtomos(output_dir) -> tuple: | |
| 228 231 | 
             
                scans = []
         | 
| 229 232 | 
             
                for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions):
         | 
| 230 233 | 
             
                    nx_tomo = NXtomo()
         | 
| 231 | 
            -
                    nx_tomo.sample.z_translation = [z_pos] * n_projs
         | 
| 232 | 
            -
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False)
         | 
| 234 | 
            +
                    nx_tomo.sample.z_translation = [z_pos] * n_projs * _ureg.meter
         | 
| 235 | 
            +
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False) * _ureg.degree
         | 
| 233 236 | 
             
                    nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs
         | 
| 234 | 
            -
                    nx_tomo.instrument.detector.x_pixel_size = 1.0
         | 
| 235 | 
            -
                    nx_tomo.instrument.detector.y_pixel_size = 1.0
         | 
| 236 | 
            -
                    nx_tomo.instrument.detector.distance = 2.3
         | 
| 237 | 
            -
                    nx_tomo.energy = 19.2
         | 
| 237 | 
            +
                    nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
         | 
| 238 | 
            +
                    nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
         | 
| 239 | 
            +
                    nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
         | 
| 240 | 
            +
                    nx_tomo.energy = 19.2 * _ureg.keV
         | 
| 238 241 | 
             
                    nx_tomo.instrument.detector.data = frame
         | 
| 239 242 |  | 
| 240 243 | 
             
                    file_path = os.path.join(output_dir, f"nxtomo_{i_frame}.nx")
         | 
| @@ -301,11 +304,11 @@ def test_DistributePreProcessZStitcher(tmp_path, configuration_dist): | |
| 301 304 | 
             
                )
         | 
| 302 305 |  | 
| 303 306 | 
             
                if complete:
         | 
| 304 | 
            -
                    len(final_nx_tomo.instrument.detector.data) ==  | 
| 307 | 
            +
                    assert len(final_nx_tomo.instrument.detector.data) == 100
         | 
| 305 308 | 
             
                    # test middle
         | 
| 306 309 | 
             
                    numpy.testing.assert_array_almost_equal(raw_data[1], final_nx_tomo.instrument.detector.data[1, :, :])
         | 
| 307 310 | 
             
                else:
         | 
| 308 | 
            -
                    len(final_nx_tomo.instrument.detector.data) == 3
         | 
| 311 | 
            +
                    assert len(final_nx_tomo.instrument.detector.data) == 3
         | 
| 309 312 | 
             
                    # test middle
         | 
| 310 313 | 
             
                    numpy.testing.assert_array_almost_equal(raw_data[49], final_nx_tomo.instrument.detector.data[1, :, :])
         | 
| 311 314 | 
             
                # in the case of first, middle and last frames
         | 
| @@ -373,15 +376,15 @@ def test_frame_flip(tmp_path): | |
| 373 376 | 
             
                scans = []
         | 
| 374 377 | 
             
                for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips):
         | 
| 375 378 | 
             
                    nx_tomo = NXtomo()
         | 
| 376 | 
            -
                    nx_tomo.sample.z_translation = [z_pos] * n_proj
         | 
| 377 | 
            -
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False)
         | 
| 379 | 
            +
                    nx_tomo.sample.z_translation = [z_pos] * n_proj * _ureg.meter
         | 
| 380 | 
            +
                    nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) * _ureg.degree
         | 
| 378 381 | 
             
                    nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj
         | 
| 379 | 
            -
                    nx_tomo.instrument.detector.x_pixel_size = 1.0
         | 
| 380 | 
            -
                    nx_tomo.instrument.detector.y_pixel_size = 1.0
         | 
| 381 | 
            -
                    nx_tomo.instrument.detector.distance = 2.3
         | 
| 382 | 
            +
                    nx_tomo.instrument.detector.x_pixel_size = 1.0 * _ureg.meter
         | 
| 383 | 
            +
                    nx_tomo.instrument.detector.y_pixel_size = 1.0 * _ureg.meter
         | 
| 384 | 
            +
                    nx_tomo.instrument.detector.distance = 2.3 * _ureg.meter
         | 
| 382 385 | 
             
                    nx_tomo.instrument.detector.transformations.add_transformation(DetZFlipTransformation(flip=x_flip))
         | 
| 383 386 | 
             
                    nx_tomo.instrument.detector.transformations.add_transformation(DetYFlipTransformation(flip=y_flip))
         | 
| 384 | 
            -
                    nx_tomo.energy = 19.2
         | 
| 387 | 
            +
                    nx_tomo.energy = 19.2 * _ureg.keV
         | 
| 385 388 | 
             
                    nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj)
         | 
| 386 389 |  | 
| 387 390 | 
             
                    file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx")
         | 
| 
            File without changes
         | 
    
        nabu/stitching/utils/utils.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            from enum import Enum
         | 
| 1 2 | 
             
            from packaging.version import parse as parse_version
         | 
| 2 3 | 
             
            from typing import Optional, Union
         | 
| 3 4 | 
             
            import logging
         | 
| @@ -6,10 +7,8 @@ import numpy | |
| 6 7 | 
             
            from tomoscan.scanbase import TomoScanBase
         | 
| 7 8 | 
             
            from tomoscan.volumebase import VolumeBase
         | 
| 8 9 | 
             
            from nxtomo.utils.transformation import build_matrix, DetYFlipTransformation
         | 
| 9 | 
            -
            from silx.utils.enum import Enum as _Enum
         | 
| 10 10 | 
             
            from scipy.fft import rfftn as local_fftn
         | 
| 11 11 | 
             
            from scipy.fft import irfftn as local_ifftn
         | 
| 12 | 
            -
            from ..overlap import OverlapStitchingStrategy, ImageStichOverlapKernel
         | 
| 13 12 | 
             
            from ..alignment import AlignmentAxis1, AlignmentAxis2, PaddedRawData
         | 
| 14 13 | 
             
            from ...misc import fourier_filters
         | 
| 15 14 | 
             
            from ...estimation.alignment import AlignmentBase
         | 
| @@ -37,7 +36,7 @@ else: | |
| 37 36 | 
             
                __has_sk_phase_correlation__ = True
         | 
| 38 37 |  | 
| 39 38 |  | 
| 40 | 
            -
            class ShiftAlgorithm( | 
| 39 | 
            +
            class ShiftAlgorithm(Enum):
         | 
| 41 40 | 
             
                """All generic shift search algorithm"""
         | 
| 42 41 |  | 
| 43 42 | 
             
                NABU_FFT = "nabu-fft"
         | 
| @@ -59,7 +58,7 @@ class ShiftAlgorithm(_Enum): | |
| 59 58 | 
             
                    if value in ("", None):
         | 
| 60 59 | 
             
                        return ShiftAlgorithm.NONE
         | 
| 61 60 | 
             
                    else:
         | 
| 62 | 
            -
                        return super(). | 
| 61 | 
            +
                        return super().__new__(cls, value)
         | 
| 63 62 |  | 
| 64 63 |  | 
| 65 64 | 
             
            def find_frame_relative_shifts(
         | 
| @@ -73,9 +72,9 @@ def find_frame_relative_shifts( | |
| 73 72 | 
             
                y_shifts_params: Optional[dict] = None,
         | 
| 74 73 | 
             
            ):
         | 
| 75 74 | 
             
                """
         | 
| 76 | 
            -
                :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x
         | 
| 75 | 
            +
                :param overlap_axis: axis in [0, 1] on which the overlap exists. In image space. So 0 is aka y and 1 as x.
         | 
| 77 76 | 
             
                """
         | 
| 78 | 
            -
                if not  | 
| 77 | 
            +
                if overlap_axis not in (0, 1):
         | 
| 79 78 | 
             
                    raise ValueError(f"overlap_axis should be in (0, 1). Get {overlap_axis}")
         | 
| 80 79 | 
             
                from nabu.stitching.config import (
         | 
| 81 80 | 
             
                    KEY_LOW_PASS_FILTER,
         | 
| @@ -146,7 +145,7 @@ def find_frame_relative_shifts( | |
| 146 145 | 
             
                }
         | 
| 147 146 |  | 
| 148 147 | 
             
                res_algo = {}
         | 
| 149 | 
            -
                for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)):
         | 
| 148 | 
            +
                for shift_alg in set((x_cross_correlation_function, y_cross_correlation_function)):  # noqa: C405
         | 
| 150 149 | 
             
                    if shift_alg not in shift_methods:
         | 
| 151 150 | 
             
                        raise ValueError(f"requested image alignment function not handled ({shift_alg})")
         | 
| 152 151 | 
             
                    try:
         | 
| @@ -200,8 +199,8 @@ def find_volumes_relative_shifts( | |
| 200 199 | 
             
                else:
         | 
| 201 200 | 
             
                    raise ValueError(f"Stitching is done in 3D space. Expect axis to be in [0,2]. Get {overlap_axis}")
         | 
| 202 201 |  | 
| 203 | 
            -
                alignment_axis_2 = AlignmentAxis2 | 
| 204 | 
            -
                alignment_axis_1 = AlignmentAxis1 | 
| 202 | 
            +
                alignment_axis_2 = AlignmentAxis2(alignment_axis_2)
         | 
| 203 | 
            +
                alignment_axis_1 = AlignmentAxis1(alignment_axis_1)
         | 
| 205 204 | 
             
                assert dim_axis_1 > 0, "dim_axis_1 <= 0"
         | 
| 206 205 |  | 
| 207 206 | 
             
                if isinstance(slice_for_shift, str):
         | 
| @@ -249,7 +248,7 @@ def find_volumes_relative_shifts( | |
| 249 248 |  | 
| 250 249 | 
             
                w_window_size = int(y_shifts_params.get(KEY_WINDOW_SIZE, 400))
         | 
| 251 250 | 
             
                start_overlap = max(estimated_shifts[0] // 2 - w_window_size // 2, 0)
         | 
| 252 | 
            -
                end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2,  | 
| 251 | 
            +
                end_overlap = min(estimated_shifts[0] // 2 + w_window_size // 2, upper_frame.shape[0], lower_frame.shape[0])
         | 
| 253 252 |  | 
| 254 253 | 
             
                if start_overlap == 0:
         | 
| 255 254 | 
             
                    overlap_upper_frame = upper_frame[-end_overlap:]
         | 
| @@ -385,12 +384,10 @@ def find_projections_relative_shifts( | |
| 385 384 | 
             
                        cor_options=cor_options,
         | 
| 386 385 | 
             
                    )
         | 
| 387 386 |  | 
| 388 | 
            -
                    estimated_shifts =  | 
| 389 | 
            -
                        [
         | 
| 390 | 
            -
             | 
| 391 | 
            -
             | 
| 392 | 
            -
                        ]
         | 
| 393 | 
            -
                    )
         | 
| 387 | 
            +
                    estimated_shifts = [
         | 
| 388 | 
            +
                        estimated_shifts[0],
         | 
| 389 | 
            +
                        (lower_scan_pos - upper_scan_pos),
         | 
| 390 | 
            +
                    ]
         | 
| 394 391 | 
             
                    x_cross_correlation_function = ShiftAlgorithm.NONE
         | 
| 395 392 |  | 
| 396 393 | 
             
                # } else we will compute shift from the flat projections
         | 
| @@ -464,7 +461,8 @@ def find_projections_relative_shifts( | |
| 464 461 | 
             
                start_overlap = max(estimated_shifts[axis_proj_space] // 2 - w_window_size // 2, 0)
         | 
| 465 462 | 
             
                end_overlap = min(
         | 
| 466 463 | 
             
                    estimated_shifts[axis_proj_space] // 2 + w_window_size // 2,
         | 
| 467 | 
            -
                     | 
| 464 | 
            +
                    upper_proj.shape[axis_proj_space],
         | 
| 465 | 
            +
                    lower_proj.shape[axis_proj_space],
         | 
| 468 466 | 
             
                )
         | 
| 469 467 | 
             
                o_upper_sel = numpy.array(range(-end_overlap, -start_overlap))
         | 
| 470 468 | 
             
                overlap_upper_frame = numpy.take_along_axis(
         | 
| @@ -502,7 +500,7 @@ def find_shift_correlate(img1, img2, padding_mode="reflect"): | |
| 502 500 | 
             
                    padding_mode,
         | 
| 503 501 | 
             
                )
         | 
| 504 502 |  | 
| 505 | 
            -
                img_shape =  | 
| 503 | 
            +
                img_shape = cc.shape  # Because cc.shape can differ from img_2.shape (e.g. in case of odd nb of cols)
         | 
| 506 504 | 
             
                cc_vs = numpy.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
         | 
| 507 505 | 
             
                cc_hs = numpy.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
         | 
| 508 506 |  |