careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,35 @@
1
1
  """Configuration module."""
2
2
 
3
3
 
4
- __all__ = ["Configuration", "load_configuration", "save_configuration"]
4
+ __all__ = [
5
+ "AlgorithmModel",
6
+ "DataModel",
7
+ "Configuration",
8
+ "CheckpointModel",
9
+ "InferenceModel",
10
+ "load_configuration",
11
+ "save_configuration",
12
+ "TrainingModel",
13
+ "create_n2v_configuration",
14
+ "register_model",
15
+ "CustomModel",
16
+ "create_inference_configuration",
17
+ "clear_custom_models",
18
+ "ConfigurationInformation",
19
+ ]
5
20
 
6
- from .config import (
21
+ from .algorithm_model import AlgorithmModel
22
+ from .architectures import CustomModel, clear_custom_models, register_model
23
+ from .callback_model import CheckpointModel
24
+ from .configuration_factory import (
25
+ create_inference_configuration,
26
+ create_n2v_configuration,
27
+ )
28
+ from .configuration_model import (
7
29
  Configuration,
8
30
  load_configuration,
9
31
  save_configuration,
10
32
  )
11
- from .torch_optim import get_parameters as get_parameters
33
+ from .data_model import DataModel
34
+ from .inference_model import InferenceModel
35
+ from .training_model import TrainingModel
@@ -0,0 +1,167 @@
1
+ from __future__ import annotations
2
+
3
+ from pprint import pformat
4
+ from typing import Literal, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
7
+ from typing_extensions import Self
8
+
9
+ from .architectures import CustomModel, UNetModel, VAEModel
10
+ from .optimizer_models import LrSchedulerModel, OptimizerModel
11
+
12
+
13
+ class AlgorithmModel(BaseModel):
14
+ """Algorithm configuration.
15
+
16
+ This Pydantic model validates the parameters governing the components of the
17
+ training algorithm: which algorithm, loss function, model architecture, optimizer,
18
+ and learning rate scheduler to use.
19
+
20
+ Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only
21
+ compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows
22
+ you to register your own architecture and select it using its name as
23
+ `name` in the custom pydantic model.
24
+
25
+ Attributes
26
+ ----------
27
+ algorithm : Literal["n2v", "custom"]
28
+ Algorithm to use.
29
+ loss : Literal["n2v", "mae", "mse"]
30
+ Loss function to use.
31
+ model : Union[UNetModel, VAEModel, CustomModel]
32
+ Model architecture to use.
33
+ optimizer : OptimizerModel, optional
34
+ Optimizer to use.
35
+ lr_scheduler : LrSchedulerModel, optional
36
+ Learning rate scheduler to use.
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ Algorithm parameter type validation errors.
42
+ ValueError
43
+ If the algorithm, loss and model are not compatible.
44
+
45
+ Examples
46
+ --------
47
+ Minimum example:
48
+ >>> from careamics.config import AlgorithmModel
49
+ >>> config_dict = {
50
+ ... "algorithm": "n2v",
51
+ ... "loss": "n2v",
52
+ ... "model": {
53
+ ... "architecture": "UNet",
54
+ ... }
55
+ ... }
56
+ >>> config = AlgorithmModel(**config_dict)
57
+
58
+ Using a custom model:
59
+ >>> from torch import nn, ones
60
+ >>> from careamics.config import AlgorithmModel, register_model
61
+ ...
62
+ >>> @register_model(name="linear_model")
63
+ ... class LinearModel(nn.Module):
64
+ ... def __init__(self, in_features, out_features, *args, **kwargs):
65
+ ... super().__init__()
66
+ ... self.in_features = in_features
67
+ ... self.out_features = out_features
68
+ ... self.weight = nn.Parameter(ones(in_features, out_features))
69
+ ... self.bias = nn.Parameter(ones(out_features))
70
+ ... def forward(self, input):
71
+ ... return (input @ self.weight) + self.bias
72
+ ...
73
+ >>> config_dict = {
74
+ ... "algorithm": "custom",
75
+ ... "loss": "mse",
76
+ ... "model": {
77
+ ... "architecture": "Custom",
78
+ ... "name": "linear_model",
79
+ ... "in_features": 10,
80
+ ... "out_features": 5,
81
+ ... }
82
+ ... }
83
+ >>> config = AlgorithmModel(**config_dict)
84
+ """
85
+
86
+ # Pydantic class configuration
87
+ model_config = ConfigDict(
88
+ protected_namespaces=(), # allows to use model_* as a field name
89
+ validate_assignment=True,
90
+ )
91
+
92
+ # Mandatory fields
93
+ algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
94
+ loss: Literal["n2v", "mae", "mse"]
95
+ model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
96
+
97
+ # Optional fields
98
+ optimizer: OptimizerModel = OptimizerModel()
99
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
100
+
101
+ @model_validator(mode="after")
102
+ def algorithm_cross_validation(self: Self) -> Self:
103
+ """Validate the algorithm model based on `algorithm`.
104
+
105
+ N2V:
106
+ - loss must be n2v
107
+ - model must be a `UNetModel`
108
+
109
+ Returns
110
+ -------
111
+ Self
112
+ The validated model.
113
+ """
114
+ # N2V
115
+ if self.algorithm == "n2v":
116
+ # n2v is only compatible with the n2v loss
117
+ if self.loss != "n2v":
118
+ raise ValueError(
119
+ f"Algorithm {self.algorithm} only supports loss `n2v`."
120
+ )
121
+
122
+ # n2v is only compatible with the UNet model
123
+ if not isinstance(self.model, UNetModel):
124
+ raise ValueError(
125
+ f"Model for algorithm {self.algorithm} must be a `UNetModel`."
126
+ )
127
+
128
+ # n2v requires the number of input and output channels to be the same
129
+ if self.model.in_channels != self.model.num_classes:
130
+ raise ValueError(
131
+ "N2V requires the same number of input and output channels. Make "
132
+ "sure that `in_channels` and `num_classes` are the same."
133
+ )
134
+
135
+ # N2N
136
+ if self.algorithm == "n2n":
137
+ # n2n is only compatible with the UNet model
138
+ if not isinstance(self.model, UNetModel):
139
+ raise ValueError(
140
+ f"Model for algorithm {self.algorithm} must be a `UNetModel`."
141
+ )
142
+
143
+ # n2n requires the number of input and output channels to be the same
144
+ if self.model.in_channels != self.model.num_classes:
145
+ raise ValueError(
146
+ "N2N requires the same number of input and output channels. Make "
147
+ "sure that `in_channels` and `num_classes` are the same."
148
+ )
149
+
150
+ if self.algorithm == "care" or self.algorithm == "n2n":
151
+ if self.loss == "n2v":
152
+ raise ValueError("Supervised algorithms do not support loss `n2v`.")
153
+
154
+ if isinstance(self.model, VAEModel):
155
+ raise ValueError("VAE are currently not implemented.")
156
+
157
+ return self
158
+
159
+ def __str__(self) -> str:
160
+ """Pretty string representing the configuration.
161
+
162
+ Returns
163
+ -------
164
+ str
165
+ Pretty string.
166
+ """
167
+ return pformat(self.model_dump())
@@ -0,0 +1,17 @@
1
+ """Deep-learning model configurations."""
2
+
3
+ __all__ = [
4
+ "ArchitectureModel",
5
+ "CustomModel",
6
+ "UNetModel",
7
+ "VAEModel",
8
+ "clear_custom_models",
9
+ "get_custom_model",
10
+ "register_model",
11
+ ]
12
+
13
+ from .architecture_model import ArchitectureModel
14
+ from .custom_model import CustomModel
15
+ from .register_model import clear_custom_models, get_custom_model, register_model
16
+ from .unet_model import UNetModel
17
+ from .vae_model import VAEModel
@@ -0,0 +1,29 @@
1
+ from typing import Any, Dict
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ArchitectureModel(BaseModel):
7
+ """
8
+ Base Pydantic model for all model architectures.
9
+
10
+ The `model_dump` method allows removing the `architecture` key from the model.
11
+ """
12
+
13
+ architecture: str
14
+
15
+ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
16
+ """
17
+ Dump the model as a dictionary, ignoring the architecture keyword.
18
+
19
+ Returns
20
+ -------
21
+ dict[str, Any]
22
+ Model as a dictionnary.
23
+ """
24
+ model_dict = super().model_dump(**kwargs)
25
+
26
+ # remove the architecture key
27
+ model_dict.pop("architecture")
28
+
29
+ return model_dict
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ from pprint import pformat
4
+ from typing import Any, Dict, Literal
5
+
6
+ from pydantic import ConfigDict, field_validator, model_validator
7
+ from torch.nn import Module
8
+ from typing_extensions import Self
9
+
10
+ from .architecture_model import ArchitectureModel
11
+ from .register_model import get_custom_model
12
+
13
+
14
+ class CustomModel(ArchitectureModel):
15
+ """Custom model configuration.
16
+
17
+ This Pydantic model allows storing parameters for a custom model. In order for the
18
+ model to be valid, the specific model needs to be registered using the
19
+ `register_model` decorator, and its name correctly passed to this model
20
+ configuration (see Examples).
21
+
22
+ Attributes
23
+ ----------
24
+ architecture : Literal["Custom"]
25
+ Discriminator for the custom model, must be set to "Custom".
26
+ name : str
27
+ Name of the custom model.
28
+ parameters : CustomParametersModel
29
+ Parameters of the custom model.
30
+
31
+ Raises
32
+ ------
33
+ ValueError
34
+ If the custom model `name` is unknown.
35
+ ValueError
36
+ If the custom model is not a torch Module subclass.
37
+ ValueError
38
+ If the custom model parameters are not valid.
39
+
40
+ Examples
41
+ --------
42
+ >>> from torch import nn, ones
43
+ >>> from careamics.config import CustomModel, register_model
44
+ >>> # Register a custom model
45
+ >>> @register_model(name="my_linear")
46
+ ... class LinearModel(nn.Module):
47
+ ... def __init__(self, in_features, out_features, *args, **kwargs):
48
+ ... super().__init__()
49
+ ... self.in_features = in_features
50
+ ... self.out_features = out_features
51
+ ... self.weight = nn.Parameter(ones(in_features, out_features))
52
+ ... self.bias = nn.Parameter(ones(out_features))
53
+ ... def forward(self, input):
54
+ ... return (input @ self.weight) + self.bias
55
+ ...
56
+ >>> # Create a configuration
57
+ >>> config_dict = {
58
+ ... "architecture": "Custom",
59
+ ... "name": "my_linear",
60
+ ... "in_features": 10,
61
+ ... "out_features": 5,
62
+ ... }
63
+ >>> config = CustomModel(**config_dict)
64
+ """
65
+
66
+ # pydantic model config
67
+ model_config = ConfigDict(
68
+ extra="allow",
69
+ )
70
+
71
+ # discriminator used for choosing the pydantic model in Model
72
+ architecture: Literal["Custom"]
73
+
74
+ # name of the custom model
75
+ name: str
76
+
77
+ @field_validator("name")
78
+ @classmethod
79
+ def custom_model_is_known(cls, value: str) -> str:
80
+ """Check whether the custom model is known.
81
+
82
+ Parameters
83
+ ----------
84
+ value : str
85
+ Name of the custom model as registered using the `@register_model`
86
+ decorator.
87
+ """
88
+ # delegate error to get_custom_model
89
+ model = get_custom_model(value)
90
+
91
+ # check if it is a torch Module subclass
92
+ if not issubclass(model, Module):
93
+ raise ValueError(
94
+ f'Retrieved class {model} with name "{value}" is not a '
95
+ f"torch.nn.Module subclass."
96
+ )
97
+
98
+ return value
99
+
100
+ @model_validator(mode="after")
101
+ def check_parameters(self: Self) -> Self:
102
+ """Validate model by instantiating the model with the parameters.
103
+
104
+ Returns
105
+ -------
106
+ Self
107
+ The validated model.
108
+ """
109
+ # instantiate model
110
+ try:
111
+ get_custom_model(self.name)(**self.model_dump())
112
+ except Exception as e:
113
+ raise ValueError(
114
+ f"error while passing parameters to the model {e}. Verify that all "
115
+ f"mandatory parameters are provided, and that either the {e} accepts "
116
+ f"*args and **kwargs in its __init__() method, or that no additional"
117
+ f"parameter is provided."
118
+ ) from None
119
+
120
+ return self
121
+
122
+ def __str__(self) -> str:
123
+ """Pretty string representing the configuration.
124
+
125
+ Returns
126
+ -------
127
+ str
128
+ Pretty string.
129
+ """
130
+ return pformat(self.model_dump())
131
+
132
+ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
133
+ """Dump the model configuration.
134
+
135
+ Parameters
136
+ ----------
137
+ kwargs : Any
138
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
139
+
140
+ Returns
141
+ -------
142
+ Dict[str, Any]
143
+ Model configuration.
144
+ """
145
+ model_dict = super().model_dump()
146
+
147
+ # remove the name key
148
+ model_dict.pop("name")
149
+
150
+ return model_dict
@@ -0,0 +1,101 @@
1
+ from typing import Callable
2
+
3
+ from torch.nn import Module
4
+
5
+ CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
6
+
7
+
8
+ def register_model(name: str) -> Callable:
9
+ """Decorator used to register a torch.nn.Module class with a given `name`.
10
+
11
+ Parameters
12
+ ----------
13
+ name : str
14
+ Name of the model.
15
+
16
+ Returns
17
+ -------
18
+ Callable
19
+ Function allowing to instantiate the wrapped Module class.
20
+
21
+ Raises
22
+ ------
23
+ ValueError
24
+ If a model is already registered with that name.
25
+
26
+ Examples
27
+ --------
28
+ ```python
29
+ @register_model(name="linear")
30
+ class LinearModel(nn.Module):
31
+ def __init__(self, in_features, out_features):
32
+ super().__init__()
33
+
34
+ self.weight = nn.Parameter(ones(in_features, out_features))
35
+ self.bias = nn.Parameter(ones(out_features))
36
+
37
+ def forward(self, input):
38
+ return (input @ self.weight) + self.bias
39
+ ```
40
+ """
41
+ if name is None or name == "":
42
+ raise ValueError("Model name cannot be empty.")
43
+
44
+ if name in CUSTOM_MODELS:
45
+ raise ValueError(
46
+ f"Model {name} already exists. Choose a different name or run "
47
+ f"`clear_custom_models()` to empty the registry."
48
+ )
49
+
50
+ def add_custom_model(model: Module) -> Module:
51
+ """Add a custom model to the registry and return it.
52
+
53
+ Parameters
54
+ ----------
55
+ model : Module
56
+ Module class to register
57
+
58
+ Returns
59
+ -------
60
+ Module
61
+ The registered model.
62
+ """
63
+ # add model to the registry
64
+ CUSTOM_MODELS[name] = model
65
+
66
+ return model
67
+
68
+ return add_custom_model
69
+
70
+
71
+ def get_custom_model(name: str) -> Module:
72
+ """Get the custom model corresponding to `name` from the registry.
73
+
74
+ Parameters
75
+ ----------
76
+ name : str
77
+ Name of the model to retrieve.
78
+
79
+ Returns
80
+ -------
81
+ Module
82
+ The requested model.
83
+
84
+ Raises
85
+ ------
86
+ ValueError
87
+ If the model is not registered.
88
+ """
89
+ if name not in CUSTOM_MODELS:
90
+ raise ValueError(
91
+ f"Model {name} is unknown. Have you registered it using "
92
+ f'@register_model("{name}") as decorator?'
93
+ )
94
+
95
+ return CUSTOM_MODELS[name]
96
+
97
+
98
+ def clear_custom_models() -> None:
99
+ """Clear the custom models registry."""
100
+ # clear dictionary
101
+ CUSTOM_MODELS.clear()
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field, field_validator
6
+
7
+ from .architecture_model import ArchitectureModel
8
+
9
+
10
+ # TODO tests activation <-> pydantic model, test the literals!
11
+ # TODO annotations for the json schema?
12
+ class UNetModel(ArchitectureModel):
13
+ """
14
+ Pydantic model for a N2V(2)-compatible UNet.
15
+
16
+ Attributes
17
+ ----------
18
+ depth : int
19
+ Depth of the model, between 1 and 10 (default 2).
20
+ num_channels_init : int
21
+ Number of filters of the first level of the network, should be even
22
+ and minimum 8 (default 96).
23
+ """
24
+
25
+ # pydantic model config
26
+ model_config = ConfigDict(validate_assignment=True)
27
+
28
+ # discriminator used for choosing the pydantic model in Model
29
+ architecture: Literal["UNet"]
30
+
31
+ # parameters
32
+ # validate_defaults allow ignoring default values in the dump if they were not set
33
+ conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
34
+ num_classes: int = Field(default=1, ge=1, validate_default=True)
35
+ in_channels: int = Field(default=1, ge=1, validate_default=True)
36
+ depth: int = Field(default=2, ge=1, le=10, validate_default=True)
37
+ num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
38
+ final_activation: Literal[
39
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
40
+ ] = Field(default="None", validate_default=True)
41
+ n2v2: bool = Field(default=False, validate_default=True)
42
+
43
+ @field_validator("num_channels_init")
44
+ @classmethod
45
+ def validate_num_channels_init(cls, num_channels_init: int) -> int:
46
+ """
47
+ Validate that num_channels_init is even.
48
+
49
+ Parameters
50
+ ----------
51
+ num_channels_init : int
52
+ Number of channels.
53
+
54
+ Returns
55
+ -------
56
+ int
57
+ Validated number of channels.
58
+
59
+ Raises
60
+ ------
61
+ ValueError
62
+ If the number of channels is odd.
63
+ """
64
+ # if odd
65
+ if num_channels_init % 2 != 0:
66
+ raise ValueError(
67
+ f"Number of channels for the bottom layer must be even"
68
+ f" (got {num_channels_init})."
69
+ )
70
+
71
+ return num_channels_init
72
+
73
+ def set_3D(self, is_3D: bool) -> None:
74
+ """
75
+ Set 3D model by setting the `conv_dims` parameters.
76
+
77
+ Parameters
78
+ ----------
79
+ is_3D : bool
80
+ Whether the algorithm is 3D or not.
81
+ """
82
+ if is_3D:
83
+ self.conv_dims = 3
84
+ else:
85
+ self.conv_dims = 2
86
+
87
+ def is_3D(self) -> bool:
88
+ """
89
+ Return whether the model is 3D or not.
90
+
91
+ Returns
92
+ -------
93
+ bool
94
+ Whether the model is 3D or not.
95
+ """
96
+ return self.conv_dims == 3
@@ -0,0 +1,39 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import (
4
+ ConfigDict,
5
+ )
6
+
7
+ from .architecture_model import ArchitectureModel
8
+
9
+
10
+ class VAEModel(ArchitectureModel):
11
+ """VAE model placeholder."""
12
+
13
+ model_config = ConfigDict(
14
+ use_enum_values=True, protected_namespaces=(), validate_assignment=True
15
+ )
16
+
17
+ architecture: Literal["VAE"]
18
+
19
+ def set_3D(self, is_3D: bool) -> None:
20
+ """
21
+ Set 3D model by setting the `conv_dims` parameters.
22
+
23
+ Parameters
24
+ ----------
25
+ is_3D : bool
26
+ Whether the algorithm is 3D or not.
27
+ """
28
+ raise NotImplementedError("VAE is not implemented yet.")
29
+
30
+ def is_3D(self) -> bool:
31
+ """
32
+ Return whether the model is 3D or not.
33
+
34
+ Returns
35
+ -------
36
+ bool
37
+ Whether the model is 3D or not.
38
+ """
39
+ raise NotImplementedError("VAE is not implemented yet.")