nabu 2024.2.13__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +14 -0
- nabu/reconstruction/fbp_base.py +40 -8
- nabu/reconstruction/fbp_opencl.py +8 -0
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.13.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
    
        nabu/resources/utils.py
    CHANGED
    
    | @@ -1,8 +1,9 @@ | |
| 1 1 | 
             
            from ast import literal_eval
         | 
| 2 2 | 
             
            import numpy as np
         | 
| 3 | 
            +
            import pint
         | 
| 3 4 | 
             
            from psutil import virtual_memory, cpu_count
         | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            _ureg = pint.get_application_registry()
         | 
| 6 7 |  | 
| 7 8 |  | 
| 8 9 | 
             
            def get_values_from_file(fname, n_values=None, shape=None, sep=None, any_size=False):
         | 
| @@ -163,12 +164,17 @@ def get_quantities_and_units(string, sep=";"): | |
| 163 164 | 
             
                    value, unit = value_and_unit.split()
         | 
| 164 165 | 
             
                    val = float(value)
         | 
| 165 166 | 
             
                    # Convert to SI
         | 
| 166 | 
            -
                     | 
| 167 | 
            -
                         | 
| 168 | 
            -
             | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
                         | 
| 172 | 
            -
             | 
| 173 | 
            -
                     | 
| 167 | 
            +
                    if unit.lower() == "kev":
         | 
| 168 | 
            +
                        current_unit = _ureg.keV
         | 
| 169 | 
            +
                    elif unit.lower() == "ev":
         | 
| 170 | 
            +
                        current_unit = _ureg.eV
         | 
| 171 | 
            +
                    else:
         | 
| 172 | 
            +
                        current_unit = _ureg(unit)
         | 
| 173 | 
            +
                    # handle energies (to move to keV)
         | 
| 174 | 
            +
                    if _ureg.keV.dimensionality == current_unit.dimensionality:
         | 
| 175 | 
            +
                        result[quantity_name] = (val * current_unit).to(_ureg.keV).magnitude
         | 
| 176 | 
            +
                    elif _ureg.meter.dimensionality == current_unit.dimensionality:
         | 
| 177 | 
            +
                        result[quantity_name] = (val * current_unit).to_base_units().magnitude
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        raise ValueError(f"Cannot convert: {unit}")
         | 
| 174 180 | 
             
                return result
         | 
    
        nabu/stitching/alignment.py
    CHANGED
    
    | @@ -1,13 +1,10 @@ | |
| 1 | 
            +
            from enum import Enum
         | 
| 1 2 | 
             
            import h5py
         | 
| 2 3 | 
             
            import numpy
         | 
| 3 4 | 
             
            from typing import Union
         | 
| 4 | 
            -
            from silx.utils.enum import Enum as _Enum
         | 
| 5 | 
            -
            from tomoscan.volumebase import VolumeBase
         | 
| 6 | 
            -
            from tomoscan.esrf.volume.hdf5volume import HDF5Volume
         | 
| 7 | 
            -
            from nabu.io.utils import DatasetReader
         | 
