nabu 2024.2.14__py3-none-any.whl → 2025.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doc/doc_config.py +32 -0
- nabu/__init__.py +1 -1
- nabu/app/bootstrap_stitching.py +4 -2
- nabu/app/cast_volume.py +16 -14
- nabu/app/cli_configs.py +102 -9
- nabu/app/compare_volumes.py +1 -1
- nabu/app/composite_cor.py +2 -4
- nabu/app/diag_to_pix.py +5 -6
- nabu/app/diag_to_rot.py +10 -11
- nabu/app/double_flatfield.py +18 -5
- nabu/app/estimate_motion.py +75 -0
- nabu/app/multicor.py +28 -15
- nabu/app/parse_reconstruction_log.py +1 -0
- nabu/app/pcaflats.py +122 -0
- nabu/app/prepare_weights_double.py +1 -2
- nabu/app/reconstruct.py +1 -7
- nabu/app/reconstruct_helical.py +5 -9
- nabu/app/reduce_dark_flat.py +5 -4
- nabu/app/rotate.py +3 -1
- nabu/app/stitching.py +7 -2
- nabu/app/tests/test_reduce_dark_flat.py +2 -2
- nabu/app/validator.py +1 -4
- nabu/cuda/convolution.py +1 -1
- nabu/cuda/fft.py +1 -1
- nabu/cuda/medfilt.py +1 -1
- nabu/cuda/padding.py +1 -1
- nabu/cuda/src/backproj.cu +6 -6
- nabu/cuda/src/cone.cu +4 -0
- nabu/cuda/src/hierarchical_backproj.cu +14 -0
- nabu/cuda/utils.py +2 -2
- nabu/estimation/alignment.py +17 -31
- nabu/estimation/cor.py +27 -33
- nabu/estimation/cor_sino.py +2 -8
- nabu/estimation/focus.py +4 -8
- nabu/estimation/motion.py +557 -0
- nabu/estimation/tests/test_alignment.py +2 -0
- nabu/estimation/tests/test_motion_estimation.py +471 -0
- nabu/estimation/tests/test_tilt.py +1 -1
- nabu/estimation/tilt.py +6 -5
- nabu/estimation/translation.py +47 -1
- nabu/io/cast_volume.py +108 -18
- nabu/io/detector_distortion.py +5 -6
- nabu/io/reader.py +45 -6
- nabu/io/reader_helical.py +5 -4
- nabu/io/tests/test_cast_volume.py +2 -2
- nabu/io/tests/test_readers.py +41 -38
- nabu/io/tests/test_remove_volume.py +152 -0
- nabu/io/tests/test_writers.py +2 -2
- nabu/io/utils.py +8 -4
- nabu/io/writer.py +1 -2
- nabu/misc/fftshift.py +1 -1
- nabu/misc/fourier_filters.py +1 -1
- nabu/misc/histogram.py +1 -1
- nabu/misc/histogram_cuda.py +1 -1
- nabu/misc/padding_base.py +1 -1
- nabu/misc/rotation.py +1 -1
- nabu/misc/rotation_cuda.py +1 -1
- nabu/misc/tests/test_binning.py +1 -1
- nabu/misc/transpose.py +1 -1
- nabu/misc/unsharp.py +1 -1
- nabu/misc/unsharp_cuda.py +1 -1
- nabu/misc/unsharp_opencl.py +1 -1
- nabu/misc/utils.py +1 -1
- nabu/opencl/fft.py +1 -1
- nabu/opencl/padding.py +1 -1
- nabu/opencl/src/backproj.cl +6 -6
- nabu/opencl/utils.py +8 -8
- nabu/pipeline/config.py +2 -2
- nabu/pipeline/config_validators.py +46 -46
- nabu/pipeline/datadump.py +3 -3
- nabu/pipeline/estimators.py +271 -11
- nabu/pipeline/fullfield/chunked.py +103 -67
- nabu/pipeline/fullfield/chunked_cuda.py +5 -2
- nabu/pipeline/fullfield/computations.py +4 -1
- nabu/pipeline/fullfield/dataset_validator.py +0 -1
- nabu/pipeline/fullfield/get_double_flatfield.py +147 -0
- nabu/pipeline/fullfield/nabu_config.py +36 -17
- nabu/pipeline/fullfield/processconfig.py +41 -7
- nabu/pipeline/fullfield/reconstruction.py +14 -10
- nabu/pipeline/helical/dataset_validator.py +3 -4
- nabu/pipeline/helical/fbp.py +4 -4
- nabu/pipeline/helical/filtering.py +5 -4
- nabu/pipeline/helical/gridded_accumulator.py +10 -11
- nabu/pipeline/helical/helical_chunked_regridded.py +1 -0
- nabu/pipeline/helical/helical_reconstruction.py +12 -9
- nabu/pipeline/helical/helical_utils.py +1 -2
- nabu/pipeline/helical/nabu_config.py +2 -1
- nabu/pipeline/helical/span_strategy.py +1 -0
- nabu/pipeline/helical/weight_balancer.py +2 -3
- nabu/pipeline/params.py +20 -3
- nabu/pipeline/tests/__init__.py +0 -0
- nabu/pipeline/tests/test_estimators.py +240 -3
- nabu/pipeline/utils.py +1 -1
- nabu/pipeline/writer.py +1 -1
- nabu/preproc/alignment.py +0 -10
- nabu/preproc/ccd.py +53 -3
- nabu/preproc/ctf.py +8 -8
- nabu/preproc/ctf_cuda.py +1 -1
- nabu/preproc/double_flatfield_cuda.py +2 -2
- nabu/preproc/double_flatfield_variable_region.py +0 -1
- nabu/preproc/flatfield.py +307 -2
- nabu/preproc/flatfield_cuda.py +1 -2
- nabu/preproc/flatfield_variable_region.py +3 -3
- nabu/preproc/phase.py +2 -4
- nabu/preproc/phase_cuda.py +2 -2
- nabu/preproc/shift.py +4 -2
- nabu/preproc/shift_cuda.py +0 -1
- nabu/preproc/tests/test_ctf.py +4 -4
- nabu/preproc/tests/test_double_flatfield.py +1 -1
- nabu/preproc/tests/test_flatfield.py +1 -1
- nabu/preproc/tests/test_paganin.py +1 -3
- nabu/preproc/tests/test_pcaflats.py +154 -0
- nabu/preproc/tests/test_vshift.py +4 -1
- nabu/processing/azim.py +9 -5
- nabu/processing/convolution_cuda.py +6 -4
- nabu/processing/fft_base.py +7 -3
- nabu/processing/fft_cuda.py +25 -164
- nabu/processing/fft_opencl.py +28 -6
- nabu/processing/fftshift.py +1 -1
- nabu/processing/histogram.py +1 -1
- nabu/processing/muladd.py +0 -1
- nabu/processing/padding_base.py +1 -1
- nabu/processing/padding_cuda.py +0 -2
- nabu/processing/processing_base.py +12 -6
- nabu/processing/rotation_cuda.py +3 -1
- nabu/processing/tests/test_fft.py +2 -64
- nabu/processing/tests/test_fftshift.py +1 -1
- nabu/processing/tests/test_medfilt.py +1 -3
- nabu/processing/tests/test_padding.py +1 -1
- nabu/processing/tests/test_roll.py +1 -1
- nabu/processing/tests/test_rotation.py +4 -2
- nabu/processing/unsharp_opencl.py +1 -1
- nabu/reconstruction/astra.py +245 -0
- nabu/reconstruction/cone.py +39 -9
- nabu/reconstruction/fbp.py +7 -0
- nabu/reconstruction/fbp_base.py +36 -5
- nabu/reconstruction/filtering.py +59 -25
- nabu/reconstruction/filtering_cuda.py +22 -21
- nabu/reconstruction/filtering_opencl.py +10 -14
- nabu/reconstruction/hbp.py +26 -13
- nabu/reconstruction/mlem.py +55 -16
- nabu/reconstruction/projection.py +3 -5
- nabu/reconstruction/sinogram.py +1 -1
- nabu/reconstruction/sinogram_cuda.py +0 -1
- nabu/reconstruction/tests/test_cone.py +37 -2
- nabu/reconstruction/tests/test_deringer.py +4 -4
- nabu/reconstruction/tests/test_fbp.py +36 -15
- nabu/reconstruction/tests/test_filtering.py +27 -7
- nabu/reconstruction/tests/test_halftomo.py +28 -2
- nabu/reconstruction/tests/test_mlem.py +94 -64
- nabu/reconstruction/tests/test_projector.py +7 -2
- nabu/reconstruction/tests/test_reconstructor.py +1 -1
- nabu/reconstruction/tests/test_sino_normalization.py +0 -1
- nabu/resources/dataset_analyzer.py +210 -24
- nabu/resources/gpu.py +4 -4
- nabu/resources/logger.py +4 -4
- nabu/resources/nxflatfield.py +103 -37
- nabu/resources/tests/test_dataset_analyzer.py +37 -0
- nabu/resources/tests/test_extract.py +11 -0
- nabu/resources/tests/test_nxflatfield.py +5 -5
- nabu/resources/utils.py +16 -10
- nabu/stitching/alignment.py +8 -11
- nabu/stitching/config.py +44 -35
- nabu/stitching/definitions.py +2 -2
- nabu/stitching/frame_composition.py +8 -10
- nabu/stitching/overlap.py +4 -4
- nabu/stitching/sample_normalization.py +5 -5
- nabu/stitching/slurm_utils.py +2 -2
- nabu/stitching/stitcher/base.py +2 -0
- nabu/stitching/stitcher/dumper/base.py +0 -1
- nabu/stitching/stitcher/dumper/postprocessing.py +1 -1
- nabu/stitching/stitcher/post_processing.py +11 -9
- nabu/stitching/stitcher/pre_processing.py +37 -31
- nabu/stitching/stitcher/single_axis.py +2 -3
- nabu/stitching/stitcher_2D.py +2 -1
- nabu/stitching/tests/test_config.py +10 -11
- nabu/stitching/tests/test_sample_normalization.py +1 -1
- nabu/stitching/tests/test_slurm_utils.py +1 -2
- nabu/stitching/tests/test_y_preprocessing_stitching.py +11 -8
- nabu/stitching/tests/test_z_postprocessing_stitching.py +3 -3
- nabu/stitching/tests/test_z_preprocessing_stitching.py +27 -24
- nabu/stitching/utils/tests/__init__.py +0 -0
- nabu/stitching/utils/tests/test_post-processing.py +1 -0
- nabu/stitching/utils/utils.py +16 -18
- nabu/tests.py +0 -3
- nabu/testutils.py +62 -9
- nabu/utils.py +50 -20
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/METADATA +7 -7
- nabu-2025.1.0.dist-info/RECORD +328 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/WHEEL +1 -1
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/entry_points.txt +2 -1
- nabu/app/correct_rot.py +0 -70
- nabu/io/tests/test_detector_distortion.py +0 -178
- nabu-2024.2.14.dist-info/RECORD +0 -317
- /nabu/{stitching → app}/tests/__init__.py +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/licenses/LICENSE +0 -0
- {nabu-2024.2.14.dist-info → nabu-2025.1.0.dist-info}/top_level.txt +0 -0
| @@ -29,8 +29,7 @@ class GriddedAccumulator: | |
| 29 29 | 
             
                    This class creates, for a selected volume slab, a standard set of radios from an helical dataset.
         | 
