careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
careamics/config/__init__.py
CHANGED
|
@@ -1,11 +1,35 @@
|
|
|
1
1
|
"""Configuration module."""
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
__all__ = [
|
|
4
|
+
__all__ = [
|
|
5
|
+
"AlgorithmModel",
|
|
6
|
+
"DataModel",
|
|
7
|
+
"Configuration",
|
|
8
|
+
"CheckpointModel",
|
|
9
|
+
"InferenceModel",
|
|
10
|
+
"load_configuration",
|
|
11
|
+
"save_configuration",
|
|
12
|
+
"TrainingModel",
|
|
13
|
+
"create_n2v_configuration",
|
|
14
|
+
"register_model",
|
|
15
|
+
"CustomModel",
|
|
16
|
+
"create_inference_configuration",
|
|
17
|
+
"clear_custom_models",
|
|
18
|
+
"ConfigurationInformation",
|
|
19
|
+
]
|
|
5
20
|
|
|
6
|
-
from .
|
|
21
|
+
from .algorithm_model import AlgorithmModel
|
|
22
|
+
from .architectures import CustomModel, clear_custom_models, register_model
|
|
23
|
+
from .callback_model import CheckpointModel
|
|
24
|
+
from .configuration_factory import (
|
|
25
|
+
create_inference_configuration,
|
|
26
|
+
create_n2v_configuration,
|
|
27
|
+
)
|
|
28
|
+
from .configuration_model import (
|
|
7
29
|
Configuration,
|
|
8
30
|
load_configuration,
|
|
9
31
|
save_configuration,
|
|
10
32
|
)
|
|
11
|
-
from .
|
|
33
|
+
from .data_model import DataModel
|
|
34
|
+
from .inference_model import InferenceModel
|
|
35
|
+
from .training_model import TrainingModel
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pprint import pformat
|
|
4
|
+
from typing import Literal, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from .architectures import CustomModel, UNetModel, VAEModel
|
|
10
|
+
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AlgorithmModel(BaseModel):
|
|
14
|
+
"""Algorithm configuration.
|
|
15
|
+
|
|
16
|
+
This Pydantic model validates the parameters governing the components of the
|
|
17
|
+
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
18
|
+
and learning rate scheduler to use.
|
|
19
|
+
|
|
20
|
+
Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only
|
|
21
|
+
compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows
|
|
22
|
+
you to register your own architecture and select it using its name as
|
|
23
|
+
`name` in the custom pydantic model.
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
algorithm : Literal["n2v", "custom"]
|
|
28
|
+
Algorithm to use.
|
|
29
|
+
loss : Literal["n2v", "mae", "mse"]
|
|
30
|
+
Loss function to use.
|
|
31
|
+
model : Union[UNetModel, VAEModel, CustomModel]
|
|
32
|
+
Model architecture to use.
|
|
33
|
+
optimizer : OptimizerModel, optional
|
|
34
|
+
Optimizer to use.
|
|
35
|
+
lr_scheduler : LrSchedulerModel, optional
|
|
36
|
+
Learning rate scheduler to use.
|
|
37
|
+
|
|
38
|
+
Raises
|
|
39
|
+
------
|
|
40
|
+
ValueError
|
|
41
|
+
Algorithm parameter type validation errors.
|
|
42
|
+
ValueError
|
|
43
|
+
If the algorithm, loss and model are not compatible.
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
Minimum example:
|
|
48
|
+
>>> from careamics.config import AlgorithmModel
|
|
49
|
+
>>> config_dict = {
|
|
50
|
+
... "algorithm": "n2v",
|
|
51
|
+
... "loss": "n2v",
|
|
52
|
+
... "model": {
|
|
53
|
+
... "architecture": "UNet",
|
|
54
|
+
... }
|
|
55
|
+
... }
|
|
56
|
+
>>> config = AlgorithmModel(**config_dict)
|
|
57
|
+
|
|
58
|
+
Using a custom model:
|
|
59
|
+
>>> from torch import nn, ones
|
|
60
|
+
>>> from careamics.config import AlgorithmModel, register_model
|
|
61
|
+
...
|
|
62
|
+
>>> @register_model(name="linear_model")
|
|
63
|
+
... class LinearModel(nn.Module):
|
|
64
|
+
... def __init__(self, in_features, out_features, *args, **kwargs):
|
|
65
|
+
... super().__init__()
|
|
66
|
+
... self.in_features = in_features
|
|
67
|
+
... self.out_features = out_features
|
|
68
|
+
... self.weight = nn.Parameter(ones(in_features, out_features))
|
|
69
|
+
... self.bias = nn.Parameter(ones(out_features))
|
|
70
|
+
... def forward(self, input):
|
|
71
|
+
... return (input @ self.weight) + self.bias
|
|
72
|
+
...
|
|
73
|
+
>>> config_dict = {
|
|
74
|
+
... "algorithm": "custom",
|
|
75
|
+
... "loss": "mse",
|
|
76
|
+
... "model": {
|
|
77
|
+
... "architecture": "Custom",
|
|
78
|
+
... "name": "linear_model",
|
|
79
|
+
... "in_features": 10,
|
|
80
|
+
... "out_features": 5,
|
|
81
|
+
... }
|
|
82
|
+
... }
|
|
83
|
+
>>> config = AlgorithmModel(**config_dict)
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
# Pydantic class configuration
|
|
87
|
+
model_config = ConfigDict(
|
|
88
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
89
|
+
validate_assignment=True,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Mandatory fields
|
|
93
|
+
algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
|
|
94
|
+
loss: Literal["n2v", "mae", "mse"]
|
|
95
|
+
model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
|
|
96
|
+
|
|
97
|
+
# Optional fields
|
|
98
|
+
optimizer: OptimizerModel = OptimizerModel()
|
|
99
|
+
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
100
|
+
|
|
101
|
+
@model_validator(mode="after")
|
|
102
|
+
def algorithm_cross_validation(self: Self) -> Self:
|
|
103
|
+
"""Validate the algorithm model based on `algorithm`.
|
|
104
|
+
|
|
105
|
+
N2V:
|
|
106
|
+
- loss must be n2v
|
|
107
|
+
- model must be a `UNetModel`
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
Self
|
|
112
|
+
The validated model.
|
|
113
|
+
"""
|
|
114
|
+
# N2V
|
|
115
|
+
if self.algorithm == "n2v":
|
|
116
|
+
# n2v is only compatible with the n2v loss
|
|
117
|
+
if self.loss != "n2v":
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Algorithm {self.algorithm} only supports loss `n2v`."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# n2v is only compatible with the UNet model
|
|
123
|
+
if not isinstance(self.model, UNetModel):
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"Model for algorithm {self.algorithm} must be a `UNetModel`."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# n2v requires the number of input and output channels to be the same
|
|
129
|
+
if self.model.in_channels != self.model.num_classes:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"N2V requires the same number of input and output channels. Make "
|
|
132
|
+
"sure that `in_channels` and `num_classes` are the same."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# N2N
|
|
136
|
+
if self.algorithm == "n2n":
|
|
137
|
+
# n2n is only compatible with the UNet model
|
|
138
|
+
if not isinstance(self.model, UNetModel):
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"Model for algorithm {self.algorithm} must be a `UNetModel`."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# n2n requires the number of input and output channels to be the same
|
|
144
|
+
if self.model.in_channels != self.model.num_classes:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
"N2N requires the same number of input and output channels. Make "
|
|
147
|
+
"sure that `in_channels` and `num_classes` are the same."
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if self.algorithm == "care" or self.algorithm == "n2n":
|
|
151
|
+
if self.loss == "n2v":
|
|
152
|
+
raise ValueError("Supervised algorithms do not support loss `n2v`.")
|
|
153
|
+
|
|
154
|
+
if isinstance(self.model, VAEModel):
|
|
155
|
+
raise ValueError("VAE are currently not implemented.")
|
|
156
|
+
|
|
157
|
+
return self
|
|
158
|
+
|
|
159
|
+
def __str__(self) -> str:
|
|
160
|
+
"""Pretty string representing the configuration.
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
str
|
|
165
|
+
Pretty string.
|
|
166
|
+
"""
|
|
167
|
+
return pformat(self.model_dump())
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Deep-learning model configurations."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"ArchitectureModel",
|
|
5
|
+
"CustomModel",
|
|
6
|
+
"UNetModel",
|
|
7
|
+
"VAEModel",
|
|
8
|
+
"clear_custom_models",
|
|
9
|
+
"get_custom_model",
|
|
10
|
+
"register_model",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from .architecture_model import ArchitectureModel
|
|
14
|
+
from .custom_model import CustomModel
|
|
15
|
+
from .register_model import clear_custom_models, get_custom_model, register_model
|
|
16
|
+
from .unet_model import UNetModel
|
|
17
|
+
from .vae_model import VAEModel
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ArchitectureModel(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Base Pydantic model for all model architectures.
|
|
9
|
+
|
|
10
|
+
The `model_dump` method allows removing the `architecture` key from the model.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
architecture: str
|
|
14
|
+
|
|
15
|
+
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
|
16
|
+
"""
|
|
17
|
+
Dump the model as a dictionary, ignoring the architecture keyword.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
dict[str, Any]
|
|
22
|
+
Model as a dictionnary.
|
|
23
|
+
"""
|
|
24
|
+
model_dict = super().model_dump(**kwargs)
|
|
25
|
+
|
|
26
|
+
# remove the architecture key
|
|
27
|
+
model_dict.pop("architecture")
|
|
28
|
+
|
|
29
|
+
return model_dict
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pprint import pformat
|
|
4
|
+
from typing import Any, Dict, Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import ConfigDict, field_validator, model_validator
|
|
7
|
+
from torch.nn import Module
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from .architecture_model import ArchitectureModel
|
|
11
|
+
from .register_model import get_custom_model
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CustomModel(ArchitectureModel):
|
|
15
|
+
"""Custom model configuration.
|
|
16
|
+
|
|
17
|
+
This Pydantic model allows storing parameters for a custom model. In order for the
|
|
18
|
+
model to be valid, the specific model needs to be registered using the
|
|
19
|
+
`register_model` decorator, and its name correctly passed to this model
|
|
20
|
+
configuration (see Examples).
|
|
21
|
+
|
|
22
|
+
Attributes
|
|
23
|
+
----------
|
|
24
|
+
architecture : Literal["Custom"]
|
|
25
|
+
Discriminator for the custom model, must be set to "Custom".
|
|
26
|
+
name : str
|
|
27
|
+
Name of the custom model.
|
|
28
|
+
parameters : CustomParametersModel
|
|
29
|
+
Parameters of the custom model.
|
|
30
|
+
|
|
31
|
+
Raises
|
|
32
|
+
------
|
|
33
|
+
ValueError
|
|
34
|
+
If the custom model `name` is unknown.
|
|
35
|
+
ValueError
|
|
36
|
+
If the custom model is not a torch Module subclass.
|
|
37
|
+
ValueError
|
|
38
|
+
If the custom model parameters are not valid.
|
|
39
|
+
|
|
40
|
+
Examples
|
|
41
|
+
--------
|
|
42
|
+
>>> from torch import nn, ones
|
|
43
|
+
>>> from careamics.config import CustomModel, register_model
|
|
44
|
+
>>> # Register a custom model
|
|
45
|
+
>>> @register_model(name="my_linear")
|
|
46
|
+
... class LinearModel(nn.Module):
|
|
47
|
+
... def __init__(self, in_features, out_features, *args, **kwargs):
|
|
48
|
+
... super().__init__()
|
|
49
|
+
... self.in_features = in_features
|
|
50
|
+
... self.out_features = out_features
|
|
51
|
+
... self.weight = nn.Parameter(ones(in_features, out_features))
|
|
52
|
+
... self.bias = nn.Parameter(ones(out_features))
|
|
53
|
+
... def forward(self, input):
|
|
54
|
+
... return (input @ self.weight) + self.bias
|
|
55
|
+
...
|
|
56
|
+
>>> # Create a configuration
|
|
57
|
+
>>> config_dict = {
|
|
58
|
+
... "architecture": "Custom",
|
|
59
|
+
... "name": "my_linear",
|
|
60
|
+
... "in_features": 10,
|
|
61
|
+
... "out_features": 5,
|
|
62
|
+
... }
|
|
63
|
+
>>> config = CustomModel(**config_dict)
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
# pydantic model config
|
|
67
|
+
model_config = ConfigDict(
|
|
68
|
+
extra="allow",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# discriminator used for choosing the pydantic model in Model
|
|
72
|
+
architecture: Literal["Custom"]
|
|
73
|
+
|
|
74
|
+
# name of the custom model
|
|
75
|
+
name: str
|
|
76
|
+
|
|
77
|
+
@field_validator("name")
|
|
78
|
+
@classmethod
|
|
79
|
+
def custom_model_is_known(cls, value: str) -> str:
|
|
80
|
+
"""Check whether the custom model is known.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
value : str
|
|
85
|
+
Name of the custom model as registered using the `@register_model`
|
|
86
|
+
decorator.
|
|
87
|
+
"""
|
|
88
|
+
# delegate error to get_custom_model
|
|
89
|
+
model = get_custom_model(value)
|
|
90
|
+
|
|
91
|
+
# check if it is a torch Module subclass
|
|
92
|
+
if not issubclass(model, Module):
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f'Retrieved class {model} with name "{value}" is not a '
|
|
95
|
+
f"torch.nn.Module subclass."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return value
|
|
99
|
+
|
|
100
|
+
@model_validator(mode="after")
|
|
101
|
+
def check_parameters(self: Self) -> Self:
|
|
102
|
+
"""Validate model by instantiating the model with the parameters.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
Self
|
|
107
|
+
The validated model.
|
|
108
|
+
"""
|
|
109
|
+
# instantiate model
|
|
110
|
+
try:
|
|
111
|
+
get_custom_model(self.name)(**self.model_dump())
|
|
112
|
+
except Exception as e:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"error while passing parameters to the model {e}. Verify that all "
|
|
115
|
+
f"mandatory parameters are provided, and that either the {e} accepts "
|
|
116
|
+
f"*args and **kwargs in its __init__() method, or that no additional"
|
|
117
|
+
f"parameter is provided."
|
|
118
|
+
) from None
|
|
119
|
+
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
def __str__(self) -> str:
|
|
123
|
+
"""Pretty string representing the configuration.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
str
|
|
128
|
+
Pretty string.
|
|
129
|
+
"""
|
|
130
|
+
return pformat(self.model_dump())
|
|
131
|
+
|
|
132
|
+
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
|
133
|
+
"""Dump the model configuration.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
kwargs : Any
|
|
138
|
+
Additional keyword arguments from Pydantic BaseModel model_dump method.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
Dict[str, Any]
|
|
143
|
+
Model configuration.
|
|
144
|
+
"""
|
|
145
|
+
model_dict = super().model_dump()
|
|
146
|
+
|
|
147
|
+
# remove the name key
|
|
148
|
+
model_dict.pop("name")
|
|
149
|
+
|
|
150
|
+
return model_dict
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from torch.nn import Module
|
|
4
|
+
|
|
5
|
+
CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def register_model(name: str) -> Callable:
|
|
9
|
+
"""Decorator used to register a torch.nn.Module class with a given `name`.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
name : str
|
|
14
|
+
Name of the model.
|
|
15
|
+
|
|
16
|
+
Returns
|
|
17
|
+
-------
|
|
18
|
+
Callable
|
|
19
|
+
Function allowing to instantiate the wrapped Module class.
|
|
20
|
+
|
|
21
|
+
Raises
|
|
22
|
+
------
|
|
23
|
+
ValueError
|
|
24
|
+
If a model is already registered with that name.
|
|
25
|
+
|
|
26
|
+
Examples
|
|
27
|
+
--------
|
|
28
|
+
```python
|
|
29
|
+
@register_model(name="linear")
|
|
30
|
+
class LinearModel(nn.Module):
|
|
31
|
+
def __init__(self, in_features, out_features):
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
self.weight = nn.Parameter(ones(in_features, out_features))
|
|
35
|
+
self.bias = nn.Parameter(ones(out_features))
|
|
36
|
+
|
|
37
|
+
def forward(self, input):
|
|
38
|
+
return (input @ self.weight) + self.bias
|
|
39
|
+
```
|
|
40
|
+
"""
|
|
41
|
+
if name is None or name == "":
|
|
42
|
+
raise ValueError("Model name cannot be empty.")
|
|
43
|
+
|
|
44
|
+
if name in CUSTOM_MODELS:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Model {name} already exists. Choose a different name or run "
|
|
47
|
+
f"`clear_custom_models()` to empty the registry."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def add_custom_model(model: Module) -> Module:
|
|
51
|
+
"""Add a custom model to the registry and return it.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
model : Module
|
|
56
|
+
Module class to register
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
Module
|
|
61
|
+
The registered model.
|
|
62
|
+
"""
|
|
63
|
+
# add model to the registry
|
|
64
|
+
CUSTOM_MODELS[name] = model
|
|
65
|
+
|
|
66
|
+
return model
|
|
67
|
+
|
|
68
|
+
return add_custom_model
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_custom_model(name: str) -> Module:
|
|
72
|
+
"""Get the custom model corresponding to `name` from the registry.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
name : str
|
|
77
|
+
Name of the model to retrieve.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
Module
|
|
82
|
+
The requested model.
|
|
83
|
+
|
|
84
|
+
Raises
|
|
85
|
+
------
|
|
86
|
+
ValueError
|
|
87
|
+
If the model is not registered.
|
|
88
|
+
"""
|
|
89
|
+
if name not in CUSTOM_MODELS:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Model {name} is unknown. Have you registered it using "
|
|
92
|
+
f'@register_model("{name}") as decorator?'
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return CUSTOM_MODELS[name]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def clear_custom_models() -> None:
|
|
99
|
+
"""Clear the custom models registry."""
|
|
100
|
+
# clear dictionary
|
|
101
|
+
CUSTOM_MODELS.clear()
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, field_validator
|
|
6
|
+
|
|
7
|
+
from .architecture_model import ArchitectureModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# TODO tests activation <-> pydantic model, test the literals!
|
|
11
|
+
# TODO annotations for the json schema?
|
|
12
|
+
class UNetModel(ArchitectureModel):
|
|
13
|
+
"""
|
|
14
|
+
Pydantic model for a N2V(2)-compatible UNet.
|
|
15
|
+
|
|
16
|
+
Attributes
|
|
17
|
+
----------
|
|
18
|
+
depth : int
|
|
19
|
+
Depth of the model, between 1 and 10 (default 2).
|
|
20
|
+
num_channels_init : int
|
|
21
|
+
Number of filters of the first level of the network, should be even
|
|
22
|
+
and minimum 8 (default 96).
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# pydantic model config
|
|
26
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
27
|
+
|
|
28
|
+
# discriminator used for choosing the pydantic model in Model
|
|
29
|
+
architecture: Literal["UNet"]
|
|
30
|
+
|
|
31
|
+
# parameters
|
|
32
|
+
# validate_defaults allow ignoring default values in the dump if they were not set
|
|
33
|
+
conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
|
|
34
|
+
num_classes: int = Field(default=1, ge=1, validate_default=True)
|
|
35
|
+
in_channels: int = Field(default=1, ge=1, validate_default=True)
|
|
36
|
+
depth: int = Field(default=2, ge=1, le=10, validate_default=True)
|
|
37
|
+
num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
|
|
38
|
+
final_activation: Literal[
|
|
39
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
|
|
40
|
+
] = Field(default="None", validate_default=True)
|
|
41
|
+
n2v2: bool = Field(default=False, validate_default=True)
|
|
42
|
+
|
|
43
|
+
@field_validator("num_channels_init")
|
|
44
|
+
@classmethod
|
|
45
|
+
def validate_num_channels_init(cls, num_channels_init: int) -> int:
|
|
46
|
+
"""
|
|
47
|
+
Validate that num_channels_init is even.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
num_channels_init : int
|
|
52
|
+
Number of channels.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
int
|
|
57
|
+
Validated number of channels.
|
|
58
|
+
|
|
59
|
+
Raises
|
|
60
|
+
------
|
|
61
|
+
ValueError
|
|
62
|
+
If the number of channels is odd.
|
|
63
|
+
"""
|
|
64
|
+
# if odd
|
|
65
|
+
if num_channels_init % 2 != 0:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Number of channels for the bottom layer must be even"
|
|
68
|
+
f" (got {num_channels_init})."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return num_channels_init
|
|
72
|
+
|
|
73
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
76
|
+
|
|
77
|
+
Parameters
|
|
78
|
+
----------
|
|
79
|
+
is_3D : bool
|
|
80
|
+
Whether the algorithm is 3D or not.
|
|
81
|
+
"""
|
|
82
|
+
if is_3D:
|
|
83
|
+
self.conv_dims = 3
|
|
84
|
+
else:
|
|
85
|
+
self.conv_dims = 2
|
|
86
|
+
|
|
87
|
+
def is_3D(self) -> bool:
|
|
88
|
+
"""
|
|
89
|
+
Return whether the model is 3D or not.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
bool
|
|
94
|
+
Whether the model is 3D or not.
|
|
95
|
+
"""
|
|
96
|
+
return self.conv_dims == 3
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import (
|
|
4
|
+
ConfigDict,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from .architecture_model import ArchitectureModel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class VAEModel(ArchitectureModel):
|
|
11
|
+
"""VAE model placeholder."""
|
|
12
|
+
|
|
13
|
+
model_config = ConfigDict(
|
|
14
|
+
use_enum_values=True, protected_namespaces=(), validate_assignment=True
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
architecture: Literal["VAE"]
|
|
18
|
+
|
|
19
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
is_3D : bool
|
|
26
|
+
Whether the algorithm is 3D or not.
|
|
27
|
+
"""
|
|
28
|
+
raise NotImplementedError("VAE is not implemented yet.")
|
|
29
|
+
|
|
30
|
+
def is_3D(self) -> bool:
|
|
31
|
+
"""
|
|
32
|
+
Return whether the model is 3D or not.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
bool
|
|
37
|
+
Whether the model is 3D or not.
|
|
38
|
+
"""
|
|
39
|
+
raise NotImplementedError("VAE is not implemented yet.")
|