careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic models representing coordinate and patch filters."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"FilterConfig",
|
|
5
|
+
"MaskFilterConfig",
|
|
6
|
+
"MaxFilterConfig",
|
|
7
|
+
"MeanSTDFilterConfig",
|
|
8
|
+
"ShannonFilterConfig",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from .filter_config import FilterConfig
|
|
12
|
+
from .mask_filter_config import MaskFilterConfig
|
|
13
|
+
from .max_filter_config import MaxFilterConfig
|
|
14
|
+
from .meanstd_filter_config import MeanSTDFilterConfig
|
|
15
|
+
from .shannon_filter_config import ShannonFilterConfig
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Base class for patch and coordinate filtering models."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FilterConfig(BaseModel):
|
|
7
|
+
"""Base class for patch and coordinate filtering models."""
|
|
8
|
+
|
|
9
|
+
name: str
|
|
10
|
+
"""Name of the filter."""
|
|
11
|
+
|
|
12
|
+
p: float = Field(1.0, ge=0.0, le=1.0)
|
|
13
|
+
"""Probability of applying the filter to a patch or coordinate."""
|
|
14
|
+
|
|
15
|
+
seed: int | None = Field(default=None, gt=0)
|
|
16
|
+
"""Seed for the random number generator for reproducibility."""
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Pydantic model for the mask coordinate filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from .filter_config import FilterConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MaskFilterConfig(FilterConfig):
|
|
11
|
+
"""Pydantic model for the mask coordinate filter."""
|
|
12
|
+
|
|
13
|
+
name: Literal["mask"] = "mask"
|
|
14
|
+
"""Name of the filter."""
|
|
15
|
+
|
|
16
|
+
coverage: float = Field(0.5, ge=0.0, le=1.0)
|
|
17
|
+
"""Percentage of masked pixels required to keep a patch."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic model for the max patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_config import FilterConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MaxFilterConfig(FilterConfig):
|
|
9
|
+
"""Pydantic model for the max patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["max"] = "max"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
threshold: float
|
|
15
|
+
"""Threshold for the minimum of the max-filtered patch."""
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Pydantic model for the mean std patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_config import FilterConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MeanSTDFilterConfig(FilterConfig):
|
|
9
|
+
"""Pydantic model for the mean std patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["mean_std"] = "mean_std"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
mean_threshold: float
|
|
15
|
+
"""Minimum mean intensity required to keep a patch."""
|
|
16
|
+
|
|
17
|
+
std_threshold: float | None = None
|
|
18
|
+
"""Minimum standard deviation required to keep a patch."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Pydantic model for the Shannon entropy patch filter."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from .filter_config import FilterConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ShannonFilterConfig(FilterConfig):
|
|
9
|
+
"""Pydantic model for the Shannon entropy patch filter."""
|
|
10
|
+
|
|
11
|
+
name: Literal["shannon"] = "shannon"
|
|
12
|
+
"""Name of the filter."""
|
|
13
|
+
|
|
14
|
+
threshold: float
|
|
15
|
+
"""Minimum Shannon entropy required to keep a patch."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Patching strategies Pydantic models."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"FixedRandomPatchingConfig",
|
|
5
|
+
"RandomPatchingConfig",
|
|
6
|
+
"SequentialPatchingConfig",
|
|
7
|
+
"TiledPatchingConfig",
|
|
8
|
+
"WholePatchingConfig",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
from .random_patching_config import FixedRandomPatchingConfig, RandomPatchingConfig
|
|
13
|
+
from .sequential_patching_config import SequentialPatchingConfig
|
|
14
|
+
from .tiled_patching_config import TiledPatchingConfig
|
|
15
|
+
from .whole_patching_config import WholePatchingConfig
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Sequential patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, ValidationInfo, field_validator
|
|
6
|
+
|
|
7
|
+
from ._patched_config import _PatchedConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _OverlappingPatchedConfig(_PatchedConfig):
|
|
11
|
+
"""Overlapping patching Pydantic model.
|
|
12
|
+
|
|
13
|
+
This model is only used for inheritance and validation purposes.
|
|
14
|
+
|
|
15
|
+
Attributes
|
|
16
|
+
----------
|
|
17
|
+
patch_size : list of int
|
|
18
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
19
|
+
of 2 and larger than 8.
|
|
20
|
+
overlaps : sequence of int, optional
|
|
21
|
+
The overlaps between patches in each spatial dimension. If `None`, no overlap is
|
|
22
|
+
applied. The overlaps must be smaller than the patch size in each spatial
|
|
23
|
+
dimension, and the number of dimensions be either 2 or 3.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
overlaps: Sequence[int] | None = Field(
|
|
27
|
+
default=None,
|
|
28
|
+
min_length=2,
|
|
29
|
+
max_length=3,
|
|
30
|
+
)
|
|
31
|
+
"""The overlaps between patches in each spatial dimension. If `None`, no overlap is
|
|
32
|
+
applied. The overlaps must be smaller than the patch size in each spatial dimension,
|
|
33
|
+
and the number of dimensions be either 2 or 3.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@field_validator("overlaps")
|
|
37
|
+
@classmethod
|
|
38
|
+
def overlap_smaller_than_patch_size(
|
|
39
|
+
cls, overlaps: Sequence[int] | None, values: ValidationInfo
|
|
40
|
+
) -> Sequence[int] | None:
|
|
41
|
+
"""
|
|
42
|
+
Validate overlap.
|
|
43
|
+
|
|
44
|
+
Overlaps must be smaller than the patch size in each spatial dimension.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
overlaps : Sequence of int
|
|
49
|
+
Overlap in each dimension.
|
|
50
|
+
values : ValidationInfo
|
|
51
|
+
Dictionary of values.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Sequence of int
|
|
56
|
+
Validated overlap.
|
|
57
|
+
"""
|
|
58
|
+
if overlaps is None:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
patch_size = values.data["patch_size"]
|
|
62
|
+
|
|
63
|
+
if len(overlaps) != len(patch_size):
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Overlaps must have the same number of dimensions as the patch size. "
|
|
66
|
+
f"Got {len(overlaps)} dimensions for overlaps and {len(patch_size)} "
|
|
67
|
+
f"dimensions for patch size."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if any(o >= p for o, p in zip(overlaps, patch_size, strict=False)):
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Overlap must be smaller than the patch size, got {overlaps} versus "
|
|
73
|
+
f"{patch_size}."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return overlaps
|
|
77
|
+
|
|
78
|
+
@field_validator("overlaps")
|
|
79
|
+
@classmethod
|
|
80
|
+
def overlap_even(cls, overlaps: Sequence[int] | None) -> Sequence[int] | None:
|
|
81
|
+
"""
|
|
82
|
+
Validate overlaps.
|
|
83
|
+
|
|
84
|
+
Overlap must be even.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
overlaps : Sequence of int
|
|
89
|
+
Overlaps.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
Sequence of int
|
|
94
|
+
Validated overlap.
|
|
95
|
+
"""
|
|
96
|
+
if overlaps is None:
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
if any(o % 2 != 0 for o in overlaps):
|
|
100
|
+
raise ValueError(f"Overlaps must be even, got {overlaps}.")
|
|
101
|
+
|
|
102
|
+
return overlaps
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Generic patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from careamics.config.validators import patch_size_ge_than_8_power_of_2
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _PatchedConfig(BaseModel):
|
|
11
|
+
"""Generic patching Pydantic model.
|
|
12
|
+
|
|
13
|
+
This model is only used for inheritance and validation purposes.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(
|
|
17
|
+
extra="ignore", # default behaviour, make it explicit
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
name: str
|
|
21
|
+
"""The name of the patching strategy."""
|
|
22
|
+
|
|
23
|
+
patch_size: Sequence[int] = Field(..., min_length=2, max_length=3)
|
|
24
|
+
"""The size of the patch in each spatial dimensions, each patch size must be a power
|
|
25
|
+
of 2 and larger than 8."""
|
|
26
|
+
|
|
27
|
+
@field_validator("patch_size")
|
|
28
|
+
@classmethod
|
|
29
|
+
def all_elements_power_of_2_minimum_8(
|
|
30
|
+
cls, patch_list: Sequence[int]
|
|
31
|
+
) -> Sequence[int]:
|
|
32
|
+
"""
|
|
33
|
+
Validate patch size.
|
|
34
|
+
|
|
35
|
+
Patch size must be powers of 2 and minimum 8.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
patch_list : Sequence of int
|
|
40
|
+
Patch size.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Sequence of int
|
|
45
|
+
Validated patch size.
|
|
46
|
+
|
|
47
|
+
Raises
|
|
48
|
+
------
|
|
49
|
+
ValueError
|
|
50
|
+
If the patch size is smaller than 8.
|
|
51
|
+
ValueError
|
|
52
|
+
If the patch size is not a power of 2.
|
|
53
|
+
"""
|
|
54
|
+
patch_size_ge_than_8_power_of_2(patch_list)
|
|
55
|
+
|
|
56
|
+
return patch_list
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Random patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from ._patched_config import _PatchedConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RandomPatchingConfig(_PatchedConfig):
|
|
11
|
+
"""Random patching Pydantic model.
|
|
12
|
+
|
|
13
|
+
Attributes
|
|
14
|
+
----------
|
|
15
|
+
name : "random"
|
|
16
|
+
The name of the patching strategy.
|
|
17
|
+
patch_size : sequence of int
|
|
18
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
19
|
+
of 2 and larger than 8.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
name: Literal["random"] = "random"
|
|
23
|
+
"""The name of the patching strategy."""
|
|
24
|
+
|
|
25
|
+
seed: int | None = Field(default=None, gt=0)
|
|
26
|
+
"""Random seed for patch sampling, set to None for random seeding."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FixedRandomPatchingConfig(_PatchedConfig):
|
|
30
|
+
"""Fixed random patching Pydantic model.
|
|
31
|
+
|
|
32
|
+
Attributes
|
|
33
|
+
----------
|
|
34
|
+
name : "fixed_random"
|
|
35
|
+
The name of the patching strategy.
|
|
36
|
+
patch_size : sequence of int
|
|
37
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
38
|
+
of 2 and larger than 8.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
name: Literal["fixed_random"] = "fixed_random"
|
|
42
|
+
"""The name of the patching strategy."""
|
|
43
|
+
|
|
44
|
+
seed: int | None = Field(default=None, gt=0)
|
|
45
|
+
"""The random seed to use for patch sampling."""
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Sequential patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from ._overlapping_patched_config import _OverlappingPatchedConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SequentialPatchingConfig(_OverlappingPatchedConfig):
|
|
9
|
+
"""Sequential patching Pydantic model.
|
|
10
|
+
|
|
11
|
+
Attributes
|
|
12
|
+
----------
|
|
13
|
+
name : "sequential"
|
|
14
|
+
The name of the patching strategy.
|
|
15
|
+
patch_size : sequence of int
|
|
16
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
17
|
+
of 2 and larger than 8.
|
|
18
|
+
overlaps : list of int, optional
|
|
19
|
+
The overlaps between patches in each spatial dimension. If `None`, no overlap is
|
|
20
|
+
applied. The overlaps must be smaller than the patch size in each spatial
|
|
21
|
+
dimension, and the number of dimensions be either 2 or 3.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
name: Literal["sequential"] = "sequential"
|
|
25
|
+
"""The name of the patching strategy."""
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Tiled patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
|
|
8
|
+
from ._overlapping_patched_config import _OverlappingPatchedConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO with UNet tiling must obey different rules than sequential tiling
|
|
12
|
+
# - needs to validated at the level of the configuration
|
|
13
|
+
class TiledPatchingConfig(_OverlappingPatchedConfig):
|
|
14
|
+
"""Tiled patching Pydantic model.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
name : "tiled"
|
|
19
|
+
The name of the patching strategy.
|
|
20
|
+
patch_size : sequence of int
|
|
21
|
+
The size of the patch in each spatial dimension, each patch size must be a power
|
|
22
|
+
of 2 and larger than 8.
|
|
23
|
+
overlaps : sequence of int
|
|
24
|
+
The overlaps between patches in each spatial dimension. The overlaps must be
|
|
25
|
+
smaller than the patch size in each spatial dimension, and the number of
|
|
26
|
+
dimensions be either 2 or 3.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
name: Literal["tiled"] = "tiled"
|
|
30
|
+
"""The name of the patching strategy."""
|
|
31
|
+
|
|
32
|
+
overlaps: Sequence[int] = Field(
|
|
33
|
+
...,
|
|
34
|
+
min_length=2,
|
|
35
|
+
max_length=3,
|
|
36
|
+
)
|
|
37
|
+
"""The overlaps between patches in each spatial dimension. The overlaps must be
|
|
38
|
+
smaller than the patch size in each spatial dimension, and the number of dimensions
|
|
39
|
+
be either 2 or 3.
|
|
40
|
+
"""
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Whole image patching Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class WholePatchingConfig(BaseModel):
|
|
9
|
+
"""Whole image patching Pydantic model."""
|
|
10
|
+
|
|
11
|
+
name: Literal["whole"] = "whole"
|
|
12
|
+
"""The name of the patching strategy."""
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Pydantic model representing the metadata of a prediction tile."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Annotated
|
|
6
|
+
|
|
7
|
+
from annotated_types import Len
|
|
8
|
+
from pydantic import BaseModel, ConfigDict
|
|
9
|
+
|
|
10
|
+
DimTuple = Annotated[tuple, Len(min_length=3, max_length=4)]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TileInformation(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Pydantic model containing tile information.
|
|
16
|
+
|
|
17
|
+
This model is used to represent the information required to stitch back a tile into
|
|
18
|
+
a larger image. It is used throughout the prediction pipeline of CAREamics.
|
|
19
|
+
|
|
20
|
+
Array shape should be C(Z)YX, where Z is an optional dimensions.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
model_config = ConfigDict(validate_default=True)
|
|
24
|
+
|
|
25
|
+
array_shape: DimTuple # TODO: find a way to add custom error message?
|
|
26
|
+
"""Shape of the original (untiled) array."""
|
|
27
|
+
|
|
28
|
+
last_tile: bool = False
|
|
29
|
+
"""Whether this tile is the last one of the array."""
|
|
30
|
+
|
|
31
|
+
overlap_crop_coords: tuple[tuple[int, ...], ...]
|
|
32
|
+
"""Inner coordinates of the tile where to crop the prediction in order to stitch
|
|
33
|
+
it back into the original image."""
|
|
34
|
+
|
|
35
|
+
stitch_coords: tuple[tuple[int, ...], ...]
|
|
36
|
+
"""Coordinates in the original image where to stitch the cropped tile back."""
|
|
37
|
+
|
|
38
|
+
sample_id: int
|
|
39
|
+
"""Sample ID of the tile."""
|
|
40
|
+
|
|
41
|
+
# TODO: Test that ZYX axes are not singleton ?
|
|
42
|
+
|
|
43
|
+
def __eq__(self, other_tile: object):
|
|
44
|
+
"""Check if two tile information objects are equal.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
other_tile : object
|
|
49
|
+
Tile information object to compare with.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
bool
|
|
54
|
+
Whether the two tile information objects are equal.
|
|
55
|
+
"""
|
|
56
|
+
if not isinstance(other_tile, TileInformation):
|
|
57
|
+
return NotImplemented
|
|
58
|
+
|
|
59
|
+
return (
|
|
60
|
+
self.array_shape == other_tile.array_shape
|
|
61
|
+
and self.last_tile == other_tile.last_tile
|
|
62
|
+
and self.overlap_crop_coords == other_tile.overlap_crop_coords
|
|
63
|
+
and self.stitch_coords == other_tile.stitch_coords
|
|
64
|
+
and self.sample_id == other_tile.sample_id
|
|
65
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Training and lightning related Pydantic configurations."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CheckpointConfig",
|
|
5
|
+
"EarlyStoppingConfig",
|
|
6
|
+
"LrSchedulerConfig",
|
|
7
|
+
"OptimizerConfig",
|
|
8
|
+
"TrainerConfig",
|
|
9
|
+
"TrainingConfig",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
from .callbacks import CheckpointConfig, EarlyStoppingConfig
|
|
14
|
+
from .optimizer_configs import LrSchedulerConfig, OptimizerConfig
|
|
15
|
+
from .training_config import TrainingConfig
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Callback Pydantic models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
ConfigDict,
|
|
11
|
+
Field,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CheckpointConfig(BaseModel):
|
|
16
|
+
"""Checkpoint saving callback Pydantic model.
|
|
17
|
+
|
|
18
|
+
The parameters corresponds to those of
|
|
19
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
20
|
+
|
|
21
|
+
See:
|
|
22
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
26
|
+
|
|
27
|
+
monitor: Literal["val_loss"] | str | None = Field(default="val_loss")
|
|
28
|
+
"""Quantity to monitor, currently only `val_loss`."""
|
|
29
|
+
|
|
30
|
+
verbose: bool = Field(default=False)
|
|
31
|
+
"""Verbosity mode."""
|
|
32
|
+
|
|
33
|
+
save_weights_only: bool = Field(default=False)
|
|
34
|
+
"""When `True`, only the model's weights will be saved (model.save_weights)."""
|
|
35
|
+
|
|
36
|
+
save_last: Literal[True, False, "link"] | None = Field(default=True)
|
|
37
|
+
"""When `True`, saves a last.ckpt copy whenever a checkpoint file gets saved."""
|
|
38
|
+
|
|
39
|
+
save_top_k: int = Field(
|
|
40
|
+
default=3,
|
|
41
|
+
ge=-1,
|
|
42
|
+
le=100,
|
|
43
|
+
)
|
|
44
|
+
"""If `save_top_k == k, the best k models according to the quantity monitored
|
|
45
|
+
will be saved. If `save_top_k == 0`, no models are saved. if `save_top_k == -1`,
|
|
46
|
+
all models are saved."""
|
|
47
|
+
|
|
48
|
+
mode: Literal["min", "max"] = Field(default="min")
|
|
49
|
+
"""One of {min, max}. If `save_top_k != 0`, the decision to overwrite the current
|
|
50
|
+
save file is made based on either the maximization or the minimization of the
|
|
51
|
+
monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should
|
|
52
|
+
be 'min', etc.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
auto_insert_metric_name: bool = Field(default=False)
|
|
56
|
+
"""When `True`, the checkpoints filenames will contain the metric name."""
|
|
57
|
+
|
|
58
|
+
every_n_train_steps: int | None = Field(default=None, ge=1, le=1000)
|
|
59
|
+
"""Number of training steps between checkpoints."""
|
|
60
|
+
|
|
61
|
+
train_time_interval: timedelta | None = Field(default=None)
|
|
62
|
+
"""Checkpoints are monitored at the specified time interval."""
|
|
63
|
+
|
|
64
|
+
every_n_epochs: int | None = Field(default=None, ge=1, le=100)
|
|
65
|
+
"""Number of epochs between checkpoints."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class EarlyStoppingConfig(BaseModel):
|
|
69
|
+
"""Early stopping callback Pydantic model.
|
|
70
|
+
|
|
71
|
+
The parameters corresponds to those of
|
|
72
|
+
`pytorch_lightning.callbacks.ModelCheckpoint`.
|
|
73
|
+
|
|
74
|
+
See:
|
|
75
|
+
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
model_config = ConfigDict(
|
|
79
|
+
validate_assignment=True,
|
|
80
|
+
validate_default=True,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
monitor: Literal["val_loss"] = Field(default="val_loss")
|
|
84
|
+
"""Quantity to monitor."""
|
|
85
|
+
|
|
86
|
+
min_delta: float = Field(default=0.0, ge=0.0, le=1.0)
|
|
87
|
+
"""Minimum change in the monitored quantity to qualify as an improvement, i.e. an
|
|
88
|
+
absolute change of less than or equal to min_delta, will count as no improvement."""
|
|
89
|
+
|
|
90
|
+
patience: int = Field(default=3, ge=1, le=10)
|
|
91
|
+
"""Number of checks with no improvement after which training will be stopped."""
|
|
92
|
+
|
|
93
|
+
verbose: bool = Field(default=False)
|
|
94
|
+
"""Verbosity mode."""
|
|
95
|
+
|
|
96
|
+
mode: Literal["min", "max", "auto"] = Field(default="min")
|
|
97
|
+
"""One of {min, max, auto}."""
|
|
98
|
+
|
|
99
|
+
check_finite: bool = Field(default=True)
|
|
100
|
+
"""When `True`, stops training when the monitored quantity becomes `NaN` or
|
|
101
|
+
`inf`."""
|
|
102
|
+
|
|
103
|
+
stopping_threshold: float | None = Field(default=None)
|
|
104
|
+
"""Stop training immediately once the monitored quantity reaches this threshold."""
|
|
105
|
+
|
|
106
|
+
divergence_threshold: float | None = Field(default=None)
|
|
107
|
+
"""Stop training as soon as the monitored quantity becomes worse than this
|
|
108
|
+
threshold."""
|
|
109
|
+
|
|
110
|
+
check_on_train_epoch_end: bool | None = Field(default=False)
|
|
111
|
+
"""Whether to run early stopping at the end of the training epoch. If this is
|
|
112
|
+
`False`, then the check runs at the end of the validation."""
|
|
113
|
+
|
|
114
|
+
log_rank_zero_only: bool = Field(default=False)
|
|
115
|
+
"""When set `True`, logs the status of the early stopping callback only for rank 0
|
|
116
|
+
process."""
|