nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +36 -5
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.14.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
    
        nabu/io/cast_volume.py
    CHANGED
    
    | @@ -1,10 +1,12 @@ | |
| 1 1 | 
             
            import os
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            from  | 
| 4 | 
            -
            from  | 
| 5 | 
            -
             | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            from glob import glob
         | 
| 4 | 
            +
            from shutil import rmtree
         | 
| 5 | 
            +
            import numpy
         | 
| 6 | 
            +
            from silx.io.utils import get_data
         | 
| 7 | 
            +
            from silx.io.url import DataUrl
         | 
| 6 8 | 
             
            from tomoscan.volumebase import VolumeBase
         | 
| 7 | 
            -
            from tomoscan. | 
| 9 | 
            +
            from tomoscan.esrf.volume.singleframebase import VolumeSingleFrameBase
         | 
| 8 10 | 
             
            from tomoscan.esrf.volume import (
         | 
| 9 11 | 
             
                EDFVolume,
         | 
| 10 12 | 
             
                HDF5Volume,
         | 
| @@ -13,17 +15,15 @@ from tomoscan.esrf.volume import ( | |
| 13 15 | 
             
                TIFFVolume,
         | 
| 14 16 | 
             
            )
         | 
| 15 17 | 
             
            from tomoscan.io import HDF5File
         | 
| 16 | 
            -
            from  | 
| 17 | 
            -
            from  | 
| 18 | 
            -
            import  | 
| 19 | 
            -
            from  | 
| 20 | 
            -
            from typing import Optional
         | 
| 21 | 
            -
            import logging
         | 
| 18 | 
            +
            from ..utils import first_generator_item
         | 
| 19 | 
            +
            from ..misc.utils import rescale_data
         | 
| 20 | 
            +
            from ..pipeline.params import files_formats
         | 
| 21 | 
            +
            from .reader import get_hdf5_file_all_virtual_sources, list_hdf5_entries
         | 
| 22 22 |  | 
| 23 23 | 
             
            _logger = logging.getLogger(__name__)
         | 
| 24 24 |  | 
| 25 25 |  | 
| 26 | 
            -
            __all__ = [" | 
| 26 | 
            +
            __all__ = ["cast_volume", "get_default_output_volume"]
         | 
| 27 27 |  | 
| 28 28 | 
             
            _DEFAULT_OUTPUT_DIR = "vol_cast"
         | 
| 29 29 |  | 
| @@ -45,7 +45,7 @@ def get_default_output_volume( | |
| 45 45 | 
             
                if not isinstance(input_volume, VolumeBase):
         | 
| 46 46 | 
             
                    raise TypeError(f"input_volume is expected to be an instance of {VolumeBase}")
         | 
| 47 47 | 
             
                valid_file_formats = set(files_formats.values())
         | 
| 48 | 
            -
                if not  | 
| 48 | 
            +
                if output_type not in valid_file_formats:
         | 
| 49 49 | 
             
                    raise ValueError(f"output_type is not a valid value ({output_type}). Valid values are {valid_file_formats}")
         | 
| 50 50 |  | 
| 51 51 | 
             
                if isinstance(input_volume, (EDFVolume, TIFFVolume, JP2KVolume)):
         | 
| @@ -134,11 +134,12 @@ def cast_volume( | |
| 134 134 | 
             
                output_data_type: numpy.dtype,
         | 
| 135 135 | 
             
                data_min=None,
         | 
| 136 136 | 
             
                data_max=None,
         | 
| 137 | 
            -
                scan | 
| 137 | 
            +
                scan=None,
         | 
| 138 138 | 
             
                rescale_min_percentile=RESCALE_MIN_PERCENTILE,
         | 
| 139 139 | 
             
                rescale_max_percentile=RESCALE_MAX_PERCENTILE,
         | 
| 140 140 | 
             
                save=True,
         | 
| 141 141 | 
             
                store=False,
         | 
| 142 | 
            +
                remove_input_volume: bool = False,
         | 
| 142 143 | 
             
            ) -> VolumeBase:
         | 
| 143 144 | 
             
                """
         | 
| 144 145 | 
             
                cast givent volume to output_volume of 'output_data_type' type
         | 
| @@ -169,6 +170,7 @@ def cast_volume( | |
| 169 170 | 
             
                if not isinstance(output_volume, VolumeBase):
         | 
| 170 171 | 
             
                    raise TypeError(f"output_volume is expected to be a {VolumeBase}. {type(output_volume)} provided")
         | 
| 171 172 |  | 
| 173 | 
            +
                # ruff: noqa: SIM105, S110
         | 
| 172 174 | 
             
                try:
         | 
| 173 175 | 
             
                    output_data_type = numpy.dtype(
         | 
| 174 176 | 
             
                        output_data_type
         | 
| @@ -247,6 +249,14 @@ def cast_volume( | |
| 247 249 | 
             
                except (OSError, KeyError):
         | 
| 248 250 | 
             
                    # if no metadata provided and or saved in disk or if some key are missing
         | 
| 249 251 | 
             
                    pass
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                if save and output_volume.metadata is not None:
         | 
| 254 | 
            +
                    output_volume.save_metadata()
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                if remove_input_volume:
         | 
| 257 | 
            +
                    _logger.info(f"Removing {input_volume.data_url.file_path()}")
         | 
| 258 | 
            +
                    remove_volume(input_volume, check=True)
         | 
| 259 | 
            +
             | 
| 250 260 | 
             
                return output_volume
         | 
| 251 261 |  | 
| 252 262 |  | 
| @@ -283,7 +293,7 @@ def clamp_and_rescale_data( | |
| 283 293 | 
             
                return rescaled_data
         | 
| 284 294 |  | 
| 285 295 |  | 
| 286 | 
            -
            def find_histogram(volume: VolumeBase, scan | 
| 296 | 
            +
            def find_histogram(volume: VolumeBase, scan=None):
         | 
| 287 297 | 
             
                """
         | 
| 288 298 | 
             
                Look for histogram of the provided url. If found one return the DataUrl of the nabu histogram
         | 
| 289 299 | 
             
                """
         | 
| @@ -301,7 +311,7 @@ def find_histogram(volume: VolumeBase, scan: Optional[TomoScanBase] = None) -> O | |
| 301 311 | 
             
                                ]
         | 
| 302 312 | 
             
                            )
         | 
| 303 313 | 
             
                        else:
         | 
| 304 | 
            -
                            data_path = " | 
| 314 | 
            +
                            data_path = f"{volume.url.data_path()}/histogram/results/data"
         | 
| 305 315 | 
             
                    else:
         | 
| 306 316 | 
             
                        # TODO: FIXME: in some case (if the users provides the full data_url and if the 'DATA_DATASET_NAME' is not used we
         | 
| 307 317 | 
             
                        # will endup with an invalid data_path. Hope this case will not happen. Anyway this is a case that we can't handle.)
         | 
| @@ -330,7 +340,7 @@ def find_histogram(volume: VolumeBase, scan: Optional[TomoScanBase] = None) -> O | |
| 330 340 | 
             
                        data_path = getattr(scan, "entry/histogram/results/data", "entry/histogram/results/data")
         | 
| 331 341 | 
             
                    else:
         | 
| 332 342 |  | 
| 333 | 
            -
                        def get_file_entries(file_path: str) | 
| 343 | 
            +
                        def get_file_entries(file_path: str):
         | 
| 334 344 | 
             
                            if os.path.exists(file_path):
         | 
| 335 345 | 
             
                                with HDF5File(file_path, mode="r") as h5s:
         | 
| 336 346 | 
             
                                    return tuple(h5s.keys())
         | 
| @@ -359,7 +369,7 @@ def find_histogram(volume: VolumeBase, scan: Optional[TomoScanBase] = None) -> O | |
| 359 369 | 
             
                    return None
         | 
| 360 370 |  | 
| 361 371 | 
             
                with HDF5File(histogram_file, mode="r") as h5f:
         | 
| 362 | 
            -
                    if not  | 
| 372 | 
            +
                    if data_path not in h5f:
         | 
| 363 373 | 
             
                        _logger.info(f"{data_path} in {histogram_file} not found")
         | 
| 364 374 | 
             
                        return None
         | 
| 365 375 | 
             
                    else:
         | 
| @@ -408,3 +418,83 @@ def _min_max_from_histo(url: DataUrl, rescale_min_percentile: int, rescale_max_p | |
| 408 418 | 
             
                    return _get_hst_saturations(
         | 
| 409 419 | 
             
                        hist, bins, numpy.float32(rescale_min_percentile), numpy.float32(rescale_max_percentile)
         | 
| 410 420 | 
             
                    )
         | 
| 421 | 
            +
             | 
| 422 | 
            +
             | 
| 423 | 
            +
            def _remove_volume_singleframe(volume, check=True):
         | 
| 424 | 
            +
                volume_directory = volume.data_url.file_path()
         | 
| 425 | 
            +
                if check:
         | 
| 426 | 
            +
                    volume_files = set(volume.browse_data_files())
         | 
| 427 | 
            +
                    files_names_pattern = os.path.join(volume_directory, "*." + volume.data_extension)
         | 
| 428 | 
            +
                    files_on_disk = set(glob(files_names_pattern))
         | 
| 429 | 
            +
                    # Don't check strict equality here, as some files on disk might be already removed.
         | 
| 430 | 
            +
                    # i.e, there should be no more files on disk than expected files in the volume
         | 
| 431 | 
            +
                    if not (files_on_disk.issubset(volume_files)):
         | 
| 432 | 
            +
                        raise RuntimeError(f"Unexpected files present in {volume_directory}: {files_on_disk - volume_files}")
         | 
| 433 | 
            +
                    # TODO also check for metadata file(s) ?
         | 
| 434 | 
            +
                rmtree(volume_directory)
         | 
| 435 | 
            +
             | 
| 436 | 
            +
             | 
| 437 | 
            +
            def _remove_volume_multiframe(volume, check=True):
         | 
| 438 | 
            +
                file_path = volume.data_url.file_path()
         | 
| 439 | 
            +
                if check:
         | 
| 440 | 
            +
                    if not (os.path.isfile(file_path)):
         | 
| 441 | 
            +
                        raise RuntimeError(f"Expected a file: {file_path}")
         | 
| 442 | 
            +
                os.remove(file_path)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
             | 
| 445 | 
            +
            def _remove_volume_hdf5(volume, check=True):
         | 
| 446 | 
            +
                file_path = volume.data_url.file_path()
         | 
| 447 | 
            +
                entry = volume.data_url.data_path().lstrip("/").split("/")[0]
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                # Nabu HDF5 reconstructions have a folder alongside the HDF5 file, with the same prefix
         | 
| 450 | 
            +
                # For example the HDF5 file "/path/to/rec.hdf5" has an associated directory "/path/to/rec"
         | 
| 451 | 
            +
                associated_dir, _ = os.path.splitext(os.path.basename(file_path))
         | 
| 452 | 
            +
                associated_dir_abs = os.path.join(os.path.dirname(file_path), associated_dir)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                with HDF5File(file_path, "r") as f:
         | 
| 455 | 
            +
                    fdesc = f[entry]
         | 
| 456 | 
            +
                    virtual_sources = get_hdf5_file_all_virtual_sources(fdesc, return_only_filenames=True)
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                # TODO check if this is legitimate. Nabu reconstruction will only do one VS (for entry/reconstruction/results/data).
         | 
| 459 | 
            +
                # Bliss/Lima do have multiple VS (flats/darks/projs), but we generally don't want to remove raw data ?
         | 
| 460 | 
            +
                if len(virtual_sources) > 1:
         | 
| 461 | 
            +
                    raise ValueError("Found more than one virtual source - this looks weird. Interrupting.")
         | 
| 462 | 
            +
                #
         | 
| 463 | 
            +
                if len(virtual_sources) > 0:
         | 
| 464 | 
            +
                    h5path, virtual_source_files_paths = first_generator_item(virtual_sources[0].items())
         | 
| 465 | 
            +
                    if len(virtual_source_files_paths) == 1:
         | 
| 466 | 
            +
                        target_dir = os.path.dirname(virtual_source_files_paths[0])
         | 
| 467 | 
            +
                    else:
         | 
| 468 | 
            +
                        target_dir = os.path.commonpath(virtual_source_files_paths)
         | 
| 469 | 
            +
                    target_dir_abs = os.path.join(os.path.dirname(file_path), target_dir)
         | 
| 470 | 
            +
                    if check and (target_dir_abs != associated_dir_abs):
         | 
| 471 | 
            +
                        raise ValueError(
         | 
| 472 | 
            +
                            f"The virtual sources in {file_path}:{h5path} reference the directory {target_dir}, but expected was {associated_dir}"
         | 
| 473 | 
            +
                        )
         | 
| 474 | 
            +
                    if os.path.isdir(target_dir_abs):
         | 
| 475 | 
            +
                        rmtree(associated_dir_abs)
         | 
| 476 | 
            +
                os.remove(file_path)
         | 
| 477 | 
            +
             | 
| 478 | 
            +
             | 
| 479 | 
            +
            def remove_volume(volume, check=True):
         | 
| 480 | 
            +
                """
         | 
| 481 | 
            +
                Remove files belonging to a volume, claim disk space.
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                Parameters
         | 
| 484 | 
            +
                ----------
         | 
| 485 | 
            +
                volume: tomoscan.esrf.volume
         | 
| 486 | 
            +
                    Volume object
         | 
| 487 | 
            +
                check: bool, optional
         | 
| 488 | 
            +
                    Whether to check if the files that would be removed do not have extra other files ; interrupt the operation if so.
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                """
         | 
| 491 | 
            +
                if isinstance(volume, (EDFVolume, JP2KVolume, TIFFVolume)):
         | 
| 492 | 
            +
                    _remove_volume_singleframe(volume, check=check)
         | 
| 493 | 
            +
                elif isinstance(volume, MultiTIFFVolume):
         | 
| 494 | 
            +
                    _remove_volume_multiframe(volume, check=check)
         | 
| 495 | 
            +
                elif isinstance(volume, HDF5Volume):
         | 
| 496 | 
            +
                    if len(list_hdf5_entries(volume.file_path)) > 1:
         | 
| 497 | 
            +
                        raise NotImplementedError("Removing a HDF5 volume with more than one entry is not supported")
         | 
| 498 | 
            +
                    _remove_volume_hdf5(volume, check=check)
         | 
| 499 | 
            +
                else:
         | 
| 500 | 
            +
                    raise TypeError("Unknown type of volume")
         | 
    
        nabu/io/detector_distortion.py
    CHANGED
    
    | @@ -3,7 +3,6 @@ from scipy import sparse | |
| 3 3 |  | 
| 4 4 |  | 
| 5 5 | 
             
            class DetectorDistortionBase:
         | 
| 6 | 
            -
                """ """
         | 
| 7 6 |  | 
| 8 7 | 
             
                def __init__(self, detector_full_shape_vh=(0, 0)):
         | 
| 9 8 | 
             
                    """This is the basis class.
         | 
| @@ -110,7 +109,7 @@ class DetectorDistortionBase: | |
| 110 109 | 
             
                        The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the 
         | 
| 111 110 | 
             
                        full horizontal size which is {self.detector_full_shape_vh[1]}
         | 
| 112 111 | 
             
                        """
         | 
| 113 | 
            -
                        raise ValueError()
         | 
| 112 | 
            +
                        raise ValueError(message)
         | 
| 114 113 |  | 
| 115 114 | 
             
                    x_start, x_end = 0, self.detector_full_shape_vh[1]
         | 
| 116 115 |  | 
| @@ -165,7 +164,8 @@ class DetectorDistortionMapsXZ(DetectorDistortionBase): | |
| 165 164 | 
             
                    Pixel (i,j) of the corrected image is obtained by interpolating the raw data at position
         | 
| 166 165 | 
             
                    ( map_z(i,j), map_x(i,j) ).
         | 
| 167 166 |  | 
| 168 | 
            -
                    Parameters | 
| 167 | 
            +
                    Parameters
         | 
| 168 | 
            +
                    ----------
         | 
| 169 169 | 
             
                        map_x : float 2D array
         | 
| 170 170 | 
             
                        map_z : float 2D array
         | 
| 171 171 | 
             
                    """
         | 
| @@ -173,7 +173,6 @@ class DetectorDistortionMapsXZ(DetectorDistortionBase): | |
| 173 173 | 
             
                    self._build_full_transformation(map_x, map_z)
         | 
| 174 174 |  | 
| 175 175 | 
             
                def _build_full_transformation(self, map_x, map_z):
         | 
| 176 | 
            -
                    """ """
         | 
| 177 176 | 
             
                    detector_full_shape_vh = map_x.shape
         | 
| 178 177 | 
             
                    if detector_full_shape_vh != map_z.shape:
         | 
| 179 178 | 
             
                        message = f"""  map_x and map_z must have the same shape
         | 
| @@ -185,7 +184,7 @@ class DetectorDistortionMapsXZ(DetectorDistortionBase): | |
| 185 184 |  | 
| 186 185 | 
             
                    # padding
         | 
| 187 186 | 
             
                    sz, sx = detector_full_shape_vh
         | 
| 188 | 
            -
                    total_detector_npixs = sz * sx
         | 
| 187 | 
            +
                    # total_detector_npixs = sz * sx
         | 
| 189 188 | 
             
                    xs = np.clip(np.array(coordinates[1].flat), [[0]], [[sx - 1]])
         | 
| 190 189 | 
             
                    zs = np.clip(np.array(coordinates[0].flat), [[0]], [[sz - 1]])
         | 
| 191 190 |  | 
| @@ -253,7 +252,7 @@ class DetectorDistortionMapsXZ(DetectorDistortionBase): | |
| 253 252 | 
             
                        The sub_region contained (x_start, x_end)={(x_start, x_end)} which would slice the 
         | 
| 254 253 | 
             
                        full horizontal size which is {self.detector_full_shape_vh[1]}
         | 
| 255 254 | 
             
                        """
         | 
| 256 | 
            -
                        raise ValueError()
         | 
| 255 | 
            +
                        raise ValueError(message)
         | 
| 257 256 |  | 
| 258 257 | 
             
                    x_start, x_end = 0, self.detector_full_shape_vh[1]
         | 
| 259 258 |  | 
    
        nabu/io/reader.py
    CHANGED
    
    | @@ -4,6 +4,7 @@ from math import ceil | |
| 4 4 | 
             
            from multiprocessing.pool import ThreadPool
         | 
| 5 5 | 
             
            from posixpath import sep as posix_sep, join as posix_join
         | 
| 6 6 | 
             
            import numpy as np
         | 
| 7 | 
            +
            from h5py import Dataset
         | 
| 7 8 | 
             
            from silx.io import get_data
         | 
| 8 9 | 
             
            from silx.io.dictdump import h5todict
         | 
| 9 10 | 
             
            from tomoscan.io import HDF5File
         | 
| @@ -555,7 +556,11 @@ class VolReaderBase: | |
| 555 556 | 
             
                        slice_x = None
         | 
| 556 557 | 
             
                    if isinstance(sub_region, (tuple, list)):
         | 
| 557 558 | 
             
                        slice_angle, slice_z, slice_x = sub_region
         | 
| 558 | 
            -
                    self.sub_region = ( | 
| 559 | 
            +
                    self.sub_region = (
         | 
| 560 | 
            +
                        slice_angle if slice_angle is not None else slice(None, None),
         | 
| 561 | 
            +
                        slice_z if slice_z is not None else slice(None, None),
         | 
| 562 | 
            +
                        slice_x if slice_x is not None else slice(None, None),
         | 
| 563 | 
            +
                    )
         | 
| 559 564 |  | 
| 560 565 | 
             
                def _set_processing_function(self, processing_func, processing_func_args, processing_func_kwargs):
         | 
| 561 566 | 
             
                    self.processing_func = processing_func
         | 
| @@ -619,7 +624,7 @@ class NXTomoReader(VolReaderBase): | |
| 619 624 | 
             
                        If provided, this function first argument must be the source buffer (3D array: stack of raw images),
         | 
| 620 625 | 
             
                        and the second argument must be the destination buffer (3D array, stack of output images). It can be None.
         | 
| 621 626 |  | 
| 622 | 
            -
                    Other  | 
| 627 | 
            +
                    Other Parameters
         | 
| 623 628 | 
             
                    ----------------
         | 
| 624 629 | 
             
                    The other parameters are passed to "processing_func" if this parameter is not None.
         | 
| 625 630 |  | 
| @@ -681,9 +686,13 @@ class NXTomoReader(VolReaderBase): | |
| 681 686 | 
             
                        # In this case, we can use h5py read_direct() to avoid extraneous memory consumption
         | 
| 682 687 | 
             
                        image_key_slice = self._image_key_slices[0]
         | 
| 683 688 | 
             
                        # merge image key selection and user selection (if any)
         | 
| 684 | 
            -
                         | 
| 685 | 
            -
             | 
| 686 | 
            -
             | 
| 689 | 
            +
                        angles_slice = self.sub_region[0]
         | 
| 690 | 
            +
                        if isinstance(angles_slice, slice) or angles_slice is None:
         | 
| 691 | 
            +
                            angles_slice = merge_slices(image_key_slice, self.sub_region[0] or slice(None, None))
         | 
| 692 | 
            +
                        else:  # assuming numpy array
         | 
| 693 | 
            +
                            # TODO more elegant
         | 
| 694 | 
            +
                            angles_slice = np.arange(self.data_shape_total[0], dtype=np.uint64)[image_key_slice][angles_slice]
         | 
| 695 | 
            +
                        self._source_selection = (angles_slice,) + self.sub_region[1:]
         | 
| 687 696 | 
             
                    else:
         | 
| 688 697 | 
             
                        user_selection_dim0 = self.sub_region[0]
         | 
| 689 698 | 
             
                        indices = np.arange(self.data_shape_total[0])
         | 
| @@ -793,7 +802,7 @@ class NXDarksFlats: | |
| 793 802 |  | 
| 794 803 | 
             
                def get_reduced_current(self, h5_path="{entry}/control/data", method="median"):
         | 
| 795 804 | 
             
                    current = self.get_raw_current(h5_path=h5_path)
         | 
| 796 | 
            -
                    return {k: self._reduce_func[method](current[k]) for k in current | 
| 805 | 
            +
                    return {k: self._reduce_func[method](current[k]) for k in current}
         | 
| 797 806 |  | 
| 798 807 |  | 
| 799 808 | 
             
            class EDFStackReader(VolReaderBase):
         | 
| @@ -987,6 +996,12 @@ def get_entry_from_h5_path(h5_path): | |
| 987 996 | 
             
                return v[0] or v[1]
         | 
| 988 997 |  | 
| 989 998 |  | 
| 999 | 
            +
            def list_hdf5_entries(fname):
         | 
| 1000 | 
            +
                with HDF5File(fname, "r") as f:
         | 
| 1001 | 
            +
                    entries = list(f.keys())
         | 
| 1002 | 
            +
                return entries
         | 
| 1003 | 
            +
             | 
| 1004 | 
            +
             | 
| 990 1005 | 
             
            def check_virtual_sources_exist(fname, data_path):
         | 
| 991 1006 | 
             
                with HDF5File(fname, "r") as f:
         | 
| 992 1007 | 
             
                    if data_path not in f:
         | 
| @@ -1006,6 +1021,30 @@ def check_virtual_sources_exist(fname, data_path): | |
| 1006 1021 | 
             
                return True
         | 
| 1007 1022 |  | 
| 1008 1023 |  | 
| 1024 | 
            +
            def get_hdf5_file_all_virtual_sources(file_path_or_obj, return_only_filenames=False):
         | 
| 1025 | 
            +
                result = []
         | 
| 1026 | 
            +
             | 
| 1027 | 
            +
                def collect_vsources(name, obj):
         | 
| 1028 | 
            +
                    if isinstance(obj, Dataset) and obj.is_virtual:
         | 
| 1029 | 
            +
                        vs = obj.virtual_sources()
         | 
| 1030 | 
            +
                        if return_only_filenames:
         | 
| 1031 | 
            +
                            vs = [vs_.file_name for vs_ in vs]
         | 
| 1032 | 
            +
                        result.append({name: vs})
         | 
| 1033 | 
            +
             | 
| 1034 | 
            +
                _self_opened_file = False
         | 
| 1035 | 
            +
                if isinstance(file_path_or_obj, str):
         | 
| 1036 | 
            +
                    fdesc = HDF5File(file_path_or_obj, "r")
         | 
| 1037 | 
            +
                    _self_opened_file = True
         | 
| 1038 | 
            +
                else:
         | 
| 1039 | 
            +
                    fdesc = file_path_or_obj
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                fdesc.visititems(collect_vsources)
         | 
| 1042 | 
            +
             | 
| 1043 | 
            +
                if _self_opened_file:
         | 
| 1044 | 
            +
                    fdesc.close()
         | 
| 1045 | 
            +
                return result
         | 
| 1046 | 
            +
             | 
| 1047 | 
            +
             | 
| 1009 1048 | 
             
            def import_h5_to_dict(h5file, h5path, asarray=False):
         | 
| 1010 1049 | 
             
                """
         | 
| 1011 1050 | 
             
                Wrapper on top of silx.io.dictdump.dicttoh5 replacing "None" with None
         | 
    
        nabu/io/reader_helical.py
    CHANGED
    
    | @@ -1,4 +1,5 @@ | |
| 1 | 
            -
             | 
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from .reader import ChunkReader, get_compacted_dataslices
         | 
| 2 3 |  | 
| 3 4 |  | 
| 4 5 | 
             
            class ChunkReaderHelical(ChunkReader):
         | 
| @@ -101,9 +102,9 @@ class ChunkReaderHelical(ChunkReader): | |
| 101 102 | 
             
                        self._load_multi(sub_total_prange_slice)
         | 
| 102 103 | 
             
                    else:
         | 
| 103 104 | 
             
                        if self.dataset_subsampling > 1:
         | 
| 104 | 
            -
                             | 
| 105 | 
            -
                                 | 
| 106 | 
            -
                            ) | 
| 105 | 
            +
                            raise ValueError(
         | 
| 106 | 
            +
                                "in helical pipeline, load file _load_single has not yet been adapted to angular subsampling"
         | 
| 107 | 
            +
                            )
         | 
| 107 108 | 
             
                        self._load_single(sub_total_prange_slice)
         | 
| 108 109 | 
             
                    self._loaded = True
         | 
| 109 110 |  | 
| @@ -12,7 +12,7 @@ from tomoscan.esrf.volume import ( | |
| 12 12 | 
             
                MultiTIFFVolume,
         | 
| 13 13 | 
             
                TIFFVolume,
         | 
| 14 14 | 
             
            )
         | 
| 15 | 
            -
            from  | 
| 15 | 
            +
            from tomoscan.esrf.volume.jp2kvolume import has_glymur as __have_jp2k__
         | 
| 16 16 | 
             
            from tomoscan.esrf.scan.edfscan import EDFTomoScan
         | 
| 17 17 | 
             
            from tomoscan.esrf.scan.nxtomoscan import NXtomoScan
         | 
| 18 18 | 
             
            import pytest
         | 
| @@ -105,7 +105,7 @@ def test_find_histogram_hdf5_volume(tmp_path): | |
| 105 105 | 
             
                    scheme="silx",
         | 
| 106 106 | 
             
                )
         | 
| 107 107 |  | 
| 108 | 
            -
                assert find_histogram(volume=HDF5Volume(file_path=h5_file, data_path="entry"))  | 
| 108 | 
            +
                assert find_histogram(volume=HDF5Volume(file_path=h5_file, data_path="entry")) is None
         | 
| 109 109 |  | 
| 110 110 |  | 
| 111 111 | 
             
            def test_find_histogram_single_frame_volume(tmp_path):
         | 
    
        nabu/io/tests/test_readers.py
    CHANGED
    
    | @@ -1,41 +1,23 @@ | |
| 1 1 | 
             
            from math import ceil
         | 
| 2 2 | 
             
            from tempfile import TemporaryDirectory
         | 
| 3 | 
            -
            from dataclasses import dataclass
         | 
| 4 3 | 
             
            from tomoscan.io import HDF5File
         | 
| 5 4 | 
             
            import pytest
         | 
| 6 5 | 
             
            import numpy as np
         | 
| 7 6 | 
             
            from nxtomo.application.nxtomo import ImageKey
         | 
| 8 7 | 
             
            from tomoscan.esrf import EDFVolume
         | 
| 9 8 | 
             
            from nabu.pipeline.reader import NXTomoReaderBinning
         | 
| 10 | 
            -
            from nabu.testutils import utilstest, __do_long_tests__, get_file
         | 
| 9 | 
            +
            from nabu.testutils import utilstest, __do_long_tests__, get_file, get_dummy_nxtomo_info
         | 
| 11 10 | 
             
            from nabu.utils import indices_to_slices, merge_slices
         | 
| 12 11 | 
             
            from nabu.io.reader import EDFStackReader, NXTomoReader, NXDarksFlats
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
            @dataclass
         | 
| 16 | 
            -
            class SimpleNXTomoDescription:
         | 
| 17 | 
            -
                n_darks: int = 0
         | 
| 18 | 
            -
                n_flats1: int = 0
         | 
| 19 | 
            -
                n_projs: int = 0
         | 
| 20 | 
            -
                n_flats2: int = 0
         | 
| 21 | 
            -
                n_align: int = 0
         | 
| 22 | 
            -
                frame_shape: tuple = None
         | 
| 23 | 
            -
                dtype: np.dtype = np.uint16
         | 
| 12 | 
            +
            from nabu.resources.dataset_analyzer import analyze_dataset
         | 
| 24 13 |  | 
| 25 14 |  | 
| 26 15 | 
             
            @pytest.fixture(scope="class")
         | 
| 27 16 | 
             
            def bootstrap_nx_reader(request):
         | 
| 28 17 | 
             
                cls = request.cls
         | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
                cls.nx_data_path = "entry/instrument/detector/data"
         | 
| 32 | 
            -
                cls.data_desc = SimpleNXTomoDescription(
         | 
| 33 | 
            -
                    n_darks=10, n_flats1=11, n_projs=100, n_flats2=11, n_align=12, frame_shape=(11, 10), dtype=np.uint16
         | 
| 18 | 
            +
                cls.nx_fname, cls.data_desc, cls.image_key, cls.projs_vals, cls.darks_vals, cls.flats1_vals, cls.flats2_vals = (
         | 
| 19 | 
            +
                    get_dummy_nxtomo_info()
         | 
| 34 20 | 
             
                )
         | 
| 35 | 
            -
                cls.projs_vals = np.arange(cls.data_desc.n_projs) + cls.data_desc.n_flats1 + cls.data_desc.n_darks
         | 
| 36 | 
            -
                cls.darks_vals = np.arange(cls.data_desc.n_darks)
         | 
| 37 | 
            -
                cls.flats1_vals = np.arange(cls.data_desc.n_darks, cls.data_desc.n_darks + cls.data_desc.n_flats1)
         | 
| 38 | 
            -
                cls.flats2_vals = np.arange(cls.data_desc.n_darks, cls.data_desc.n_darks + cls.data_desc.n_flats2)
         | 
| 39 21 |  | 
| 40 22 | 
             
                yield
         | 
| 41 23 | 
             
                # teardown
         | 
| @@ -45,15 +27,15 @@ def bootstrap_nx_reader(request): | |
| 45 27 | 
             
            class TestNXReader:
         | 
| 46 28 | 
             
                def test_incorrect_path(self):
         | 
| 47 29 | 
             
                    with pytest.raises(FileNotFoundError):
         | 
| 48 | 
            -
                        reader = NXTomoReader("/invalid/path" | 
| 30 | 
            +
                        reader = NXTomoReader("/invalid/path")
         | 
| 49 31 | 
             
                    with pytest.raises(KeyError):
         | 
| 50 | 
            -
                        reader = NXTomoReader(self.nx_fname, "/bad/data/path")
         | 
| 32 | 
            +
                        reader = NXTomoReader(self.nx_fname, "/bad/data/path")  # noqa: F841
         | 
| 51 33 |  | 
| 52 34 | 
             
                def test_simple_reads(self):
         | 
| 53 35 | 
             
                    """
         | 
| 54 36 | 
             
                    Test NXTomoReader with simplest settings
         | 
| 55 37 | 
             
                    """
         | 
| 56 | 
            -
                    reader1 = NXTomoReader(self.nx_fname | 
| 38 | 
            +
                    reader1 = NXTomoReader(self.nx_fname)
         | 
| 57 39 | 
             
                    data1 = reader1.load_data()
         | 
| 58 40 | 
             
                    assert data1.shape == (self.data_desc.n_projs,) + self.data_desc.frame_shape
         | 
| 59 41 | 
             
                    assert np.allclose(data1[:, 0, 0], self.projs_vals)
         | 
| @@ -62,15 +44,15 @@ class TestNXReader: | |
| 62 44 | 
             
                    """
         | 
| 63 45 | 
             
                    Test the data selection using "image_key".
         | 
| 64 46 | 
             
                    """
         | 
| 65 | 
            -
                    reader_projs = NXTomoReader(self.nx_fname,  | 
| 47 | 
            +
                    reader_projs = NXTomoReader(self.nx_fname, image_key=ImageKey.PROJECTION.value)
         | 
| 66 48 | 
             
                    data = reader_projs.load_data()
         | 
| 67 49 | 
             
                    assert np.allclose(data[:, 0, 0], self.projs_vals)
         | 
| 68 50 |  | 
| 69 | 
            -
                    reader_darks = NXTomoReader(self.nx_fname,  | 
| 51 | 
            +
                    reader_darks = NXTomoReader(self.nx_fname, image_key=ImageKey.DARK_FIELD.value)
         | 
| 70 52 | 
             
                    data_darks = reader_darks.load_data()
         | 
| 71 53 | 
             
                    assert np.allclose(data_darks[:, 0, 0], self.darks_vals)
         | 
| 72 54 |  | 
| 73 | 
            -
                    reader_flats = NXTomoReader(self.nx_fname,  | 
| 55 | 
            +
                    reader_flats = NXTomoReader(self.nx_fname, image_key=ImageKey.FLAT_FIELD.value)
         | 
| 74 56 | 
             
                    data_flats = reader_flats.load_data()
         | 
| 75 57 | 
             
                    assert np.allclose(data_flats[:, 0, 0], np.concatenate([self.flats1_vals, self.flats2_vals]))
         | 
| 76 58 |  | 
| @@ -83,10 +65,10 @@ class TestNXReader: | |
| 83 65 | 
             
                    def _check_correct_shape_succeeds(shape, sub_region, test_description=""):
         | 
| 84 66 | 
             
                        err_msg = "Something wrong with the following test:" + test_description
         | 
| 85 67 | 
             
                        data_buffer = np.zeros(shape, dtype="f")
         | 
| 86 | 
            -
                        reader1 = NXTomoReader(self.nx_fname,  | 
| 68 | 
            +
                        reader1 = NXTomoReader(self.nx_fname, sub_region=sub_region)
         | 
| 87 69 | 
             
                        data1 = reader1.load_data(output=data_buffer)
         | 
| 88 70 | 
             
                        assert id(data1) == id(data_buffer), err_msg
         | 
| 89 | 
            -
                        reader2 = NXTomoReader(self.nx_fname,  | 
| 71 | 
            +
                        reader2 = NXTomoReader(self.nx_fname, sub_region=sub_region)
         | 
| 90 72 | 
             
                        data2 = reader2.load_data()
         | 
| 91 73 | 
             
                        assert np.allclose(data1, data2), err_msg
         | 
| 92 74 |  | 
| @@ -120,11 +102,10 @@ class TestNXReader: | |
| 120 102 |  | 
| 121 103 | 
             
                    for test_case in test_cases:
         | 
| 122 104 | 
             
                        for wrong_shape in test_case["wrong_shapes"]:
         | 
| 123 | 
            -
                            with pytest.raises(ValueError):
         | 
| 105 | 
            +
                            with pytest.raises(ValueError):  # noqa: PT012
         | 
| 124 106 | 
             
                                data_buffer_wrong_shape = np.zeros(wrong_shape, dtype="f")
         | 
| 125 107 | 
             
                                reader = NXTomoReader(
         | 
| 126 108 | 
             
                                    self.nx_fname,
         | 
| 127 | 
            -
                                    self.nx_data_path,
         | 
| 128 109 | 
             
                                    sub_region=test_case["sub_region"],
         | 
| 129 110 | 
             
                                )
         | 
| 130 111 | 
             
                                reader.load_data(output=data_buffer_wrong_shape)
         | 
| @@ -148,7 +129,7 @@ class TestNXReader: | |
| 148 129 | 
             
                    ]
         | 
| 149 130 |  | 
| 150 131 | 
             
                    for test_case in test_cases:
         | 
| 151 | 
            -
                        reader = NXTomoReader(self.nx_fname,  | 
| 132 | 
            +
                        reader = NXTomoReader(self.nx_fname, sub_region=test_case["sub_region"])
         | 
| 152 133 | 
             
                        data = reader.load_data()
         | 
| 153 134 | 
             
                        assert data.shape == test_case["expected_shape"]
         | 
| 154 135 | 
             
                        assert np.allclose(data[:, 0, 0], test_case["expected_values"])
         | 
| @@ -156,7 +137,7 @@ class TestNXReader: | |
| 156 137 | 
             
                def test_reading_with_binning_(self):
         | 
| 157 138 | 
             
                    from nabu.pipeline.reader import NXTomoReaderBinning
         | 
| 158 139 |  | 
| 159 | 
            -
                    reader_with_binning = NXTomoReaderBinning((2, 2), self.nx_fname | 
| 140 | 
            +
                    reader_with_binning = NXTomoReaderBinning((2, 2), self.nx_fname)
         | 
| 160 141 | 
             
                    data = reader_with_binning.load_data()
         | 
| 161 142 | 
             
                    assert data.shape == (self.data_desc.n_projs,) + tuple(n // 2 for n in self.data_desc.frame_shape)
         | 
| 162 143 |  | 
| @@ -171,13 +152,12 @@ class TestNXReader: | |
| 171 152 |  | 
| 172 153 | 
             
                    distortion_corrector = DetectorDistortionBase(detector_full_shape_vh=data_desc.frame_shape)
         | 
| 173 154 | 
             
                    distortion_corrector.set_sub_region_transformation(target_sub_region=sub_region_xy)
         | 
| 174 | 
            -
                    adapted_subregion = distortion_corrector.get_adapted_subregion(sub_region_xy)
         | 
| 155 | 
            +
                    # adapted_subregion = distortion_corrector.get_adapted_subregion(sub_region_xy)
         | 
| 175 156 | 
             
                    sub_region = (slice(None, None), slice(*sub_region_xy[2:]), slice(*sub_region_xy[:2]))
         | 
| 176 157 |  | 
| 177 158 | 
             
                    reader_distortion_corr = NXTomoReaderDistortionCorrection(
         | 
| 178 159 | 
             
                        distortion_corrector,
         | 
| 179 160 | 
             
                        self.nx_fname,
         | 
| 180 | 
            -
                        self.nx_data_path,
         | 
| 181 161 | 
             
                        sub_region=sub_region,
         | 
| 182 162 | 
             
                    )
         | 
| 183 163 |  | 
| @@ -220,7 +200,7 @@ class TestNXReader: | |
| 220 200 | 
             
                    for test_case in test_cases:
         | 
| 221 201 | 
             
                        binning = test_case.get("binning", None)
         | 
| 222 202 | 
             
                        reader_cls = NXTomoReader
         | 
| 223 | 
            -
                        init_args = [self.nx_fname | 
| 203 | 
            +
                        init_args = [self.nx_fname]
         | 
| 224 204 | 
             
                        init_kwargs = {"sub_region": test_case["sub_region"]}
         | 
| 225 205 | 
             
                        if binning is not None:
         | 
| 226 206 | 
             
                            reader_cls = NXTomoReaderBinning
         | 
| @@ -231,6 +211,29 @@ class TestNXReader: | |
| 231 211 | 
             
                        assert data.shape == test_case["expected_shape"], err_msg
         | 
| 232 212 | 
             
                        assert np.allclose(data[:, 0, 0], test_case["expected_values"]), err_msg
         | 
| 233 213 |  | 
| 214 | 
            +
                def test_load_exclude_projections(self):
         | 
| 215 | 
            +
                    n_z, n_x = self.data_desc.frame_shape
         | 
| 216 | 
            +
                    # projs_idx = np.where(self.image_key == 0)[0]
         | 
| 217 | 
            +
                    projs_idx = np.arange(self.data_desc.n_projs, dtype=np.int64)
         | 
| 218 | 
            +
                    excluded_projs_idx_1 = projs_idx[10:20]
         | 
| 219 | 
            +
                    excluded_projs_idx_2 = np.concatenate([projs_idx[10:14], projs_idx[50:57]])
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    set_to_nparray = lambda x: np.array(sorted(list(x)))
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    projs_idx1 = set_to_nparray(set(projs_idx) - set(excluded_projs_idx_1))
         | 
| 224 | 
            +
                    projs_idx2 = set_to_nparray(set(projs_idx) - set(excluded_projs_idx_2))
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    sub_regions_to_test = (
         | 
| 227 | 
            +
                        (projs_idx1, None, None),
         | 
| 228 | 
            +
                        (projs_idx1, slice(0, n_z // 2), None),
         | 
| 229 | 
            +
                        (projs_idx2, None, None),
         | 
| 230 | 
            +
                        (projs_idx2, slice(3, n_z // 2), None),
         | 
| 231 | 
            +
                    )
         | 
| 232 | 
            +
                    for sub_region in sub_regions_to_test:
         | 
| 233 | 
            +
                        reader = NXTomoReader(self.nx_fname, sub_region=sub_region)
         | 
| 234 | 
            +
                        data = reader.load_data()
         | 
| 235 | 
            +
                        assert np.allclose(data[:, 0, 0], self.projs_vals[sub_region[0]])
         | 
| 236 | 
            +
             | 
| 234 237 |  | 
| 235 238 | 
             
            @pytest.fixture(scope="class")
         | 
| 236 239 | 
             
            def bootstrap_edf_reader(request):
         | 
| @@ -328,7 +331,7 @@ class TestEDFReader: | |
| 328 331 |  | 
| 329 332 | 
             
                    distortion_corrector = DetectorDistortionBase(detector_full_shape_vh=self.frame_shape)
         | 
| 330 333 | 
             
                    distortion_corrector.set_sub_region_transformation(target_sub_region=sub_region_xy)
         | 
| 331 | 
            -
                    adapted_subregion = distortion_corrector.get_adapted_subregion(sub_region_xy)
         | 
| 334 | 
            +
                    # adapted_subregion = distortion_corrector.get_adapted_subregion(sub_region_xy)
         | 
| 332 335 | 
             
                    sub_region = (slice(None, None), slice(*sub_region_xy[2:]), slice(*sub_region_xy[:2]))
         | 
| 333 336 |  | 
| 334 337 | 
             
                    reader_distortion_corr = EDFStackReaderDistortionCorrection(
         |