| 8 5 |  | 
| 9 6 |  | 
| 10 | 
            -
            class AlignmentAxis2( | 
| 7 | 
            +
            class AlignmentAxis2(Enum):
         | 
| 11 8 | 
             
                """Specific alignment named to help users orienting themself with specific name"""
         | 
| 12 9 |  | 
| 13 10 | 
             
                CENTER = "center"
         | 
| @@ -15,7 +12,7 @@ class AlignmentAxis2(_Enum): | |
| 15 12 | 
             
                RIGTH = "right"
         | 
| 16 13 |  | 
| 17 14 |  | 
| 18 | 
            -
            class AlignmentAxis1( | 
| 15 | 
            +
            class AlignmentAxis1(Enum):
         | 
| 19 16 | 
             
                """Specific alignment named to help users orienting themself with specific name"""
         | 
| 20 17 |  | 
| 21 18 | 
             
                FRONT = "front"
         | 
| @@ -23,7 +20,7 @@ class AlignmentAxis1(_Enum): | |
| 23 20 | 
             
                BACK = "back"
         | 
| 24 21 |  | 
| 25 22 |  | 
| 26 | 
            -
            class _Alignment( | 
| 23 | 
            +
            class _Alignment(Enum):
         | 
| 27 24 | 
             
                """Internal alignment to be used for 2D alignment"""
         | 
| 28 25 |  | 
| 29 26 | 
             
                LOWER_BOUNDARY = "lower boundary"
         | 
| @@ -32,7 +29,7 @@ class _Alignment(_Enum): | |
| 32 29 |  | 
| 33 30 | 
             
                @classmethod
         | 
| 34 31 | 
             
                def from_value(cls, value):
         | 
| 35 | 
            -
                    # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition
         | 
| 32 | 
            +
                    # cast the AlignmentAxis1 and AlignmentAxis2 values to fit the generic definition.
         | 
| 36 33 | 
             
                    if value in ("front", "left", AlignmentAxis1.FRONT, AlignmentAxis2.LEFT):
         | 
| 37 34 | 
             
                        return _Alignment.LOWER_BOUNDARY
         | 
| 38 35 | 
             
                    elif value in ("back", "right", AlignmentAxis1.BACK, AlignmentAxis2.RIGTH):
         | 
| @@ -40,7 +37,7 @@ class _Alignment(_Enum): | |
| 40 37 | 
             
                    elif value in (AlignmentAxis1.CENTER, AlignmentAxis2.CENTER):
         | 
| 41 38 | 
             
                        return _Alignment.CENTER
         | 
| 42 39 | 
             
                    else:
         | 
| 43 | 
            -
                        return super(). | 
| 40 | 
            +
                        return super().__new__(cls, value)
         | 
| 44 41 |  | 
| 45 42 |  | 
| 46 43 | 
             
            def align_frame(
         | 
| @@ -106,7 +103,7 @@ def align_horizontally(data: numpy.ndarray, alignment: AlignmentAxis2, new_width | |
| 106 103 | 
             
                :param HAlignment alignment: alignment strategy
         | 
| 107 104 | 
             
                :param int new_width: output data width
         | 
| 108 105 | 
             
                """
         | 
| 109 | 
            -
                alignment = AlignmentAxis2 | 
| 106 | 
            +
                alignment = AlignmentAxis2(alignment).value
         | 
| 110 107 | 
             
                return align_frame(
         | 
| 111 108 | 
             
                    data=data, alignment=alignment, new_aligned_axis_size=new_width, pad_mode=pad_mode, alignment_axis=1
         | 
| 112 109 | 
             
                )
         | 
| @@ -151,7 +148,7 @@ class PaddedRawData: | |
| 151 148 | 
             
                @property
         | 
| 152 149 | 
             
                def shape(self):
         | 
| 153 150 | 
             
                    if self._shape is None:
         | 
| 154 | 
            -
                        self._shape = tuple(
         | 
| 151 | 
            +
                        self._shape = tuple(  # noqa: C409
         | 
| 155 152 | 
             
                            (
         | 
| 156 153 | 
             
                                self._raw_data_shape[0],
         | 
| 157 154 | 
             
                                numpy.sum(
         | 
    
        nabu/stitching/config.py
    CHANGED
    
    | @@ -1,8 +1,9 @@ | |
| 1 | 
            +
            import pint
         | 
| 1 2 | 
             
            from math import ceil
         | 
| 2 | 
            -
            from typing import  | 
| 3 | 
            +
            from typing import Optional, Union
         | 
| 4 | 
            +
            from collections.abc import Iterable, Sized
         | 
| 3 5 | 
             
            from dataclasses import dataclass
         | 
| 4 6 | 
             
            import numpy
         | 
| 5 | 
            -
            from pyunitsystem.metricsystem import MetricSystem
         | 
| 6 7 | 
             
            from nxtomo.paths import nxtomo
         | 
| 7 8 | 
             
            from tomoscan.factory import Factory
         | 
| 8 9 | 
             
            from tomoscan.identifier import VolumeIdentifier, ScanIdentifier
         | 
| @@ -12,13 +13,14 @@ from ..pipeline.config_validators import ( | |
| 12 13 | 
             
                convert_to_bool,
         | 
| 13 14 | 
             
            )
         | 
| 14 15 | 
             
            from ..utils import concatenate_dict, convert_str_to_tuple
         | 
| 15 | 
            -
            from ..io.utils import get_output_volume
         | 
| 16 16 | 
             
            from .overlap import OverlapStitchingStrategy
         | 
| 17 17 | 
             
            from .utils.utils import ShiftAlgorithm
         | 
| 18 18 | 
             
            from .definitions import StitchingType
         | 
| 19 19 | 
             
            from .alignment import AlignmentAxis1, AlignmentAxis2
         | 
| 20 | 
            -
            from pyunitsystem.metricsystem import MetricSystem
         | 
| 21 20 |  | 
| 21 | 
            +
            _ureg = pint.get_application_registry()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # ruff: noqa: S105
         | 
| 22 24 |  | 
| 23 25 | 
             
            KEY_IMG_REG_METHOD = "img_reg_method"
         | 
| 24 26 |  | 
| @@ -128,6 +130,8 @@ SLURM_MODULES_TO_LOADS = "modules" | |
| 128 130 |  | 
| 129 131 | 
             
            SLURM_CLEAN_SCRIPTS = "clean_scripts"
         | 
| 130 132 |  | 
| 133 | 
            +
            SLURM_JOB_NAME = "job_name"
         | 
| 134 | 
            +
             | 
| 131 135 | 
             
            # normalization by sample
         | 
| 132 136 |  | 
| 133 137 | 
             
            NORMALIZATION_BY_SAMPLE_SECTION = "normalization_by_sample"
         | 
| @@ -205,7 +209,7 @@ def _str_to_dict(my_str: Union[str, dict]): | |
| 205 209 |  | 
| 206 210 |  | 
| 207 211 | 
             
            def _dict_to_str(ddict: dict):
         | 
| 208 | 
            -
                return ";".join([f"{ | 
| 212 | 
            +
                return ";".join([f"{key!s}={value!s}" for key, value in ddict.items()])
         | 
| 209 213 |  | 
| 210 214 |  | 
| 211 215 | 
             
            def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]:
         | 
| @@ -218,7 +222,7 @@ def str_to_shifts(my_str: Optional[str]) -> Union[str, tuple]: | |
| 218 222 | 
             
                    if my_str == "":
         | 
| 219 223 | 
             
                        return None
         | 
| 220 224 | 
             
                    try:
         | 
| 221 | 
            -
                        shift = ShiftAlgorithm | 
| 225 | 
            +
                        shift = ShiftAlgorithm(my_str)
         | 
| 222 226 | 
             
                    except ValueError:
         | 
| 223 227 | 
             
                        shifts_as_str = filter(None, my_str.replace(";", ",").split(","))
         | 
| 224 228 | 
             
                        return [float(shift) for shift in shifts_as_str]
         | 
| @@ -235,8 +239,8 @@ def _valid_stitching_kernels_params(my_dict: Union[dict, str]): | |
| 235 239 | 
             
                    my_dict = _str_to_dict(my_str=my_dict)
         | 
| 236 240 |  | 
| 237 241 | 
             
                valid_keys = (KEY_THRESHOLD_FREQUENCY, KEY_SIDE)
         | 
| 238 | 
            -
                for key in my_dict | 
| 239 | 
            -
                    if not  | 
| 242 | 
            +
                for key in my_dict:
         | 
| 243 | 
            +
                    if key not in valid_keys:
         | 
| 240 244 | 
             
                        raise KeyError(f"{key} is a unrecognized key")
         | 
| 241 245 | 
             
                return my_dict
         | 
| 242 246 |  | 
| @@ -253,8 +257,8 @@ def _valid_shifts_params(my_dict: Union[dict, str]): | |
| 253 257 | 
             
                    KEY_LOW_PASS_FILTER,
         | 
| 254 258 | 
             
                    KEY_SIDE,
         | 
| 255 259 | 
             
                )
         | 
| 256 | 
            -
                for key in my_dict | 
| 257 | 
            -
                    if not  | 
| 260 | 
            +
                for key in my_dict:
         | 
| 261 | 
            +
                    if key not in valid_keys:
         | 
| 258 262 | 
             
                        raise KeyError(f"{key} is a unrecognized key")
         | 
| 259 263 | 
             
                return my_dict
         | 
| 260 264 |  | 
| @@ -334,7 +338,7 @@ class NormalizationBySample: | |
| 334 338 |  | 
| 335 339 | 
             
                @method.setter
         | 
| 336 340 | 
             
                def method(self, method: Union[Method, str]) -> None:
         | 
| 337 | 
            -
                    self._method = Method | 
| 341 | 
            +
                    self._method = Method(method)
         | 
| 338 342 |  | 
| 339 343 | 
             
                @property
         | 
| 340 344 | 
             
                def margin(self) -> int:
         | 
| @@ -351,7 +355,7 @@ class NormalizationBySample: | |
| 351 355 |  | 
| 352 356 | 
             
                @side.setter
         | 
| 353 357 | 
             
                def side(self, side: Union[SampleSide, str]):
         | 
| 354 | 
            -
                    self._side = SampleSide | 
| 358 | 
            +
                    self._side = SampleSide(side)
         | 
| 355 359 |  | 
| 356 360 | 
             
                @property
         | 
| 357 361 | 
             
                def width(self) -> int:
         | 
| @@ -401,16 +405,16 @@ class NormalizationBySample: | |
| 401 405 | 
             
                        NORMALIZATION_BY_SAMPLE_WIDTH: self.width,
         | 
| 402 406 | 
             
                    }
         | 
| 403 407 |  | 
| 404 | 
            -
                def __eq__(self,  | 
| 405 | 
            -
                    if not isinstance( | 
| 408 | 
            +
                def __eq__(self, value: object, /) -> bool:
         | 
| 409 | 
            +
                    if not isinstance(value, NormalizationBySample):
         | 
| 406 410 | 
             
                        return False
         | 
| 407 411 | 
             
                    else:
         | 
| 408 | 
            -
                        return self.to_dict() ==  | 
| 412 | 
            +
                        return self.to_dict() == value.to_dict()
         | 
| 409 413 |  | 
| 410 414 |  | 
| 411 415 | 
             
            @dataclass
         | 
| 412 416 | 
             
            class SlurmConfig:
         | 
| 413 | 
            -
                "configuration for slurm jobs"
         | 
| 417 | 
            +
                """configuration for slurm jobs"""
         | 
| 414 418 |  | 
| 415 419 | 
             
                partition: str = ""  # note: must stay empty to make by default we don't use slurm (use by the  configuration file)
         | 
| 416 420 | 
             
                mem: str = "128"
         | 
| @@ -421,6 +425,7 @@ class SlurmConfig: | |
| 421 425 | 
             
                clean_script: bool = ""
         | 
| 422 426 | 
             
                n_tasks: int = 1
         | 
| 423 427 | 
             
                n_cpu_per_task: int = 4
         | 
| 428 | 
            +
                job_name: str = ""
         | 
| 424 429 |  | 
| 425 430 | 
             
                def __post_init__(self) -> None:
         | 
| 426 431 | 
             
                    # make sure either 'modules' or 'preprocessing_command' is provided
         | 
| @@ -430,7 +435,7 @@ class SlurmConfig: | |
| 430 435 | 
             
                        )
         | 
| 431 436 |  | 
| 432 437 | 
             
                def to_dict(self) -> dict:
         | 
| 433 | 
            -
                    "dump configuration to dict"
         | 
| 438 | 
            +
                    """dump configuration to dict"""
         | 
| 434 439 | 
             
                    return {
         | 
| 435 440 | 
             
                        SLURM_PARTITION: self.partition if self.partition is not None else "",
         | 
| 436 441 | 
             
                        SLURM_MEM: self.mem,
         | 
| @@ -441,6 +446,7 @@ class SlurmConfig: | |
| 441 446 | 
             
                        SLURM_CLEAN_SCRIPTS: self.clean_script,
         | 
| 442 447 | 
             
                        SLURM_NUMBER_OF_TASKS: self.n_tasks,
         | 
| 443 448 | 
             
                        SLURM_COR_PER_TASKS: self.n_cpu_per_task,
         | 
| 449 | 
            +
                        SLURM_JOB_NAME: self.job_name,
         | 
| 444 450 | 
             
                    }
         | 
| 445 451 |  | 
| 446 452 | 
             
                @staticmethod
         | 
| @@ -457,18 +463,21 @@ class SlurmConfig: | |
| 457 463 | 
             
                        preprocessing_command=config.get(SLURM_PREPROCESSING_COMMAND, ""),
         | 
| 458 464 | 
             
                        modules_to_load=convert_str_to_tuple(config.get(SLURM_MODULES_TO_LOADS, "")),
         | 
| 459 465 | 
             
                        clean_script=convert_to_bool(config.get(SLURM_CLEAN_SCRIPTS, False))[0],
         | 
| 466 | 
            +
                        job_name=config.get(SLURM_JOB_NAME, ""),
         | 
| 460 467 | 
             
                    )
         | 
| 461 468 |  | 
| 462 469 |  | 
| 463 | 
            -
            def _cast_shift_to_str(shifts: Union[tuple, str, None]) -> str:
         | 
| 470 | 
            +
            def _cast_shift_to_str(shifts: Union[tuple, numpy.ndarray, str, None]) -> str:
         | 
| 464 471 | 
             
                if shifts is None:
         | 
| 465 472 | 
             
                    return ""
         | 
| 466 473 | 
             
                elif isinstance(shifts, ShiftAlgorithm):
         | 
| 467 474 | 
             
                    return shifts.value
         | 
| 468 475 | 
             
                elif isinstance(shifts, str):
         | 
| 469 476 | 
             
                    return shifts
         | 
| 470 | 
            -
                elif isinstance(shifts, (tuple, list)):
         | 
| 477 | 
            +
                elif isinstance(shifts, (tuple, list, numpy.ndarray)):
         | 
| 471 478 | 
             
                    return ";".join([str(value) for value in shifts])
         | 
| 479 | 
            +
                else:
         | 
| 480 | 
            +
                    raise TypeError(f"unexpected type: {type(shifts)}")
         | 
| 472 481 |  | 
| 473 482 |  | 
| 474 483 | 
             
            @dataclass
         | 
| @@ -541,12 +550,12 @@ class StitchingConfiguration: | |
| 541 550 | 
             
                        STITCHING_SECTION: {
         | 
| 542 551 | 
             
                            STITCHING_TYPE_FIELD: {
         | 
| 543 552 | 
             
                                "default": StitchingType.Z_PREPROC.value,
         | 
| 544 | 
            -
                                "help": f"stitching to be applied. Must be in {StitchingType | 
| 553 | 
            +
                                "help": f"stitching to be applied. Must be in {[st.value for st in StitchingType]}",
         | 
| 545 554 | 
             
                                "type": "required",
         | 
| 546 555 | 
             
                            },
         | 
| 547 556 | 
             
                            STITCHING_STRATEGY_FIELD: {
         | 
| 548 557 | 
             
                                "default": "cosinus weights",
         | 
| 549 | 
            -
                                "help": f"Policy to apply to compute the overlap area. Must be in {OverlapStitchingStrategy | 
| 558 | 
            +
                                "help": f"Policy to apply to compute the overlap area. Must be in {[ov.value for ov in OverlapStitchingStrategy]}.",
         | 
| 550 559 | 
             
                                "type": "required",
         | 
| 551 560 | 
             
                            },
         | 
| 552 561 | 
             
                            CROSS_CORRELATION_SLICE_FIELD: {
         | 
| @@ -626,7 +635,7 @@ class StitchingConfiguration: | |
| 626 635 | 
             
                            },
         | 
| 627 636 | 
             
                            ALIGNMENT_AXIS_2_FIELD: {
         | 
| 628 637 | 
             
                                "default": "center",
         | 
| 629 | 
            -
                                "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {AlignmentAxis2 | 
| 638 | 
            +
                                "help": f"In case frame have different frame widths how to align them (so along volume axis 2). Valid keys are {[aa.value for aa in AlignmentAxis2]}",
         | 
| 630 639 | 
             
                                "type": "advanced",
         | 
| 631 640 | 
             
                            },
         | 
| 632 641 | 
             
                            PAD_MODE_FIELD: {
         | 
| @@ -748,7 +757,7 @@ class StitchingConfiguration: | |
| 748 757 | 
             
                            AXIS_2_POS_PX: _cast_shift_to_str(self.axis_2_pos_px),
         | 
| 749 758 | 
             
                            AXIS_2_POS_MM: _cast_shift_to_str(self.axis_2_pos_mm),
         | 
| 750 759 | 
             
                            AXIS_2_PARAMS: _dict_to_str(self.axis_2_params or {}),
         | 
| 751 | 
            -
                            STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy | 
| 760 | 
            +
                            STITCHING_STRATEGY_FIELD: OverlapStitchingStrategy(self.stitching_strategy).value,
         | 
| 752 761 | 
             
                            FLIP_UD: self.flip_ud,
         | 
| 753 762 | 
             
                            FLIP_LR: self.flip_lr,
         | 
| 754 763 | 
             
                            RESCALE_FRAMES: self.rescale_frames,
         | 
| @@ -927,7 +936,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat | |
| 927 936 | 
             
                    if self.pixel_size is None:
         | 
| 928 937 | 
             
                        pixel_size_mm = ""
         | 
| 929 938 | 
             
                    else:
         | 
| 930 | 
            -
                        pixel_size_mm = self.pixel_size *  | 
| 939 | 
            +
                        pixel_size_mm = (self.pixel_size * _ureg.meter).to(_ureg.millimeter).magnitude
         | 
| 931 940 | 
             
                    return concatenate_dict(
         | 
| 932 941 | 
             
                        super().to_dict(),
         | 
| 933 942 | 
             
                        {
         | 
| @@ -991,10 +1000,10 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat | |
| 991 1000 | 
             
                    if pixel_size == "":
         | 
| 992 1001 | 
             
                        pixel_size = None
         | 
| 993 1002 | 
             
                    else:
         | 
| 994 | 
            -
                        pixel_size = float(pixel_size)  | 
| 1003 | 
            +
                        pixel_size = (float(pixel_size) * _ureg.millimeter).to_base_units().magnitude
         | 
| 995 1004 |  | 
| 996 1005 | 
             
                    return cls(
         | 
| 997 | 
            -
                        stitching_strategy=OverlapStitchingStrategy | 
| 1006 | 
            +
                        stitching_strategy=OverlapStitchingStrategy(
         | 
| 998 1007 | 
             
                            config[STITCHING_SECTION].get(
         | 
| 999 1008 | 
             
                                STITCHING_STRATEGY_FIELD,
         | 
| 1000 1009 | 
             
                                OverlapStitchingStrategy.COSINUS_WEIGHTS,
         | 
| @@ -1035,7 +1044,7 @@ class PreProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigurat | |
| 1035 1044 | 
             
                                config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
         | 
| 1036 1045 | 
             
                            )
         | 
| 1037 1046 | 
             
                        ),
         | 
| 1038 | 
            -
                        alignment_axis_2=AlignmentAxis2 | 
| 1047 | 
            +
                        alignment_axis_2=AlignmentAxis2(
         | 
| 1039 1048 | 
             
                            config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
         | 
| 1040 1049 | 
             
                        ),
         | 
| 1041 1050 | 
             
                        pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
         | 
| @@ -1156,11 +1165,11 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura | |
| 1156 1165 | 
             
                    if voxel_size == "":
         | 
| 1157 1166 | 
             
                        voxel_size = None
         | 
| 1158 1167 | 
             
                    else:
         | 
| 1159 | 
            -
                        voxel_size = float(voxel_size) *  | 
| 1168 | 
            +
                        voxel_size = (float(voxel_size) * _ureg.millimeter).to_base_units().magnitude
         | 
| 1160 1169 |  | 
| 1161 1170 | 
             
                    # on the next section the one with a default value qre the optional one
         | 
| 1162 1171 | 
             
                    return cls(
         | 
| 1163 | 
            -
                        stitching_strategy=OverlapStitchingStrategy | 
| 1172 | 
            +
                        stitching_strategy=OverlapStitchingStrategy(
         | 
| 1164 1173 | 
             
                            config[STITCHING_SECTION].get(
         | 
| 1165 1174 | 
             
                                STITCHING_STRATEGY_FIELD,
         | 
| 1166 1175 | 
             
                                OverlapStitchingStrategy.COSINUS_WEIGHTS,
         | 
| @@ -1191,10 +1200,10 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura | |
| 1191 1200 | 
             
                                config[STITCHING_SECTION].get(STITCHING_KERNELS_EXTRA_PARAMS, {}),
         | 
| 1192 1201 | 
             
                            )
         | 
| 1193 1202 | 
             
                        ),
         | 
| 1194 | 
            -
                        alignment_axis_1=AlignmentAxis1 | 
| 1203 | 
            +
                        alignment_axis_1=AlignmentAxis1(
         | 
| 1195 1204 | 
             
                            config[STITCHING_SECTION].get(ALIGNMENT_AXIS_1_FIELD, AlignmentAxis1.CENTER)
         | 
| 1196 1205 | 
             
                        ),
         | 
| 1197 | 
            -
                        alignment_axis_2=AlignmentAxis2 | 
| 1206 | 
            +
                        alignment_axis_2=AlignmentAxis2(
         | 
| 1198 1207 | 
             
                            config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER)
         | 
| 1199 1208 | 
             
                        ),
         | 
| 1200 1209 | 
             
                        pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"),
         | 
| @@ -1208,7 +1217,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura | |
| 1208 1217 | 
             
                    if self.voxel_size is None:
         | 
| 1209 1218 | 
             
                        voxel_size_mm = ""
         | 
| 1210 1219 | 
             
                    else:
         | 
| 1211 | 
            -
                        voxel_size_mm = numpy.array(self.voxel_size | 
| 1220 | 
            +
                        voxel_size_mm = numpy.array((self.voxel_size * _ureg.meter).to(_ureg.millimeter).magnitude)
         | 
| 1212 1221 |  | 
| 1213 1222 | 
             
                    return concatenate_dict(
         | 
| 1214 1223 | 
             
                        super().to_dict(),
         | 
| @@ -1243,7 +1252,7 @@ class PostProcessedSingleAxisStitchingConfiguration(SingleAxisStitchingConfigura | |
| 1243 1252 | 
             
                            STITCHING_SECTION: {
         | 
| 1244 1253 | 
             
                                ALIGNMENT_AXIS_1_FIELD: {
         | 
| 1245 1254 | 
             
                                    "default": "center",
         | 
| 1246 | 
            -
                                    "help": f"alignment to apply over axis 1 if needed. Valid values are {AlignmentAxis1 | 
| 1255 | 
            +
                                    "help": f"alignment to apply over axis 1 if needed. Valid values are {[aa for aa in AlignmentAxis1]}",
         | 
| 1247 1256 | 
             
                                    "type": "advanced",
         | 
| 1248 1257 | 
             
                                }
         | 
| 1249 1258 | 
             
                            },
         | 
| @@ -1274,7 +1283,7 @@ def dict_to_config_obj(config: dict): | |
| 1274 1283 | 
             
                if stitching_type is None:
         | 
| 1275 1284 | 
             
                    raise ValueError("Unable to find stitching type from config dict")
         | 
| 1276 1285 | 
             
                else:
         | 
| 1277 | 
            -
                    stitching_type = StitchingType | 
| 1286 | 
            +
                    stitching_type = StitchingType(stitching_type)
         | 
| 1278 1287 | 
             
                    if stitching_type is StitchingType.Z_POSTPROC:
         | 
| 1279 1288 | 
             
                        return PostProcessedZStitchingConfiguration.from_dict(config)
         | 
| 1280 1289 | 
             
                    elif stitching_type is StitchingType.Z_PREPROC:
         | 
| @@ -1297,7 +1306,7 @@ def get_default_stitching_config(stitching_type: Optional[Union[StitchingType, s | |
| 1297 1306 | 
             
                if stitching_type is None:
         | 
| 1298 1307 | 
             
                    return concatenate_dict(z_postproc_stitching_config, z_preproc_stitching_config)
         | 
| 1299 1308 |  | 
| 1300 | 
            -
                stitching_type = StitchingType | 
| 1309 | 
            +
                stitching_type = StitchingType(stitching_type)
         | 
| 1301 1310 | 
             
                if stitching_type is StitchingType.Z_POSTPROC:
         | 
| 1302 1311 | 
             
                    return z_postproc_stitching_config
         | 
| 1303 1312 | 
             
                elif stitching_type is StitchingType.Z_PREPROC:
         | 
    
        nabu/stitching/definitions.py
    CHANGED
    
    
| @@ -31,7 +31,7 @@ class FrameComposition: | |
| 31 31 | 
             
                        )
         | 
| 32 32 |  | 
| 33 33 | 
             
                def compose(self, output_frame: numpy.ndarray, input_frames: tuple):
         | 
| 34 | 
            -
                    if  | 
| 34 | 
            +
                    if output_frame.ndim not in (2, 3):
         | 
| 35 35 | 
             
                        raise TypeError(
         | 
| 36 36 | 
             
                            f"output_frame is expected to be 2D (gray scale) or 3D (RGB(A)) and not {output_frame.ndim}"
         | 
| 37 37 | 
             
                        )
         | 
| @@ -74,9 +74,10 @@ class FrameComposition: | |
| 74 74 | 
             
                    local_start_indices.extend(
         | 
| 75 75 | 
             
                        [ceil(key_line[1] + kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)]
         | 
| 76 76 | 
             
                    )
         | 
| 77 | 
            -
                    local_end_indices =  | 
| 78 | 
            -
                         | 
| 79 | 
            -
                     | 
| 77 | 
            +
                    local_end_indices = [
         | 
| 78 | 
            +
                        ceil(key_line[0] - kernel.overlap_size / 2) for (key_line, kernel) in zip(key_lines, overlap_kernels)
         | 
| 79 | 
            +
                    ]
         | 
| 80 | 
            +
             | 
| 80 81 | 
             
                    local_end_indices.append(frames[-1].shape[stitching_axis])
         | 
| 81 82 |  | 
| 82 83 | 
             
                    for (
         | 
| @@ -155,9 +156,6 @@ class FrameComposition: | |
| 155 156 | 
             
                        print(
         | 
| 156 157 | 
             
                            f"stitch_frame[{stitch_global_start}:{stitch_global_end}] = stitched_frame_{i_frame}[{stitch_local_start}:{stitch_local_end}]"
         | 
| 157 158 | 
             
                        )
         | 
| 158 | 
            -
                     | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
                        print(
         | 
| 162 | 
            -
                            f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]"
         | 
| 163 | 
            -
                        )
         | 
| 159 | 
            +
                    i_frame += 1
         | 
| 160 | 
            +
                    raw_local_start, raw_local_end, raw_global_start, raw_global_end = list(raw_composition.browse())[-1]
         | 
| 161 | 
            +
                    print(f"stitch_frame[{raw_global_start}:{raw_global_end}] = frame_{i_frame}[{raw_local_start}:{raw_local_end}]")
         | 
    
        nabu/stitching/overlap.py
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            import numpy
         | 
| 2 2 | 
             
            import logging
         | 
| 3 3 | 
             
            from typing import Optional, Union
         | 
| 4 | 
            -
            from  | 
| 4 | 
            +
            from enum import Enum
         | 
| 5 5 | 
             
            from nabu.misc import fourier_filters
         | 
| 6 6 | 
             
            from scipy.fft import rfftn as local_fftn
         | 
| 7 7 | 
             
            from scipy.fft import irfftn as local_ifftn
         | 
| @@ -10,7 +10,7 @@ from tomoscan.utils.geometry import BoundingBox1D | |
| 10 10 | 
             
            _logger = logging.getLogger(__name__)
         | 
| 11 11 |  | 
| 12 12 |  | 
| 13 | 
            -
            class OverlapStitchingStrategy( | 
| 13 | 
            +
            class OverlapStitchingStrategy(Enum):
         | 
| 14 14 | 
             
                MEAN = "mean"
         | 
| 15 15 | 
             
                COSINUS_WEIGHTS = "cosinus weights"
         | 
| 16 16 | 
             
                LINEAR_WEIGHTS = "linear weights"
         | 
| @@ -64,7 +64,7 @@ class ImageStichOverlapKernel(OverlapKernelBase): | |
| 64 64 | 
             
                            f"frame_width is expected to be a positive int, {frame_unstitched_axis_size} - not {frame_unstitched_axis_size} ({type(frame_unstitched_axis_size)})"
         | 
| 65 65 | 
             
                        )
         | 
| 66 66 |  | 
| 67 | 
            -
                    if not  | 
| 67 | 
            +
                    if stitching_axis not in (0, 1):
         | 
| 68 68 | 
             
                        raise ValueError(
         | 
| 69 69 | 
             
                            "stitching_axis is expected to be the axis along which stitching must be done. It should be '0' or '1'"
         | 
| 70 70 | 
             
                        )
         | 
| @@ -72,7 +72,7 @@ class ImageStichOverlapKernel(OverlapKernelBase): | |
| 72 72 | 
             
                    self._stitching_axis = stitching_axis
         | 
| 73 73 | 
             
                    self._overlap_size = abs(overlap_size)
         | 
| 74 74 | 
             
                    self._frame_unstitched_axis_size = frame_unstitched_axis_size
         | 
| 75 | 
            -
                    self._stitching_strategy = OverlapStitchingStrategy | 
| 75 | 
            +
                    self._stitching_strategy = OverlapStitchingStrategy(stitching_strategy)
         | 
| 76 76 | 
             
                    self._weights_img_1 = None
         | 
| 77 77 | 
             
                    self._weights_img_2 = None
         | 
| 78 78 | 
             
                    if extra_params is None:
         | 
| @@ -1,13 +1,13 @@ | |
| 1 | 
            +
            from enum import Enum
         | 
| 1 2 | 
             
            import numpy
         | 
| 2 | 
            -
            from silx.utils.enum import Enum as _Enum
         | 
| 3 3 |  | 
| 4 4 |  | 
| 5 | 
            -
            class SampleSide( | 
| 5 | 
            +
            class SampleSide(Enum):
         | 
| 6 6 | 
             
                LEFT = "left"
         | 
| 7 7 | 
             
                RIGHT = "right"
         | 
| 8 8 |  | 
| 9 9 |  | 
| 10 | 
            -
            class Method( | 
| 10 | 
            +
            class Method(Enum):
         | 
| 11 11 | 
             
                MEAN = "mean"
         | 
| 12 12 | 
             
                MEDIAN = "median"
         | 
| 13 13 |  | 
| @@ -28,8 +28,8 @@ def normalize_frame( | |
| 28 28 | 
             
                    raise TypeError(f"Frame is expected to be a 2D numpy array.")
         | 
| 29 29 | 
             
                if frame.ndim != 2:
         | 
| 30 30 | 
             
                    raise TypeError(f"Frame is expected to be a 2D numpy array. Get {frame.ndim}D")
         | 
| 31 | 
            -
                side = SampleSide | 
| 32 | 
            -
                method = Method | 
| 31 | 
            +
                side = SampleSide(side)
         | 
| 32 | 
            +
                method = Method(method)
         | 
| 33 33 |  | 
| 34 34 | 
             
                if frame.shape[1] < sample_width + margin_before_sample:
         | 
| 35 35 | 
             
                    raise ValueError(
         | 
    
        nabu/stitching/slurm_utils.py
    CHANGED
    
    | @@ -177,7 +177,7 @@ def split_slices(slices: Union[slice, tuple], n_parts: int): | |
| 177 177 | 
             
                    raise TypeError(f"slices type ({type(slices)}) is not handled. Must be a slice or an Iterable")
         | 
| 178 178 |  | 
| 179 179 |  | 
| 180 | 
            -
            def get_working_directory(obj: TomoObject) -> Optional[str]:
         | 
| 180 | 
            +
            def get_working_directory(obj: TomoObject) -> Optional[str]:  # noqa: PLR0911
         | 
| 181 181 | 
             
                """
         | 
| 182 182 | 
             
                return working directory for a specific TomoObject
         | 
| 183 183 | 
             
                """
         | 
| @@ -201,4 +201,4 @@ def get_working_directory(obj: TomoObject) -> Optional[str]: | |
| 201 201 | 
             
                    else:
         | 
| 202 202 | 
             
                        return os.path.abspath(os.path.dirname(obj.master_file))
         | 
| 203 203 | 
             
                else:
         | 
| 204 | 
            -
                    raise RuntimeError(f"obj type not handled ({type(obj)})")
         | 
| 204 | 
            +
                    raise RuntimeError(f"obj type not handled ({type(obj)})")  # noqa: TRY004
         | 
    
        nabu/stitching/stitcher/base.py
    CHANGED
    
    | @@ -21,6 +21,8 @@ def get_obj_constant_side_length(obj: Union[NXtomoScan, VolumeBase], axis: int) | |
| 21 21 | 
             
                        return obj.dim_1
         | 
| 22 22 | 
             
                    elif axis in (1, 2):
         | 
| 23 23 | 
             
                        return obj.dim_2
         | 
| 24 | 
            +
                    else:
         | 
| 25 | 
            +
                        raise ValueError(f"Axis ({axis}) not handled. Should be in (0, 1, 2)")
         | 
| 24 26 | 
             
                elif isinstance(obj, VolumeBase) and axis == 0:
         | 
| 25 27 | 
             
                    return obj.get_volume_shape()[-1]
         | 
| 26 28 | 
             
                else:
         | 
| @@ -96,7 +96,7 @@ class OutputVolumeContext(AbstractContextManager): | |
| 96 96 | 
             
                    if self._file_handler is not None:
         | 
| 97 97 | 
             
                        return self._file_handler.close()
         | 
| 98 98 | 
             
                    else:
         | 
| 99 | 
            -
                        self._volume.save_data()
         | 
| 99 | 
            +
                        self._volume.save_data()  # noqa: RET503
         | 
| 100 100 |  | 
| 101 101 |  | 
| 102 102 | 
             
            class OutputVolumeNoDDContext(OutputVolumeContext):
         | 
| @@ -2,6 +2,7 @@ import logging | |
| 2 2 | 
             
            import numpy
         | 
| 3 3 | 
             
            import os
         | 
| 4 4 | 
             
            import h5py
         | 
| 5 | 
            +
            import pint
         | 
| 5 6 | 
             
            from typing import Union
         | 
| 6 7 | 
             
            from nabu.stitching.config import PostProcessedSingleAxisStitchingConfiguration
         | 
| 7 8 | 
             
            from nabu.stitching.alignment import AlignmentAxis1
         | 
| @@ -13,11 +14,9 @@ from tomoscan.esrf import NXtomoScan | |
| 13 14 | 
             
            from tomoscan.series import Series
         | 
| 14 15 | 
             
            from tomoscan.volumebase import VolumeBase
         | 
| 15 16 | 
             
            from tomoscan.esrf.volume import HDF5Volume
         | 
| 16 | 
            -
            from  | 
| 17 | 
            +
            from collections.abc import Iterable
         | 
| 17 18 | 
             
            from contextlib import AbstractContextManager
         | 
| 18 | 
            -
            from pyunitsystem.metricsystem import MetricSystem
         | 
| 19 19 | 
             
            from nabu.stitching.config import (
         | 
| 20 | 
            -
                PostProcessedSingleAxisStitchingConfiguration,
         | 
| 21 20 | 
             
                KEY_IMG_REG_METHOD,
         | 
| 22 21 | 
             
            )
         | 
| 23 22 | 
             
            from nabu.stitching.utils.utils import find_volumes_relative_shifts
         | 
| @@ -26,6 +25,8 @@ from .single_axis import SingleAxisStitcher | |
| 26 25 |  | 
| 27 26 | 
             
            _logger = logging.getLogger(__name__)
         | 
| 28 27 |  | 
| 28 | 
            +
            _ureg = pint.get_application_registry()
         | 
| 29 | 
            +
             | 
| 29 30 |  | 
| 30 31 | 
             
            class FlippingValueError(ValueError):
         | 
| 31 32 | 
             
                pass
         | 
| @@ -267,7 +268,7 @@ class PostProcessingStitching(SingleAxisStitcher): | |
| 267 268 | 
             
                            axis_N_pos_px = []
         | 
| 268 269 | 
             
                            for volume, pos_in_mm in zip(self.series, pos_as_mm):
         | 
| 269 270 | 
             
                                voxel_size_m = self.configuration.voxel_size or volume.voxel_size
         | 
| 270 | 
            -
                                axis_N_pos_px.append((pos_in_mm  | 
| 271 | 
            +
                                axis_N_pos_px.append((pos_in_mm * _ureg.millimeter).to_base_units().magnitude / voxel_size_m[0])
         | 
| 271 272 | 
             
                            return axis_N_pos_px
         | 
| 272 273 | 
             
                        else:
         | 
| 273 274 | 
             
                            # deduce from motor position and pixel size
         | 
| @@ -426,7 +427,7 @@ class PostProcessingStitching(SingleAxisStitcher): | |
| 426 427 |  | 
| 427 428 | 
             
                    bunch_size = 50
         | 
| 428 429 | 
             
                    # how many frame to we stitch between two read from disk / save to disk
         | 
| 429 | 
            -
                    with self.dumper.OutputDatasetContext(**output_dataset_args):
         | 
| 430 | 
            +
                    with self.dumper.OutputDatasetContext(**output_dataset_args):  # noqa: SIM117
         | 
| 430 431 | 
             
                        # note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array
         | 
| 431 432 | 
             
                        with _RawDatasetsContext(
         | 
| 432 433 | 
             
                            self._input_volumes,
         | 
| @@ -528,7 +529,8 @@ class _RawDatasetsContext(AbstractContextManager): | |
| 528 529 | 
             
                            else:
         | 
| 529 530 | 
             
                                data = volume.load_data(store=False)
         | 
| 530 531 | 
             
                                if data is None:
         | 
| 531 | 
            -
                                     | 
| 532 | 
            +
                                    # TODO
         | 
| 533 | 
            +
                                    raise ValueError(f"No data found for volume {volume.get_identifier()}")  # noqa: TRY301
         | 
| 532 534 | 
             
                            if axis_1_need_padding:
         | 
| 533 535 | 
             
                                data = self.add_padding(data=data, axis_1_dim=axis_1_dim, alignment=self.alignment_axis_1)
         | 
| 534 536 | 
             
                            datasets.append(data)
         | 
| @@ -536,7 +538,7 @@ class _RawDatasetsContext(AbstractContextManager): | |
| 536 538 | 
             
                        # if some errors happen during loading HDF5
         | 
| 537 539 | 
             
                        for file_handled in self.__file_handlers:
         | 
| 538 540 | 
             
                            file_handled.close()
         | 
| 539 | 
            -
                        raise e
         | 
| 541 | 
            +
                        raise e  # noqa: TRY201
         | 
| 540 542 |  | 
| 541 543 | 
             
                    return datasets
         | 
| 542 544 |  | 
| @@ -544,11 +546,11 @@ class _RawDatasetsContext(AbstractContextManager): | |
| 544 546 | 
             
                    success = True
         | 
| 545 547 | 
             
                    for file_handler in self.__file_handlers:
         | 
| 546 548 | 
             
                        success = success and file_handler.close()
         | 
| 547 | 
            -
                    if exc_type is None:
         | 
| 549 | 
            +
                    if exc_type is None:  # noqa: RET503
         | 
| 548 550 | 
             
                        return success
         | 
| 549 551 |  | 
| 550 552 | 
             
                def add_padding(self, data: Union[h5py.Dataset, numpy.ndarray], axis_1_dim, alignment: AlignmentAxis1):
         | 
| 551 | 
            -
                    alignment = AlignmentAxis1 | 
| 553 | 
            +
                    alignment = AlignmentAxis1(alignment)
         | 
| 552 554 | 
             
                    if alignment is AlignmentAxis1.BACK:
         | 
| 553 555 | 
             
                        axis_1_pad_width = (axis_1_dim - data.shape[1], 0)
         | 
| 554 556 | 
             
                    elif alignment is AlignmentAxis1.CENTER:
         |