| 30 30 |  | 
| 31 31 | 
             
                    Parameters
         | 
| 32 | 
            -
                     | 
| 33 | 
            -
             | 
| 32 | 
            +
                    ----------
         | 
| 34 33 | 
             
                    gridded_radios : 3D np.array
         | 
| 35 34 | 
             
                       this is the stack of new radios which will be resynthetised, by this class,
         | 
| 36 35 | 
             
                       for a selected slab.
         | 
| @@ -97,7 +96,7 @@ class GriddedAccumulator: | |
| 97 96 | 
             
                    self.flats_srcurrent = flats_srcurrent
         | 
| 98 97 |  | 
| 99 98 | 
             
                    self.flat_indexes = flat_indexes
         | 
| 100 | 
            -
                    self.flat_indexes_reverse_map = dict(
         | 
| 99 | 
            +
                    self.flat_indexes_reverse_map = dict(  # noqa: C404
         | 
| 101 100 | 
             
                        [(global_index, local_index) for (local_index, global_index) in enumerate(flat_indexes)]
         | 
| 102 101 | 
             
                    )
         | 
| 103 102 | 
             
                    self.flats = flats
         | 
| @@ -121,7 +120,7 @@ class GriddedAccumulator: | |
| 121 120 | 
             
                    the accumulators are ready.
         | 
| 122 121 |  | 
| 123 122 | 
             
                    Parameters
         | 
| 124 | 
            -
                     | 
| 123 | 
            +
                    ----------
         | 
| 125 124 | 
             
                    subchunk_slice: an object of the python class "slice"
         | 
| 126 125 | 
             
                      this slice slices the angular domain which corresponds to the useful
         | 
| 127 126 | 
             
                      projections  which are useful for the chunk, and whose informations
         | 
| @@ -278,7 +277,7 @@ class GriddedAccumulator: | |
| 278 277 | 
             
                            i_diag_list = [(i0 - 1) // 2, (i0 - 1) // 2 + len(self.diagnostic_searched_angles_rad_clipped)]
         | 
| 279 278 | 
             
                            for i_redundancy, i_diag in enumerate(i_diag_list):
         | 
| 280 279 | 
             
                                # print("IRED ", i_redundancy)
         | 
| 281 | 
            -
                                if i_redundancy:
         | 
| 280 | 
            +
                                if i_redundancy:  # noqa: SIM102
         | 
| 282 281 | 
             
                                    # to avoid, in z_stages with >360 range for one single stage, to fill the second items which should instead be filled by another stage.
         | 
| 283 282 | 
             
                                    if abs(original_zpix_transl - self.diagnostic_zpix_transl[i_diag_list[0]]) < 2.0:
         | 
| 284 283 | 
             
                                        # print( " >>>>>> stesso z" , i_redundancy )
         | 
| @@ -306,7 +305,7 @@ class GriddedAccumulator: | |
| 306 305 | 
             
                                    self.diagnostic_radios[i_diag] += data_token * factor
         | 
| 307 306 | 
             
                                    self.diagnostic_weights[i_diag] += weight * factor
         | 
| 308 307 | 
             
                                    break
         | 
| 309 | 
            -
                                else:
         | 
| 308 | 
            +
                                else:  # noqa: RET508
         | 
| 310 309 | 
             
                                    pass
         | 
| 311 310 |  | 
| 312 311 | 
             
                class _ReframingInfos:
         | 
| @@ -483,8 +482,8 @@ def overlap_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z): | |
| 483 482 |  | 
| 484 483 | 
             
            def padding_logic(subr_start_z, subr_end_z, dtasrc_start_z, dtasrc_end_z):
         | 
| 485 484 | 
             
                """.......... and the missing ranges which possibly could be obtained by extension padding"""
         | 
| 486 | 
            -
                t_h = subr_end_z - subr_start_z
         | 
| 487 | 
            -
                s_h = dtasrc_end_z - dtasrc_start_z
         | 
| 485 | 
            +
                # t_h = subr_end_z - subr_start_z
         | 
| 486 | 
            +
                # s_h = dtasrc_end_z - dtasrc_start_z
         | 
| 488 487 |  | 
| 489 488 | 
             
                if dtasrc_start_z <= subr_start_z:
         | 
| 490 489 | 
             
                    target_lower_padding = None
         | 
| @@ -503,9 +502,9 @@ def get_reconstruction_space(span_info, min_scanwise_z, end_scanwise_z, phase_ma | |
| 503 502 | 
             
                """Utility function, so far used only by the unit test, which, given the span_info object, creates the auxiliary collection arrays
         | 
| 504 503 | 
             
                and initialises the  my_z_min, my_z_end variable keeping into account the scan direction
         | 
| 505 504 | 
             
                and the min_scanwise_z, end_scanwise_z input arguments
         | 
| 506 | 
            -
                Parameters
         | 
| 507 | 
            -
                ==========
         | 
| 508 505 |  | 
| 506 | 
            +
                Parameters
         | 
| 507 | 
            +
                ----------
         | 
| 509 508 | 
             
                span_info: SpanStrategy
         | 
| 510 509 |  | 
| 511 510 | 
             
                min_scanwise_z: int
         | 
| @@ -533,7 +532,7 @@ def get_reconstruction_space(span_info, min_scanwise_z, end_scanwise_z, phase_ma | |
| 533 532 | 
             
                # regridded dataset, estimating a meaningul angular step representative
         | 
| 534 533 | 
             
                # of the raw data
         | 
| 535 534 | 
             
                my_angle_step = abs(np.diff(span_info.projection_angles_deg).mean())
         | 
| 536 | 
            -
                n_gridded_angles =  | 
| 535 | 
            +
                n_gridded_angles = round(360.0 / my_angle_step)
         | 
| 537 536 |  | 
| 538 537 | 
             
                radios_h = phase_margin_pix + (my_z_end - my_z_min) + phase_margin_pix
         | 
| 539 538 |  | 
| @@ -3,6 +3,8 @@ from math import ceil | |
| 3 3 | 
             
            from time import time
         | 
| 4 4 | 
             
            import numpy as np
         | 
| 5 5 | 
             
            import copy
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from nabu.utils import first_generator_item
         | 
| 6 8 | 
             
            from ...resources.logger import LoggerOrPrint
         | 
| 7 9 | 
             
            from ...io.writer import merge_hdf5_files
         | 
| 8 10 | 
             
            from ...cuda.utils import collect_cuda_gpus
         | 
| @@ -18,7 +20,7 @@ except: | |
| 18 20 |  | 
| 19 21 | 
             
            from .helical_chunked_regridded_cuda import CudaHelicalChunkedRegriddedPipeline
         | 
| 20 22 |  | 
| 21 | 
            -
            from ..fullfield.reconstruction import  | 
| 23 | 
            +
            from ..fullfield.reconstruction import FullFieldReconstructor
         | 
| 22 24 |  | 
| 23 25 | 
             
            avail_gpus = collect_cuda_gpus() or {}
         | 
| 24 26 |  | 
| @@ -53,7 +55,8 @@ class HelicalReconstructorRegridded: | |
| 53 55 | 
             
                        Dictionary with advanced options. Please see 'Other parameters' below
         | 
| 54 56 | 
             
                    cuda_options: dict, optional
         | 
| 55 57 | 
             
                        Dictionary with cuda options passed to `nabu.cuda.processing.CudaProcessing`
         | 
| 56 | 
            -
             | 
| 58 | 
            +
             | 
| 59 | 
            +
                    Other Parameters
         | 
| 57 60 | 
             
                    -----------------
         | 
| 58 61 | 
             
                    Advanced options can be passed in the 'extra_options' dictionary. These can be:
         | 
| 59 62 |  | 
| @@ -165,8 +168,8 @@ class HelicalReconstructorRegridded: | |
| 165 168 |  | 
| 166 169 | 
             
                        # the meaming of z_min and z_max is: position in slices units from the
         | 
| 167 170 | 
             
                        # first available slice and in the direction of the scan
         | 
| 168 | 
            -
                        self.z_min =  | 
| 169 | 
            -
                        self.z_max =  | 
| 171 | 
            +
                        self.z_min = round(z_start * (0 - z_fract_min) + z_max * z_fract_min)
         | 
| 172 | 
            +
                        self.z_max = round(z_start * (0 - z_fract_max) + z_max * z_fract_max) + 1
         | 
| 170 173 |  | 
| 171 174 | 
             
                def _compute_translations_margin(self):
         | 
| 172 175 | 
             
                    return 0, 0
         | 
| @@ -242,9 +245,9 @@ class HelicalReconstructorRegridded: | |
| 242 245 | 
             
                        "reconstruction" in process_config.processing_steps
         | 
| 243 246 | 
             
                        and process_config.processing_options["reconstruction"]["enable_halftomo"]
         | 
| 244 247 | 
             
                    ):
         | 
| 245 | 
            -
                        radios_and_sinos = True
         | 
| 248 | 
            +
                        radios_and_sinos = True  # noqa: F841
         | 
| 246 249 |  | 
| 247 | 
            -
                    max_dz = process_config.dataset_info.radio_dims[1]
         | 
| 250 | 
            +
                    # max_dz = process_config.dataset_info.radio_dims[1]
         | 
| 248 251 | 
             
                    chunk_size = chunk_step
         | 
| 249 252 | 
             
                    last_good_chunk_size = chunk_size
         | 
| 250 253 | 
             
                    while True:
         | 
| @@ -431,7 +434,7 @@ class HelicalReconstructorRegridded: | |
| 431 434 | 
             
                    angles_deg = np.rad2deg(angles_rad)
         | 
| 432 435 |  | 
| 433 436 | 
             
                    redundancy_angle_deg = self.process_config.nabu_config["reconstruction"]["redundancy_angle_deg"]
         | 
| 434 | 
            -
                    do_helical_half_tomo = self.process_config.nabu_config["reconstruction"]["helical_halftomo"]
         | 
| 437 | 
            +
                    # do_helical_half_tomo = self.process_config.nabu_config["reconstruction"]["helical_halftomo"]
         | 
| 435 438 |  | 
| 436 439 | 
             
                    self.logger.info("Creating SpanStrategy object for helical ")
         | 
| 437 440 | 
             
                    t0 = time()
         | 
| @@ -460,7 +463,7 @@ class HelicalReconstructorRegridded: | |
| 460 463 | 
             
                    self.logger.debug("Creating a new pipeline object")
         | 
| 461 464 | 
             
                    args = [self.process_config, task["sub_region"]]
         | 
| 462 465 |  | 
| 463 | 
            -
                    dz = self._get_delta_z(task)
         | 
| 466 | 
            +
                    # dz = self._get_delta_z(task)
         | 
| 464 467 |  | 
| 465 468 | 
             
                    pipeline = self._pipeline_cls(
         | 
| 466 469 | 
             
                        *args,
         | 
| @@ -542,7 +545,7 @@ class HelicalReconstructorRegridded: | |
| 542 545 | 
             
                    # Prevent issue when out_dir is empty, which happens only if dataset/location is a relative path.
         | 
| 543 546 | 
             
                    # TODO this should be prevented earlier
         | 
| 544 547 | 
             
                    if out_dir is None or len(out_dir.strip()) == 0:
         | 
| 545 | 
            -
                        out_dir = dirname(dirname(self.results[ | 
| 548 | 
            +
                        out_dir = dirname(dirname(self.results[first_generator_item(self.results.keys())]))
         | 
| 546 549 | 
             
                    #
         | 
| 547 550 | 
             
                    if output_file is None:
         | 
| 548 551 | 
             
                        output_file = join(out_dir, prefix + out_cfg["file_prefix"]) + ".hdf5"
         | 
| @@ -9,9 +9,8 @@ def find_mirror_indexes(angles_deg, tolerance_factor=1.0): | |
| 9 9 | 
             
                contains the index of the angles_deg array element which has the value the closest
         | 
| 10 10 | 
             
                to angles_deg[i] + 180. It is used for padding in halftomo.
         | 
| 11 11 |  | 
| 12 | 
            -
                Parameters | 
| 12 | 
            +
                Parameters
         | 
| 13 13 | 
             
                -----------
         | 
| 14 | 
            -
             | 
| 15 14 | 
             
                angles_deg: a nd.array of floats
         | 
| 16 15 |  | 
| 17 16 | 
             
                tolerance: float
         | 
| @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            # ruff: noqa
         | 
| 1 2 | 
             
            from ..fullfield.nabu_config import *
         | 
| 2 3 | 
             
            import copy
         | 
| 3 4 |  | 
| @@ -42,7 +43,7 @@ nabu_config["preproc"]["processes_file"] = { | |
| 42 43 | 
             
                "validator": optional_file_location_validator,
         | 
| 43 44 | 
             
                "type": "required",
         | 
| 44 45 | 
             
            }
         | 
| 45 | 
            -
            nabu_config["preproc"][" | 
| 46 | 
            +
            nabu_config["preproc"]["double_flatfield"]["default"] = 1
         | 
| 46 47 |  | 
| 47 48 |  | 
| 48 49 | 
             
            nabu_config["reconstruction"].update(
         | 
| @@ -12,8 +12,7 @@ class WeightBalancer: | |
| 12 12 | 
             
                    to Nabu, we create this class and follow the scheme initialisation + application.
         | 
| 13 13 |  | 
| 14 14 | 
             
                    Parameters
         | 
| 15 | 
            -
                     | 
| 16 | 
            -
             | 
| 15 | 
            +
                    ----------
         | 
| 17 16 | 
             
                    rot_center : float
         | 
| 18 17 | 
             
                                the center of rotation in pixel units
         | 
| 19 18 | 
             
                    angles_rad :
         | 
| @@ -84,7 +83,7 @@ def shift(arr, shift, fill_value=0.0): | |
| 84 83 | 
             
                """
         | 
| 85 84 | 
             
                result = np.zeros_like(arr)
         | 
| 86 85 |  | 
| 87 | 
            -
                num1 =  | 
| 86 | 
            +
                num1 = math.floor(shift)
         | 
| 88 87 | 
             
                num2 = num1 + 1
         | 
| 89 88 | 
             
                partition = shift - num1
         | 
| 90 89 |  | 
    
        nabu/pipeline/params.py
    CHANGED
    
    | @@ -3,6 +3,17 @@ flatfield_modes = { | |
| 3 3 | 
             
                "1": True,
         | 
| 4 4 | 
             
                "false": False,
         | 
| 5 5 | 
             
                "0": False,
         | 
| 6 | 
            +
                # These three should be removed after a while (moved to 'flatfield_loading_mode')
         | 
| 7 | 
            +
                "forced": "force-load",
         | 
| 8 | 
            +
                "force-load": "force-load",
         | 
| 9 | 
            +
                "force-compute": "force-compute",
         | 
| 10 | 
            +
                #
         | 
| 11 | 
            +
                "pca": "pca",
         | 
| 12 | 
            +
            }
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            flatfield_loading_mode = {
         | 
| 15 | 
            +
                "": "load_if_present",
         | 
| 16 | 
            +
                "load_if_present": "load_if_present",
         | 
| 6 17 | 
             
                "forced": "force-load",
         | 
| 7 18 | 
             
                "force-load": "force-load",
         | 
| 8 19 | 
             
                "force-compute": "force-compute",
         | 
| @@ -25,12 +36,17 @@ unsharp_methods = { | |
| 25 36 | 
             
                "": None,
         | 
| 26 37 | 
             
            }
         | 
| 27 38 |  | 
| 39 | 
            +
            # see PaddingBase.supported_modes
         | 
| 28 40 | 
             
            padding_modes = {
         | 
| 29 | 
            -
                "edges": "edge",
         | 
| 30 | 
            -
                "edge": "edge",
         | 
| 31 | 
            -
                "mirror": "mirror",
         | 
| 32 41 | 
             
                "zeros": "zeros",
         | 
| 33 42 | 
             
                "zero": "zeros",
         | 
| 43 | 
            +
                "constant": "zeros",
         | 
| 44 | 
            +
                "edges": "edge",
         | 
| 45 | 
            +
                "edge": "edge",
         | 
| 46 | 
            +
                "mirror": "reflect",
         | 
| 47 | 
            +
                "reflect": "reflect",
         | 
| 48 | 
            +
                "symmetric": "symmetric",
         | 
| 49 | 
            +
                "wrap": "wrap",
         | 
| 34 50 | 
             
            }
         | 
| 35 51 |  | 
| 36 52 | 
             
            reconstruction_methods = {
         | 
| @@ -72,6 +88,7 @@ iterative_methods = { | |
| 72 88 | 
             
            optim_algorithms = {
         | 
| 73 89 | 
             
                "chambolle": "chambolle-pock",
         | 
| 74 90 | 
             
                "chambollepock": "chambolle-pock",
         | 
| 91 | 
            +
                "chambolle-pock": "chambolle-pock",
         | 
| 75 92 | 
             
                "fista": "fista",
         | 
| 76 93 | 
             
            }
         | 
| 77 94 |  | 
| 
            File without changes
         | 
| @@ -1,14 +1,23 @@ | |
| 1 1 | 
             
            import os
         | 
| 2 | 
            +
            from tempfile import TemporaryDirectory
         | 
| 2 3 | 
             
            import pytest
         | 
| 3 4 | 
             
            import numpy as np
         | 
| 4 | 
            -
            from  | 
| 5 | 
            -
            from  | 
| 5 | 
            +
            from pint import get_application_registry
         | 
| 6 | 
            +
            from nxtomo import NXtomo
         | 
| 7 | 
            +
            from nabu.testutils import utilstest, __do_long_tests__, get_data
         | 
| 8 | 
            +
            from nabu.resources.dataset_analyzer import HDF5DatasetAnalyzer, analyze_dataset, ImageKey
         | 
| 6 9 | 
             
            from nabu.resources.nxflatfield import update_dataset_info_flats_darks
         | 
| 7 10 | 
             
            from nabu.resources.utils import extract_parameters
         | 
| 8 | 
            -
            from nabu.pipeline.estimators import CompositeCOREstimator
         | 
| 11 | 
            +
            from nabu.pipeline.estimators import CompositeCOREstimator, TranslationsEstimator
         | 
| 9 12 | 
             
            from nabu.pipeline.config import parse_nabu_config_file
         | 
| 10 13 | 
             
            from nabu.pipeline.estimators import SinoCORFinder, CORFinder
         | 
| 11 14 |  | 
| 15 | 
            +
            from nabu.estimation.tests.test_motion_estimation import (
         | 
| 16 | 
            +
                check_motion_estimation,
         | 
| 17 | 
            +
                project_volume,
         | 
| 18 | 
            +
                _create_translations_vector,
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
             | 
| 12 21 |  | 
| 13 22 | 
             
            #
         | 
| 14 23 | 
             
            # Test CoR estimation with "composite-coarse-to-fine" (aka "near" in the legacy system vocable)
         | 
| @@ -119,3 +128,231 @@ class TestCorNearPos: | |
| 119 128 | 
             
                        cor = finder.find_cor()
         | 
| 120 129 | 
             
                        message = f"Computed CoR {cor} and expected CoR {self.true_cor} do not coincide. Near_pos options was set to {cor_options.get('near_pos',None)}."
         | 
| 121 130 | 
             
                        assert np.isclose(self.true_cor + 0.5, cor, atol=self.abs_tol), message
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def _add_fake_flats_and_dark_to_data(data, n_darks=10, n_flats=21, dark_val=1, flat_val=3):
         | 
| 134 | 
            +
                img_shape = data.shape[1:]
         | 
| 135 | 
            +
                # Use constant darks/flats, to avoid "reduction" (mean/median) issues
         | 
| 136 | 
            +
                fake_darks = np.ones((n_darks,) + img_shape, dtype=np.uint16) * dark_val
         | 
| 137 | 
            +
                fake_flats = np.ones((n_flats,) + img_shape, dtype=np.uint16) * flat_val
         | 
| 138 | 
            +
                return data * (fake_flats[0, 0, 0] - fake_darks[0, 0, 0]) + fake_darks[0, 0, 0], fake_darks, fake_flats
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            def _generate_nx_for_180_dataset(volume, output_file_path, n_darks=10, n_flats=21):
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                n_angles = 250
         | 
| 144 | 
            +
                cor = -10
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                alpha_x = 4
         | 
| 147 | 
            +
                beta_x = 3
         | 
| 148 | 
            +
                alpha_y = -5
         | 
| 149 | 
            +
                beta_y = 10
         | 
| 150 | 
            +
                beta_z = 0
         | 
| 151 | 
            +
                orig_det_dist = 0
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                angles0 = np.linspace(0, np.pi, n_angles, False)
         | 
| 154 | 
            +
                return_angles = np.deg2rad([180.0, 135.0, 90.0, 45.0, 0.0])
         | 
| 155 | 
            +
                angles = np.hstack([angles0, return_angles]).ravel()
         | 
| 156 | 
            +
                a = np.arange(angles0.size + return_angles.size) / angles0.size
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                tx = _create_translations_vector(a, alpha_x, beta_x)
         | 
| 159 | 
            +
                ty = _create_translations_vector(a, alpha_y, beta_y)
         | 
| 160 | 
            +
                tz = _create_translations_vector(a, 0, beta_z)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
         | 
| 163 | 
            +
                data = np.moveaxis(sinos, 1, 0)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                sample_motion_xy = np.stack([-tx, ty], axis=1)
         | 
| 166 | 
            +
                sample_motion_z = -tz
         | 
| 167 | 
            +
                angles_deg = np.degrees(angles0)
         | 
| 168 | 
            +
                return_angles_deg = np.degrees(return_angles)
         | 
| 169 | 
            +
                n_return_radios = len(return_angles_deg)
         | 
| 170 | 
            +
                n_radios = data.shape[0] - n_return_radios
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                ureg = get_application_registry()
         | 
| 173 | 
            +
                fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                nxtomo = NXtomo()
         | 
| 176 | 
            +
                nxtomo.instrument.detector.data = np.concatenate(
         | 
| 177 | 
            +
                    [
         | 
| 178 | 
            +
                        darks,
         | 
| 179 | 
            +
                        flats,
         | 
| 180 | 
            +
                        fake_raw_data,  # radios + return radios (in float32 !)
         | 
| 181 | 
            +
                    ]
         | 
| 182 | 
            +
                )
         | 
| 183 | 
            +
                image_key_control = np.concatenate(
         | 
| 184 | 
            +
                    [
         | 
| 185 | 
            +
                        [ImageKey.DARK_FIELD.value] * n_darks,
         | 
| 186 | 
            +
                        [ImageKey.FLAT_FIELD.value] * n_flats,
         | 
| 187 | 
            +
                        [ImageKey.PROJECTION.value] * n_radios,
         | 
| 188 | 
            +
                        [ImageKey.ALIGNMENT.value] * n_return_radios,
         | 
| 189 | 
            +
                    ]
         | 
| 190 | 
            +
                )
         | 
| 191 | 
            +
                nxtomo.instrument.detector.image_key_control = image_key_control
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                rotation_angle = np.concatenate(
         | 
| 194 | 
            +
                    [np.zeros(n_darks, dtype="f"), np.zeros(n_flats, dtype="f"), angles_deg, return_angles_deg]
         | 
| 195 | 
            +
                )
         | 
| 196 | 
            +
                nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
         | 
| 197 | 
            +
                nxtomo.instrument.detector.field_of_view = "Full"
         | 
| 198 | 
            +
                nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
         | 
| 199 | 
            +
                nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                return sample_motion_xy, sample_motion_z, cor
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def _generate_nx_for_360_dataset(volume, output_file_path, n_darks=10, n_flats=21):
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                n_angles = 250
         | 
| 207 | 
            +
                cor = -5.5
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                alpha_x = -2
         | 
| 210 | 
            +
                beta_x = 7.0
         | 
| 211 | 
            +
                alpha_y = -2
         | 
| 212 | 
            +
                beta_y = 3
         | 
| 213 | 
            +
                beta_z = 100
         | 
| 214 | 
            +
                orig_det_dist = 0
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                angles = np.linspace(0, 2 * np.pi, n_angles, False)
         | 
| 217 | 
            +
                a = np.linspace(0, 1, angles.size, endpoint=False)  # theta/theta_max
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                tx = _create_translations_vector(a, alpha_x, beta_x)
         | 
| 220 | 
            +
                ty = _create_translations_vector(a, alpha_y, beta_y)
         | 
| 221 | 
            +
                tz = _create_translations_vector(a, 0, beta_z)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                sinos = project_volume(volume, angles, -tx, -ty, -tz, cor=-cor, orig_det_dist=orig_det_dist)
         | 
| 224 | 
            +
                data = np.moveaxis(sinos, 1, 0)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                sample_motion_xy = np.stack([-tx, ty], axis=1)
         | 
| 227 | 
            +
                sample_motion_z = -tz
         | 
| 228 | 
            +
                angles_deg = np.degrees(angles)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                ureg = get_application_registry()
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                fake_raw_data, darks, flats = _add_fake_flats_and_dark_to_data(data, n_darks=n_darks, n_flats=n_flats)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                nxtomo = NXtomo()
         | 
| 235 | 
            +
                nxtomo.instrument.detector.data = np.concatenate([darks, flats, fake_raw_data])  # in float32 !
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                image_key_control = np.concatenate(
         | 
| 238 | 
            +
                    [
         | 
| 239 | 
            +
                        [ImageKey.DARK_FIELD.value] * n_darks,
         | 
| 240 | 
            +
                        [ImageKey.FLAT_FIELD.value] * n_flats,
         | 
| 241 | 
            +
                        [ImageKey.PROJECTION.value] * data.shape[0],
         | 
| 242 | 
            +
                    ]
         | 
| 243 | 
            +
                )
         | 
| 244 | 
            +
                nxtomo.instrument.detector.image_key_control = image_key_control
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                rotation_angle = np.concatenate(
         | 
| 247 | 
            +
                    [
         | 
| 248 | 
            +
                        np.zeros(n_darks, dtype="f"),
         | 
| 249 | 
            +
                        np.zeros(n_flats, dtype="f"),
         | 
| 250 | 
            +
                        angles_deg,
         | 
| 251 | 
            +
                    ]
         | 
| 252 | 
            +
                )
         | 
| 253 | 
            +
                nxtomo.sample.rotation_angle = rotation_angle * ureg.degree
         | 
| 254 | 
            +
                nxtomo.instrument.detector.field_of_view = "Full"
         | 
| 255 | 
            +
                nxtomo.instrument.detector.x_pixel_size = nxtomo.instrument.detector.y_pixel_size = 1 * ureg.micrometer
         | 
| 256 | 
            +
                nxtomo.save(file_path=output_file_path, data_path="entry", overwrite=True)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                return sample_motion_xy, sample_motion_z, cor
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            @pytest.fixture(scope="class")
         | 
| 262 | 
            +
            def setup_test_motion_estimator(request):
         | 
| 263 | 
            +
                cls = request.cls
         | 
| 264 | 
            +
                cls.volume = get_data("motion/mri_volume_subsampled.npy")
         | 
| 265 | 
            +
             | 
| 266 | 
            +
             | 
| 267 | 
            +
            @pytest.mark.skipif(not (__do_long_tests__), reason="need environment variable NABU_LONG_TESTS=1")
         | 
| 268 | 
            +
            @pytest.mark.usefixtures("setup_test_motion_estimator")
         | 
| 269 | 
            +
            class TestMotionEstimator:
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def _setup(self, tmpdir):
         | 
| 272 | 
            +
                    # pytest uses some weird data structure for "tmpdir"
         | 
| 273 | 
            +
                    if not (isinstance(tmpdir, str)):
         | 
| 274 | 
            +
                        tmpdir = str(tmpdir)
         | 
| 275 | 
            +
                    #
         | 
| 276 | 
            +
                    if getattr(self, "volume", None) is None:
         | 
| 277 | 
            +
                        self.volume = get_data("motion/mri_volume_subsampled.npy")
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def test_estimate_motion_360_dataset(self, tmpdir, verbose=False):
         | 
| 280 | 
            +
                    self._setup(tmpdir)
         | 
| 281 | 
            +
                    nx_file_path = os.path.join(tmpdir, "mri_projected_360_motion.nx")
         | 
| 282 | 
            +
                    sample_motion_xy, sample_motion_z, cor = _generate_nx_for_360_dataset(self.volume, nx_file_path)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    dataset_info = analyze_dataset(nx_file_path)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    translations_estimator = TranslationsEstimator(
         | 
| 287 | 
            +
                        dataset_info, do_flatfield=True, rot_center=cor, angular_subsampling=5, deg_xy=2, deg_z=2
         | 
| 288 | 
            +
                    )
         | 
| 289 | 
            +
                    estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    s = translations_estimator.angular_subsampling
         | 
| 292 | 
            +
                    if verbose:
         | 
| 293 | 
            +
                        translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
         | 
| 294 | 
            +
                        translations_estimator.motion_estimator.plot_movements(
         | 
| 295 | 
            +
                            cor=cor,
         | 
| 296 | 
            +
                            angles_rad=dataset_info.rotation_angles[::s],
         | 
| 297 | 
            +
                            gt_xy=sample_motion_xy[::s, :],
         | 
| 298 | 
            +
                            gt_z=sample_motion_z[::s],
         | 
| 299 | 
            +
                        )
         | 
| 300 | 
            +
                    check_motion_estimation(
         | 
| 301 | 
            +
                        translations_estimator.motion_estimator,
         | 
| 302 | 
            +
                        dataset_info.rotation_angles[::s],
         | 
| 303 | 
            +
                        cor,
         | 
| 304 | 
            +
                        sample_motion_xy[::s, :],
         | 
| 305 | 
            +
                        sample_motion_z[::s],
         | 
| 306 | 
            +
                        fit_error_shifts_tol_vu=(0.2, 0.2),
         | 
| 307 | 
            +
                        fit_error_det_tol_vu=(1e-5, 5e-2),
         | 
| 308 | 
            +
                        fit_error_tol_xyz=(0.05, 0.05, 0.05),
         | 
| 309 | 
            +
                        fit_error_det_all_angles_tol_vu=(1e-5, 0.05),
         | 
| 310 | 
            +
                    )
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                def test_estimate_motion_180_dataset(self, tmpdir, verbose=False):
         | 
| 313 | 
            +
                    self._setup(tmpdir)
         | 
| 314 | 
            +
                    nx_file_path = os.path.join(tmpdir, "mri_projected_180_motion.nx")
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    sample_motion_xy, sample_motion_z, cor = _generate_nx_for_180_dataset(self.volume, nx_file_path)
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    dataset_info = analyze_dataset(nx_file_path)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    translations_estimator = TranslationsEstimator(
         | 
| 321 | 
            +
                        dataset_info,
         | 
| 322 | 
            +
                        do_flatfield=True,
         | 
| 323 | 
            +
                        rot_center=cor,
         | 
| 324 | 
            +
                        angular_subsampling=2,
         | 
| 325 | 
            +
                        deg_xy=2,
         | 
| 326 | 
            +
                        deg_z=2,
         | 
| 327 | 
            +
                        shifts_estimator="DetectorTranslationAlongBeam",
         | 
| 328 | 
            +
                    )
         | 
| 329 | 
            +
                    estimated_shifts_h, estimated_shifts_v, estimated_cor = translations_estimator.estimate_motion()
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    if verbose:
         | 
| 332 | 
            +
                        translations_estimator.motion_estimator.plot_detector_shifts(cor=cor)
         | 
| 333 | 
            +
                        translations_estimator.motion_estimator.plot_movements(
         | 
| 334 | 
            +
                            cor=cor,
         | 
| 335 | 
            +
                            angles_rad=dataset_info.rotation_angles,
         | 
| 336 | 
            +
                            gt_xy=sample_motion_xy[: dataset_info.n_angles],
         | 
| 337 | 
            +
                            gt_z=sample_motion_z[: dataset_info.n_angles],
         | 
| 338 | 
            +
                        )
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    check_motion_estimation(
         | 
| 341 | 
            +
                        translations_estimator.motion_estimator,
         | 
| 342 | 
            +
                        dataset_info.rotation_angles,
         | 
| 343 | 
            +
                        cor,
         | 
| 344 | 
            +
                        sample_motion_xy,
         | 
| 345 | 
            +
                        sample_motion_z,
         | 
| 346 | 
            +
                        fit_error_shifts_tol_vu=(0.02, 0.1),
         | 
| 347 | 
            +
                        fit_error_det_tol_vu=(1e-2, 0.5),
         | 
| 348 | 
            +
                        fit_error_tol_xyz=(0.5, 2, 1e-2),
         | 
| 349 | 
            +
                        fit_error_det_all_angles_tol_vu=(1e-2, 2),
         | 
| 350 | 
            +
                    )
         | 
| 351 | 
            +
             | 
| 352 | 
            +
             | 
| 353 | 
            +
            if __name__ == "__main__":
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                T = TestMotionEstimator()
         | 
| 356 | 
            +
                with TemporaryDirectory(suffix="_motion", prefix="nabu_testdata") as tmpdir:
         | 
| 357 | 
            +
                    T.test_estimate_motion_360_dataset(tmpdir, verbose=True)
         | 
| 358 | 
            +
                    T.test_estimate_motion_180_dataset(tmpdir, verbose=True)
         | 
    
        nabu/pipeline/utils.py
    CHANGED
    
    | @@ -72,7 +72,7 @@ def get_subregion(sub_region, ndim=3): | |
| 72 72 | 
             
                if sub_region is None:
         | 
| 73 73 | 
             
                    res = ((None, None),)
         | 
| 74 74 | 
             
                elif hasattr(sub_region[0], "__iter__"):
         | 
| 75 | 
            -
                    if set(map(len, sub_region)) !=  | 
| 75 | 
            +
                    if set(map(len, sub_region)) != {2}:
         | 
| 76 76 | 
             
                        raise ValueError("Expected each tuple to be in the form (start, end)")
         | 
| 77 77 | 
             
                    res = sub_region
         | 
| 78 78 | 
             
                else:
         | 
    
        nabu/pipeline/writer.py
    CHANGED
    
    | @@ -163,7 +163,7 @@ class WriterManager: | |
| 163 163 | 
             
                def _init_histogram_writer(self):
         | 
| 164 164 | 
             
                    if not self.histogram:
         | 
| 165 165 | 
             
                        return
         | 
| 166 | 
            -
                    separate_histogram_file =  | 
| 166 | 
            +
                    separate_histogram_file = self.file_format != "hdf5"
         | 
| 167 167 | 
             
                    if separate_histogram_file:
         | 
| 168 168 | 
             
                        fmode = "w"
         | 
| 169 169 | 
             
                        hist_fname = path.join(self.output_dir, "histogram_%05d.hdf5" % self.start_index)
         | 
    
        nabu/preproc/alignment.py
    CHANGED
    
    | @@ -1,11 +1 @@ | |
| 1 1 | 
             
            # Backward compat.
         | 
| 2 | 
            -
            from ..estimation.alignment import AlignmentBase
         | 
| 3 | 
            -
            from ..estimation.cor import (
         | 
| 4 | 
            -
                CenterOfRotation,
         | 
| 5 | 
            -
                CenterOfRotationAdaptiveSearch,
         | 
| 6 | 
            -
                CenterOfRotationGrowingWindow,
         | 
| 7 | 
            -
                CenterOfRotationSlidingWindow,
         | 
| 8 | 
            -
            )
         | 
| 9 | 
            -
            from ..estimation.translation import DetectorTranslationAlongBeam
         | 
| 10 | 
            -
            from ..estimation.focus import CameraFocus
         | 
| 11 | 
            -
            from ..estimation.tilt import CameraTilt
         | 
    
        nabu/preproc/ccd.py
    CHANGED
    
    | @@ -1,5 +1,6 @@ | |
| 1 1 | 
             
            import numpy as np
         | 
| 2 2 | 
             
            from ..utils import check_supported
         | 
| 3 | 
            +
            from scipy.ndimage import binary_dilation
         | 
| 3 4 | 
             
            from silx.math.medianfilter import medfilt2d
         | 
| 4 5 |  | 
| 5 6 |  | 
| @@ -13,6 +14,7 @@ class CCDFilter: | |
| 13 14 | 
             
                def __init__(
         | 
| 14 15 | 
             
                    self,
         | 
| 15 16 | 
             
                    radios_shape: tuple,
         | 
| 17 | 
            +
                    kernel_size: int = 3,
         | 
| 16 18 | 
             
                    correction_type: str = "median_clip",
         | 
| 17 19 | 
             
                    median_clip_thresh: float = 0.1,
         | 
| 18 20 | 
             
                    abs_diff=False,
         | 
| @@ -26,6 +28,9 @@ class CCDFilter: | |
| 26 28 | 
             
                    radios_shape: tuple
         | 
| 27 29 | 
             
                        A tuple describing the shape of the radios stack, in the form
         | 
| 28 30 | 
             
                        `(n_radios, n_z, n_x)`.
         | 
| 31 | 
            +
                    kernel_size: int
         | 
| 32 | 
            +
                        Size of the kernel for the median filter.
         | 
| 33 | 
            +
                        Default is 3.
         | 
| 29 34 | 
             
                    correction_type: str
         | 
| 30 35 | 
             
                        Correction type for radios ("median_clip", "sigma_clip", ...)
         | 
| 31 36 | 
             
                    median_clip_thresh: float, optional
         | 
| @@ -48,6 +53,7 @@ class CCDFilter: | |
| 48 53 | 
             
                         then this pixel value is set to the median value.
         | 
| 49 54 | 
             
                    """
         | 
| 50 55 | 
             
                    self._set_radios_shape(radios_shape)
         | 
| 56 | 
            +
                    self.kernel_size = kernel_size
         | 
| 51 57 | 
             
                    check_supported(correction_type, self._supported_ccd_corrections, "CCD correction mode")
         | 
| 52 58 | 
             
                    self.correction_type = correction_type
         | 
| 53 59 | 
             
                    self.median_clip_thresh = median_clip_thresh
         | 
| @@ -67,11 +73,11 @@ class CCDFilter: | |
| 67 73 | 
             
                    self.shape = (n_z, n_x)
         | 
| 68 74 |  | 
| 69 75 | 
             
                @staticmethod
         | 
| 70 | 
            -
                def median_filter(img):
         | 
| 76 | 
            +
                def median_filter(img, kernel_size=3):
         | 
| 71 77 | 
             
                    """
         | 
| 72 78 | 
             
                    Perform a median filtering on an image.
         | 
| 73 79 | 
             
                    """
         | 
| 74 | 
            -
                    return medfilt2d(img, ( | 
| 80 | 
            +
                    return medfilt2d(img, (kernel_size, kernel_size), mode="reflect")
         | 
| 75 81 |  | 
| 76 82 | 
             
                def median_clip_mask(self, img, return_medians=False):
         | 
| 77 83 | 
             
                    """
         | 
| @@ -85,7 +91,7 @@ class CCDFilter: | |
| 85 91 | 
             
                    return_medians: bool, optional
         | 
| 86 92 | 
             
                        Whether to return the median values additionally to the mask.
         | 
| 87 93 | 
             
                    """
         | 
| 88 | 
            -
                    median_values = self.median_filter(img)
         | 
| 94 | 
            +
                    median_values = self.median_filter(img, kernel_size=self.kernel_size)
         | 
| 89 95 | 
             
                    if not self.abs_diff:
         | 
| 90 96 | 
             
                        invalid_mask = img >= median_values + self.median_clip_thresh
         | 
| 91 97 | 
             
                    else:
         | 
| @@ -124,6 +130,50 @@ class CCDFilter: | |
| 124 130 |  | 
| 125 131 | 
             
                    return output
         | 
| 126 132 |  | 
| 133 | 
            +
                def dezinger_correction(self, radios, dark=None, nsigma=5, output=None):
         | 
| 134 | 
            +
                    """
         | 
| 135 | 
            +
                    Compute the median clip correction on a radios stack, and propagates the invalid pixels into vert and horiz directions.
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    Parameters
         | 
| 138 | 
            +
                    ----------
         | 
| 139 | 
            +
                    radios: numpy.ndarray
         | 
| 140 | 
            +
                        A radios stack.
         | 
| 141 | 
            +
                    dark: numpy.ndarray, optional
         | 
| 142 | 
            +
                        A dark image. Default is None. If not None, it is subtracted from the radios.
         | 
| 143 | 
            +
                    nsigma: float, optional
         | 
| 144 | 
            +
                        Number of standard deviations to use for the zinger detection.
         | 
| 145 | 
            +
                        Default is 5.
         | 
| 146 | 
            +
                    output: numpy.ndarray, optional
         | 
| 147 | 
            +
                        Output array
         | 
| 148 | 
            +
                    """
         | 
| 149 | 
            +
                    if radios.shape[1:] != self.radios_shape[1:]:
         | 
| 150 | 
            +
                        raise ValueError(f"Expected radios shape {self.radios_shape}, got {radios.shape}")
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if output is None:
         | 
| 153 | 
            +
                        output = np.copy(radios)
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        output[:] = radios[:]
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    n_radios = radios.shape[0]
         | 
| 158 | 
            +
                    for i in range(n_radios):
         | 
| 159 | 
            +
                        if dark is None:
         | 
| 160 | 
            +
                            dimg = radios[i]
         | 
| 161 | 
            +
                        elif dark.shape == radios.shape[1:]:
         | 
| 162 | 
            +
                            dimg = radios[i] - dark
         | 
| 163 | 
            +
                        else:
         | 
| 164 | 
            +
                            raise ValueError("Dark image shape does not match radios shape.")
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        dimg = radios[i] - dark
         | 
| 167 | 
            +
                        med = self.median_filter(dimg, self.kernel_size)
         | 
| 168 | 
            +
                        err = dimg - med
         | 
| 169 | 
            +
                        ds0 = err.std()
         | 
| 170 | 
            +
                        msk = err > (ds0 * nsigma)
         | 
| 171 | 
            +
                        gromsk = binary_dilation(msk)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        output[i] = np.where(gromsk, med, radios[i])
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    return output
         | 
| 176 | 
            +
             | 
| 127 177 |  | 
| 128 178 | 
             
            class Log:
         | 
| 129 179 | 
             
                """
         |