careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""VAE-based algorithm Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pprint import pformat
|
|
6
|
+
from typing import Literal, Self
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
9
|
+
|
|
10
|
+
from careamics.config.architectures import LVAEConfig
|
|
11
|
+
from careamics.config.lightning.optimizer_configs import (
|
|
12
|
+
LrSchedulerConfig,
|
|
13
|
+
OptimizerConfig,
|
|
14
|
+
)
|
|
15
|
+
from careamics.config.losses.loss_config import LVAELossConfig
|
|
16
|
+
from careamics.config.noise_model.likelihood_config import (
|
|
17
|
+
GaussianLikelihoodConfig,
|
|
18
|
+
NMLikelihoodConfig,
|
|
19
|
+
)
|
|
20
|
+
from careamics.config.noise_model.noise_model_config import MultiChannelNMConfig
|
|
21
|
+
from careamics.config.support import SupportedAlgorithm, SupportedLoss
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VAEBasedAlgorithm(BaseModel):
|
|
25
|
+
"""VAE-based algorithm configuration.
|
|
26
|
+
|
|
27
|
+
# TODO
|
|
28
|
+
|
|
29
|
+
Examples
|
|
30
|
+
--------
|
|
31
|
+
# TODO add once finalized
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Pydantic class configuration
|
|
35
|
+
model_config = ConfigDict(
|
|
36
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
37
|
+
validate_assignment=True,
|
|
38
|
+
extra="allow",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Mandatory fields
|
|
42
|
+
# defined in SupportedAlgorithm
|
|
43
|
+
# TODO: Use supported Enum classes for typing?
|
|
44
|
+
# - values can still be passed as strings and they will be cast to Enum
|
|
45
|
+
algorithm: Literal["hdn", "microsplit"]
|
|
46
|
+
|
|
47
|
+
# NOTE: these are all configs (pydantic models)
|
|
48
|
+
loss: LVAELossConfig
|
|
49
|
+
model: LVAEConfig
|
|
50
|
+
noise_model: MultiChannelNMConfig | None = None
|
|
51
|
+
noise_model_likelihood: NMLikelihoodConfig | None = None
|
|
52
|
+
gaussian_likelihood: GaussianLikelihoodConfig | None = None # TODO change to str
|
|
53
|
+
|
|
54
|
+
mmse_count: int = 1
|
|
55
|
+
is_supervised: bool = False
|
|
56
|
+
|
|
57
|
+
# Optional fields
|
|
58
|
+
optimizer: OptimizerConfig = OptimizerConfig()
|
|
59
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
60
|
+
|
|
61
|
+
lr_scheduler: LrSchedulerConfig = LrSchedulerConfig()
|
|
62
|
+
|
|
63
|
+
@model_validator(mode="after")
|
|
64
|
+
def algorithm_cross_validation(self: Self) -> Self:
|
|
65
|
+
"""Validate the algorithm model based on `algorithm`.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
Self
|
|
70
|
+
The validated model.
|
|
71
|
+
"""
|
|
72
|
+
# hdn
|
|
73
|
+
# TODO move to designated configurations
|
|
74
|
+
if self.algorithm == SupportedAlgorithm.HDN:
|
|
75
|
+
if self.loss.loss_type != SupportedLoss.HDN:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Algorithm {self.algorithm} only supports loss `hdn`."
|
|
78
|
+
)
|
|
79
|
+
if self.model.multiscale_count > 1:
|
|
80
|
+
raise ValueError("Algorithm `hdn` does not support multiscale models.")
|
|
81
|
+
# musplit
|
|
82
|
+
if self.algorithm == SupportedAlgorithm.MICROSPLIT:
|
|
83
|
+
if self.loss.loss_type not in [
|
|
84
|
+
SupportedLoss.MUSPLIT,
|
|
85
|
+
SupportedLoss.DENOISPLIT,
|
|
86
|
+
SupportedLoss.DENOISPLIT_MUSPLIT,
|
|
87
|
+
]: # TODO Update losses configs, make loss just microsplit
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Algorithm {self.algorithm} only supports loss `microsplit`."
|
|
90
|
+
) # TODO Update losses configs
|
|
91
|
+
|
|
92
|
+
if (
|
|
93
|
+
self.loss.loss_type == SupportedLoss.DENOISPLIT
|
|
94
|
+
and self.model.predict_logvar is not None
|
|
95
|
+
):
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"Algorithm `denoisplit` with loss `denoisplit` only supports "
|
|
98
|
+
"`predict_logvar` as `None`."
|
|
99
|
+
)
|
|
100
|
+
if (
|
|
101
|
+
self.loss.loss_type == SupportedLoss.DENOISPLIT
|
|
102
|
+
and self.noise_model is None
|
|
103
|
+
):
|
|
104
|
+
raise ValueError("Algorithm `denoisplit` requires a noise model.")
|
|
105
|
+
# TODO: what if algorithm is not musplit or denoisplit
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
@model_validator(mode="after")
|
|
109
|
+
def output_channels_validation(self: Self) -> Self:
|
|
110
|
+
"""Validate the consistency between number of out channels and noise models.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
Self
|
|
115
|
+
The validated model.
|
|
116
|
+
"""
|
|
117
|
+
if self.noise_model is not None:
|
|
118
|
+
assert self.model.output_channels == len(self.noise_model.noise_models), (
|
|
119
|
+
f"Number of output channels ({self.model.output_channels}) must match "
|
|
120
|
+
f"the number of noise models ({len(self.noise_model.noise_models)})."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if self.algorithm == SupportedAlgorithm.HDN:
|
|
124
|
+
assert self.model.output_channels == 1, (
|
|
125
|
+
f"Number of output channels ({self.model.output_channels}) must be 1 "
|
|
126
|
+
"for algorithm `hdn`."
|
|
127
|
+
)
|
|
128
|
+
return self
|
|
129
|
+
|
|
130
|
+
@model_validator(mode="after")
|
|
131
|
+
def predict_logvar_validation(self: Self) -> Self:
|
|
132
|
+
"""Validate the consistency of `predict_logvar` throughout the model.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
Self
|
|
137
|
+
The validated model.
|
|
138
|
+
"""
|
|
139
|
+
if self.gaussian_likelihood is not None:
|
|
140
|
+
assert (
|
|
141
|
+
self.model.predict_logvar == self.gaussian_likelihood.predict_logvar
|
|
142
|
+
), (
|
|
143
|
+
f"Model `predict_logvar` ({self.model.predict_logvar}) must match "
|
|
144
|
+
"Gaussian likelihood model `predict_logvar` "
|
|
145
|
+
f"({self.gaussian_likelihood.predict_logvar}).",
|
|
146
|
+
)
|
|
147
|
+
# if self.algorithm == SupportedAlgorithm.HDN:
|
|
148
|
+
# assert (
|
|
149
|
+
# self.model.predict_logvar is None
|
|
150
|
+
# ), "Model `predict_logvar` must be `None` for algorithm `hdn`."
|
|
151
|
+
# if self.gaussian_likelihood is not None:
|
|
152
|
+
# assert self.gaussian_likelihood.predict_logvar is None, (
|
|
153
|
+
# "Gaussian likelihood model `predict_logvar` must be `None` "
|
|
154
|
+
# "for algorithm `hdn`."
|
|
155
|
+
# )
|
|
156
|
+
# TODO check this
|
|
157
|
+
return self
|
|
158
|
+
|
|
159
|
+
def __str__(self) -> str:
|
|
160
|
+
"""Pretty string representing the configuration.
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
str
|
|
165
|
+
Pretty string.
|
|
166
|
+
"""
|
|
167
|
+
return pformat(self.model_dump())
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def get_compatible_algorithms(cls) -> list[str]:
|
|
171
|
+
"""Get the list of compatible algorithms.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
list of str
|
|
176
|
+
List of compatible algorithms.
|
|
177
|
+
"""
|
|
178
|
+
return ["hdn", "microsplit"]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Base model for the various CAREamics architectures."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ArchitectureConfig(BaseModel):
|
|
9
|
+
"""
|
|
10
|
+
Base Pydantic model for all model architectures.
|
|
11
|
+
|
|
12
|
+
The `model_dump` method allows removing the `architecture` key from the model.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
architecture: str
|
|
16
|
+
"""Name of the architecture."""
|
|
17
|
+
|
|
18
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
19
|
+
"""
|
|
20
|
+
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
**kwargs : Any
|
|
25
|
+
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
{str: Any}
|
|
30
|
+
Model as a dictionary.
|
|
31
|
+
"""
|
|
32
|
+
model_dict = super().model_dump(**kwargs)
|
|
33
|
+
|
|
34
|
+
# remove the architecture key
|
|
35
|
+
model_dict.pop("architecture")
|
|
36
|
+
|
|
37
|
+
return model_dict
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""LVAE Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Self
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, field_validator, model_validator
|
|
6
|
+
|
|
7
|
+
from .architecture_config import ArchitectureConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# TODO: it is quite confusing to call this LVAEModel, as it is basically a config
|
|
11
|
+
class LVAEConfig(ArchitectureConfig):
|
|
12
|
+
"""LVAE model."""
|
|
13
|
+
|
|
14
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
15
|
+
|
|
16
|
+
architecture: Literal["LVAE"]
|
|
17
|
+
|
|
18
|
+
input_shape: tuple[int, ...] = Field(default=(64, 64), validate_default=True)
|
|
19
|
+
"""Shape of the input patch (Z, Y, X) or (Y, X) if the data is 2D."""
|
|
20
|
+
encoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
21
|
+
|
|
22
|
+
# TODO make this per hierarchy step ?
|
|
23
|
+
decoder_conv_strides: list = Field(default=[2, 2], validate_default=True)
|
|
24
|
+
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
25
|
+
|
|
26
|
+
multiscale_count: int = Field(default=1)
|
|
27
|
+
# TODO there should be a check for multiscale_count in dataset !!
|
|
28
|
+
|
|
29
|
+
# 1 - off, len(z_dims) + 1 # TODO Consider starting from 0
|
|
30
|
+
z_dims: list = Field(default=[128, 128, 128, 128])
|
|
31
|
+
output_channels: int = Field(default=1, ge=1)
|
|
32
|
+
encoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
33
|
+
decoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
34
|
+
encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
35
|
+
decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
36
|
+
nonlinearity: Literal[
|
|
37
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
38
|
+
] = Field(
|
|
39
|
+
default="ELU",
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
predict_logvar: Literal[None, "pixelwise"] = "pixelwise"
|
|
43
|
+
analytical_kl: bool = Field(default=False)
|
|
44
|
+
|
|
45
|
+
@model_validator(mode="after")
|
|
46
|
+
def validate_conv_strides(self: Self) -> Self:
|
|
47
|
+
"""
|
|
48
|
+
Validate the convolutional strides.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
list
|
|
53
|
+
Validated strides.
|
|
54
|
+
|
|
55
|
+
Raises
|
|
56
|
+
------
|
|
57
|
+
ValueError
|
|
58
|
+
If the number of strides is not 2.
|
|
59
|
+
"""
|
|
60
|
+
if len(self.encoder_conv_strides) < 2 or len(self.encoder_conv_strides) > 3:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Strides must be 2 or 3 (got {len(self.encoder_conv_strides)})."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if len(self.decoder_conv_strides) < 2 or len(self.decoder_conv_strides) > 3:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Strides must be 2 or 3 (got {len(self.decoder_conv_strides)})."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# adding 1 to encoder strides for the number of input channels
|
|
71
|
+
if len(self.input_shape) != len(self.encoder_conv_strides):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Input dimensions must be equal to the number of encoder conv strides"
|
|
74
|
+
f" (got {len(self.input_shape)} and {len(self.encoder_conv_strides)})."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if len(self.encoder_conv_strides) < len(self.decoder_conv_strides):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Decoder can't be 3D when encoder is 2D (got"
|
|
80
|
+
f" {len(self.encoder_conv_strides)} and"
|
|
81
|
+
f"{len(self.decoder_conv_strides)})."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if any(s < 1 for s in self.encoder_conv_strides) or any(
|
|
85
|
+
s < 1 for s in self.decoder_conv_strides
|
|
86
|
+
):
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"All strides must be greater or equal to 1"
|
|
89
|
+
f"(got {self.encoder_conv_strides} and {self.decoder_conv_strides})."
|
|
90
|
+
)
|
|
91
|
+
# TODO: validate max stride size ?
|
|
92
|
+
return self
|
|
93
|
+
|
|
94
|
+
@field_validator("input_shape")
|
|
95
|
+
@classmethod
|
|
96
|
+
def validate_input_shape(cls, input_shape: list) -> list:
|
|
97
|
+
"""
|
|
98
|
+
Validate the input shape.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
input_shape : list
|
|
103
|
+
Shape of the input patch.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
list
|
|
108
|
+
Validated input shape.
|
|
109
|
+
|
|
110
|
+
Raises
|
|
111
|
+
------
|
|
112
|
+
ValueError
|
|
113
|
+
If the number of dimensions is not 3 or 4.
|
|
114
|
+
"""
|
|
115
|
+
if len(input_shape) < 2 or len(input_shape) > 3:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Number of input dimensions must be 2 for 2D data 3 for 3D"
|
|
118
|
+
f"(got {len(input_shape)})."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if any(s < 1 for s in input_shape):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Input shape must be greater than 1 in all dimensions"
|
|
124
|
+
f"(got {input_shape})."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if any(s < 64 for s in input_shape[-2:]):
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Input shape must be greater or equal to 64 in XY dimensions"
|
|
130
|
+
f"(got {input_shape})."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return input_shape
|
|
134
|
+
|
|
135
|
+
@field_validator("encoder_n_filters")
|
|
136
|
+
@classmethod
|
|
137
|
+
def validate_encoder_even(cls, encoder_n_filters: int) -> int:
|
|
138
|
+
"""
|
|
139
|
+
Validate that num_channels_init is even.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
encoder_n_filters : int
|
|
144
|
+
Number of channels.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
int
|
|
149
|
+
Validated number of channels.
|
|
150
|
+
|
|
151
|
+
Raises
|
|
152
|
+
------
|
|
153
|
+
ValueError
|
|
154
|
+
If the number of channels is odd.
|
|
155
|
+
"""
|
|
156
|
+
# if odd
|
|
157
|
+
if encoder_n_filters % 2 != 0:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Number of channels for the bottom layer must be even"
|
|
160
|
+
f" (got {encoder_n_filters})."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return encoder_n_filters
|
|
164
|
+
|
|
165
|
+
@field_validator("decoder_n_filters")
|
|
166
|
+
@classmethod
|
|
167
|
+
def validate_decoder_even(cls, decoder_n_filters: int) -> int:
|
|
168
|
+
"""
|
|
169
|
+
Validate that num_channels_init is even.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
decoder_n_filters : int
|
|
174
|
+
Number of channels.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
int
|
|
179
|
+
Validated number of channels.
|
|
180
|
+
|
|
181
|
+
Raises
|
|
182
|
+
------
|
|
183
|
+
ValueError
|
|
184
|
+
If the number of channels is odd.
|
|
185
|
+
"""
|
|
186
|
+
# if odd
|
|
187
|
+
if decoder_n_filters % 2 != 0:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f"Number of channels for the bottom layer must be even"
|
|
190
|
+
f" (got {decoder_n_filters})."
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return decoder_n_filters
|
|
194
|
+
|
|
195
|
+
@field_validator("z_dims")
|
|
196
|
+
def validate_z_dims(cls, z_dims: tuple) -> tuple:
|
|
197
|
+
"""
|
|
198
|
+
Validate the z_dims.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
z_dims : tuple
|
|
203
|
+
Tuple of z dimensions.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
tuple
|
|
208
|
+
Validated z dimensions.
|
|
209
|
+
|
|
210
|
+
Raises
|
|
211
|
+
------
|
|
212
|
+
ValueError
|
|
213
|
+
If the number of z dimensions is not 4.
|
|
214
|
+
"""
|
|
215
|
+
if len(z_dims) < 2:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return z_dims
|
|
221
|
+
|
|
222
|
+
@model_validator(mode="after")
|
|
223
|
+
def validate_multiscale_count(self: Self) -> Self:
|
|
224
|
+
"""
|
|
225
|
+
Validate the multiscale count.
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
Self
|
|
230
|
+
The validated model.
|
|
231
|
+
"""
|
|
232
|
+
if self.multiscale_count < 1 or self.multiscale_count > len(self.z_dims) + 1:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
f"Multiscale count must be 1 for LC off or less or equal to the number"
|
|
235
|
+
f" of Z dims + 1 (got {self.multiscale_count} and {len(self.z_dims)})."
|
|
236
|
+
)
|
|
237
|
+
return self
|
|
238
|
+
|
|
239
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
240
|
+
"""
|
|
241
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
is_3D : bool
|
|
246
|
+
Whether the algorithm is 3D or not.
|
|
247
|
+
"""
|
|
248
|
+
if is_3D:
|
|
249
|
+
self.conv_dims = 3
|
|
250
|
+
else:
|
|
251
|
+
self.conv_dims = 2
|
|
252
|
+
|
|
253
|
+
def is_3D(self) -> bool:
|
|
254
|
+
"""
|
|
255
|
+
Return whether the model is 3D or not.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
bool
|
|
260
|
+
Whether the model is 3D or not.
|
|
261
|
+
"""
|
|
262
|
+
return len(self.input_shape) == 3
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""UNet Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import ConfigDict, Field, field_validator
|
|
8
|
+
|
|
9
|
+
from .architecture_config import ArchitectureConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO tests activation <-> pydantic model, test the literals!
|
|
13
|
+
# TODO annotations for the json schema?
|
|
14
|
+
class UNetConfig(ArchitectureConfig):
|
|
15
|
+
"""
|
|
16
|
+
Pydantic model for a N2V(2)-compatible UNet.
|
|
17
|
+
|
|
18
|
+
Attributes
|
|
19
|
+
----------
|
|
20
|
+
depth : int
|
|
21
|
+
Depth of the model, between 1 and 10 (default 2).
|
|
22
|
+
num_channels_init : int
|
|
23
|
+
Number of filters of the first level of the network, should be even
|
|
24
|
+
and minimum 8 (default 96).
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# pydantic model config
|
|
28
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
29
|
+
|
|
30
|
+
# discriminator used for choosing the pydantic model in Model
|
|
31
|
+
architecture: Literal["UNet"]
|
|
32
|
+
"""Name of the architecture."""
|
|
33
|
+
|
|
34
|
+
# parameters
|
|
35
|
+
# validate_defaults allow ignoring default values in the dump if they were not set
|
|
36
|
+
conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
|
|
37
|
+
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
38
|
+
|
|
39
|
+
num_classes: int = Field(default=1, ge=1, validate_default=True)
|
|
40
|
+
"""Number of classes or channels in the model output."""
|
|
41
|
+
|
|
42
|
+
in_channels: int = Field(default=1, ge=1, validate_default=True)
|
|
43
|
+
"""Number of channels in the input to the model."""
|
|
44
|
+
|
|
45
|
+
depth: int = Field(default=2, ge=1, le=10, validate_default=True)
|
|
46
|
+
"""Number of levels in the UNet."""
|
|
47
|
+
|
|
48
|
+
num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
|
|
49
|
+
"""Number of convolutional filters in the first layer of the UNet."""
|
|
50
|
+
|
|
51
|
+
# TODO we are not using this, so why make it a choice?
|
|
52
|
+
final_activation: Literal[
|
|
53
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
54
|
+
] = Field(default="None", validate_default=True)
|
|
55
|
+
"""Final activation function."""
|
|
56
|
+
|
|
57
|
+
n2v2: bool = Field(default=False, validate_default=True)
|
|
58
|
+
"""Whether to use N2V2 architecture modifications, with blur pool layers and fewer
|
|
59
|
+
skip connections.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
independent_channels: bool = Field(default=True, validate_default=True)
|
|
63
|
+
"""Whether information is processed independently in each channel, used to train
|
|
64
|
+
channels independently."""
|
|
65
|
+
|
|
66
|
+
use_batch_norm: bool = Field(default=True, validate_default=True)
|
|
67
|
+
"""Whether to use batch normalization in the model."""
|
|
68
|
+
|
|
69
|
+
@field_validator("num_channels_init")
|
|
70
|
+
@classmethod
|
|
71
|
+
def validate_num_channels_init(cls, num_channels_init: int) -> int:
|
|
72
|
+
"""
|
|
73
|
+
Validate that num_channels_init is even.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
num_channels_init : int
|
|
78
|
+
Number of channels.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
int
|
|
83
|
+
Validated number of channels.
|
|
84
|
+
|
|
85
|
+
Raises
|
|
86
|
+
------
|
|
87
|
+
ValueError
|
|
88
|
+
If the number of channels is odd.
|
|
89
|
+
"""
|
|
90
|
+
# if odd
|
|
91
|
+
if num_channels_init % 2 != 0:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Number of channels for the bottom layer must be even"
|
|
94
|
+
f" (got {num_channels_init})."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return num_channels_init
|
|
98
|
+
|
|
99
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
is_3D : bool
|
|
106
|
+
Whether the algorithm is 3D or not.
|
|
107
|
+
"""
|
|
108
|
+
if is_3D:
|
|
109
|
+
self.conv_dims = 3
|
|
110
|
+
else:
|
|
111
|
+
self.conv_dims = 2
|
|
112
|
+
|
|
113
|
+
def is_3D(self) -> bool:
|
|
114
|
+
"""
|
|
115
|
+
Return whether the model is 3D or not.
|
|
116
|
+
|
|
117
|
+
This method is used in the NG configuration validation to check that the model
|
|
118
|
+
dimensions match the data dimensions.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
bool
|
|
123
|
+
Whether the model is 3D or not.
|
|
124
|
+
"""
|
|
125
|
+
return self.conv_dims == 3
|