careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
"""Pydantic CAREamics configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from pprint import pformat
|
|
7
|
+
from typing import Any, Literal, Self, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
12
|
+
|
|
13
|
+
from careamics.config.algorithms import (
|
|
14
|
+
CAREAlgorithm,
|
|
15
|
+
HDNAlgorithm,
|
|
16
|
+
MicroSplitAlgorithm,
|
|
17
|
+
N2NAlgorithm,
|
|
18
|
+
N2VAlgorithm,
|
|
19
|
+
PN2VAlgorithm,
|
|
20
|
+
)
|
|
21
|
+
from careamics.config.data import DataConfig
|
|
22
|
+
from careamics.config.lightning.training_config import TrainingConfig
|
|
23
|
+
from careamics.lvae_training.dataset.config import MicroSplitDataConfig
|
|
24
|
+
|
|
25
|
+
ALGORITHMS = Union[
|
|
26
|
+
CAREAlgorithm,
|
|
27
|
+
HDNAlgorithm,
|
|
28
|
+
MicroSplitAlgorithm,
|
|
29
|
+
N2NAlgorithm,
|
|
30
|
+
N2VAlgorithm,
|
|
31
|
+
PN2VAlgorithm,
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Configuration(BaseModel):
|
|
36
|
+
"""
|
|
37
|
+
CAREamics configuration.
|
|
38
|
+
|
|
39
|
+
The configuration defines all parameters used to build and train a CAREamics model.
|
|
40
|
+
These parameters are validated to ensure that they are compatible with each other.
|
|
41
|
+
|
|
42
|
+
It contains three sub-configurations:
|
|
43
|
+
|
|
44
|
+
- AlgorithmModel: configuration for the algorithm training, which includes the
|
|
45
|
+
architecture, loss function, optimizer, and other hyperparameters.
|
|
46
|
+
- DataModel: configuration for the dataloader, which includes the type of data,
|
|
47
|
+
transformations, mean/std and other parameters.
|
|
48
|
+
- TrainingModel: configuration for the training, which includes the number of
|
|
49
|
+
epochs or the callbacks.
|
|
50
|
+
|
|
51
|
+
Attributes
|
|
52
|
+
----------
|
|
53
|
+
experiment_name : str
|
|
54
|
+
Name of the experiment, used when saving logs and checkpoints.
|
|
55
|
+
algorithm : AlgorithmModel
|
|
56
|
+
Algorithm configuration.
|
|
57
|
+
data : DataModel
|
|
58
|
+
Data configuration.
|
|
59
|
+
training : TrainingModel
|
|
60
|
+
Training configuration.
|
|
61
|
+
|
|
62
|
+
Methods
|
|
63
|
+
-------
|
|
64
|
+
set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
|
|
65
|
+
Switch configuration between 2D and 3D.
|
|
66
|
+
model_dump(
|
|
67
|
+
exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
|
|
68
|
+
) -> Dict
|
|
69
|
+
Export configuration to a dictionary.
|
|
70
|
+
|
|
71
|
+
Raises
|
|
72
|
+
------
|
|
73
|
+
ValueError
|
|
74
|
+
Configuration parameter type validation errors.
|
|
75
|
+
ValueError
|
|
76
|
+
If the experiment name contains invalid characters or is empty.
|
|
77
|
+
ValueError
|
|
78
|
+
If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
|
|
79
|
+
with "Z" in data axes.
|
|
80
|
+
ValueError
|
|
81
|
+
Algorithm, data or training validation errors.
|
|
82
|
+
|
|
83
|
+
Notes
|
|
84
|
+
-----
|
|
85
|
+
We provide convenience methods to create standards configurations, for instance:
|
|
86
|
+
>>> from careamics.config import create_n2v_configuration
|
|
87
|
+
>>> config = create_n2v_configuration(
|
|
88
|
+
... experiment_name="n2v_experiment",
|
|
89
|
+
... data_type="array",
|
|
90
|
+
... axes="YX",
|
|
91
|
+
... patch_size=[64, 64],
|
|
92
|
+
... batch_size=32,
|
|
93
|
+
... )
|
|
94
|
+
|
|
95
|
+
The configuration can be exported to a dictionary using the model_dump method:
|
|
96
|
+
>>> config_dict = config.model_dump()
|
|
97
|
+
|
|
98
|
+
Configurations can also be exported or imported from yaml files:
|
|
99
|
+
>>> from careamics.config import save_configuration, load_configuration
|
|
100
|
+
>>> path_to_config = save_configuration(config, my_path / "config.yml")
|
|
101
|
+
>>> other_config = load_configuration(path_to_config)
|
|
102
|
+
|
|
103
|
+
Examples
|
|
104
|
+
--------
|
|
105
|
+
Minimum example:
|
|
106
|
+
>>> from careamics import Configuration
|
|
107
|
+
>>> config_dict = {
|
|
108
|
+
... "experiment_name": "N2V_experiment",
|
|
109
|
+
... "algorithm_config": {
|
|
110
|
+
... "algorithm": "n2v",
|
|
111
|
+
... "loss": "n2v",
|
|
112
|
+
... "model": {
|
|
113
|
+
... "architecture": "UNet",
|
|
114
|
+
... },
|
|
115
|
+
... },
|
|
116
|
+
... "training_config": {},
|
|
117
|
+
... "data_config": {
|
|
118
|
+
... "data_type": "tiff",
|
|
119
|
+
... "patch_size": [64, 64],
|
|
120
|
+
... "axes": "SYX",
|
|
121
|
+
... },
|
|
122
|
+
... }
|
|
123
|
+
>>> config = Configuration(**config_dict)
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
model_config = ConfigDict(
|
|
127
|
+
validate_assignment=True,
|
|
128
|
+
arbitrary_types_allowed=True,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# version
|
|
132
|
+
version: Literal["0.1.0"] = "0.1.0"
|
|
133
|
+
"""CAREamics configuration version."""
|
|
134
|
+
|
|
135
|
+
# required parameters
|
|
136
|
+
experiment_name: str
|
|
137
|
+
"""Name of the experiment, used to name logs and checkpoints."""
|
|
138
|
+
|
|
139
|
+
# Sub-configurations
|
|
140
|
+
algorithm_config: ALGORITHMS = Field(discriminator="algorithm")
|
|
141
|
+
"""Algorithm configuration, holding all parameters required to configure the
|
|
142
|
+
model."""
|
|
143
|
+
|
|
144
|
+
data_config: DataConfig | MicroSplitDataConfig
|
|
145
|
+
"""Data configuration, holding all parameters required to configure the training
|
|
146
|
+
data loader."""
|
|
147
|
+
|
|
148
|
+
training_config: TrainingConfig
|
|
149
|
+
"""Training configuration, holding all parameters required to configure the
|
|
150
|
+
training process."""
|
|
151
|
+
|
|
152
|
+
@field_validator("experiment_name")
|
|
153
|
+
@classmethod
|
|
154
|
+
def no_symbol(cls, name: str) -> str:
|
|
155
|
+
"""
|
|
156
|
+
Validate experiment name.
|
|
157
|
+
|
|
158
|
+
A valid experiment name is a non-empty string with only contains letters,
|
|
159
|
+
numbers, underscores, dashes and spaces.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
name : str
|
|
164
|
+
Name to validate.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
str
|
|
169
|
+
Validated name.
|
|
170
|
+
|
|
171
|
+
Raises
|
|
172
|
+
------
|
|
173
|
+
ValueError
|
|
174
|
+
If the name is empty or contains invalid characters.
|
|
175
|
+
"""
|
|
176
|
+
if len(name) == 0 or name.isspace():
|
|
177
|
+
raise ValueError("Experiment name is empty.")
|
|
178
|
+
|
|
179
|
+
# Validate using a regex that it contains only letters, numbers, underscores,
|
|
180
|
+
# dashes and spaces
|
|
181
|
+
if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Experiment name contains invalid characters (got {name}). "
|
|
184
|
+
f"Only letters, numbers, underscores, dashes and spaces are allowed."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return name
|
|
188
|
+
|
|
189
|
+
@model_validator(mode="after") # TODO move to n2v configs or remove
|
|
190
|
+
def validate_n2v_mask_pixel_perc(self: Self) -> Self:
|
|
191
|
+
"""
|
|
192
|
+
Validate that there will always be at least one blind-spot pixel in every patch.
|
|
193
|
+
|
|
194
|
+
The probability of creating a blind-spot pixel is a function of the chosen
|
|
195
|
+
masked pixel percentage and patch size.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
Self
|
|
200
|
+
Validated configuration.
|
|
201
|
+
|
|
202
|
+
Raises
|
|
203
|
+
------
|
|
204
|
+
ValueError
|
|
205
|
+
If the probability of masking a pixel within a patch is less than 1 for the
|
|
206
|
+
chosen masked pixel percentage and patch size.
|
|
207
|
+
"""
|
|
208
|
+
# No validation needed for non n2v algorithms # TODO: why ?
|
|
209
|
+
if not isinstance(self.algorithm_config, N2VAlgorithm | PN2VAlgorithm):
|
|
210
|
+
return self
|
|
211
|
+
|
|
212
|
+
mask_pixel_perc = self.algorithm_config.n2v_config.masked_pixel_percentage
|
|
213
|
+
patch_size = self.data_config.patch_size
|
|
214
|
+
expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
|
|
215
|
+
|
|
216
|
+
n_dims = 3 if self.algorithm_config.model.is_3D() else 2
|
|
217
|
+
patch_size_lower_bound = int(np.ceil(expected_area_per_pixel ** (1 / n_dims)))
|
|
218
|
+
required_patch_size = tuple(
|
|
219
|
+
2 ** int(np.ceil(np.log2(patch_size_lower_bound))) for _ in range(n_dims)
|
|
220
|
+
)
|
|
221
|
+
required_mask_pixel_perc = (1 / np.prod(patch_size)) * 100
|
|
222
|
+
if expected_area_per_pixel > np.prod(patch_size):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"The probability of creating a blind-spot pixel within a patch is "
|
|
225
|
+
f"below 1, for a patch size of {patch_size} with a masked pixel "
|
|
226
|
+
f"percentage of {mask_pixel_perc}%. Either increase the patch size to "
|
|
227
|
+
f"{required_patch_size} or increase the masked pixel percentage to "
|
|
228
|
+
f"at least {required_mask_pixel_perc}%."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return self
|
|
232
|
+
|
|
233
|
+
@model_validator(mode="after")
|
|
234
|
+
def validate_3D(self: Self) -> Self:
|
|
235
|
+
"""
|
|
236
|
+
Change algorithm dimensions to match data.axes.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
Self
|
|
241
|
+
Validated configuration.
|
|
242
|
+
"""
|
|
243
|
+
if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
|
|
244
|
+
# change algorithm to 3D
|
|
245
|
+
self.algorithm_config.model.set_3D(True)
|
|
246
|
+
elif "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D():
|
|
247
|
+
# change algorithm to 2D
|
|
248
|
+
self.algorithm_config.model.set_3D(False)
|
|
249
|
+
|
|
250
|
+
return self
|
|
251
|
+
|
|
252
|
+
def __str__(self) -> str:
|
|
253
|
+
"""
|
|
254
|
+
Pretty string reprensenting the configuration.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
str
|
|
259
|
+
Pretty string.
|
|
260
|
+
"""
|
|
261
|
+
return pformat(self.model_dump())
|
|
262
|
+
|
|
263
|
+
def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
|
|
264
|
+
"""
|
|
265
|
+
Set 3D flag and axes.
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
is_3D : bool
|
|
270
|
+
Whether the algorithm is 3D or not.
|
|
271
|
+
axes : str
|
|
272
|
+
Axes of the data.
|
|
273
|
+
patch_size : list[int]
|
|
274
|
+
Patch size.
|
|
275
|
+
"""
|
|
276
|
+
# set the flag and axes (this will not trigger validation at the config level)
|
|
277
|
+
self.algorithm_config.model.set_3D(is_3D)
|
|
278
|
+
self.data_config.set_3D(axes, patch_size)
|
|
279
|
+
|
|
280
|
+
# cheap hack: trigger validation
|
|
281
|
+
self.algorithm_config = self.algorithm_config
|
|
282
|
+
|
|
283
|
+
def get_algorithm_friendly_name(self) -> str:
|
|
284
|
+
"""
|
|
285
|
+
Get the algorithm name.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
str
|
|
290
|
+
Algorithm name.
|
|
291
|
+
"""
|
|
292
|
+
return self.algorithm_config.get_algorithm_friendly_name()
|
|
293
|
+
|
|
294
|
+
def get_algorithm_description(self) -> str:
|
|
295
|
+
"""
|
|
296
|
+
Return a description of the algorithm.
|
|
297
|
+
|
|
298
|
+
This method is used to generate the README of the BioImage Model Zoo export.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
str
|
|
303
|
+
Description of the algorithm.
|
|
304
|
+
"""
|
|
305
|
+
return self.algorithm_config.get_algorithm_description()
|
|
306
|
+
|
|
307
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
308
|
+
"""
|
|
309
|
+
Return a list of citation entries of the current algorithm.
|
|
310
|
+
|
|
311
|
+
This is used to generate the model description for the BioImage Model Zoo.
|
|
312
|
+
|
|
313
|
+
Returns
|
|
314
|
+
-------
|
|
315
|
+
List[CiteEntry]
|
|
316
|
+
List of citation entries.
|
|
317
|
+
"""
|
|
318
|
+
return self.algorithm_config.get_algorithm_citations()
|
|
319
|
+
|
|
320
|
+
def get_algorithm_references(self) -> str:
|
|
321
|
+
"""
|
|
322
|
+
Get the algorithm references.
|
|
323
|
+
|
|
324
|
+
This is used to generate the README of the BioImage Model Zoo export.
|
|
325
|
+
|
|
326
|
+
Returns
|
|
327
|
+
-------
|
|
328
|
+
str
|
|
329
|
+
Algorithm references.
|
|
330
|
+
"""
|
|
331
|
+
return self.algorithm_config.get_algorithm_references()
|
|
332
|
+
|
|
333
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
334
|
+
"""
|
|
335
|
+
Get algorithm keywords.
|
|
336
|
+
|
|
337
|
+
Returns
|
|
338
|
+
-------
|
|
339
|
+
list[str]
|
|
340
|
+
List of keywords.
|
|
341
|
+
"""
|
|
342
|
+
return self.algorithm_config.get_algorithm_keywords()
|
|
343
|
+
|
|
344
|
+
def model_dump(
|
|
345
|
+
self,
|
|
346
|
+
**kwargs: Any,
|
|
347
|
+
) -> dict[str, Any]:
|
|
348
|
+
"""
|
|
349
|
+
Override model_dump method in order to set default values.
|
|
350
|
+
|
|
351
|
+
As opposed to the parent model_dump method, this method sets exclude none by
|
|
352
|
+
default.
|
|
353
|
+
|
|
354
|
+
Parameters
|
|
355
|
+
----------
|
|
356
|
+
**kwargs : Any
|
|
357
|
+
Additional arguments to pass to the parent model_dump method.
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
dict
|
|
362
|
+
Dictionary containing the model parameters.
|
|
363
|
+
"""
|
|
364
|
+
if "exclude_none" not in kwargs:
|
|
365
|
+
kwargs["exclude_none"] = True
|
|
366
|
+
|
|
367
|
+
return super().model_dump(**kwargs)
|