careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Configuration module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"FCNAlgorithmConfig",
|
|
5
|
+
"VAEAlgorithmConfig",
|
|
6
|
+
"DataConfig",
|
|
7
|
+
"Configuration",
|
|
8
|
+
"CheckpointModel",
|
|
9
|
+
"InferenceConfig",
|
|
10
|
+
"load_configuration",
|
|
11
|
+
"save_configuration",
|
|
12
|
+
"TrainingConfig",
|
|
13
|
+
"create_n2v_configuration",
|
|
14
|
+
"create_n2n_configuration",
|
|
15
|
+
"create_care_configuration",
|
|
16
|
+
"register_model",
|
|
17
|
+
"CustomModel",
|
|
18
|
+
"clear_custom_models",
|
|
19
|
+
"GaussianMixtureNMConfig",
|
|
20
|
+
"MultiChannelNMConfig",
|
|
21
|
+
]
|
|
22
|
+
from .architectures import CustomModel, clear_custom_models, register_model
|
|
23
|
+
from .callback_model import CheckpointModel
|
|
24
|
+
from .configuration_factory import (
|
|
25
|
+
create_care_configuration,
|
|
26
|
+
create_n2n_configuration,
|
|
27
|
+
create_n2v_configuration,
|
|
28
|
+
)
|
|
29
|
+
from .configuration_model import (
|
|
30
|
+
Configuration,
|
|
31
|
+
load_configuration,
|
|
32
|
+
save_configuration,
|
|
33
|
+
)
|
|
34
|
+
from .data_model import DataConfig
|
|
35
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
36
|
+
from .inference_model import InferenceConfig
|
|
37
|
+
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
38
|
+
from .training_model import TrainingConfig
|
|
39
|
+
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Deep-learning model configurations."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"ArchitectureModel",
|
|
5
|
+
"CustomModel",
|
|
6
|
+
"UNetModel",
|
|
7
|
+
"LVAEModel",
|
|
8
|
+
"clear_custom_models",
|
|
9
|
+
"get_custom_model",
|
|
10
|
+
"register_model",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from .architecture_model import ArchitectureModel
|
|
14
|
+
from .custom_model import CustomModel
|
|
15
|
+
from .lvae_model import LVAEModel
|
|
16
|
+
from .register_model import clear_custom_models, get_custom_model, register_model
|
|
17
|
+
from .unet_model import UNetModel
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Base model for the various CAREamics architectures."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ArchitectureModel(BaseModel):
|
|
9
|
+
"""
|
|
10
|
+
Base Pydantic model for all model architectures.
|
|
11
|
+
|
|
12
|
+
The `model_dump` method allows removing the `architecture` key from the model.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
architecture: str
|
|
16
|
+
"""Name of the architecture."""
|
|
17
|
+
|
|
18
|
+
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
|
19
|
+
"""
|
|
20
|
+
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
**kwargs : Any
|
|
25
|
+
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
dict[str, Any]
|
|
30
|
+
Model as a dictionary.
|
|
31
|
+
"""
|
|
32
|
+
model_dict = super().model_dump(**kwargs)
|
|
33
|
+
|
|
34
|
+
# remove the architecture key
|
|
35
|
+
model_dict.pop("architecture")
|
|
36
|
+
|
|
37
|
+
return model_dict
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Custom architecture Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
from pprint import pformat
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import ConfigDict, field_validator, model_validator
|
|
10
|
+
from torch.nn import Module
|
|
11
|
+
from typing_extensions import Self
|
|
12
|
+
|
|
13
|
+
from .architecture_model import ArchitectureModel
|
|
14
|
+
from .register_model import get_custom_model
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CustomModel(ArchitectureModel):
|
|
18
|
+
"""Custom model configuration.
|
|
19
|
+
|
|
20
|
+
This Pydantic model allows storing parameters for a custom model. In order for the
|
|
21
|
+
model to be valid, the specific model needs to be registered using the
|
|
22
|
+
`register_model` decorator, and its name correctly passed to this model
|
|
23
|
+
configuration (see Examples).
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
architecture : Literal["custom"]
|
|
28
|
+
Discriminator for the custom model, must be set to "custom".
|
|
29
|
+
name : str
|
|
30
|
+
Name of the custom model.
|
|
31
|
+
parameters : CustomParametersModel
|
|
32
|
+
All parameters, required for the initialization of the torch module have to be
|
|
33
|
+
passed here.
|
|
34
|
+
|
|
35
|
+
Raises
|
|
36
|
+
------
|
|
37
|
+
ValueError
|
|
38
|
+
If the custom model `name` is unknown.
|
|
39
|
+
ValueError
|
|
40
|
+
If the custom model is not a torch Module subclass.
|
|
41
|
+
ValueError
|
|
42
|
+
If the custom model parameters are not valid.
|
|
43
|
+
|
|
44
|
+
Examples
|
|
45
|
+
--------
|
|
46
|
+
>>> from torch import nn, ones
|
|
47
|
+
>>> from careamics.config import CustomModel, register_model
|
|
48
|
+
>>> # Register a custom model
|
|
49
|
+
>>> @register_model(name="my_linear")
|
|
50
|
+
... class LinearModel(nn.Module):
|
|
51
|
+
... def __init__(self, in_features, out_features, *args, **kwargs):
|
|
52
|
+
... super().__init__()
|
|
53
|
+
... self.in_features = in_features
|
|
54
|
+
... self.out_features = out_features
|
|
55
|
+
... self.weight = nn.Parameter(ones(in_features, out_features))
|
|
56
|
+
... self.bias = nn.Parameter(ones(out_features))
|
|
57
|
+
... def forward(self, input):
|
|
58
|
+
... return (input @ self.weight) + self.bias
|
|
59
|
+
...
|
|
60
|
+
>>> # Create a configuration
|
|
61
|
+
>>> config_dict = {
|
|
62
|
+
... "architecture": "custom",
|
|
63
|
+
... "name": "my_linear",
|
|
64
|
+
... "in_features": 10,
|
|
65
|
+
... "out_features": 5,
|
|
66
|
+
... }
|
|
67
|
+
>>> config = CustomModel(**config_dict)
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
# pydantic model config
|
|
71
|
+
model_config = ConfigDict(
|
|
72
|
+
extra="allow",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# discriminator used for choosing the pydantic model in Model
|
|
76
|
+
architecture: Literal["custom"]
|
|
77
|
+
"""Name of the architecture."""
|
|
78
|
+
|
|
79
|
+
name: str
|
|
80
|
+
"""Name of the custom model."""
|
|
81
|
+
|
|
82
|
+
@field_validator("name")
|
|
83
|
+
@classmethod
|
|
84
|
+
def custom_model_is_known(cls, value: str) -> str:
|
|
85
|
+
"""Check whether the custom model is known.
|
|
86
|
+
|
|
87
|
+
Parameters
|
|
88
|
+
----------
|
|
89
|
+
value : str
|
|
90
|
+
Name of the custom model as registered using the `@register_model`
|
|
91
|
+
decorator.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
str
|
|
96
|
+
The custom model name.
|
|
97
|
+
"""
|
|
98
|
+
# delegate error to get_custom_model
|
|
99
|
+
model = get_custom_model(value)
|
|
100
|
+
|
|
101
|
+
# check if it is a torch Module subclass
|
|
102
|
+
if not issubclass(model, Module):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f'Retrieved class {model} with name "{value}" is not a '
|
|
105
|
+
f"torch.nn.Module subclass."
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return value
|
|
109
|
+
|
|
110
|
+
@model_validator(mode="after")
|
|
111
|
+
def check_parameters(self: Self) -> Self:
|
|
112
|
+
"""Validate model by instantiating the model with the parameters.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
Self
|
|
117
|
+
The validated model.
|
|
118
|
+
"""
|
|
119
|
+
# instantiate model
|
|
120
|
+
try:
|
|
121
|
+
get_custom_model(self.name)(**self.model_dump())
|
|
122
|
+
except Exception as e:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"while passing parameters to the model {e}. Verify that all "
|
|
125
|
+
f"mandatory parameters are provided, and that either the {e} accepts "
|
|
126
|
+
f"*args and **kwargs in its __init__() method, or that no additional"
|
|
127
|
+
f"parameter is provided. Trace: "
|
|
128
|
+
f"filename: {inspect.trace()[-1].filename}, function: "
|
|
129
|
+
f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
|
|
130
|
+
) from None
|
|
131
|
+
|
|
132
|
+
return self
|
|
133
|
+
|
|
134
|
+
def __str__(self) -> str:
|
|
135
|
+
"""Pretty string representing the configuration.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
str
|
|
140
|
+
Pretty string.
|
|
141
|
+
"""
|
|
142
|
+
return pformat(self.model_dump())
|
|
143
|
+
|
|
144
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
145
|
+
"""Dump the model configuration.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
**kwargs : Any
|
|
150
|
+
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
dict[str, Any]
|
|
155
|
+
Model configuration.
|
|
156
|
+
"""
|
|
157
|
+
model_dict = super().model_dump()
|
|
158
|
+
|
|
159
|
+
# remove the name key
|
|
160
|
+
model_dict.pop("name")
|
|
161
|
+
|
|
162
|
+
return model_dict
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""LVAE Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, field_validator, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from .architecture_model import ArchitectureModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO: it is quite confusing to call this LVAEModel, as it is basically a config
|
|
12
|
+
class LVAEModel(ArchitectureModel):
|
|
13
|
+
"""LVAE model."""
|
|
14
|
+
|
|
15
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
16
|
+
|
|
17
|
+
architecture: Literal["LVAE"]
|
|
18
|
+
input_shape: int = Field(default=64, ge=8, le=1024)
|
|
19
|
+
multiscale_count: int = Field(default=5) # TODO clarify
|
|
20
|
+
# 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
|
|
21
|
+
z_dims: list = Field(default=[128, 128, 128, 128])
|
|
22
|
+
output_channels: int = Field(default=1, ge=1)
|
|
23
|
+
encoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
24
|
+
decoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
25
|
+
encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
26
|
+
decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
27
|
+
nonlinearity: Literal[
|
|
28
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
29
|
+
] = Field(
|
|
30
|
+
default="ELU",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
predict_logvar: Literal[None, "pixelwise"] = None
|
|
34
|
+
|
|
35
|
+
# TODO this parameter is exessive -> Remove & refactor
|
|
36
|
+
enable_noise_model: bool = Field(
|
|
37
|
+
default=True,
|
|
38
|
+
)
|
|
39
|
+
analytical_kl: bool = Field(
|
|
40
|
+
default=False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@field_validator("encoder_n_filters")
|
|
44
|
+
@classmethod
|
|
45
|
+
def validate_encoder_even(cls, encoder_n_filters: int) -> int:
|
|
46
|
+
"""
|
|
47
|
+
Validate that num_channels_init is even.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
encoder_n_filters : int
|
|
52
|
+
Number of channels.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
int
|
|
57
|
+
Validated number of channels.
|
|
58
|
+
|
|
59
|
+
Raises
|
|
60
|
+
------
|
|
61
|
+
ValueError
|
|
62
|
+
If the number of channels is odd.
|
|
63
|
+
"""
|
|
64
|
+
# if odd
|
|
65
|
+
if encoder_n_filters % 2 != 0:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Number of channels for the bottom layer must be even"
|
|
68
|
+
f" (got {encoder_n_filters})."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return encoder_n_filters
|
|
72
|
+
|
|
73
|
+
@field_validator("decoder_n_filters")
|
|
74
|
+
@classmethod
|
|
75
|
+
def validate_decoder_even(cls, decoder_n_filters: int) -> int:
|
|
76
|
+
"""
|
|
77
|
+
Validate that num_channels_init is even.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
decoder_n_filters : int
|
|
82
|
+
Number of channels.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
int
|
|
87
|
+
Validated number of channels.
|
|
88
|
+
|
|
89
|
+
Raises
|
|
90
|
+
------
|
|
91
|
+
ValueError
|
|
92
|
+
If the number of channels is odd.
|
|
93
|
+
"""
|
|
94
|
+
# if odd
|
|
95
|
+
if decoder_n_filters % 2 != 0:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Number of channels for the bottom layer must be even"
|
|
98
|
+
f" (got {decoder_n_filters})."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return decoder_n_filters
|
|
102
|
+
|
|
103
|
+
@field_validator("z_dims")
|
|
104
|
+
def validate_z_dims(cls, z_dims: tuple) -> tuple:
|
|
105
|
+
"""
|
|
106
|
+
Validate the z_dims.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
z_dims : tuple
|
|
111
|
+
Tuple of z dimensions.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
tuple
|
|
116
|
+
Validated z dimensions.
|
|
117
|
+
|
|
118
|
+
Raises
|
|
119
|
+
------
|
|
120
|
+
ValueError
|
|
121
|
+
If the number of z dimensions is not 4.
|
|
122
|
+
"""
|
|
123
|
+
if len(z_dims) < 2:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return z_dims
|
|
129
|
+
|
|
130
|
+
@model_validator(mode="after")
|
|
131
|
+
def validate_multiscale_count(cls, self: Self) -> Self:
|
|
132
|
+
"""
|
|
133
|
+
Validate the multiscale count.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
self : Self
|
|
138
|
+
The model.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
Self
|
|
143
|
+
The validated model.
|
|
144
|
+
"""
|
|
145
|
+
# if self.multiscale_count != 0:
|
|
146
|
+
# if self.multiscale_count != len(self.z_dims) - 1:
|
|
147
|
+
# raise ValueError(
|
|
148
|
+
# f"Multiscale count must be 0 or equal to the number of Z "
|
|
149
|
+
# f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
|
|
150
|
+
# )
|
|
151
|
+
|
|
152
|
+
return self
|
|
153
|
+
|
|
154
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
155
|
+
"""
|
|
156
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
is_3D : bool
|
|
161
|
+
Whether the algorithm is 3D or not.
|
|
162
|
+
"""
|
|
163
|
+
raise NotImplementedError("VAE is not implemented yet.")
|
|
164
|
+
|
|
165
|
+
def is_3D(self) -> bool:
|
|
166
|
+
"""
|
|
167
|
+
Return whether the model is 3D or not.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
bool
|
|
172
|
+
Whether the model is 3D or not.
|
|
173
|
+
"""
|
|
174
|
+
raise NotImplementedError("VAE is not implemented yet.")
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Custom model registration utilities."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
from torch.nn import Module
|
|
6
|
+
|
|
7
|
+
CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def register_model(name: str) -> Callable:
|
|
11
|
+
"""Decorator used to register a torch.nn.Module class with a given `name`.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
name : str
|
|
16
|
+
Name of the model.
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
Callable
|
|
21
|
+
Function allowing to instantiate the wrapped Module class.
|
|
22
|
+
|
|
23
|
+
Raises
|
|
24
|
+
------
|
|
25
|
+
ValueError
|
|
26
|
+
If a model is already registered with that name.
|
|
27
|
+
|
|
28
|
+
Examples
|
|
29
|
+
--------
|
|
30
|
+
```python
|
|
31
|
+
@register_model(name="linear")
|
|
32
|
+
class LinearModel(nn.Module):
|
|
33
|
+
def __init__(self, in_features, out_features):
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
self.weight = nn.Parameter(ones(in_features, out_features))
|
|
37
|
+
self.bias = nn.Parameter(ones(out_features))
|
|
38
|
+
|
|
39
|
+
def forward(self, input):
|
|
40
|
+
return (input @ self.weight) + self.bias
|
|
41
|
+
```
|
|
42
|
+
"""
|
|
43
|
+
if name is None or name == "":
|
|
44
|
+
raise ValueError("Model name cannot be empty.")
|
|
45
|
+
|
|
46
|
+
if name in CUSTOM_MODELS:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Model {name} already exists. Choose a different name or run "
|
|
49
|
+
f"`clear_custom_models()` to empty the registry."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def add_custom_model(model: Module) -> Module:
|
|
53
|
+
"""Add a custom model to the registry and return it.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
model : Module
|
|
58
|
+
Module class to register.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
Module
|
|
63
|
+
The registered model.
|
|
64
|
+
"""
|
|
65
|
+
# add model to the registry
|
|
66
|
+
CUSTOM_MODELS[name] = model
|
|
67
|
+
|
|
68
|
+
return model
|
|
69
|
+
|
|
70
|
+
return add_custom_model
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_custom_model(name: str) -> Module:
|
|
74
|
+
"""Get the custom model corresponding to `name` from the registry.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
name : str
|
|
79
|
+
Name of the model to retrieve.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
Module
|
|
84
|
+
The requested model.
|
|
85
|
+
|
|
86
|
+
Raises
|
|
87
|
+
------
|
|
88
|
+
ValueError
|
|
89
|
+
If the model is not registered.
|
|
90
|
+
"""
|
|
91
|
+
if name not in CUSTOM_MODELS:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Model {name} is unknown. Have you registered it using "
|
|
94
|
+
f'@register_model("{name}") as decorator?'
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return CUSTOM_MODELS[name]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def clear_custom_models() -> None:
|
|
101
|
+
"""Clear the custom models registry."""
|
|
102
|
+
# clear dictionary
|
|
103
|
+
CUSTOM_MODELS.clear()
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""UNet Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import ConfigDict, Field, field_validator
|
|
8
|
+
|
|
9
|
+
from .architecture_model import ArchitectureModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO tests activation <-> pydantic model, test the literals!
|
|
13
|
+
# TODO annotations for the json schema?
|
|
14
|
+
class UNetModel(ArchitectureModel):
|
|
15
|
+
"""
|
|
16
|
+
Pydantic model for a N2V(2)-compatible UNet.
|
|
17
|
+
|
|
18
|
+
Attributes
|
|
19
|
+
----------
|
|
20
|
+
depth : int
|
|
21
|
+
Depth of the model, between 1 and 10 (default 2).
|
|
22
|
+
num_channels_init : int
|
|
23
|
+
Number of filters of the first level of the network, should be even
|
|
24
|
+
and minimum 8 (default 96).
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# pydantic model config
|
|
28
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
29
|
+
|
|
30
|
+
# discriminator used for choosing the pydantic model in Model
|
|
31
|
+
architecture: Literal["UNet"]
|
|
32
|
+
"""Name of the architecture."""
|
|
33
|
+
|
|
34
|
+
# parameters
|
|
35
|
+
# validate_defaults allow ignoring default values in the dump if they were not set
|
|
36
|
+
conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
|
|
37
|
+
"""Dimensions (2D or 3D) of the convolutional layers."""
|
|
38
|
+
|
|
39
|
+
num_classes: int = Field(default=1, ge=1, validate_default=True)
|
|
40
|
+
"""Number of classes or channels in the model output."""
|
|
41
|
+
|
|
42
|
+
in_channels: int = Field(default=1, ge=1, validate_default=True)
|
|
43
|
+
"""Number of channels in the input to the model."""
|
|
44
|
+
|
|
45
|
+
depth: int = Field(default=2, ge=1, le=10, validate_default=True)
|
|
46
|
+
"""Number of levels in the UNet."""
|
|
47
|
+
|
|
48
|
+
num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
|
|
49
|
+
"""Number of convolutional filters in the first layer of the UNet."""
|
|
50
|
+
|
|
51
|
+
final_activation: Literal[
|
|
52
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
53
|
+
] = Field(default="None", validate_default=True)
|
|
54
|
+
"""Final activation function."""
|
|
55
|
+
|
|
56
|
+
n2v2: bool = Field(default=False, validate_default=True)
|
|
57
|
+
"""Whether to use N2V2 architecture modifications, with blur pool layers and fewer
|
|
58
|
+
skip connections.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
independent_channels: bool = Field(default=True, validate_default=True)
|
|
62
|
+
"""Whether information is processed independently in each channel, used to train
|
|
63
|
+
channels independently."""
|
|
64
|
+
|
|
65
|
+
@field_validator("num_channels_init")
|
|
66
|
+
@classmethod
|
|
67
|
+
def validate_num_channels_init(cls, num_channels_init: int) -> int:
|
|
68
|
+
"""
|
|
69
|
+
Validate that num_channels_init is even.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
num_channels_init : int
|
|
74
|
+
Number of channels.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
int
|
|
79
|
+
Validated number of channels.
|
|
80
|
+
|
|
81
|
+
Raises
|
|
82
|
+
------
|
|
83
|
+
ValueError
|
|
84
|
+
If the number of channels is odd.
|
|
85
|
+
"""
|
|
86
|
+
# if odd
|
|
87
|
+
if num_channels_init % 2 != 0:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Number of channels for the bottom layer must be even"
|
|
90
|
+
f" (got {num_channels_init})."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return num_channels_init
|
|
94
|
+
|
|
95
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
is_3D : bool
|
|
102
|
+
Whether the algorithm is 3D or not.
|
|
103
|
+
"""
|
|
104
|
+
if is_3D:
|
|
105
|
+
self.conv_dims = 3
|
|
106
|
+
else:
|
|
107
|
+
self.conv_dims = 2
|
|
108
|
+
|
|
109
|
+
def is_3D(self) -> bool:
|
|
110
|
+
"""
|
|
111
|
+
Return whether the model is 3D or not.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
bool
|
|
116
|
+
Whether the model is 3D or not.
|
|
117
|
+
"""
|
|
118
|
+
return self.conv_dims == 3
|