careamics 0.0.2__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- {careamics-0.0.2 → careamics-0.0.3}/.pre-commit-config.yaml +3 -2
- {careamics-0.0.2 → careamics-0.0.3}/PKG-INFO +2 -2
- careamics-0.0.3/examples/example_training_LVAE_split.ipynb +439 -0
- careamics-0.0.3/mypy.ini +17 -0
- {careamics-0.0.2 → careamics-0.0.3}/pyproject.toml +4 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/careamist.py +14 -11
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/__init__.py +7 -3
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/__init__.py +2 -2
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/architecture_model.py +1 -1
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/custom_model.py +11 -8
- careamics-0.0.3/src/careamics/config/architectures/lvae_model.py +174 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/configuration_factory.py +11 -3
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/configuration_model.py +7 -3
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/data_model.py +33 -8
- careamics-0.0.2/src/careamics/config/algorithm_model.py → careamics-0.0.3/src/careamics/config/fcn_algorithm_model.py +28 -43
- careamics-0.0.3/src/careamics/config/likelihood_model.py +43 -0
- careamics-0.0.3/src/careamics/config/nm_model.py +101 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_activations.py +1 -0
- careamics-0.0.3/src/careamics/config/support/supported_algorithms.py +33 -0
- careamics-0.0.3/src/careamics/config/support/supported_architectures.py +17 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_losses.py +3 -1
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics-0.0.3/src/careamics/config/vae_algorithm_model.py +171 -0
- careamics-0.0.3/src/careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/tiff.py +1 -1
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/__init__.py +3 -2
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics-0.0.3/src/careamics/lightning/lightning_module.py +632 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/predict_data_module.py +2 -2
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/train_data_module.py +2 -2
- careamics-0.0.3/src/careamics/losses/__init__.py +15 -0
- careamics-0.0.3/src/careamics/losses/fcn/__init__.py +1 -0
- {careamics-0.0.2/src/careamics/losses → careamics-0.0.3/src/careamics/losses/fcn}/losses.py +1 -1
- careamics-0.0.3/src/careamics/losses/loss_factory.py +155 -0
- careamics-0.0.3/src/careamics/losses/lvae/__init__.py +1 -0
- careamics-0.0.3/src/careamics/losses/lvae/loss_utils.py +83 -0
- careamics-0.0.3/src/careamics/losses/lvae/losses.py +445 -0
- {careamics-0.0.2/src/careamics/lvae_training → careamics-0.0.3/src/careamics/lvae_training/dataset}/data_utils.py +277 -194
- careamics-0.0.3/src/careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics-0.0.3/src/careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics-0.0.3/src/careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics-0.0.2/src/careamics/lvae_training/data_modules.py → careamics-0.0.3/src/careamics/lvae_training/dataset/vae_dataset.py +306 -472
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/get_config.py +1 -1
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/train_lvae.py +6 -3
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/bioimage_utils.py +1 -1
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/model_description.py +2 -2
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bmz_io.py +19 -6
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/model_io_utils.py +16 -4
- careamics-0.0.3/src/careamics/models/__init__.py +5 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/activation.py +2 -0
- careamics-0.0.3/src/careamics/models/lvae/__init__.py +3 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/lvae/layers.py +21 -21
- careamics-0.0.3/src/careamics/models/lvae/likelihoods.py +364 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/lvae/lvae.py +52 -136
- careamics-0.0.3/src/careamics/models/lvae/noise_models.py +541 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/lvae/utils.py +2 -2
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/model_factory.py +22 -7
- careamics-0.0.3/src/careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics-0.0.3/src/careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/prediction_utils/stitch_prediction.py +16 -2
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/pixel_manipulation.py +1 -1
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/metrics.py +74 -1
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_custom_model.py +14 -32
- careamics-0.0.3/tests/config/architectures/test_lvae_model.py +134 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_register_model.py +1 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_unet_model.py +6 -5
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_configuration_model.py +2 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_data_model.py +22 -0
- careamics-0.0.3/tests/config/test_fcn_algorithm_model.py +125 -0
- careamics-0.0.3/tests/config/test_vae_algorithm_model.py +68 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/conftest.py +108 -11
- careamics-0.0.3/tests/dataset/tiling/test_lvae_tiled_patching.py +188 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py +2 -0
- careamics-0.0.3/tests/lightning/test_LVAE_lightning_module.py +579 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_lightning_api.py +2 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_lightning_module.py +40 -25
- careamics-0.0.3/tests/likelihood_modules/test_likelihoods.py +97 -0
- careamics-0.0.3/tests/losses/test_lvae_losses.py +335 -0
- careamics-0.0.3/tests/models/lvae/test_dataset.py +131 -0
- careamics-0.0.3/tests/models/lvae/test_noise_model.py +141 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/models/test_model_factory.py +2 -13
- careamics-0.0.3/tests/prediction_utils/test_lvae_prediction.py +182 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/test_conftest.py +2 -2
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_pixel_manipulation.py +1 -1
- careamics-0.0.2/src/careamics/config/architectures/vae_model.py +0 -42
- careamics-0.0.2/src/careamics/config/support/supported_algorithms.py +0 -20
- careamics-0.0.2/src/careamics/config/support/supported_architectures.py +0 -20
- careamics-0.0.2/src/careamics/lightning/lightning_module.py +0 -276
- careamics-0.0.2/src/careamics/losses/__init__.py +0 -5
- careamics-0.0.2/src/careamics/losses/loss_factory.py +0 -49
- careamics-0.0.2/src/careamics/models/__init__.py +0 -7
- careamics-0.0.2/src/careamics/models/lvae/likelihoods.py +0 -312
- careamics-0.0.2/src/careamics/models/lvae/noise_models.py +0 -409
- careamics-0.0.2/tests/config/test_algorithm_model.py +0 -101
- {careamics-0.0.2 → careamics-0.0.3}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/.github/TEST_FAIL_TEMPLATE.md +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/.github/pull_request_template.md +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/.github/workflows/ci.yml +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/.gitignore +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/LICENSE +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/README.md +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/examples/3D/example_flywing_3D.ipynb +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/examples/3D/n2v_flywing_3D.yml +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/examples/evaluate_LVAE.ipynb +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/register_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/unet_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/callback_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/inference_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/optimizer_models.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/references/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/references/algorithm_descriptions.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/references/references.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_data.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_loggers.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_optimizers.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_pixel_manipulations.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_struct_axis.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_transforms.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/tile_information.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/training_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/normalize_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/transform_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/xy_flip_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/xy_random_rotate90_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/validators/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/validators/validator_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/conftest.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/dataset_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/file_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/iterate_over_files.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/running_stats.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/in_memory_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/in_memory_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/in_memory_tiled_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/iterable_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/iterable_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/iterable_tiled_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/patching.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/random_patching.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/sequential_patching.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/validate_patch_dimension.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/tiling/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/tiling/collate_tiles.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/tiling/tiled_patching.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/zarr_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/get_func.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/zarr.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/write/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/write/get_func.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/write/tiff.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/progress_bar_callback.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/__init__.py +0 -0
- {careamics-0.0.2/src/careamics/models/lvae → careamics-0.0.3/src/careamics/lvae_training/dataset}/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/eval_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/lightning_module.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/metrics.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/train_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/_readme_factory.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/layers.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/unet.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/prediction_utils/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/prediction_utils/prediction_outputs.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/py.typed +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/compose.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/n2v_manipulate.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/normalize.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/struct_mask_parameters.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/transform.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/tta.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/xy_flip.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/xy_random_rotate90.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/__init__.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/autocorrelation.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/base_enum.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/context.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/logging.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/path_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/ram.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/receptive_field.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/torch_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_architecture_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/support/test_supported_data.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/support/test_supported_optimizers.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_callback_models.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_configuration_factory.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_inference_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_optimizers_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_tile_information.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_training_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_n2v_manipulate_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_normalize_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_xy_flip_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_xy_random_rotate90_model.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/config/validators/test_validator_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/dataset_utils/test_compute_normalization_stats.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/dataset_utils/test_list_files.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/patching/test_patching_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/patching/test_random_patching.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/patching/test_sequential_patching.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_in_memory_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_in_memory_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_in_memory_tiled_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_iterable_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_iterable_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_iterable_tiled_pred_dataset.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/tiling/test_collate_tiles.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/tiling/test_tiled_patching.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/read/test_get_read_func.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/read/test_read_tiff.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/write/test_get_write_func.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/write/test_write_tiff.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_predict_data_module.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_train_data_module.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/model_io/test_bmz_io.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/models/test_unet.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/prediction_utils/test_prediction_outputs.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/prediction_utils/test_stitch_prediction.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/test_careamist.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_compose.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_manipulate_n2v.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_normalize.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_supported_transforms.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_tta.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_xy_flip.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_xy_random_rotate90.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_autocorrelation.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_base_enum.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_context.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_logging.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_metrics.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_torch_utils.py +0 -0
- {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_wandb.py +0 -0
|
@@ -30,7 +30,8 @@ repos:
|
|
|
30
30
|
hooks:
|
|
31
31
|
- id: mypy
|
|
32
32
|
files: "^src/"
|
|
33
|
-
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae
|
|
33
|
+
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/config/likelihood_model.py|^src/careamics/losses/loss_factory.py|^src/careamics/losses/lvae/losses.py"
|
|
34
|
+
args: ['--config-file', 'mypy.ini']
|
|
34
35
|
additional_dependencies:
|
|
35
36
|
- numpy
|
|
36
37
|
- types-PyYAML
|
|
@@ -41,7 +42,7 @@ repos:
|
|
|
41
42
|
rev: v1.8.0rc2
|
|
42
43
|
hooks:
|
|
43
44
|
- id: numpydoc-validation
|
|
44
|
-
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*"
|
|
45
|
+
exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*"
|
|
45
46
|
|
|
46
47
|
# # jupyter linting and formatting
|
|
47
48
|
# - repo: https://github.com/nbQA-dev/nbQA
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: careamics
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.3
|
|
4
4
|
Summary: Toolbox for running N2V and friends.
|
|
5
5
|
Project-URL: homepage, https://careamics.github.io/
|
|
6
6
|
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
7
|
-
Author-email: Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
|
|
7
|
+
Author-email: CAREamics team <rse@fht.org>, Ashesh <ashesh.ashesh@fht.org>, Federico Carrara <federico.carrara@fht.org>, Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Vera Galinova <vera.galinova@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
|
|
8
8
|
License: BSD-3-Clause
|
|
9
9
|
License-File: LICENSE
|
|
10
10
|
Classifier: Development Status :: 3 - Alpha
|
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "code",
|
|
5
|
+
"execution_count": null,
|
|
6
|
+
"metadata": {},
|
|
7
|
+
"outputs": [],
|
|
8
|
+
"source": [
|
|
9
|
+
"import os\n",
|
|
10
|
+
"import socket\n",
|
|
11
|
+
"from pathlib import Path\n",
|
|
12
|
+
"from typing import Optional, Literal, Union\n",
|
|
13
|
+
"\n",
|
|
14
|
+
"import numpy as np\n",
|
|
15
|
+
"import torch \n",
|
|
16
|
+
"from torch.utils.data import Dataset, DataLoader\n",
|
|
17
|
+
"from pytorch_lightning import Trainer\n",
|
|
18
|
+
"from pytorch_lightning.loggers import WandbLogger\n",
|
|
19
|
+
"from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping\n",
|
|
20
|
+
"\n",
|
|
21
|
+
"from careamics.config import VAEAlgorithmConfig\n",
|
|
22
|
+
"from careamics.config.architectures import LVAEModel\n",
|
|
23
|
+
"from careamics.config.callback_model import CheckpointModel, EarlyStoppingModel\n",
|
|
24
|
+
"from careamics.config.likelihood_model import (\n",
|
|
25
|
+
" GaussianLikelihoodConfig,\n",
|
|
26
|
+
" NMLikelihoodConfig,\n",
|
|
27
|
+
")\n",
|
|
28
|
+
"from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig\n",
|
|
29
|
+
"from careamics.lightning import VAEModule\n",
|
|
30
|
+
"from careamics.models.lvae.noise_models import noise_model_factory"
|
|
31
|
+
]
|
|
32
|
+
},
|
|
33
|
+
{
|
|
34
|
+
"cell_type": "markdown",
|
|
35
|
+
"metadata": {},
|
|
36
|
+
"source": [
|
|
37
|
+
"Set some parameters for the current training simulation"
|
|
38
|
+
]
|
|
39
|
+
},
|
|
40
|
+
{
|
|
41
|
+
"cell_type": "code",
|
|
42
|
+
"execution_count": null,
|
|
43
|
+
"metadata": {},
|
|
44
|
+
"outputs": [],
|
|
45
|
+
"source": [
|
|
46
|
+
"img_size: int = 64\n",
|
|
47
|
+
"\"\"\"Spatial size of the input image.\"\"\"\n",
|
|
48
|
+
"target_channels: int = 2\n",
|
|
49
|
+
"\"\"\"Number of channels in the target image.\"\"\"\n",
|
|
50
|
+
"multiscale_count: int = 5\n",
|
|
51
|
+
"\"\"\"The number of LC inputs plus one (the actual input).\"\"\"\n",
|
|
52
|
+
"predict_logvar: Optional[Literal[\"pixelwise\"]] = \"pixelwise\"\n",
|
|
53
|
+
"\"\"\"Whether to compute also the log-variance as LVAE output.\"\"\"\n",
|
|
54
|
+
"loss_type: Optional[Literal[\"musplit\", \"denoisplit\", \"denoisplit_musplit\"]] = \"musplit\"\n",
|
|
55
|
+
"\"\"\"The type of reconstruction loss (i.e., likelihood) to use.\"\"\""
|
|
56
|
+
]
|
|
57
|
+
},
|
|
58
|
+
{
|
|
59
|
+
"cell_type": "markdown",
|
|
60
|
+
"metadata": {},
|
|
61
|
+
"source": [
|
|
62
|
+
"## 1. Create `Dataset` and `Dataloader`"
|
|
63
|
+
]
|
|
64
|
+
},
|
|
65
|
+
{
|
|
66
|
+
"cell_type": "markdown",
|
|
67
|
+
"metadata": {},
|
|
68
|
+
"source": [
|
|
69
|
+
"### 1.1. Dummy Data"
|
|
70
|
+
]
|
|
71
|
+
},
|
|
72
|
+
{
|
|
73
|
+
"cell_type": "code",
|
|
74
|
+
"execution_count": null,
|
|
75
|
+
"metadata": {},
|
|
76
|
+
"outputs": [],
|
|
77
|
+
"source": [
|
|
78
|
+
"class DummyDataset(Dataset):\n",
|
|
79
|
+
" def __init__(\n",
|
|
80
|
+
" self, \n",
|
|
81
|
+
" img_size: int = 64, \n",
|
|
82
|
+
" target_ch: int = 1,\n",
|
|
83
|
+
" multiscale_count: int = 1,\n",
|
|
84
|
+
" ):\n",
|
|
85
|
+
" self.num_samples = 100\n",
|
|
86
|
+
" self.img_size = img_size\n",
|
|
87
|
+
" self.target_ch = target_ch\n",
|
|
88
|
+
" self.multiscale_count = multiscale_count\n",
|
|
89
|
+
" \n",
|
|
90
|
+
" def __len__(self):\n",
|
|
91
|
+
" return self.num_samples\n",
|
|
92
|
+
" \n",
|
|
93
|
+
" def __getitem__(self, idx: int):\n",
|
|
94
|
+
" input_ = torch.randn(self.multiscale_count, self.img_size, self.img_size)\n",
|
|
95
|
+
" target = torch.randn(self.target_ch, self.img_size, self.img_size)\n",
|
|
96
|
+
" return input_, target\n",
|
|
97
|
+
"\n",
|
|
98
|
+
"def dummy_dataloader(\n",
|
|
99
|
+
" batch_size: int = 1,\n",
|
|
100
|
+
" img_size: int = 64,\n",
|
|
101
|
+
" target_ch: int = 1,\n",
|
|
102
|
+
" multiscale_count: int = 1,\n",
|
|
103
|
+
"):\n",
|
|
104
|
+
" dataset = DummyDataset(\n",
|
|
105
|
+
" img_size=img_size,\n",
|
|
106
|
+
" target_ch=target_ch,\n",
|
|
107
|
+
" multiscale_count=multiscale_count,\n",
|
|
108
|
+
" )\n",
|
|
109
|
+
" return DataLoader(dataset, batch_size=batch_size, num_workers=3, shuffle=False)"
|
|
110
|
+
]
|
|
111
|
+
},
|
|
112
|
+
{
|
|
113
|
+
"cell_type": "code",
|
|
114
|
+
"execution_count": null,
|
|
115
|
+
"metadata": {},
|
|
116
|
+
"outputs": [],
|
|
117
|
+
"source": [
|
|
118
|
+
"dloader = dummy_dataloader(\n",
|
|
119
|
+
" img_size=img_size,\n",
|
|
120
|
+
" target_ch=target_channels,\n",
|
|
121
|
+
" multiscale_count=multiscale_count,\n",
|
|
122
|
+
")\n",
|
|
123
|
+
"input_, target = next(iter(dloader))\n",
|
|
124
|
+
"input_.shape, target.shape"
|
|
125
|
+
]
|
|
126
|
+
},
|
|
127
|
+
{
|
|
128
|
+
"cell_type": "markdown",
|
|
129
|
+
"metadata": {},
|
|
130
|
+
"source": [
|
|
131
|
+
"### 1.2. Real Data"
|
|
132
|
+
]
|
|
133
|
+
},
|
|
134
|
+
{
|
|
135
|
+
"cell_type": "code",
|
|
136
|
+
"execution_count": null,
|
|
137
|
+
"metadata": {},
|
|
138
|
+
"outputs": [],
|
|
139
|
+
"source": []
|
|
140
|
+
},
|
|
141
|
+
{
|
|
142
|
+
"cell_type": "markdown",
|
|
143
|
+
"metadata": {},
|
|
144
|
+
"source": [
|
|
145
|
+
"## 2. Instantiate the lightning module"
|
|
146
|
+
]
|
|
147
|
+
},
|
|
148
|
+
{
|
|
149
|
+
"cell_type": "code",
|
|
150
|
+
"execution_count": null,
|
|
151
|
+
"metadata": {},
|
|
152
|
+
"outputs": [],
|
|
153
|
+
"source": [
|
|
154
|
+
"def create_dummy_noise_model(\n",
|
|
155
|
+
" save_path: Optional[Union[Path, str]] = None,\n",
|
|
156
|
+
" n_gaussians: int = 3,\n",
|
|
157
|
+
" n_coeffs: int = 3,\n",
|
|
158
|
+
") -> Path:\n",
|
|
159
|
+
" weights = np.random.rand(3*n_gaussians, n_coeffs)\n",
|
|
160
|
+
" nm_dict = {\n",
|
|
161
|
+
" \"trained_weight\": weights,\n",
|
|
162
|
+
" \"min_signal\": np.array([0]),\n",
|
|
163
|
+
" \"max_signal\": np.array([2**16 - 1]),\n",
|
|
164
|
+
" \"min_sigma\": 0.125,\n",
|
|
165
|
+
" }\n",
|
|
166
|
+
" out_path = Path(save_path) / \"dummy_noise_model.npz\"\n",
|
|
167
|
+
" np.savez(out_path, **nm_dict)\n",
|
|
168
|
+
" return out_path"
|
|
169
|
+
]
|
|
170
|
+
},
|
|
171
|
+
{
|
|
172
|
+
"cell_type": "code",
|
|
173
|
+
"execution_count": null,
|
|
174
|
+
"metadata": {},
|
|
175
|
+
"outputs": [],
|
|
176
|
+
"source": [
|
|
177
|
+
"def create_split_lightning_model(\n",
|
|
178
|
+
" algorithm: str,\n",
|
|
179
|
+
" loss_type: str,\n",
|
|
180
|
+
" multiscale_count: int = 1,\n",
|
|
181
|
+
" predict_logvar: Optional[Literal[\"pixelwise\"]] = None,\n",
|
|
182
|
+
" target_ch: int = 1,\n",
|
|
183
|
+
" NM_path: Optional[Path] = None,\n",
|
|
184
|
+
") -> VAEModule:\n",
|
|
185
|
+
" \"\"\"Instantiate the muSplit lightining model.\"\"\"\n",
|
|
186
|
+
" lvae_config = LVAEModel(\n",
|
|
187
|
+
" architecture=\"LVAE\",\n",
|
|
188
|
+
" input_shape=64,\n",
|
|
189
|
+
" multiscale_count=multiscale_count,\n",
|
|
190
|
+
" z_dims=[128, 128, 128, 128],\n",
|
|
191
|
+
" output_channels=target_ch,\n",
|
|
192
|
+
" predict_logvar=predict_logvar,\n",
|
|
193
|
+
" )\n",
|
|
194
|
+
"\n",
|
|
195
|
+
" # gaussian likelihood\n",
|
|
196
|
+
" if loss_type in [\"musplit\", \"denoisplit_musplit\"]:\n",
|
|
197
|
+
" gaussian_lik_config = GaussianLikelihoodConfig(\n",
|
|
198
|
+
" predict_logvar=predict_logvar,\n",
|
|
199
|
+
" logvar_lowerbound=0.0,\n",
|
|
200
|
+
" )\n",
|
|
201
|
+
" else:\n",
|
|
202
|
+
" gaussian_lik_config = None\n",
|
|
203
|
+
" # noise model likelihood\n",
|
|
204
|
+
" if loss_type in [\"denoisplit\", \"denoisplit_musplit\"]:\n",
|
|
205
|
+
" if NM_path is None:\n",
|
|
206
|
+
" NM_path = create_dummy_noise_model(Path(\"./\"), 3, 3)\n",
|
|
207
|
+
" gmm = GaussianMixtureNMConfig(\n",
|
|
208
|
+
" model_type=\"GaussianMixtureNoiseModel\",\n",
|
|
209
|
+
" path=NM_path,\n",
|
|
210
|
+
" )\n",
|
|
211
|
+
" noise_model_config = MultiChannelNMConfig(noise_models=[gmm] * target_ch)\n",
|
|
212
|
+
" nm = noise_model_factory(noise_model_config)\n",
|
|
213
|
+
" nm_lik_config = NMLikelihoodConfig(noise_model=nm)\n",
|
|
214
|
+
" else:\n",
|
|
215
|
+
" noise_model_config = None\n",
|
|
216
|
+
" nm_lik_config = None\n",
|
|
217
|
+
"\n",
|
|
218
|
+
" vae_config = VAEAlgorithmConfig(\n",
|
|
219
|
+
" algorithm_type=\"vae\",\n",
|
|
220
|
+
" algorithm=algorithm,\n",
|
|
221
|
+
" loss=loss_type,\n",
|
|
222
|
+
" model=lvae_config,\n",
|
|
223
|
+
" gaussian_likelihood_model=gaussian_lik_config,\n",
|
|
224
|
+
" noise_model=noise_model_config,\n",
|
|
225
|
+
" noise_model_likelihood_model=nm_lik_config,\n",
|
|
226
|
+
" )\n",
|
|
227
|
+
"\n",
|
|
228
|
+
" return VAEModule(\n",
|
|
229
|
+
" algorithm_config=vae_config,\n",
|
|
230
|
+
" )"
|
|
231
|
+
]
|
|
232
|
+
},
|
|
233
|
+
{
|
|
234
|
+
"cell_type": "code",
|
|
235
|
+
"execution_count": null,
|
|
236
|
+
"metadata": {},
|
|
237
|
+
"outputs": [],
|
|
238
|
+
"source": [
|
|
239
|
+
"algo = \"musplit\" if loss_type == \"musplit\" else \"denoisplit\"\n",
|
|
240
|
+
"lightning_model = create_split_lightning_model(\n",
|
|
241
|
+
" algorithm=algo,\n",
|
|
242
|
+
" loss_type=loss_type,\n",
|
|
243
|
+
" multiscale_count=multiscale_count,\n",
|
|
244
|
+
" predict_logvar=predict_logvar,\n",
|
|
245
|
+
" target_ch=target_channels,\n",
|
|
246
|
+
" NM_path=None\n",
|
|
247
|
+
")"
|
|
248
|
+
]
|
|
249
|
+
},
|
|
250
|
+
{
|
|
251
|
+
"cell_type": "markdown",
|
|
252
|
+
"metadata": {},
|
|
253
|
+
"source": [
|
|
254
|
+
"## 3. Set utils for training"
|
|
255
|
+
]
|
|
256
|
+
},
|
|
257
|
+
{
|
|
258
|
+
"cell_type": "code",
|
|
259
|
+
"execution_count": null,
|
|
260
|
+
"metadata": {},
|
|
261
|
+
"outputs": [],
|
|
262
|
+
"source": [
|
|
263
|
+
"from datetime import datetime\n",
|
|
264
|
+
"\n",
|
|
265
|
+
"from careamics.lvae_training.train_utils import get_new_model_version\n",
|
|
266
|
+
"\n",
|
|
267
|
+
"def get_new_model_version(model_dir: Union[Path, str]) -> int:\n",
|
|
268
|
+
" \"\"\"Create a unique version ID for a new model run.\"\"\"\n",
|
|
269
|
+
" versions = []\n",
|
|
270
|
+
" for version_dir in os.listdir(model_dir):\n",
|
|
271
|
+
" try:\n",
|
|
272
|
+
" versions.append(int(version_dir))\n",
|
|
273
|
+
" except:\n",
|
|
274
|
+
" print(\n",
|
|
275
|
+
" f\"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed\"\n",
|
|
276
|
+
" )\n",
|
|
277
|
+
" exit()\n",
|
|
278
|
+
" if len(versions) == 0:\n",
|
|
279
|
+
" return \"0\"\n",
|
|
280
|
+
" return f\"{max(versions) + 1}\"\n",
|
|
281
|
+
"\n",
|
|
282
|
+
"def get_workdir(\n",
|
|
283
|
+
" root_dir: str,\n",
|
|
284
|
+
" model_name: str,\n",
|
|
285
|
+
") -> tuple[Path, Path]:\n",
|
|
286
|
+
" \"\"\"Get the workdir for the current model.\n",
|
|
287
|
+
" \n",
|
|
288
|
+
" It has the following structure: \"root_dir/YYMM/model_name/version\"\n",
|
|
289
|
+
" \"\"\"\n",
|
|
290
|
+
" rel_path = datetime.now().strftime(\"%y%m\")\n",
|
|
291
|
+
" cur_workdir = os.path.join(root_dir, rel_path)\n",
|
|
292
|
+
" Path(cur_workdir).mkdir(exist_ok=True)\n",
|
|
293
|
+
"\n",
|
|
294
|
+
" rel_path = os.path.join(rel_path, model_name)\n",
|
|
295
|
+
" cur_workdir = os.path.join(root_dir, rel_path)\n",
|
|
296
|
+
" Path(cur_workdir).mkdir(exist_ok=True)\n",
|
|
297
|
+
"\n",
|
|
298
|
+
" rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir))\n",
|
|
299
|
+
" cur_workdir = os.path.join(root_dir, rel_path)\n",
|
|
300
|
+
" try:\n",
|
|
301
|
+
" Path(cur_workdir).mkdir(exist_ok=False)\n",
|
|
302
|
+
" except FileExistsError:\n",
|
|
303
|
+
" print(\n",
|
|
304
|
+
" f\"Workdir {cur_workdir} already exists.\"\n",
|
|
305
|
+
" )\n",
|
|
306
|
+
" return cur_workdir, rel_path"
|
|
307
|
+
]
|
|
308
|
+
},
|
|
309
|
+
{
|
|
310
|
+
"cell_type": "code",
|
|
311
|
+
"execution_count": null,
|
|
312
|
+
"metadata": {},
|
|
313
|
+
"outputs": [],
|
|
314
|
+
"source": [
|
|
315
|
+
"ROOT_DIR = \"/group/jug/federico/careamics_training/refac_v2/\"\n",
|
|
316
|
+
"workdir, exp_tag = get_workdir(ROOT_DIR, \"dummy_debugging\")\n",
|
|
317
|
+
"print(f\"Current workdir: {workdir}\")"
|
|
318
|
+
]
|
|
319
|
+
},
|
|
320
|
+
{
|
|
321
|
+
"cell_type": "code",
|
|
322
|
+
"execution_count": null,
|
|
323
|
+
"metadata": {},
|
|
324
|
+
"outputs": [],
|
|
325
|
+
"source": [
|
|
326
|
+
"# Define the logger\n",
|
|
327
|
+
"custom_logger = WandbLogger(\n",
|
|
328
|
+
" name=os.path.join(socket.gethostname(), exp_tag),\n",
|
|
329
|
+
" save_dir=workdir,\n",
|
|
330
|
+
" project=\"careamics_debugging_LVAE\",\n",
|
|
331
|
+
")"
|
|
332
|
+
]
|
|
333
|
+
},
|
|
334
|
+
{
|
|
335
|
+
"cell_type": "code",
|
|
336
|
+
"execution_count": null,
|
|
337
|
+
"metadata": {},
|
|
338
|
+
"outputs": [],
|
|
339
|
+
"source": [
|
|
340
|
+
"# Define callbacks (e.g., ModelCheckpoint, EarlyStopping, etc.)\n",
|
|
341
|
+
"early_stopping_config = EarlyStoppingModel(\n",
|
|
342
|
+
" monitor=\"val_loss\",\n",
|
|
343
|
+
" min_delta=1e-6,\n",
|
|
344
|
+
" patience=10,\n",
|
|
345
|
+
" mode=\"min\",\n",
|
|
346
|
+
" verbose=True,\n",
|
|
347
|
+
")\n",
|
|
348
|
+
"checkpoint_config = CheckpointModel(\n",
|
|
349
|
+
" monitor=\"val_loss\",\n",
|
|
350
|
+
" save_top_k=2,\n",
|
|
351
|
+
" mode=\"min\",\n",
|
|
352
|
+
")\n",
|
|
353
|
+
"custom_callbacks = [\n",
|
|
354
|
+
" EarlyStopping(**early_stopping_config.model_dump()), \n",
|
|
355
|
+
" ModelCheckpoint(**checkpoint_config.model_dump()),\n",
|
|
356
|
+
" LearningRateMonitor(logging_interval=\"epoch\")\n",
|
|
357
|
+
"]"
|
|
358
|
+
]
|
|
359
|
+
},
|
|
360
|
+
{
|
|
361
|
+
"cell_type": "code",
|
|
362
|
+
"execution_count": null,
|
|
363
|
+
"metadata": {},
|
|
364
|
+
"outputs": [],
|
|
365
|
+
"source": [
|
|
366
|
+
"# Save AlgorithmConfig\n",
|
|
367
|
+
"with open(os.path.join(workdir, \"algorithm_config.json\"), \"w\") as f:\n",
|
|
368
|
+
" f.write(lightning_model.algorithm_config.model_dump_json())\n",
|
|
369
|
+
"\n",
|
|
370
|
+
"custom_logger.experiment.config.update(\n",
|
|
371
|
+
" lightning_model.algorithm_config.model_dump() \n",
|
|
372
|
+
")"
|
|
373
|
+
]
|
|
374
|
+
},
|
|
375
|
+
{
|
|
376
|
+
"cell_type": "markdown",
|
|
377
|
+
"metadata": {},
|
|
378
|
+
"source": [
|
|
379
|
+
"## 4. Train the model"
|
|
380
|
+
]
|
|
381
|
+
},
|
|
382
|
+
{
|
|
383
|
+
"cell_type": "code",
|
|
384
|
+
"execution_count": null,
|
|
385
|
+
"metadata": {},
|
|
386
|
+
"outputs": [],
|
|
387
|
+
"source": [
|
|
388
|
+
"trainer = Trainer(\n",
|
|
389
|
+
" max_epochs=10,\n",
|
|
390
|
+
" accelerator=\"cpu\",\n",
|
|
391
|
+
" enable_progress_bar=True,\n",
|
|
392
|
+
" logger=custom_logger,\n",
|
|
393
|
+
" callbacks=custom_callbacks,\n",
|
|
394
|
+
")"
|
|
395
|
+
]
|
|
396
|
+
},
|
|
397
|
+
{
|
|
398
|
+
"cell_type": "code",
|
|
399
|
+
"execution_count": null,
|
|
400
|
+
"metadata": {},
|
|
401
|
+
"outputs": [],
|
|
402
|
+
"source": [
|
|
403
|
+
"trainer.fit(\n",
|
|
404
|
+
" model=lightning_model,\n",
|
|
405
|
+
" train_dataloaders=dloader,\n",
|
|
406
|
+
" val_dataloaders=dloader,\n",
|
|
407
|
+
")"
|
|
408
|
+
]
|
|
409
|
+
},
|
|
410
|
+
{
|
|
411
|
+
"cell_type": "code",
|
|
412
|
+
"execution_count": null,
|
|
413
|
+
"metadata": {},
|
|
414
|
+
"outputs": [],
|
|
415
|
+
"source": []
|
|
416
|
+
}
|
|
417
|
+
],
|
|
418
|
+
"metadata": {
|
|
419
|
+
"kernelspec": {
|
|
420
|
+
"display_name": "train_lvae",
|
|
421
|
+
"language": "python",
|
|
422
|
+
"name": "python3"
|
|
423
|
+
},
|
|
424
|
+
"language_info": {
|
|
425
|
+
"codemirror_mode": {
|
|
426
|
+
"name": "ipython",
|
|
427
|
+
"version": 3
|
|
428
|
+
},
|
|
429
|
+
"file_extension": ".py",
|
|
430
|
+
"mimetype": "text/x-python",
|
|
431
|
+
"name": "python",
|
|
432
|
+
"nbconvert_exporter": "python",
|
|
433
|
+
"pygments_lexer": "ipython3",
|
|
434
|
+
"version": "3.9.19"
|
|
435
|
+
}
|
|
436
|
+
},
|
|
437
|
+
"nbformat": 4,
|
|
438
|
+
"nbformat_minor": 2
|
|
439
|
+
}
|
careamics-0.0.3/mypy.ini
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[mypy]
|
|
2
|
+
ignore_missing_imports = True
|
|
3
|
+
|
|
4
|
+
[mypy-careamics.lvae_training.*]
|
|
5
|
+
follow_imports = skip
|
|
6
|
+
|
|
7
|
+
[mypy-careamics.models.lvae.*]
|
|
8
|
+
follow_imports = skip
|
|
9
|
+
|
|
10
|
+
[mypy-careamics.losses.loss_factory]
|
|
11
|
+
follow_imports = skip
|
|
12
|
+
|
|
13
|
+
[mypy-careamics.losses.lvae.losses]
|
|
14
|
+
follow_imports = skip
|
|
15
|
+
|
|
16
|
+
[mypy-careamics.config.likelihood_model]
|
|
17
|
+
follow_imports = skip
|
|
@@ -23,8 +23,12 @@ readme = "README.md"
|
|
|
23
23
|
requires-python = ">=3.9"
|
|
24
24
|
license = { text = "BSD-3-Clause" }
|
|
25
25
|
authors = [
|
|
26
|
+
{ name = 'CAREamics team', email = 'rse@fht.org' },
|
|
27
|
+
{ name = 'Ashesh', email = 'ashesh.ashesh@fht.org' },
|
|
28
|
+
{ name = 'Federico Carrara', email = 'federico.carrara@fht.org' },
|
|
26
29
|
{ name = 'Melisande Croft', email = 'melisande.croft@fht.org' },
|
|
27
30
|
{ name = 'Joran Deschamps', email = 'joran.deschamps@fht.org' },
|
|
31
|
+
{ name = 'Vera Galinova', email = 'vera.galinova@fht.org' },
|
|
28
32
|
{ name = 'Igor Zubarev', email = 'igor.zubarev@fht.org' },
|
|
29
33
|
]
|
|
30
34
|
classifiers = [
|
|
@@ -13,10 +13,7 @@ from pytorch_lightning.callbacks import (
|
|
|
13
13
|
)
|
|
14
14
|
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
15
15
|
|
|
16
|
-
from careamics.config import
|
|
17
|
-
Configuration,
|
|
18
|
-
load_configuration,
|
|
19
|
-
)
|
|
16
|
+
from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
|
|
20
17
|
from careamics.config.support import (
|
|
21
18
|
SupportedAlgorithm,
|
|
22
19
|
SupportedArchitecture,
|
|
@@ -25,7 +22,7 @@ from careamics.config.support import (
|
|
|
25
22
|
)
|
|
26
23
|
from careamics.dataset.dataset_utils import reshape_array
|
|
27
24
|
from careamics.lightning import (
|
|
28
|
-
|
|
25
|
+
FCNModule,
|
|
29
26
|
HyperParametersCallback,
|
|
30
27
|
PredictDataModule,
|
|
31
28
|
ProgressBarCallback,
|
|
@@ -148,9 +145,12 @@ class CAREamist:
|
|
|
148
145
|
self.cfg = source
|
|
149
146
|
|
|
150
147
|
# instantiate model
|
|
151
|
-
self.
|
|
152
|
-
|
|
153
|
-
|
|
148
|
+
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
|
|
149
|
+
self.model = FCNModule(
|
|
150
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
raise NotImplementedError("Architecture not supported.")
|
|
154
154
|
|
|
155
155
|
# path to configuration file or model
|
|
156
156
|
else:
|
|
@@ -164,9 +164,12 @@ class CAREamist:
|
|
|
164
164
|
self.cfg = load_configuration(source)
|
|
165
165
|
|
|
166
166
|
# instantiate model
|
|
167
|
-
self.
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
|
|
168
|
+
self.model = FCNModule(
|
|
169
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
170
|
+
) # type: ignore
|
|
171
|
+
else:
|
|
172
|
+
raise NotImplementedError("Architecture not supported.")
|
|
170
173
|
|
|
171
174
|
# attempt loading a pre-trained model
|
|
172
175
|
else:
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Configuration module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"
|
|
4
|
+
"FCNAlgorithmConfig",
|
|
5
|
+
"VAEAlgorithmConfig",
|
|
5
6
|
"DataConfig",
|
|
6
7
|
"Configuration",
|
|
7
8
|
"CheckpointModel",
|
|
@@ -15,9 +16,9 @@ __all__ = [
|
|
|
15
16
|
"register_model",
|
|
16
17
|
"CustomModel",
|
|
17
18
|
"clear_custom_models",
|
|
19
|
+
"GaussianMixtureNMConfig",
|
|
20
|
+
"MultiChannelNMConfig",
|
|
18
21
|
]
|
|
19
|
-
|
|
20
|
-
from .algorithm_model import AlgorithmConfig
|
|
21
22
|
from .architectures import CustomModel, clear_custom_models, register_model
|
|
22
23
|
from .callback_model import CheckpointModel
|
|
23
24
|
from .configuration_factory import (
|
|
@@ -31,5 +32,8 @@ from .configuration_model import (
|
|
|
31
32
|
save_configuration,
|
|
32
33
|
)
|
|
33
34
|
from .data_model import DataConfig
|
|
35
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
34
36
|
from .inference_model import InferenceConfig
|
|
37
|
+
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
35
38
|
from .training_model import TrainingConfig
|
|
39
|
+
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
@@ -4,7 +4,7 @@ __all__ = [
|
|
|
4
4
|
"ArchitectureModel",
|
|
5
5
|
"CustomModel",
|
|
6
6
|
"UNetModel",
|
|
7
|
-
"
|
|
7
|
+
"LVAEModel",
|
|
8
8
|
"clear_custom_models",
|
|
9
9
|
"get_custom_model",
|
|
10
10
|
"register_model",
|
|
@@ -12,6 +12,6 @@ __all__ = [
|
|
|
12
12
|
|
|
13
13
|
from .architecture_model import ArchitectureModel
|
|
14
14
|
from .custom_model import CustomModel
|
|
15
|
+
from .lvae_model import LVAEModel
|
|
15
16
|
from .register_model import clear_custom_models, get_custom_model, register_model
|
|
16
17
|
from .unet_model import UNetModel
|
|
17
|
-
from .vae_model import VAEModel
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import inspect
|
|
5
6
|
from pprint import pformat
|
|
6
7
|
from typing import Any, Literal
|
|
7
8
|
|
|
@@ -23,12 +24,13 @@ class CustomModel(ArchitectureModel):
|
|
|
23
24
|
|
|
24
25
|
Attributes
|
|
25
26
|
----------
|
|
26
|
-
architecture : Literal["
|
|
27
|
-
Discriminator for the custom model, must be set to "
|
|
27
|
+
architecture : Literal["custom"]
|
|
28
|
+
Discriminator for the custom model, must be set to "custom".
|
|
28
29
|
name : str
|
|
29
30
|
Name of the custom model.
|
|
30
31
|
parameters : CustomParametersModel
|
|
31
|
-
|
|
32
|
+
All parameters, required for the initialization of the torch module have to be
|
|
33
|
+
passed here.
|
|
32
34
|
|
|
33
35
|
Raises
|
|
34
36
|
------
|
|
@@ -57,7 +59,7 @@ class CustomModel(ArchitectureModel):
|
|
|
57
59
|
...
|
|
58
60
|
>>> # Create a configuration
|
|
59
61
|
>>> config_dict = {
|
|
60
|
-
... "architecture": "
|
|
62
|
+
... "architecture": "custom",
|
|
61
63
|
... "name": "my_linear",
|
|
62
64
|
... "in_features": 10,
|
|
63
65
|
... "out_features": 5,
|
|
@@ -71,10 +73,9 @@ class CustomModel(ArchitectureModel):
|
|
|
71
73
|
)
|
|
72
74
|
|
|
73
75
|
# discriminator used for choosing the pydantic model in Model
|
|
74
|
-
architecture: Literal["
|
|
76
|
+
architecture: Literal["custom"]
|
|
75
77
|
"""Name of the architecture."""
|
|
76
78
|
|
|
77
|
-
# name of the custom model
|
|
78
79
|
name: str
|
|
79
80
|
"""Name of the custom model."""
|
|
80
81
|
|
|
@@ -120,10 +121,12 @@ class CustomModel(ArchitectureModel):
|
|
|
120
121
|
get_custom_model(self.name)(**self.model_dump())
|
|
121
122
|
except Exception as e:
|
|
122
123
|
raise ValueError(
|
|
123
|
-
f"
|
|
124
|
+
f"while passing parameters to the model {e}. Verify that all "
|
|
124
125
|
f"mandatory parameters are provided, and that either the {e} accepts "
|
|
125
126
|
f"*args and **kwargs in its __init__() method, or that no additional"
|
|
126
|
-
f"parameter is provided."
|
|
127
|
+
f"parameter is provided. Trace: "
|
|
128
|
+
f"filename: {inspect.trace()[-1].filename}, function: "
|
|
129
|
+
f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
|
|
127
130
|
) from None
|
|
128
131
|
|
|
129
132
|
return self
|