careamics 0.1.0rc7__tar.gz → 0.1.0rc8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/.pre-commit-config.yaml +2 -2
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/PKG-INFO +1 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/pyproject.toml +1 -0
- careamics-0.1.0rc8/src/careamics/__init__.py +13 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/careamist.py +83 -62
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/__init__.py +0 -3
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/algorithm_model.py +8 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/architectures/architecture_model.py +1 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/architectures/custom_model.py +2 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/architectures/unet_model.py +19 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/architectures/vae_model.py +1 -0
- careamics-0.1.0rc8/src/careamics/config/callback_model.py +123 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/configuration_factory.py +1 -79
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/configuration_model.py +12 -7
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/data_model.py +29 -10
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/inference_model.py +12 -2
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/optimizer_models.py +6 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_data.py +29 -4
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/tile_information.py +10 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/training_model.py +5 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/dataset_utils/__init__.py +0 -6
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/dataset_utils/file_utils.py +1 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/dataset_utils/iterate_over_files.py +1 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/in_memory_dataset.py +37 -21
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/iterable_dataset.py +38 -34
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/iterable_pred_dataset.py +2 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/patching/patching.py +53 -37
- careamics-0.1.0rc8/src/careamics/file_io/__init__.py +7 -0
- careamics-0.1.0rc8/src/careamics/file_io/read/__init__.py +11 -0
- careamics-0.1.0rc8/src/careamics/file_io/read/get_func.py +56 -0
- careamics-0.1.0rc7/src/careamics/dataset/dataset_utils/read_tiff.py → careamics-0.1.0rc8/src/careamics/file_io/read/tiff.py +3 -1
- careamics-0.1.0rc8/src/careamics/file_io/write/__init__.py +9 -0
- careamics-0.1.0rc8/src/careamics/file_io/write/get_func.py +59 -0
- careamics-0.1.0rc8/src/careamics/file_io/write/tiff.py +39 -0
- careamics-0.1.0rc8/src/careamics/lightning/__init__.py +17 -0
- {careamics-0.1.0rc7/src/careamics → careamics-0.1.0rc8/src/careamics/lightning}/lightning_module.py +58 -85
- careamics-0.1.0rc7/src/careamics/lightning_prediction_datamodule.py → careamics-0.1.0rc8/src/careamics/lightning/predict_data_module.py +78 -116
- careamics-0.1.0rc7/src/careamics/lightning_datamodule.py → careamics-0.1.0rc8/src/careamics/lightning/train_data_module.py +134 -214
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/model_io/bmz_io.py +1 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/model_io/model_io_utils.py +1 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/prediction_utils/__init__.py +0 -2
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/prediction_utils/prediction_outputs.py +18 -46
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/prediction_utils/stitch_prediction.py +17 -14
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/__init__.py +2 -0
- careamics-0.1.0rc8/src/careamics/utils/autocorrelation.py +40 -0
- careamics-0.1.0rc8/tests/config/support/test_supported_data.py +103 -0
- careamics-0.1.0rc8/tests/config/test_callback_models.py +16 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_configuration_factory.py +0 -86
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_data_model.py +2 -2
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/conftest.py +2 -1
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/test_iterable_dataset.py +1 -1
- careamics-0.1.0rc8/tests/file_io/read/test_get_read_func.py +14 -0
- {careamics-0.1.0rc7/tests/dataset/dataset_utils → careamics-0.1.0rc8/tests/file_io/read}/test_read_tiff.py +1 -1
- careamics-0.1.0rc8/tests/file_io/write/test_get_write_func.py +14 -0
- careamics-0.1.0rc8/tests/file_io/write/test_write_tiff.py +27 -0
- careamics-0.1.0rc8/tests/lightning/test_lightning_api.py +133 -0
- {careamics-0.1.0rc7/tests → careamics-0.1.0rc8/tests/lightning}/test_lightning_module.py +2 -2
- careamics-0.1.0rc7/tests/test_lightning_prediction_datamodule.py → careamics-0.1.0rc8/tests/lightning/test_predict_data_module.py +7 -7
- careamics-0.1.0rc7/tests/test_lightning_datamodule.py → careamics-0.1.0rc8/tests/lightning/test_train_data_module.py +75 -13
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/model_io/test_bmz_io.py +4 -2
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/prediction_utils/test_prediction_outputs.py +5 -7
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/prediction_utils/test_stitch_prediction.py +2 -2
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/test_careamist.py +22 -7
- careamics-0.1.0rc8/tests/utils/test_autocorrelation.py +42 -0
- careamics-0.1.0rc7/examples/2D/n2n/example_SEM_careamist.ipynb +0 -239
- careamics-0.1.0rc7/examples/2D/n2n/n2n_2D_SEM.yml +0 -60
- careamics-0.1.0rc7/examples/2D/n2v/example_BSD68_careamist.ipynb +0 -303
- careamics-0.1.0rc7/examples/2D/n2v/example_BSD68_lightning.ipynb +0 -389
- careamics-0.1.0rc7/examples/2D/n2v/example_SEM_lightning.ipynb +0 -312
- careamics-0.1.0rc7/examples/2D/n2v/n2v_2D_BSD.yml +0 -60
- careamics-0.1.0rc7/examples/2D/pn2v/pN2V_Convallaria.yml +0 -71
- careamics-0.1.0rc7/src/careamics/__init__.py +0 -26
- careamics-0.1.0rc7/src/careamics/config/callback_model.py +0 -81
- careamics-0.1.0rc7/src/careamics/config/configuration_example.py +0 -86
- careamics-0.1.0rc7/src/careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics-0.1.0rc7/src/careamics/prediction_utils/create_pred_datamodule.py +0 -185
- careamics-0.1.0rc7/tests/config/support/test_supported_data.py +0 -76
- careamics-0.1.0rc7/tests/config/test_full_config_example.py +0 -6
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/.github/TEST_FAIL_TEMPLATE.md +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/.github/pull_request_template.md +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/.github/workflows/ci.yml +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/.gitignore +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/LICENSE +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/README.md +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/examples/3D/example_flywing_3D.ipynb +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/examples/3D/n2v_flywing_3D.yml +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/examples/evaluate_LVAE.ipynb +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/architectures/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/architectures/register_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/references/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/references/algorithm_descriptions.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/references/references.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_activations.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_algorithms.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_architectures.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_loggers.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_losses.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_optimizers.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_pixel_manipulations.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_struct_axis.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/support/supported_transforms.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/transformations/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/transformations/n2v_manipulate_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/transformations/normalize_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/transformations/transform_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/transformations/xy_flip_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/transformations/xy_random_rotate90_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/validators/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/validators/validator_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/conftest.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/dataset_utils/dataset_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/dataset_utils/running_stats.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/in_memory_pred_dataset.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/in_memory_tiled_pred_dataset.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/patching/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/patching/random_patching.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/patching/sequential_patching.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/patching/validate_patch_dimension.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/tiling/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/tiling/collate_tiles.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/tiling/tiled_patching.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/dataset/zarr_dataset.py +0 -0
- /careamics-0.1.0rc7/src/careamics/dataset/dataset_utils/read_zarr.py → /careamics-0.1.0rc8/src/careamics/file_io/read/zarr.py +0 -0
- {careamics-0.1.0rc7/src/careamics → careamics-0.1.0rc8/src/careamics/lightning}/callbacks/__init__.py +0 -0
- {careamics-0.1.0rc7/src/careamics → careamics-0.1.0rc8/src/careamics/lightning}/callbacks/hyperparameters_callback.py +0 -0
- {careamics-0.1.0rc7/src/careamics → careamics-0.1.0rc8/src/careamics/lightning}/callbacks/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/losses/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/losses/loss_factory.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/losses/losses.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/data_modules.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/data_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/eval_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/get_config.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/lightning_module.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/metrics.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/train_lvae.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/lvae_training/train_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/model_io/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/model_io/bioimage/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/model_io/bioimage/_readme_factory.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/model_io/bioimage/bioimage_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/model_io/bioimage/model_description.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/activation.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/layers.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/lvae/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/lvae/layers.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/lvae/likelihoods.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/lvae/lvae.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/lvae/noise_models.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/lvae/utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/model_factory.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/models/unet.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/py.typed +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/__init__.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/compose.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/n2v_manipulate.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/normalize.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/pixel_manipulation.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/struct_mask_parameters.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/transform.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/tta.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/xy_flip.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/transforms/xy_random_rotate90.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/base_enum.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/context.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/logging.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/metrics.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/path_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/ram.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/receptive_field.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/utils/torch_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/architectures/test_architecture_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/architectures/test_custom_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/architectures/test_register_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/architectures/test_unet_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/support/test_supported_optimizers.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_algorithm_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_configuration_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_inference_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_optimizers_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_tile_information.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/test_training_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/transformations/test_n2v_manipulate_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/transformations/test_normalize_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/transformations/test_xy_flip_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/transformations/test_xy_random_rotate90_model.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/config/validators/test_validator_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/dataset_utils/test_compute_normalization_stats.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/dataset_utils/test_list_files.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/patching/test_patching_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/patching/test_random_patching.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/patching/test_sequential_patching.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/test_in_memory_dataset.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/test_in_memory_pred_dataset.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/test_in_memory_tiled_pred_dataset.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/test_iterable_pred_dataset.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/test_iterable_tiled_pred_dataset.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/tiling/test_collate_tiles.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/dataset/tiling/test_tiled_patching.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/models/test_model_factory.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/models/test_unet.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/test_conftest.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_compose.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_manipulate_n2v.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_normalize.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_pixel_manipulation.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_supported_transforms.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_tta.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_xy_flip.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/transforms/test_xy_random_rotate90.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/utils/test_base_enum.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/utils/test_context.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/utils/test_logging.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/utils/test_metrics.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/utils/test_torch_utils.py +0 -0
- {careamics-0.1.0rc7 → careamics-0.1.0rc8}/tests/utils/test_wandb.py +0 -0
|
@@ -14,7 +14,7 @@ repos:
|
|
|
14
14
|
- id: validate-pyproject
|
|
15
15
|
|
|
16
16
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
17
|
-
rev: v0.
|
|
17
|
+
rev: v0.5.0
|
|
18
18
|
hooks:
|
|
19
19
|
- id: ruff
|
|
20
20
|
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*"
|
|
@@ -26,7 +26,7 @@ repos:
|
|
|
26
26
|
- id: black
|
|
27
27
|
|
|
28
28
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
29
|
-
rev: v1.10.
|
|
29
|
+
rev: v1.10.1
|
|
30
30
|
hooks:
|
|
31
31
|
- id: mypy
|
|
32
32
|
files: "^src/"
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Main CAREamics module."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
__version__ = version("careamics")
|
|
7
|
+
except PackageNotFoundError:
|
|
8
|
+
__version__ = "uninstalled"
|
|
9
|
+
|
|
10
|
+
__all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]
|
|
11
|
+
|
|
12
|
+
from .careamist import CAREamist
|
|
13
|
+
from .config import Configuration, load_configuration, save_configuration
|
|
@@ -13,22 +13,29 @@ from pytorch_lightning.callbacks import (
|
|
|
13
13
|
)
|
|
14
14
|
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
15
15
|
|
|
16
|
-
from careamics.callbacks import ProgressBarCallback
|
|
17
16
|
from careamics.config import (
|
|
18
17
|
Configuration,
|
|
19
18
|
load_configuration,
|
|
20
19
|
)
|
|
21
|
-
from careamics.config.support import
|
|
20
|
+
from careamics.config.support import (
|
|
21
|
+
SupportedAlgorithm,
|
|
22
|
+
SupportedArchitecture,
|
|
23
|
+
SupportedData,
|
|
24
|
+
SupportedLogger,
|
|
25
|
+
)
|
|
22
26
|
from careamics.dataset.dataset_utils import reshape_array
|
|
23
|
-
from careamics.
|
|
24
|
-
|
|
27
|
+
from careamics.lightning import (
|
|
28
|
+
CAREamicsModule,
|
|
29
|
+
HyperParametersCallback,
|
|
30
|
+
PredictDataModule,
|
|
31
|
+
ProgressBarCallback,
|
|
32
|
+
TrainDataModule,
|
|
33
|
+
create_predict_datamodule,
|
|
34
|
+
)
|
|
25
35
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
26
|
-
from careamics.prediction_utils import convert_outputs
|
|
36
|
+
from careamics.prediction_utils import convert_outputs
|
|
27
37
|
from careamics.utils import check_path_exists, get_logger
|
|
28
38
|
|
|
29
|
-
from .callbacks import HyperParametersCallback
|
|
30
|
-
from .lightning_prediction_datamodule import CAREamicsPredictData
|
|
31
|
-
|
|
32
39
|
logger = get_logger(__name__)
|
|
33
40
|
|
|
34
41
|
LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
|
|
@@ -61,9 +68,9 @@ class CAREamist:
|
|
|
61
68
|
Experiment logger, "wandb" or "tensorboard".
|
|
62
69
|
work_dir : pathlib.Path
|
|
63
70
|
Working directory.
|
|
64
|
-
train_datamodule :
|
|
71
|
+
train_datamodule : TrainDataModule
|
|
65
72
|
Training datamodule.
|
|
66
|
-
pred_datamodule :
|
|
73
|
+
pred_datamodule : PredictDataModule
|
|
67
74
|
Prediction datamodule.
|
|
68
75
|
"""
|
|
69
76
|
|
|
@@ -193,8 +200,8 @@ class CAREamist:
|
|
|
193
200
|
)
|
|
194
201
|
|
|
195
202
|
# place holder for the datamodules
|
|
196
|
-
self.train_datamodule: Optional[
|
|
197
|
-
self.pred_datamodule: Optional[
|
|
203
|
+
self.train_datamodule: Optional[TrainDataModule] = None
|
|
204
|
+
self.pred_datamodule: Optional[PredictDataModule] = None
|
|
198
205
|
|
|
199
206
|
def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
|
|
200
207
|
"""
|
|
@@ -246,7 +253,7 @@ class CAREamist:
|
|
|
246
253
|
def train(
|
|
247
254
|
self,
|
|
248
255
|
*,
|
|
249
|
-
datamodule: Optional[
|
|
256
|
+
datamodule: Optional[TrainDataModule] = None,
|
|
250
257
|
train_source: Optional[Union[Path, str, NDArray]] = None,
|
|
251
258
|
val_source: Optional[Union[Path, str, NDArray]] = None,
|
|
252
259
|
train_target: Optional[Union[Path, str, NDArray]] = None,
|
|
@@ -273,7 +280,7 @@ class CAREamist:
|
|
|
273
280
|
|
|
274
281
|
Parameters
|
|
275
282
|
----------
|
|
276
|
-
datamodule :
|
|
283
|
+
datamodule : TrainDataModule, optional
|
|
277
284
|
Datamodule to train on, by default None.
|
|
278
285
|
train_source : pathlib.Path or str or NDArray, optional
|
|
279
286
|
Train source, if no datamodule is provided, by default None.
|
|
@@ -375,17 +382,17 @@ class CAREamist:
|
|
|
375
382
|
|
|
376
383
|
else:
|
|
377
384
|
raise ValueError(
|
|
378
|
-
f"Invalid input, expected a str, Path, array or
|
|
385
|
+
f"Invalid input, expected a str, Path, array or TrainDataModule "
|
|
379
386
|
f"instance (got {type(train_source)})."
|
|
380
387
|
)
|
|
381
388
|
|
|
382
|
-
def _train_on_datamodule(self, datamodule:
|
|
389
|
+
def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
|
|
383
390
|
"""
|
|
384
391
|
Train the model on the provided datamodule.
|
|
385
392
|
|
|
386
393
|
Parameters
|
|
387
394
|
----------
|
|
388
|
-
datamodule :
|
|
395
|
+
datamodule : TrainDataModule
|
|
389
396
|
Datamodule to train on.
|
|
390
397
|
"""
|
|
391
398
|
# record datamodule
|
|
@@ -421,7 +428,7 @@ class CAREamist:
|
|
|
421
428
|
Minimum number of patches to use for validation, by default 5.
|
|
422
429
|
"""
|
|
423
430
|
# create datamodule
|
|
424
|
-
datamodule =
|
|
431
|
+
datamodule = TrainDataModule(
|
|
425
432
|
data_config=self.cfg.data_config,
|
|
426
433
|
train_data=train_data,
|
|
427
434
|
val_data=val_data,
|
|
@@ -477,7 +484,7 @@ class CAREamist:
|
|
|
477
484
|
path_to_val_target = check_path_exists(path_to_val_target)
|
|
478
485
|
|
|
479
486
|
# create datamodule
|
|
480
|
-
datamodule =
|
|
487
|
+
datamodule = TrainDataModule(
|
|
481
488
|
data_config=self.cfg.data_config,
|
|
482
489
|
train_data=path_to_train_data,
|
|
483
490
|
val_data=path_to_val_data,
|
|
@@ -493,10 +500,7 @@ class CAREamist:
|
|
|
493
500
|
|
|
494
501
|
@overload
|
|
495
502
|
def predict( # numpydoc ignore=GL08
|
|
496
|
-
self,
|
|
497
|
-
source: CAREamicsPredictData,
|
|
498
|
-
*,
|
|
499
|
-
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
503
|
+
self, source: PredictDataModule
|
|
500
504
|
) -> Union[list[NDArray], NDArray]: ...
|
|
501
505
|
|
|
502
506
|
@overload
|
|
@@ -513,7 +517,6 @@ class CAREamist:
|
|
|
513
517
|
dataloader_params: Optional[dict] = None,
|
|
514
518
|
read_source_func: Optional[Callable] = None,
|
|
515
519
|
extension_filter: str = "",
|
|
516
|
-
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
517
520
|
) -> Union[list[NDArray], NDArray]: ...
|
|
518
521
|
|
|
519
522
|
@overload
|
|
@@ -528,12 +531,11 @@ class CAREamist:
|
|
|
528
531
|
data_type: Optional[Literal["array"]] = None,
|
|
529
532
|
tta_transforms: bool = True,
|
|
530
533
|
dataloader_params: Optional[dict] = None,
|
|
531
|
-
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
532
534
|
) -> Union[list[NDArray], NDArray]: ...
|
|
533
535
|
|
|
534
536
|
def predict(
|
|
535
537
|
self,
|
|
536
|
-
source: Union[
|
|
538
|
+
source: Union[PredictDataModule, Path, str, NDArray],
|
|
537
539
|
*,
|
|
538
540
|
batch_size: Optional[int] = None,
|
|
539
541
|
tile_size: Optional[tuple[int, ...]] = None,
|
|
@@ -544,7 +546,6 @@ class CAREamist:
|
|
|
544
546
|
dataloader_params: Optional[dict] = None,
|
|
545
547
|
read_source_func: Optional[Callable] = None,
|
|
546
548
|
extension_filter: str = "",
|
|
547
|
-
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
548
549
|
**kwargs: Any,
|
|
549
550
|
) -> Union[list[NDArray], NDArray]:
|
|
550
551
|
"""
|
|
@@ -590,8 +591,6 @@ class CAREamist:
|
|
|
590
591
|
Function to read the source data.
|
|
591
592
|
extension_filter : str, default=""
|
|
592
593
|
Filter for the file extension.
|
|
593
|
-
checkpoint : {"best", "last"}, optional
|
|
594
|
-
Checkpoint to use for prediction.
|
|
595
594
|
**kwargs : Any
|
|
596
595
|
Unused.
|
|
597
596
|
|
|
@@ -599,31 +598,64 @@ class CAREamist:
|
|
|
599
598
|
-------
|
|
600
599
|
list of NDArray or NDArray
|
|
601
600
|
Predictions made by the model.
|
|
602
|
-
"""
|
|
603
|
-
# Reuse batch size if not provided explicitly
|
|
604
|
-
if batch_size is None:
|
|
605
|
-
batch_size = (
|
|
606
|
-
self.train_datamodule.batch_size
|
|
607
|
-
if self.train_datamodule
|
|
608
|
-
else self.cfg.data_config.batch_size
|
|
609
|
-
)
|
|
610
601
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
602
|
+
Raises
|
|
603
|
+
------
|
|
604
|
+
ValueError
|
|
605
|
+
If mean and std are not provided in the configuration.
|
|
606
|
+
ValueError
|
|
607
|
+
If tile size is not divisible by 2**depth for UNet models.
|
|
608
|
+
ValueError
|
|
609
|
+
If tile overlap is not specified.
|
|
610
|
+
"""
|
|
611
|
+
if (
|
|
612
|
+
self.cfg.data_config.image_means is None
|
|
613
|
+
or self.cfg.data_config.image_stds is None
|
|
614
|
+
):
|
|
615
|
+
raise ValueError("Mean and std must be provided in the configuration.")
|
|
616
|
+
|
|
617
|
+
# tile size for UNets
|
|
618
|
+
if tile_size is not None:
|
|
619
|
+
model = self.cfg.algorithm_config.model
|
|
620
|
+
|
|
621
|
+
if model.architecture == SupportedArchitecture.UNET.value:
|
|
622
|
+
# tile size must be equal to k*2^n, where n is the number of pooling
|
|
623
|
+
# layers (equal to the depth) and k is an integer
|
|
624
|
+
depth = model.depth
|
|
625
|
+
tile_increment = 2**depth
|
|
626
|
+
|
|
627
|
+
for i, t in enumerate(tile_size):
|
|
628
|
+
if t % tile_increment != 0:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
f"Tile size must be divisible by {tile_increment} along "
|
|
631
|
+
f"all axes (got {t} for axis {i}). If your image size is "
|
|
632
|
+
f"smaller along one axis (e.g. Z), consider padding the "
|
|
633
|
+
f"image."
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
# tile overlaps must be specified
|
|
637
|
+
if tile_overlap is None:
|
|
638
|
+
raise ValueError("Tile overlap must be specified.")
|
|
639
|
+
|
|
640
|
+
# create the prediction
|
|
641
|
+
self.pred_datamodule = create_predict_datamodule(
|
|
642
|
+
pred_data=source,
|
|
643
|
+
data_type=data_type or self.cfg.data_config.data_type,
|
|
644
|
+
axes=axes or self.cfg.data_config.axes,
|
|
645
|
+
image_means=self.cfg.data_config.image_means,
|
|
646
|
+
image_stds=self.cfg.data_config.image_stds,
|
|
615
647
|
tile_size=tile_size,
|
|
616
648
|
tile_overlap=tile_overlap,
|
|
617
|
-
|
|
618
|
-
data_type=data_type,
|
|
649
|
+
batch_size=batch_size or self.cfg.data_config.batch_size,
|
|
619
650
|
tta_transforms=tta_transforms,
|
|
620
|
-
dataloader_params=dataloader_params,
|
|
621
651
|
read_source_func=read_source_func,
|
|
622
652
|
extension_filter=extension_filter,
|
|
653
|
+
dataloader_params=dataloader_params,
|
|
623
654
|
)
|
|
624
655
|
|
|
656
|
+
# predict
|
|
625
657
|
predictions = self.trainer.predict(
|
|
626
|
-
model=self.model, datamodule=self.pred_datamodule
|
|
658
|
+
model=self.model, datamodule=self.pred_datamodule
|
|
627
659
|
)
|
|
628
660
|
return convert_outputs(predictions, self.pred_datamodule.tiled)
|
|
629
661
|
|
|
@@ -659,27 +691,16 @@ class CAREamist:
|
|
|
659
691
|
data_description : str, optional
|
|
660
692
|
Description of the data, by default None.
|
|
661
693
|
"""
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
# axes need to be reformated for the export because reshaping was done in the
|
|
665
|
-
# datamodule
|
|
666
|
-
if "Z" in self.cfg.data_config.axes:
|
|
667
|
-
axes = "SCZYX"
|
|
668
|
-
else:
|
|
669
|
-
axes = "SCYX"
|
|
694
|
+
# TODO: add in docs that it is expected that input_array dimensions match
|
|
695
|
+
# those in data_config
|
|
670
696
|
|
|
671
|
-
# predict output, remove extra dimensions for the purpose of the prediction
|
|
672
697
|
output_patch = self.predict(
|
|
673
|
-
|
|
698
|
+
input_array,
|
|
674
699
|
data_type=SupportedData.ARRAY.value,
|
|
675
|
-
axes=axes,
|
|
676
700
|
tta_transforms=False,
|
|
677
701
|
)
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
output = np.concatenate(output_patch, axis=0)
|
|
681
|
-
else:
|
|
682
|
-
output = output_patch
|
|
702
|
+
output = np.concatenate(output_patch, axis=0)
|
|
703
|
+
input_array = reshape_array(input_array, self.cfg.data_config.axes)
|
|
683
704
|
|
|
684
705
|
export_to_bmz(
|
|
685
706
|
model=self.model,
|
|
@@ -688,7 +709,7 @@ class CAREamist:
|
|
|
688
709
|
name=name,
|
|
689
710
|
general_description=general_description,
|
|
690
711
|
authors=authors,
|
|
691
|
-
input_array=
|
|
712
|
+
input_array=input_array,
|
|
692
713
|
output_array=output,
|
|
693
714
|
channel_names=channel_names,
|
|
694
715
|
data_description=data_description,
|
|
@@ -14,9 +14,7 @@ __all__ = [
|
|
|
14
14
|
"create_care_configuration",
|
|
15
15
|
"register_model",
|
|
16
16
|
"CustomModel",
|
|
17
|
-
"create_inference_configuration",
|
|
18
17
|
"clear_custom_models",
|
|
19
|
-
"ConfigurationInformation",
|
|
20
18
|
]
|
|
21
19
|
|
|
22
20
|
from .algorithm_model import AlgorithmConfig
|
|
@@ -24,7 +22,6 @@ from .architectures import CustomModel, clear_custom_models, register_model
|
|
|
24
22
|
from .callback_model import CheckpointModel
|
|
25
23
|
from .configuration_factory import (
|
|
26
24
|
create_care_configuration,
|
|
27
|
-
create_inference_configuration,
|
|
28
25
|
create_n2n_configuration,
|
|
29
26
|
create_n2v_configuration,
|
|
30
27
|
)
|
|
@@ -93,12 +93,20 @@ class AlgorithmConfig(BaseModel):
|
|
|
93
93
|
|
|
94
94
|
# Mandatory fields
|
|
95
95
|
algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
|
|
96
|
+
"""Name of the algorithm, as defined in SupportedAlgorithm."""
|
|
97
|
+
|
|
96
98
|
loss: Literal["n2v", "mae", "mse"]
|
|
99
|
+
"""Loss function to use, as defined in SupportedLoss."""
|
|
100
|
+
|
|
97
101
|
model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
|
|
102
|
+
"""Model architecture to use, defined in SupportedArchitecture."""
|
|
98
103
|
|
|
99
104
|
# Optional fields
|
|
100
105
|
optimizer: OptimizerModel = OptimizerModel()
|
|
106
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
107
|
+
|
|
101
108
|
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
109
|
+
"""Learning rate scheduler to use, defined in SupportedScheduler."""
|
|
102
110
|
|
|
103
111
|
@model_validator(mode="after")
|
|
104
112
|
def algorithm_cross_validation(self: Self) -> Self:
|
{careamics-0.1.0rc7 → careamics-0.1.0rc8}/src/careamics/config/architectures/custom_model.py
RENAMED
|
@@ -72,9 +72,11 @@ class CustomModel(ArchitectureModel):
|
|
|
72
72
|
|
|
73
73
|
# discriminator used for choosing the pydantic model in Model
|
|
74
74
|
architecture: Literal["Custom"]
|
|
75
|
+
"""Name of the architecture."""
|
|
75
76
|
|
|
76
77
|
# name of the custom model
|
|
77
78
|
name: str
|
|
79
|
+
"""Name of the custom model."""
|
|
78
80
|
|
|
79
81
|
@field_validator("name")
|
|
80
82
|
@classmethod
|
|
@@ -29,19 +29,38 @@ class UNetModel(ArchitectureModel):
|
|
|
29
29
|
|
|
30
30
|
# discriminator used for choosing the pydantic model in Model
|
|
31
31
|
architecture: Literal["UNet"]
|
|
32
|
+
"""Name of the architecture."""
|
|
32
33
|
|
|
33
34
|
# parameters
|
|
34
35
|
# validate_defaults allow ignoring default values in the dump if they were not set
|
|
35
36
|
conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
|
|
37
|
+
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
38
|
+
|
|
36
39
|
num_classes: int = Field(default=1, ge=1, validate_default=True)
|
|
40
|
+
"""Number of classes or channels in the model output."""
|
|
41
|
+
|
|
37
42
|
in_channels: int = Field(default=1, ge=1, validate_default=True)
|
|
43
|
+
"""Number of channels in the input to the model."""
|
|
44
|
+
|
|
38
45
|
depth: int = Field(default=2, ge=1, le=10, validate_default=True)
|
|
46
|
+
"""Number of levels in the UNet."""
|
|
47
|
+
|
|
39
48
|
num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
|
|
49
|
+
"""Number of convolutional filters in the first layer of the UNet."""
|
|
50
|
+
|
|
40
51
|
final_activation: Literal[
|
|
41
52
|
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
42
53
|
] = Field(default="None", validate_default=True)
|
|
54
|
+
"""Final activation function."""
|
|
55
|
+
|
|
43
56
|
n2v2: bool = Field(default=False, validate_default=True)
|
|
57
|
+
"""Whether to use N2V2 architecture modifications, with blur pool layers and fewer
|
|
58
|
+
skip connections.
|
|
59
|
+
"""
|
|
60
|
+
|
|
44
61
|
independent_channels: bool = Field(default=True, validate_default=True)
|
|
62
|
+
"""Whether information is processed independently in each channel, used to train
|
|
63
|
+
channels independently."""
|
|
45
64
|
|
|
46
65
|
@field_validator("num_channels_init")
|
|
47
66
|
@classmethod
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Callback Pydantic models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from typing import Literal, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
ConfigDict,
|
|
11
|
+
Field,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CheckpointModel(BaseModel):
|
|
16
|
+
"""Checkpoint saving callback Pydantic model.
|
|
17
|
+
|
|
18
|
+
The parameters corresponds to those of
|
|
19
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
20
|
+
|
|
21
|
+
See:
|
|
22
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
model_config = ConfigDict(
|
|
26
|
+
validate_assignment=True,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
|
|
30
|
+
"""Quantity to monitor."""
|
|
31
|
+
|
|
32
|
+
verbose: bool = Field(default=False, validate_default=True)
|
|
33
|
+
"""Verbosity mode."""
|
|
34
|
+
|
|
35
|
+
save_weights_only: bool = Field(default=False, validate_default=True)
|
|
36
|
+
"""When `True`, only the model's weights will be saved (model.save_weights)."""
|
|
37
|
+
|
|
38
|
+
save_last: Optional[Literal[True, False, "link"]] = Field(
|
|
39
|
+
default=True, validate_default=True
|
|
40
|
+
)
|
|
41
|
+
"""When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
|
|
42
|
+
|
|
43
|
+
save_top_k: int = Field(default=3, ge=1, le=10, validate_default=True)
|
|
44
|
+
"""If `save_top_k == kz, the best k models according to the quantity monitored
|
|
45
|
+
will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
|
|
46
|
+
all models are saved."""
|
|
47
|
+
|
|
48
|
+
mode: Literal["min", "max"] = Field(default="min", validate_default=True)
|
|
49
|
+
"""One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
|
|
50
|
+
save file is made based on either the maximization or the minimization of the
|
|
51
|
+
monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
|
|
52
|
+
be 'min', etc.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
auto_insert_metric_name: bool = Field(default=False, validate_default=True)
|
|
56
|
+
"""When `True`, the checkpoints filenames will contain the metric name."""
|
|
57
|
+
|
|
58
|
+
every_n_train_steps: Optional[int] = Field(
|
|
59
|
+
default=None, ge=1, le=10, validate_default=True
|
|
60
|
+
)
|
|
61
|
+
"""Number of training steps between checkpoints."""
|
|
62
|
+
|
|
63
|
+
train_time_interval: Optional[timedelta] = Field(
|
|
64
|
+
default=None, validate_default=True
|
|
65
|
+
)
|
|
66
|
+
"""Checkpoints are monitored at the specified time interval."""
|
|
67
|
+
|
|
68
|
+
every_n_epochs: Optional[int] = Field(
|
|
69
|
+
default=None, ge=1, le=10, validate_default=True
|
|
70
|
+
)
|
|
71
|
+
"""Number of epochs between checkpoints."""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class EarlyStoppingModel(BaseModel):
|
|
75
|
+
"""Early stopping callback Pydantic model.
|
|
76
|
+
|
|
77
|
+
The parameters corresponds to those of
|
|
78
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
79
|
+
|
|
80
|
+
See:
|
|
81
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
model_config = ConfigDict(
|
|
85
|
+
validate_assignment=True,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
monitor: Literal["val_loss"] = Field(default="val_loss", validate_default=True)
|
|
89
|
+
"""Quantity to monitor."""
|
|
90
|
+
|
|
91
|
+
min_delta: float = Field(default=0.0, ge=0.0, le=1.0, validate_default=True)
|
|
92
|
+
"""Minimum change in the monitored quantity to qualify as an improvement, i.e. an
|
|
93
|
+
absolute change of less than or equal to min_delta, will count as no improvement."""
|
|
94
|
+
|
|
95
|
+
patience: int = Field(default=3, ge=1, le=10, validate_default=True)
|
|
96
|
+
"""Number of checks with no improvement after which training will be stopped."""
|
|
97
|
+
|
|
98
|
+
verbose: bool = Field(default=False, validate_default=True)
|
|
99
|
+
"""Verbosity mode."""
|
|
100
|
+
|
|
101
|
+
mode: Literal["min", "max", "auto"] = Field(default="min", validate_default=True)
|
|
102
|
+
"""One of {min, max, auto}."""
|
|
103
|
+
|
|
104
|
+
check_finite: bool = Field(default=True, validate_default=True)
|
|
105
|
+
"""When `True`, stops training when the monitored quantity becomes `NaN` or
|
|
106
|
+
`inf`."""
|
|
107
|
+
|
|
108
|
+
stopping_threshold: Optional[float] = Field(default=None, validate_default=True)
|
|
109
|
+
"""Stop training immediately once the monitored quantity reaches this threshold."""
|
|
110
|
+
|
|
111
|
+
divergence_threshold: Optional[float] = Field(default=None, validate_default=True)
|
|
112
|
+
"""Stop training as soon as the monitored quantity becomes worse than this
|
|
113
|
+
threshold."""
|
|
114
|
+
|
|
115
|
+
check_on_train_epoch_end: Optional[bool] = Field(
|
|
116
|
+
default=False, validate_default=True
|
|
117
|
+
)
|
|
118
|
+
"""Whether to run early stopping at the end of the training epoch. If this is
|
|
119
|
+
`False`, then the check runs at the end of the validation."""
|
|
120
|
+
|
|
121
|
+
log_rank_zero_only: bool = Field(default=False, validate_default=True)
|
|
122
|
+
"""When set `True`, logs the status of the early stopping callback only for rank 0
|
|
123
|
+
process."""
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
"""Convenience functions to create configurations for training and inference."""
|
|
2
2
|
|
|
3
|
-
from typing import Any, Dict, List, Literal, Optional
|
|
3
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
4
4
|
|
|
5
5
|
from .algorithm_model import AlgorithmConfig
|
|
6
6
|
from .architectures import UNetModel
|
|
7
7
|
from .configuration_model import Configuration
|
|
8
8
|
from .data_model import DataConfig
|
|
9
|
-
from .inference_model import InferenceConfig
|
|
10
9
|
from .support import (
|
|
11
10
|
SupportedAlgorithm,
|
|
12
11
|
SupportedArchitecture,
|
|
@@ -574,80 +573,3 @@ def create_n2v_configuration(
|
|
|
574
573
|
)
|
|
575
574
|
|
|
576
575
|
return configuration
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
def create_inference_configuration(
|
|
580
|
-
configuration: Configuration,
|
|
581
|
-
tile_size: Optional[Tuple[int, ...]] = None,
|
|
582
|
-
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
583
|
-
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
|
|
584
|
-
axes: Optional[str] = None,
|
|
585
|
-
tta_transforms: bool = True,
|
|
586
|
-
batch_size: Optional[int] = 1,
|
|
587
|
-
) -> InferenceConfig:
|
|
588
|
-
"""
|
|
589
|
-
Create a configuration for inference with N2V.
|
|
590
|
-
|
|
591
|
-
If not provided, `data_type` and `axes` are taken from the training
|
|
592
|
-
configuration.
|
|
593
|
-
|
|
594
|
-
Parameters
|
|
595
|
-
----------
|
|
596
|
-
configuration : Configuration
|
|
597
|
-
Global configuration.
|
|
598
|
-
tile_size : Tuple[int, ...], optional
|
|
599
|
-
Size of the tiles.
|
|
600
|
-
tile_overlap : Tuple[int, ...], optional
|
|
601
|
-
Overlap of the tiles.
|
|
602
|
-
data_type : str, optional
|
|
603
|
-
Type of the data, by default "tiff".
|
|
604
|
-
axes : str, optional
|
|
605
|
-
Axes of the data, by default "YX".
|
|
606
|
-
tta_transforms : bool, optional
|
|
607
|
-
Whether to apply test-time augmentations, by default True.
|
|
608
|
-
batch_size : int, optional
|
|
609
|
-
Batch size, by default 1.
|
|
610
|
-
|
|
611
|
-
Returns
|
|
612
|
-
-------
|
|
613
|
-
InferenceConfiguration
|
|
614
|
-
Configuration used to configure CAREamicsPredictData.
|
|
615
|
-
"""
|
|
616
|
-
if (
|
|
617
|
-
configuration.data_config.image_means is None
|
|
618
|
-
or configuration.data_config.image_stds is None
|
|
619
|
-
):
|
|
620
|
-
raise ValueError("Mean and std must be provided in the configuration.")
|
|
621
|
-
|
|
622
|
-
# tile size for UNets
|
|
623
|
-
if tile_size is not None:
|
|
624
|
-
model = configuration.algorithm_config.model
|
|
625
|
-
|
|
626
|
-
if model.architecture == SupportedArchitecture.UNET.value:
|
|
627
|
-
# tile size must be equal to k*2^n, where n is the number of pooling layers
|
|
628
|
-
# (equal to the depth) and k is an integer
|
|
629
|
-
depth = model.depth
|
|
630
|
-
tile_increment = 2**depth
|
|
631
|
-
|
|
632
|
-
for i, t in enumerate(tile_size):
|
|
633
|
-
if t % tile_increment != 0:
|
|
634
|
-
raise ValueError(
|
|
635
|
-
f"Tile size must be divisible by {tile_increment} along all "
|
|
636
|
-
f"axes (got {t} for axis {i}). If your image size is smaller "
|
|
637
|
-
f"along one axis (e.g. Z), consider padding the image."
|
|
638
|
-
)
|
|
639
|
-
|
|
640
|
-
# tile overlaps must be specified
|
|
641
|
-
if tile_overlap is None:
|
|
642
|
-
raise ValueError("Tile overlap must be specified.")
|
|
643
|
-
|
|
644
|
-
return InferenceConfig(
|
|
645
|
-
data_type=data_type or configuration.data_config.data_type,
|
|
646
|
-
tile_size=tile_size,
|
|
647
|
-
tile_overlap=tile_overlap,
|
|
648
|
-
axes=axes or configuration.data_config.axes,
|
|
649
|
-
image_means=configuration.data_config.image_means,
|
|
650
|
-
image_stds=configuration.data_config.image_stds,
|
|
651
|
-
tta_transforms=tta_transforms,
|
|
652
|
-
batch_size=batch_size,
|
|
653
|
-
)
|