nabu 2024.2.13__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +14 -0
- nabu/reconstruction/fbp_base.py +40 -8
- nabu/reconstruction/fbp_opencl.py +8 -0
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.13.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.13.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
    
        nabu/pipeline/datadump.py
    CHANGED
    
    | @@ -126,7 +126,7 @@ class DataDumpManager: | |
| 126 126 | 
             
                    read_opts = self.processing_options["read_chunk"]
         | 
| 127 127 | 
             
                    if read_opts.get("process_file", None) is None:
         | 
| 128 128 | 
             
                        return None
         | 
| 129 | 
            -
                    dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"]
         | 
| 129 | 
            +
                    dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"]  # noqa: F841
         | 
| 130 130 | 
             
                    relative_start_z = self.z_min - dump_start_z
         | 
| 131 131 | 
             
                    relative_end_z = relative_start_z + self.delta_z
         | 
| 132 132 | 
             
                    # When using binning, every step after "read" results in smaller-sized data.
         | 
| @@ -139,7 +139,7 @@ class DataDumpManager: | |
| 139 139 |  | 
| 140 140 | 
             
                def _check_resume_from_step(self):
         | 
| 141 141 | 
             
                    read_opts = self.processing_options["read_chunk"]
         | 
| 142 | 
            -
                    expected_radios_shape = get_hdf5_dataset_shape(
         | 
| 142 | 
            +
                    expected_radios_shape = get_hdf5_dataset_shape(  # noqa: F841
         | 
| 143 143 | 
             
                        read_opts["process_file"],
         | 
| 144 144 | 
             
                        read_opts["process_h5_path"],
         | 
| 145 145 | 
             
                        sub_region=self.get_read_dump_subregion(),
         | 
| @@ -151,7 +151,7 @@ class DataDumpManager: | |
| 151 151 | 
             
                        return
         | 
| 152 152 | 
             
                    writer = self.data_dump[step_name]
         | 
| 153 153 | 
             
                    self.logger.info("Dumping data to %s" % writer.fname)
         | 
| 154 | 
            -
                    if __has_pycuda__:
         | 
| 154 | 
            +
                    if __has_pycuda__:  # noqa: SIM102
         | 
| 155 155 | 
             
                        if isinstance(data, garray.GPUArray):
         | 
| 156 156 | 
             
                            data = data.get()
         | 
| 157 157 |  | 
    
        nabu/pipeline/estimators.py
    CHANGED
    
    | @@ -9,6 +9,7 @@ import scipy.fft  #  pylint: disable=E0611 | |
| 9 9 | 
             
            from silx.io import get_data
         | 
| 10 10 | 
             
            import math
         | 
| 11 11 | 
             
            from scipy import ndimage as nd
         | 
| 12 | 
            +
             | 
| 12 13 | 
             
            from ..preproc.flatfield import FlatField
         | 
| 13 14 | 
             
            from ..estimation.cor import (
         | 
| 14 15 | 
             
                CenterOfRotation,
         | 
| @@ -17,8 +18,10 @@ from ..estimation.cor import ( | |
| 17 18 | 
             
                CenterOfRotationGrowingWindow,
         | 
| 18 19 | 
             
                CenterOfRotationOctaveAccurate,
         | 
| 19 20 | 
             
            )
         | 
| 21 | 
            +
            from .. import version as nabu_version
         | 
| 20 22 | 
             
            from ..estimation.cor_sino import SinoCorInterface, CenterOfRotationFourierAngles, CenterOfRotationVo
         | 
| 21 23 | 
             
            from ..estimation.tilt import CameraTilt
         | 
| 24 | 
            +
            from ..estimation.motion import MotionEstimation
         | 
| 22 25 | 
             
            from ..estimation.utils import is_fullturn_scan
         | 
| 23 26 | 
             
            from ..resources.logger import LoggerOrPrint
         | 
| 24 27 | 
             
            from ..resources.utils import extract_parameters
         | 
| @@ -298,7 +301,7 @@ class SinoCORFinder(CORFinderBase): | |
| 298 301 | 
             
                    """
         | 
| 299 302 | 
             
                    Initialize a SinoCORFinder object.
         | 
| 300 303 |  | 
| 301 | 
            -
                    Other  | 
| 304 | 
            +
                    Other Parameters
         | 
| 302 305 | 
             
                    ----------------
         | 
| 303 306 | 
             
                    The following keys can be set in cor_options.
         | 
| 304 307 |  | 
| @@ -341,7 +344,7 @@ class SinoCORFinder(CORFinderBase): | |
| 341 344 | 
             
                            self.angles = self.dataset_info.rotation_angles[::subsampling]
         | 
| 342 345 | 
             
                            self.subsampling = subsampling
         | 
| 343 346 | 
             
                    else:  # Angular step
         | 
| 344 | 
            -
                        raise NotImplementedError | 
| 347 | 
            +
                        raise NotImplementedError
         | 
| 345 348 |  | 
| 346 349 | 
             
                def _load_raw_sinogram(self):
         | 
| 347 350 | 
             
                    if self.slice_idx is None:
         | 
| @@ -441,7 +444,7 @@ class CompositeCORFinder(CORFinderBase): | |
| 441 444 |  | 
| 442 445 | 
             
                search_methods = {
         | 
| 443 446 | 
             
                    "composite-coarse-to-fine": {
         | 
| 444 | 
            -
                        "class": CenterOfRotation,  #  | 
| 447 | 
            +
                        "class": CenterOfRotation,  # Not used. Everything is done in the find_cor() func.
         | 
| 445 448 | 
             
                    }
         | 
| 446 449 | 
             
                }
         | 
| 447 450 | 
             
                _default_cor_options = {"low_pass": 0.4, "high_pass": 10, "side": "near", "near_pos": 0, "near_width": 40}
         | 
| @@ -542,7 +545,7 @@ class CompositeCORFinder(CORFinderBase): | |
| 542 545 | 
             
                    else:
         | 
| 543 546 | 
             
                        my_flats = None
         | 
| 544 547 |  | 
| 545 | 
            -
                    if my_flats is not None and len(list(my_flats.keys())):
         | 
| 548 | 
            +
                    if my_flats is not None and len(list(my_flats.keys())) > 0:
         | 
| 546 549 | 
             
                        self.use_flat = True
         | 
| 547 550 | 
             
                        self.flatfield = FlatField(
         | 
| 548 551 | 
             
                            (len(self.absolute_indices), self.sy, self.sx),
         | 
| @@ -750,15 +753,15 @@ class CompositeCORFinder(CORFinderBase): | |
| 750 753 | 
             
                            my_blurred_radio1 = np.fliplr(blurred_radio1)
         | 
| 751 754 | 
             
                            my_blurred_radio2 = np.fliplr(blurred_radio2)
         | 
| 752 755 |  | 
| 753 | 
            -
                        common_left = np.fliplr(my_radio1[:, ovsd_sx - my_z :])[:, : - | 
| 756 | 
            +
                        common_left = np.fliplr(my_radio1[:, ovsd_sx - my_z :])[:, : -math.ceil(self.ovs * self.high_pass * 2)]
         | 
| 754 757 | 
             
                        # adopt a 'safe' margin considering high_pass value (possibly float)
         | 
| 755 | 
            -
                        common_right = my_radio2[:, ovsd_sx - my_z : - | 
| 758 | 
            +
                        common_right = my_radio2[:, ovsd_sx - my_z : -math.ceil(self.ovs * self.high_pass * 2)]
         | 
| 756 759 |  | 
| 757 760 | 
             
                        common_blurred_left = np.fliplr(my_blurred_radio1[:, ovsd_sx - my_z :])[
         | 
| 758 | 
            -
                            :, : - | 
| 761 | 
            +
                            :, : -math.ceil(self.ovs * self.high_pass * 2)
         | 
| 759 762 | 
             
                        ]
         | 
| 760 763 | 
             
                        # adopt a 'safe' margin considering high_pass value (possibly float)
         | 
| 761 | 
            -
                        common_blurred_right = my_blurred_radio2[:, ovsd_sx - my_z : - | 
| 764 | 
            +
                        common_blurred_right = my_blurred_radio2[:, ovsd_sx - my_z : -math.ceil(self.ovs * self.high_pass * 2)]
         | 
| 762 765 |  | 
| 763 766 | 
             
                        if common_right.size == 0:
         | 
| 764 767 | 
             
                            continue
         | 
| @@ -786,7 +789,7 @@ class CompositeCORFinder(CORFinderBase): | |
| 786 789 | 
             
                    elif self.norm_order == 1:
         | 
| 787 790 | 
             
                        return self.error_metric_l1(common_right, common_left, common_blurred_right, common_blurred_left)
         | 
| 788 791 | 
             
                    else:
         | 
| 789 | 
            -
                         | 
| 792 | 
            +
                        raise RuntimeError("this cannot happen")
         | 
| 790 793 |  | 
| 791 794 | 
             
                def error_metric_l2(self, common_right, common_left):
         | 
| 792 795 | 
             
                    common = common_right - common_left
         | 
| @@ -818,10 +821,10 @@ def oversample(radio, ovs_s): | |
| 818 821 | 
             
                # Pre-initialisation: The original data falls exactly on the following strided positions in the new data array.
         | 
| 819 822 | 
             
                result[:: ovs_s[0], :: ovs_s[1]] = radio
         | 
| 820 823 |  | 
| 821 | 
            -
                for k in range( | 
| 824 | 
            +
                for k in range(ovs_s[0]):
         | 
| 822 825 | 
             
                    # interpolation coefficient for axis 0
         | 
| 823 826 | 
             
                    g = k / ovs_s[0]
         | 
| 824 | 
            -
                    for i in range( | 
| 827 | 
            +
                    for i in range(ovs_s[1]):
         | 
| 825 828 | 
             
                        if i == 0 and k == 0:
         | 
| 826 829 | 
             
                            # this case subset was already exactly matched from before the present double loop,
         | 
| 827 830 | 
             
                            # in the pre-initialisation line.
         | 
| @@ -989,3 +992,260 @@ class DetectorTiltEstimator: | |
| 989 992 |  | 
| 990 993 | 
             
            # alias
         | 
| 991 994 | 
             
            TiltFinder = DetectorTiltEstimator
         | 
| 995 | 
            +
             | 
| 996 | 
            +
             | 
| 997 | 
            +
            def estimate_translations(dataset_info, do_flatfield=True): ...
         | 
| 998 | 
            +
             | 
| 999 | 
            +
             | 
| 1000 | 
            +
            class TranslationsEstimator:
         | 
| 1001 | 
            +
             | 
| 1002 | 
            +
                _default_extra_options = {
         | 
| 1003 | 
            +
                    "window_size": 300,
         | 
| 1004 | 
            +
                }
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                def __init__(
         | 
| 1007 | 
            +
                    self,
         | 
| 1008 | 
            +
                    dataset_info,
         | 
| 1009 | 
            +
                    do_flatfield=True,
         | 
| 1010 | 
            +
                    rot_center=None,
         | 
| 1011 | 
            +
                    halftomo_side=None,
         | 
| 1012 | 
            +
                    angular_subsampling=10,
         | 
| 1013 | 
            +
                    deg_xy=2,
         | 
| 1014 | 
            +
                    deg_z=2,
         | 
| 1015 | 
            +
                    shifts_estimator="phase_cross_correlation",
         | 
| 1016 | 
            +
                    radios_filter=None,
         | 
| 1017 | 
            +
                    extra_options=None,
         | 
| 1018 | 
            +
                ):
         | 
| 1019 | 
            +
                    self._configure_extra_options(extra_options)
         | 
| 1020 | 
            +
                    self.logger = LoggerOrPrint(dataset_info.logger)
         | 
| 1021 | 
            +
                    self.dataset_info = dataset_info
         | 
| 1022 | 
            +
                    self.angular_subsampling = angular_subsampling
         | 
| 1023 | 
            +
                    self.do_360 = self.dataset_info.is_360
         | 
| 1024 | 
            +
                    self.do_flatfield = do_flatfield
         | 
| 1025 | 
            +
                    self.radios_filter = radios_filter
         | 
| 1026 | 
            +
                    self.radios = None
         | 
| 1027 | 
            +
                    self._deg_xy = deg_xy
         | 
| 1028 | 
            +
                    self._deg_z = deg_z
         | 
| 1029 | 
            +
                    self._shifts_estimator = shifts_estimator
         | 
| 1030 | 
            +
                    self._shifts_estimator_kwargs = {}
         | 
| 1031 | 
            +
                    self._cor = rot_center
         | 
| 1032 | 
            +
                    self._configure_halftomo(halftomo_side)
         | 
| 1033 | 
            +
                    self._estimate_cor = self._cor is None
         | 
| 1034 | 
            +
                    self.sample_shifts_xy = None
         | 
| 1035 | 
            +
                    self.sample_shifts_z = None
         | 
| 1036 | 
            +
             | 
| 1037 | 
            +
                def _configure_extra_options(self, extra_options):
         | 
| 1038 | 
            +
                    self.extra_options = self._default_extra_options.copy()
         | 
| 1039 | 
            +
                    self.extra_options.update(extra_options or {})
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                def _configure_halftomo(self, halftomo_side):
         | 
| 1042 | 
            +
                    if halftomo_side is False:
         | 
| 1043 | 
            +
                        # Force disable halftomo
         | 
| 1044 | 
            +
                        self.halftomo_side = False
         | 
| 1045 | 
            +
                        return
         | 
| 1046 | 
            +
                    self._start_x = None
         | 
| 1047 | 
            +
                    self._end_x = None
         | 
| 1048 | 
            +
                    if (halftomo_side is not None) and not (self.do_360):
         | 
| 1049 | 
            +
                        raise ValueError(
         | 
| 1050 | 
            +
                            "Expected 360° dataset for half-tomography, but this dataset does not look like a 360° dataset"
         | 
| 1051 | 
            +
                        )
         | 
| 1052 | 
            +
                    if halftomo_side is None:
         | 
| 1053 | 
            +
                        if self.dataset_info.is_halftomo:
         | 
| 1054 | 
            +
                            halftomo_side = "right"
         | 
| 1055 | 
            +
                        else:
         | 
| 1056 | 
            +
                            self.halftomo_side = False
         | 
| 1057 | 
            +
                            return
         | 
| 1058 | 
            +
                    self.halftomo_side = halftomo_side
         | 
| 1059 | 
            +
                    window_size = self.extra_options["window_size"]
         | 
| 1060 | 
            +
                    if self._cor is not None:
         | 
| 1061 | 
            +
                        # In this case we look for shifts around the CoR
         | 
| 1062 | 
            +
                        self._start_x = int(self._cor - window_size / 2)
         | 
| 1063 | 
            +
                        self._end_x = int(self._cor + window_size / 2)
         | 
| 1064 | 
            +
                    elif halftomo_side == "right":
         | 
| 1065 | 
            +
                        self._start_x = -window_size
         | 
| 1066 | 
            +
                        self._end_x = None
         | 
| 1067 | 
            +
                    elif halftomo_side == "left":
         | 
| 1068 | 
            +
                        self._start_x = 0
         | 
| 1069 | 
            +
                        self._end_x = window_size
         | 
| 1070 | 
            +
                    elif is_scalar(halftomo_side):
         | 
| 1071 | 
            +
                        # Expect approximate location of CoR, relative to left-most column
         | 
| 1072 | 
            +
                        self._start_x = int(halftomo_side - window_size / 2)
         | 
| 1073 | 
            +
                        self._end_x = int(halftomo_side + window_size / 2)
         | 
| 1074 | 
            +
                    else:
         | 
| 1075 | 
            +
                        raise ValueError(
         | 
| 1076 | 
            +
                            f"Expected 'halftomo_side' to be either 'left', 'right', or an integer (got {halftomo_side})"
         | 
| 1077 | 
            +
                        )
         | 
| 1078 | 
            +
                    self.logger.debug(f"[MotionEstimation] Half-tomo looking at [{self._start_x}:{self._end_x}]")
         | 
| 1079 | 
            +
                    # For half-tomo, skimage.registration.phase_cross_correlation might look a bit too far away
         | 
| 1080 | 
            +
                    if (
         | 
| 1081 | 
            +
                        self._shifts_estimator == "phase_cross_correlation"
         | 
| 1082 | 
            +
                        and self._shifts_estimator_kwargs.get("overlap_ratio", 0.3) >= 0.3
         | 
| 1083 | 
            +
                    ):
         | 
| 1084 | 
            +
                        self._shifts_estimator_kwargs.update({"overlap_ratio": 0.2})
         | 
| 1085 | 
            +
                    #
         | 
| 1086 | 
            +
             | 
| 1087 | 
            +
                def _load_data(self):
         | 
| 1088 | 
            +
                    self.logger.debug("[MotionEstimation] reading data")
         | 
| 1089 | 
            +
                    if self.do_360:
         | 
| 1090 | 
            +
                        """
         | 
| 1091 | 
            +
                        In this case we compare pair of opposite projections.
         | 
| 1092 | 
            +
                        If rotation angles are arbitrary, we should do something like
         | 
| 1093 | 
            +
                          for angle in dataset_info.rotation_angles:
         | 
| 1094 | 
            +
                              img, angle_deg, idx = dataset_info.get_image_at_angle(
         | 
| 1095 | 
            +
                                  np.degrees(angle)+180, return_angle_and_index=True
         | 
| 1096 | 
            +
                              )
         | 
| 1097 | 
            +
                        Most of the time (always ?), the dataset was acquired with a circular trajectory,
         | 
| 1098 | 
            +
                        so we can use angles:
         | 
| 1099 | 
            +
                            dataset_info.rotation_angles[::self.angular_subsampling]
         | 
| 1100 | 
            +
                        which amounts to reading one radio out of "angular_subsampling"
         | 
| 1101 | 
            +
                        """
         | 
| 1102 | 
            +
             | 
| 1103 | 
            +
                        # TODO account for more general rotation angles. The following will only work for circular trajectory and ordered angles
         | 
| 1104 | 
            +
                        self._reader = self.dataset_info.get_reader(
         | 
| 1105 | 
            +
                            sub_region=(slice(None, None, self.angular_subsampling), slice(None), slice(None))
         | 
| 1106 | 
            +
                        )
         | 
| 1107 | 
            +
                        self.radios = self._reader.load_data()
         | 
| 1108 | 
            +
                        self.angles = self.dataset_info.rotation_angles[:: self.angular_subsampling]
         | 
| 1109 | 
            +
                        self._radios_idx = self._reader.get_frames_indices()
         | 
| 1110 | 
            +
                        self.logger.debug("[MotionEstimation] This is a 360° scan, will use pairs of opposite projections")
         | 
| 1111 | 
            +
                    else:
         | 
| 1112 | 
            +
                        """
         | 
| 1113 | 
            +
                        In this case we use the "return projections", i.e special projections acquired at several angles
         | 
| 1114 | 
            +
                        (eg. [180, 90, 0]) before ending the scan
         | 
| 1115 | 
            +
                        """
         | 
| 1116 | 
            +
                        return_projs, return_angles_deg, return_idx = self.dataset_info.get_alignment_projections()
         | 
| 1117 | 
            +
                        self._angles_return = np.radians(return_angles_deg)
         | 
| 1118 | 
            +
                        self._radios_return = return_projs
         | 
| 1119 | 
            +
                        self._radios_idx_return = return_idx
         | 
| 1120 | 
            +
             | 
| 1121 | 
            +
                        projs = []
         | 
| 1122 | 
            +
                        angles_rad = []
         | 
| 1123 | 
            +
                        projs_idx = []
         | 
| 1124 | 
            +
                        for angle_deg in return_angles_deg:
         | 
| 1125 | 
            +
                            proj, rot_angle_deg, proj_idx = self.dataset_info.get_image_at_angle(
         | 
| 1126 | 
            +
                                angle_deg, image_type="projection", return_angle_and_index=True
         | 
| 1127 | 
            +
                            )
         | 
| 1128 | 
            +
                            projs.append(proj)
         | 
| 1129 | 
            +
                            angles_rad.append(np.radians(rot_angle_deg))
         | 
| 1130 | 
            +
                            projs_idx.append(proj_idx)
         | 
| 1131 | 
            +
                        self._radios_outwards = np.array(projs)
         | 
| 1132 | 
            +
                        self._angles_outward = np.array(angles_rad)
         | 
| 1133 | 
            +
                        self._radios_idx_outwards = np.array(projs_idx)
         | 
| 1134 | 
            +
                        self.logger.debug("[MotionEstimation] This is a 180° scan, will use 'return projections'")
         | 
| 1135 | 
            +
             | 
| 1136 | 
            +
                def _apply_flatfield(self):
         | 
| 1137 | 
            +
                    if not (self.do_flatfield):
         | 
| 1138 | 
            +
                        return
         | 
| 1139 | 
            +
                    self.logger.debug("[MotionEstimation] flatfield")
         | 
| 1140 | 
            +
                    if self.do_360:
         | 
| 1141 | 
            +
                        self._flatfield = FlatField(
         | 
| 1142 | 
            +
                            self.radios.shape,
         | 
| 1143 | 
            +
                            flats=self.dataset_info.flats,
         | 
| 1144 | 
            +
                            darks=self.dataset_info.darks,
         | 
| 1145 | 
            +
                            radios_indices=self._radios_idx,
         | 
| 1146 | 
            +
                        )
         | 
| 1147 | 
            +
                        self._flatfield.normalize_radios(self.radios)
         | 
| 1148 | 
            +
                    else:
         | 
| 1149 | 
            +
                        # 180 + return projs
         | 
| 1150 | 
            +
                        self._flatfield_outwards = FlatField(
         | 
| 1151 | 
            +
                            self._radios_outwards.shape,
         | 
| 1152 | 
            +
                            flats=self.dataset_info.flats,
         | 
| 1153 | 
            +
                            darks=self.dataset_info.darks,
         | 
| 1154 | 
            +
                            radios_indices=self._radios_idx_outwards,
         | 
| 1155 | 
            +
                        )
         | 
| 1156 | 
            +
                        self._flatfield_outwards.normalize_radios(self._radios_outwards)
         | 
| 1157 | 
            +
                        self._flatfield_return = FlatField(
         | 
| 1158 | 
            +
                            self._radios_return.shape,
         | 
| 1159 | 
            +
                            flats=self.dataset_info.flats,
         | 
| 1160 | 
            +
                            darks=self.dataset_info.darks,
         | 
| 1161 | 
            +
                            radios_indices=self._radios_idx_return,
         | 
| 1162 | 
            +
                        )
         | 
| 1163 | 
            +
                        self._flatfield_outwards.normalize_radios(self._radios_return)
         | 
| 1164 | 
            +
             | 
| 1165 | 
            +
                def estimate_motion(self):
         | 
| 1166 | 
            +
                    self._load_data()
         | 
| 1167 | 
            +
                    self._apply_flatfield()
         | 
| 1168 | 
            +
                    if self.radios_filter is not None:
         | 
| 1169 | 
            +
                        self.logger.debug("[MotionEstimation] applying radios filter")
         | 
| 1170 | 
            +
                        self.radios_filter(self.radios)
         | 
| 1171 | 
            +
             | 
| 1172 | 
            +
                    n_projs_tot = self.dataset_info.n_angles
         | 
| 1173 | 
            +
                    if self.do_360:
         | 
| 1174 | 
            +
                        n_a = self.radios.shape[0]
         | 
| 1175 | 
            +
                        # See notes above - this works only for circular trajectory / ordered angles
         | 
| 1176 | 
            +
                        projs_stack1 = self.radios[: n_a // 2]
         | 
| 1177 | 
            +
                        projs_stack2 = self.radios[n_a // 2 :]
         | 
| 1178 | 
            +
                        angles1 = self.angles[: n_a // 2]
         | 
| 1179 | 
            +
                        angles2 = self.angles[n_a // 2 :]
         | 
| 1180 | 
            +
                        indices1 = (self._radios_idx - self._radios_idx[0])[: n_a // 2]
         | 
| 1181 | 
            +
                        indices2 = (self._radios_idx - self._radios_idx[0])[n_a // 2 :]
         | 
| 1182 | 
            +
                    else:
         | 
| 1183 | 
            +
                        projs_stack1 = self._radios_outwards
         | 
| 1184 | 
            +
                        projs_stack2 = self._radios_return
         | 
| 1185 | 
            +
                        angles1 = self._angles_outward
         | 
| 1186 | 
            +
                        angles2 = self._angles_return
         | 
| 1187 | 
            +
                        indices1 = self._radios_idx_outwards - self._radios_idx_outwards.min()
         | 
| 1188 | 
            +
                        indices2 = self._radios_idx_return - self._radios_idx_outwards.min()
         | 
| 1189 | 
            +
             | 
| 1190 | 
            +
                    if self._start_x is not None:
         | 
| 1191 | 
            +
                        # Compute Motion Estimation on subset of images (eg. for half-tomo)
         | 
| 1192 | 
            +
                        projs_stack1 = projs_stack1[..., self._start_x : self._end_x]
         | 
| 1193 | 
            +
                        projs_stack2 = projs_stack2[..., self._start_x : self._end_x]
         | 
| 1194 | 
            +
             | 
| 1195 | 
            +
                    self.motion_estimator = MotionEstimation(
         | 
| 1196 | 
            +
                        projs_stack1,
         | 
| 1197 | 
            +
                        projs_stack2,
         | 
| 1198 | 
            +
                        angles1,
         | 
| 1199 | 
            +
                        angles2,
         | 
| 1200 | 
            +
                        indices1,
         | 
| 1201 | 
            +
                        indices2,
         | 
| 1202 | 
            +
                        n_projs_tot,
         | 
| 1203 | 
            +
                        shifts_estimator=self._shifts_estimator,
         | 
| 1204 | 
            +
                        shifts_estimator_kwargs=self._shifts_estimator_kwargs,
         | 
| 1205 | 
            +
                    )
         | 
| 1206 | 
            +
             | 
| 1207 | 
            +
                    self.logger.debug("[MotionEstimation] estimating shifts")
         | 
| 1208 | 
            +
             | 
| 1209 | 
            +
                    estimated_shifts_v = self.motion_estimator.estimate_vertical_motion(degree=self._deg_z)
         | 
| 1210 | 
            +
                    estimated_shifts_h, cor = self.motion_estimator.estimate_horizontal_motion(degree=self._deg_xy, cor=self._cor)
         | 
| 1211 | 
            +
                    if self._start_x is not None:
         | 
| 1212 | 
            +
                        cor += (self._start_x % self.radios.shape[-1]) + (projs_stack1.shape[-1] - 1) / 2.0
         | 
| 1213 | 
            +
             | 
| 1214 | 
            +
                    self.sample_shifts_xy = estimated_shifts_h
         | 
| 1215 | 
            +
                    self.sample_shifts_z = estimated_shifts_v
         | 
| 1216 | 
            +
                    if self._cor is None:
         | 
| 1217 | 
            +
                        self.logger.info(
         | 
| 1218 | 
            +
                            "[MotionEstimation] Estimated center of rotation (relative to left-most pixel): %.2f" % cor
         | 
| 1219 | 
            +
                        )
         | 
| 1220 | 
            +
                    return estimated_shifts_h, estimated_shifts_v, cor
         | 
| 1221 | 
            +
             | 
| 1222 | 
            +
                def generate_translations_movements_file(self, filename, fmt="%.3f", only=None):
         | 
| 1223 | 
            +
                    if self.sample_shifts_xy is None:
         | 
| 1224 | 
            +
                        raise RuntimeError("Need to run estimate_motion() first")
         | 
| 1225 | 
            +
             | 
| 1226 | 
            +
                    angles = self.dataset_info.rotation_angles
         | 
| 1227 | 
            +
                    cor = self._cor or 0
         | 
| 1228 | 
            +
                    txy_est_all_angles = self.motion_estimator.apply_fit_horiz(angles=angles)
         | 
| 1229 | 
            +
                    tz_est_all_angles = self.motion_estimator.apply_fit_vertic(angles=angles)
         | 
| 1230 | 
            +
                    estimated_shifts_vu_all_angles = self.motion_estimator.convert_sample_motion_to_detector_shifts(
         | 
| 1231 | 
            +
                        txy_est_all_angles, tz_est_all_angles, angles, cor=cor
         | 
| 1232 | 
            +
                    )
         | 
| 1233 | 
            +
                    estimated_shifts_vu_all_angles[:, 1] -= cor
         | 
| 1234 | 
            +
                    correct_shifts_uv = -estimated_shifts_vu_all_angles[:, ::-1]
         | 
| 1235 | 
            +
             | 
| 1236 | 
            +
                    if only is not None:
         | 
| 1237 | 
            +
                        if only == "horizontal":
         | 
| 1238 | 
            +
                            correct_shifts_uv[:, 1] = 0
         | 
| 1239 | 
            +
                        elif only == "vertical":
         | 
| 1240 | 
            +
                            correct_shifts_uv[:, 0] = 0
         | 
| 1241 | 
            +
                        else:
         | 
| 1242 | 
            +
                            raise ValueError("Expected 'only' to be either None, 'horizontal' or 'vertical'")
         | 
| 1243 | 
            +
             | 
| 1244 | 
            +
                    header = f"Generated by nabu {nabu_version} : {str(self)}"
         | 
| 1245 | 
            +
                    np.savetxt(filename, correct_shifts_uv, fmt=fmt, header=header)
         | 
| 1246 | 
            +
             | 
| 1247 | 
            +
                def __str__(self):
         | 
| 1248 | 
            +
                    ret = f"{self.__class__.__name__}(do_flatfield={self.do_flatfield}, rot_center={self._cor}, angular_subsampling={self.angular_subsampling})"
         | 
| 1249 | 
            +
                    if self.sample_shifts_xy is not None:
         | 
| 1250 | 
            +
                        ret += f", shifts_estimator={self.motion_estimator.shifts_estimator}"
         | 
| 1251 | 
            +
                    return ret
         | 
| @@ -1,6 +1,5 @@ | |
| 1 1 | 
             
            from os import path
         | 
| 2 2 | 
             
            from time import time
         | 
| 3 | 
            -
            from math import ceil
         | 
| 4 3 | 
             
            import numpy as np
         | 
| 5 4 | 
             
            from silx.io.url import DataUrl
         | 
| 6 5 |  | 
| @@ -10,7 +9,7 @@ from ...resources.utils import extract_parameters | |
| 10 9 | 
             
            from ...misc.binning import binning as image_binning
         | 
| 11 10 | 
             
            from ...io.reader import EDFStackReader, HDF5Loader, NXTomoReader
         | 
| 12 11 | 
             
            from ...preproc.ccd import Log, CCDFilter
         | 
| 13 | 
            -
            from ...preproc.flatfield import FlatField
         | 
| 12 | 
            +
            from ...preproc.flatfield import FlatField, PCAFlatsNormalizer
         | 
| 14 13 | 
             
            from ...preproc.distortion import DistortionCorrection
         | 
| 15 14 | 
             
            from ...preproc.shift import VerticalShift
         | 
| 16 15 | 
             
            from ...preproc.double_flatfield import DoubleFlatField
         | 
| @@ -18,7 +17,7 @@ from ...preproc.phase import PaganinPhaseRetrieval | |
| 18 17 | 
             
            from ...preproc.ctf import CTFPhaseRetrieval, GeoPars
         | 
| 19 18 | 
             
            from ...reconstruction.sinogram import SinoNormalization
         | 
| 20 19 | 
             
            from ...reconstruction.filtering import SinoFilter
         | 
| 21 | 
            -
            from ...reconstruction.mlem import  | 
| 20 | 
            +
            from ...reconstruction.mlem import MLEMReconstructor
         | 
| 22 21 | 
             
            from ...processing.rotation import Rotation
         | 
| 23 22 | 
             
            from ...reconstruction.rings import MunchDeringer, SinoMeanDeringer, VoDeringer
         | 
| 24 23 | 
             
            from ...processing.unsharp import UnsharpMask
         | 
| @@ -45,6 +44,7 @@ class ChunkedPipeline: | |
| 45 44 |  | 
| 46 45 | 
             
                backend = "numpy"
         | 
| 47 46 | 
             
                FlatFieldClass = FlatField
         | 
| 47 | 
            +
                PCAFlatFieldClass = PCAFlatsNormalizer
         | 
| 48 48 | 
             
                DoubleFlatFieldClass = DoubleFlatField
         | 
| 49 49 | 
             
                CCDCorrectionClass = CCDFilter
         | 
| 50 50 | 
             
                PaganinPhaseRetrievalClass = PaganinPhaseRetrieval
         | 
| @@ -99,7 +99,6 @@ class ChunkedPipeline: | |
| 99 99 |  | 
| 100 100 | 
             
                    Notes
         | 
| 101 101 | 
             
                    ------
         | 
| 102 | 
            -
             | 
| 103 102 | 
             
                    Using `margin` results in a lesser number of reconstructed slices.
         | 
| 104 103 | 
             
                    More specifically, if `margin = (V, H)`, then there will be `delta_z - 2*V`
         | 
| 105 104 | 
             
                    reconstructed slices (if the sub-region is in the middle of the volume)
         | 
| @@ -127,10 +126,10 @@ class ChunkedPipeline: | |
| 127 126 | 
             
                    if len(chunk_shape) != 3:
         | 
| 128 127 | 
             
                        raise ValueError("Expected chunk_shape to be a tuple of length 3 in the form (n_z, n_y, n_x)")
         | 
| 129 128 | 
             
                    self.chunk_shape = tuple(int(c) for c in chunk_shape)  # cast to int, as numpy.int64 can make pycuda crash
         | 
| 130 | 
            -
                     | 
| 129 | 
            +
                    ss_start = getattr(self.process_config, "subsampling_start", 0)
         | 
| 131 130 | 
             
                    # (n_a, n_z, n_x)
         | 
| 132 131 | 
             
                    self.radios_shape = (
         | 
| 133 | 
            -
                         | 
| 132 | 
            +
                        np.arange(self.chunk_shape[0])[ss_start :: self.process_config.subsampling_factor].size,
         | 
| 134 133 | 
             
                        self.chunk_shape[1] // self.process_config.binning[1],
         | 
| 135 134 | 
             
                        self.chunk_shape[2] // self.process_config.binning[0],
         | 
| 136 135 | 
             
                    )
         | 
| @@ -175,7 +174,7 @@ class ChunkedPipeline: | |
| 175 174 | 
             
                        Data volume sub-region, in the form ((start_a, end_a), (start_z, end_z), (start_x, end_x))
         | 
| 176 175 | 
             
                        where the data volume has a layout (angles, Z, X)
         | 
| 177 176 | 
             
                    """
         | 
| 178 | 
            -
                    n_angles = self.dataset_info.n_angles
         | 
| 177 | 
            +
                    # n_angles = self.dataset_info.n_angles
         | 
| 179 178 | 
             
                    n_x, n_z = self.dataset_info.radio_dims
         | 
| 180 179 | 
             
                    c_a, c_z, c_x = self.chunk_shape
         | 
| 181 180 | 
             
                    if sub_region is None:
         | 
| @@ -190,7 +189,7 @@ class ChunkedPipeline: | |
| 190 189 | 
             
                        # check sub-region
         | 
| 191 190 | 
             
                        for i, start_end in enumerate(sub_region):
         | 
| 192 191 | 
             
                            start, end = start_end
         | 
| 193 | 
            -
                            if start is not None and end is not None:
         | 
| 192 | 
            +
                            if start is not None and end is not None:  # noqa: SIM102
         | 
| 194 193 | 
             
                                if end - start != self.chunk_shape[i]:
         | 
| 195 194 | 
             
                                    raise ValueError(
         | 
| 196 195 | 
             
                                        "Invalid (start, end)=(%d, %d) for sub-region (dimension %d): chunk shape is %s, but %d-%d=%d != %d"
         | 
| @@ -340,13 +339,28 @@ class ChunkedPipeline: | |
| 340 339 | 
             
                        subs_z = None
         | 
| 341 340 | 
             
                        subs_x = None
         | 
| 342 341 | 
             
                        angular_sub_region = slice(*(self.sub_region[0]))
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                        # exclude(subsample(.)) != subsample(exclude(.))
         | 
| 344 | 
            +
                        # Here we want the latter: first exclude the user-defined angular range, and then subsample the remaining indices
         | 
| 345 | 
            +
                        if len(self.dataset_info.get_excluded_projections_indices()) > 0:
         | 
| 346 | 
            +
                            angular_sub_region = np.array(
         | 
| 347 | 
            +
                                [
         | 
| 348 | 
            +
                                    self.dataset_info.index_to_proj_number(i)
         | 
| 349 | 
            +
                                    for i in sorted(list(self.dataset_info.projections.keys()))
         | 
| 350 | 
            +
                                ]
         | 
| 351 | 
            +
                            )
         | 
| 343 352 | 
             
                        if self.process_config.subsampling_factor:
         | 
| 344 353 | 
             
                            subs_angles = self.process_config.subsampling_factor
         | 
| 345 | 
            -
                             | 
| 346 | 
            -
             | 
| 347 | 
            -
                                 | 
| 348 | 
            -
             | 
| 349 | 
            -
             | 
| 354 | 
            +
                            start = getattr(self.process_config, "subsampling_start", 0) + self.sub_region[0][0]
         | 
| 355 | 
            +
                            if isinstance(angular_sub_region, slice):
         | 
| 356 | 
            +
                                angular_sub_region = slice(
         | 
| 357 | 
            +
                                    start,
         | 
| 358 | 
            +
                                    self.sub_region[0][1],
         | 
| 359 | 
            +
                                    subs_angles,
         | 
| 360 | 
            +
                                )
         | 
| 361 | 
            +
                            else:
         | 
| 362 | 
            +
                                angular_sub_region = angular_sub_region[start::subs_angles]
         | 
| 363 | 
            +
             | 
| 350 364 | 
             
                        reader_sub_region = (
         | 
| 351 365 | 
             
                            angular_sub_region,
         | 
| 352 366 | 
             
                            slice(*(self.sub_region[1]) + ((subs_z,) if subs_z else ())),
         | 
| @@ -363,7 +377,7 @@ class ChunkedPipeline: | |
| 363 377 | 
             
                        if self.dataset_info.kind == "nx":
         | 
| 364 378 | 
             
                            self.chunk_reader = NXTomoReader(
         | 
| 365 379 | 
             
                                self.dataset_info.dataset_hdf5_url.file_path(),
         | 
| 366 | 
            -
                                self.dataset_info.dataset_hdf5_url.data_path(),
         | 
| 380 | 
            +
                                data_path=self.dataset_info.dataset_hdf5_url.data_path(),
         | 
| 367 381 | 
             
                                sub_region=reader_sub_region,
         | 
| 368 382 | 
             
                                image_key=0,
         | 
| 369 383 | 
             
                                **other_reader_kwargs,
         | 
| @@ -394,50 +408,68 @@ class ChunkedPipeline: | |
| 394 408 |  | 
| 395 409 | 
             
                @use_options("flatfield", "flatfield")
         | 
| 396 410 | 
             
                def _init_flatfield(self):
         | 
| 397 | 
            -
                     | 
| 398 | 
            -
             | 
| 399 | 
            -
             | 
| 400 | 
            -
             | 
| 401 | 
            -
             | 
| 402 | 
            -
             | 
| 403 | 
            -
             | 
| 404 | 
            -
             | 
| 405 | 
            -
             | 
| 406 | 
            -
                         | 
| 407 | 
            -
             | 
| 408 | 
            -
             | 
| 409 | 
            -
             | 
| 410 | 
            -
             | 
| 411 | 
            -
             | 
| 412 | 
            -
                        self. | 
| 413 | 
            -
             | 
| 414 | 
            -
             | 
| 415 | 
            -
             | 
| 416 | 
            -
             | 
| 417 | 
            -
             | 
| 418 | 
            -
                             | 
| 419 | 
            -
             | 
| 411 | 
            +
                    if self.processing_options["flatfield"]:
         | 
| 412 | 
            +
                        self._ff_options = self.processing_options["flatfield"].copy()
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                        # This won't work when resuming from a step (i.e before FF), because we rely on H5Loader()
         | 
| 415 | 
            +
                        # which re-compacts the data. When data is re-compacted, we have to know the original radios positions.
         | 
| 416 | 
            +
                        # These positions can be saved in the "file_dump" metadata, but it is not loaded for now
         | 
| 417 | 
            +
                        # (the process_config object is re-built from scratch every time)
         | 
| 418 | 
            +
                        self._ff_options["projs_indices"] = self.chunk_reader.get_frames_indices()
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                        if self._ff_options.get("normalize_srcurrent", False):
         | 
| 421 | 
            +
                            a_start_idx, a_end_idx = self.sub_region[0]
         | 
| 422 | 
            +
                            subs = self.process_config.subsampling_factor
         | 
| 423 | 
            +
                            self._ff_options["radios_srcurrent"] = self._ff_options["radios_srcurrent"][a_start_idx:a_end_idx:subs]
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                        distortion_correction = None
         | 
| 426 | 
            +
                        if self._ff_options["do_flat_distortion"]:
         | 
| 427 | 
            +
                            self.logger.info("Flats distortion correction will be applied")
         | 
| 428 | 
            +
                            self.FlatFieldClass = FlatField  # no GPU implementation available, force this backend
         | 
| 429 | 
            +
                            estimation_kwargs = {}
         | 
| 430 | 
            +
                            estimation_kwargs.update(self._ff_options["flat_distortion_params"])
         | 
| 431 | 
            +
                            estimation_kwargs["logger"] = self.logger
         | 
| 432 | 
            +
                            distortion_correction = DistortionCorrection(
         | 
| 433 | 
            +
                                estimation_method="fft-correlation",
         | 
| 434 | 
            +
                                estimation_kwargs=estimation_kwargs,
         | 
| 435 | 
            +
                                correction_method="interpn",
         | 
| 436 | 
            +
                            )
         | 
| 420 437 |  | 
| 421 | 
            -
             | 
| 422 | 
            -
             | 
| 423 | 
            -
             | 
| 424 | 
            -
             | 
| 425 | 
            -
             | 
| 426 | 
            -
             | 
| 427 | 
            -
             | 
| 428 | 
            -
             | 
| 438 | 
            +
                        if self.processing_options["flatfield"]["method"].lower() != "pca":
         | 
| 439 | 
            +
                            # Reduced darks/flats are loaded, but we have to crop them on the current sub-region
         | 
| 440 | 
            +
                            # and possibly do apply some pre-processing (binning, distortion correction, ...)
         | 
| 441 | 
            +
                            darks_flats = load_darks_flats(
         | 
| 442 | 
            +
                                self.dataset_info,
         | 
| 443 | 
            +
                                self.sub_region[1:],
         | 
| 444 | 
            +
                                processing_func=self._ff_processing_function,
         | 
| 445 | 
            +
                                processing_func_args=self._ff_processing_function_args,
         | 
| 446 | 
            +
                            )
         | 
| 429 447 |  | 
| 430 | 
            -
             | 
| 431 | 
            -
             | 
| 432 | 
            -
             | 
| 433 | 
            -
             | 
| 434 | 
            -
             | 
| 435 | 
            -
             | 
| 436 | 
            -
             | 
| 437 | 
            -
             | 
| 438 | 
            -
             | 
| 439 | 
            -
             | 
| 440 | 
            -
             | 
| 448 | 
            +
                            # FlatField parameter "radios_indices" must account for subsampling
         | 
| 449 | 
            +
                            self.flatfield = self.FlatFieldClass(
         | 
| 450 | 
            +
                                self.radios_shape,
         | 
| 451 | 
            +
                                flats=darks_flats["flats"],
         | 
| 452 | 
            +
                                darks=darks_flats["darks"],
         | 
| 453 | 
            +
                                radios_indices=self._ff_options["projs_indices"],
         | 
| 454 | 
            +
                                interpolation="linear",
         | 
| 455 | 
            +
                                distortion_correction=distortion_correction,
         | 
| 456 | 
            +
                                radios_srcurrent=self._ff_options["radios_srcurrent"],
         | 
| 457 | 
            +
                                flats_srcurrent=self._ff_options["flats_srcurrent"],
         | 
| 458 | 
            +
                            )
         | 
| 459 | 
            +
                        else:
         | 
| 460 | 
            +
                            flats = self.process_config.dataset_info.flats
         | 
| 461 | 
            +
                            darks = self.process_config.dataset_info.darks
         | 
| 462 | 
            +
                            if len(darks) != 1:
         | 
| 463 | 
            +
                                raise ValueError(f"There should be only one reduced dark. Found {len(darks)}.")
         | 
| 464 | 
            +
                            else:
         | 
| 465 | 
            +
                                dark_key = list(darks.keys())[0]
         | 
| 466 | 
            +
                            nb_pca_components = len(flats) - 1
         | 
| 467 | 
            +
                            img_subregion = tuple(slice(*sr) for sr in self.sub_region[1:])
         | 
| 468 | 
            +
                            self.flatfield = self.PCAFlatFieldClass(
         | 
| 469 | 
            +
                                np.array([flats[k][img_subregion] for k in range(1, nb_pca_components)]),
         | 
| 470 | 
            +
                                darks[dark_key][img_subregion],
         | 
| 471 | 
            +
                                flats[0][img_subregion],  # Mean
         | 
| 472 | 
            +
                            )
         | 
| 441 473 |  | 
| 442 474 | 
             
                @use_options("double_flatfield", "double_flatfield")
         | 
| 443 475 | 
             
                def _init_double_flatfield(self):
         | 
| @@ -630,13 +662,14 @@ class ChunkedPipeline: | |
| 630 662 | 
             
                                "clip_outer_circle": options["clip_outer_circle"],
         | 
| 631 663 | 
             
                                "outer_circle_value": options["outer_circle_value"],
         | 
| 632 664 | 
             
                                "filter_cutoff": options["fbp_filter_cutoff"],
         | 
| 665 | 
            +
                                "crop_filtered_data": options["crop_filtered_data"],
         | 
| 633 666 | 
             
                            },
         | 
| 634 667 | 
             
                        )
         | 
| 635 668 |  | 
| 636 669 | 
             
                    if options["method"] == "mlem" and options["implementation"] in (None, "corrct"):
         | 
| 637 670 | 
             
                        self.reconstruction = self.MLEMClass(  # pylint: disable=E1102
         | 
| 638 671 | 
             
                            (self.radios_shape[1],) + self.sino_shape,
         | 
| 639 | 
            -
                            angles_rad | 
| 672 | 
            +
                            angles_rad=options["angles"],
         | 
| 640 673 | 
             
                            shifts_uv=self.dataset_info.translations,  # In config file, one line per proj, each line is (tu,tv). Corrct expects one col per proj and (tv,tu).
         | 
| 641 674 | 
             
                            cor=options["rotation_axis_position"],
         | 
| 642 675 | 
             
                            n_iterations=options["iterations"],
         | 
| @@ -647,9 +680,19 @@ class ChunkedPipeline: | |
| 647 680 | 
             
                                "v_max_for_v_shifts": None,
         | 
| 648 681 | 
             
                                "v_min_for_u_shifts": 0,
         | 
| 649 682 | 
             
                                "v_max_for_u_shifts": None,
         | 
| 683 | 
            +
                                "scale_factor": 1.0 / options["voxel_size_cm"][0],
         | 
| 684 | 
            +
                                "clip_outer_circle": options["clip_outer_circle"],
         | 
| 685 | 
            +
                                "outer_circle_value": options["outer_circle_value"],
         | 
| 686 | 
            +
                                "filter_cutoff": options["fbp_filter_cutoff"],
         | 
| 687 | 
            +
                                "crop_filtered_data": options["crop_filtered_data"],
         | 
| 650 688 | 
             
                            },
         | 
| 651 689 | 
             
                        )
         | 
| 652 690 |  | 
| 691 | 
            +
                    if options.get("crop_filtered_data", True) is False:
         | 
| 692 | 
            +
                        self.logger.warning(
         | 
| 693 | 
            +
                            "Using [reconstruction] crop_filtered_data = False. This will use a large amount of memory."
         | 
| 694 | 
            +
                        )
         | 
| 695 | 
            +
             | 
| 653 696 | 
             
                    self._allocate_recs(*self.process_config.rec_shape, n_slices=n_slices)
         | 
| 654 697 | 
             
                    n_a, _, n_x = self.radios_cropped_shape
         | 
| 655 698 | 
             
                    self._tmp_sino = self._allocate_array((n_a, n_x), "f", name="tmp_sino")
         | 
| @@ -806,15 +849,7 @@ class ChunkedPipeline: | |
| 806 849 | 
             
                    """
         | 
| 807 850 | 
             
                    This reconstructs the entire sinograms stack at once
         | 
| 808 851 | 
             
                    """
         | 
| 809 | 
            -
             | 
| 810 | 
            -
                    n_angles, n_z, n_x = self.radios.shape
         | 
| 811 | 
            -
             | 
| 812 | 
            -
                    # FIXME
         | 
| 813 | 
            -
                    # can't do a discontiguous single copy...
         | 
| 814 | 
            -
                    sinos_contig = self._allocate_array((n_z, n_angles, n_x), np.float32, "sinos_cone")
         | 
| 815 | 
            -
                    for i in range(n_z):
         | 
| 816 | 
            -
                        sinos_contig[i] = self.radios[:, i, :]
         | 
| 817 | 
            -
                    # ---
         | 
| 852 | 
            +
                    sinos_discontig = self.radios.transpose(axes=(1, 0, 2))  # view
         | 
| 818 853 |  | 
| 819 854 | 
             
                    # In principle radios are not cropped at this stage,
         | 
| 820 855 | 
             
                    # so self.sub_region[2][0] can be used instead of self.get_slice_start_index() instead of self.sub_region[2][0]
         | 
| @@ -822,7 +857,8 @@ class ChunkedPipeline: | |
| 822 857 | 
             
                    n_z_tot = self.process_config.radio_shape(binning=True)[0]
         | 
| 823 858 |  | 
| 824 859 | 
             
                    self.reconstruction.reconstruct(  # pylint: disable=E1101
         | 
| 825 | 
            -
                        sinos_contig,
         | 
| 860 | 
            +
                        # sinos_contig,
         | 
| 861 | 
            +
                        sinos_discontig,
         | 
| 826 862 | 
             
                        output=self.recs,
         | 
| 827 863 | 
             
                        relative_z_position=((z_min + z_max) / self.process_config.binning_z / 2) - n_z_tot / 2,
         | 
| 828 864 | 
             
                    )
         |