careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,35 @@
1
+ """Configuration module."""
2
+
3
+ __all__ = [
4
+ "AlgorithmConfig",
5
+ "DataConfig",
6
+ "Configuration",
7
+ "CheckpointModel",
8
+ "InferenceConfig",
9
+ "load_configuration",
10
+ "save_configuration",
11
+ "TrainingConfig",
12
+ "create_n2v_configuration",
13
+ "create_n2n_configuration",
14
+ "create_care_configuration",
15
+ "register_model",
16
+ "CustomModel",
17
+ "clear_custom_models",
18
+ ]
19
+
20
+ from .algorithm_model import AlgorithmConfig
21
+ from .architectures import CustomModel, clear_custom_models, register_model
22
+ from .callback_model import CheckpointModel
23
+ from .configuration_factory import (
24
+ create_care_configuration,
25
+ create_n2n_configuration,
26
+ create_n2v_configuration,
27
+ )
28
+ from .configuration_model import (
29
+ Configuration,
30
+ load_configuration,
31
+ save_configuration,
32
+ )
33
+ from .data_model import DataConfig
34
+ from .inference_model import InferenceConfig
35
+ from .training_model import TrainingConfig
@@ -0,0 +1,162 @@
1
+ """Algorithm configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Literal, Union
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from .architectures import CustomModel, UNetModel, VAEModel
12
+ from .optimizer_models import LrSchedulerModel, OptimizerModel
13
+
14
+
15
+ class AlgorithmConfig(BaseModel):
16
+ """Algorithm configuration.
17
+
18
+ This Pydantic model validates the parameters governing the components of the
19
+ training algorithm: which algorithm, loss function, model architecture, optimizer,
20
+ and learning rate scheduler to use.
21
+
22
+ Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
23
+ only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
24
+ allows you to register your own architecture and select it using its name as
25
+ `name` in the custom pydantic model.
26
+
27
+ Attributes
28
+ ----------
29
+ algorithm : Literal["n2v", "custom"]
30
+ Algorithm to use.
31
+ loss : Literal["n2v", "mae", "mse"]
32
+ Loss function to use.
33
+ model : Union[UNetModel, VAEModel, CustomModel]
34
+ Model architecture to use.
35
+ optimizer : OptimizerModel, optional
36
+ Optimizer to use.
37
+ lr_scheduler : LrSchedulerModel, optional
38
+ Learning rate scheduler to use.
39
+
40
+ Raises
41
+ ------
42
+ ValueError
43
+ Algorithm parameter type validation errors.
44
+ ValueError
45
+ If the algorithm, loss and model are not compatible.
46
+
47
+ Examples
48
+ --------
49
+ Minimum example:
50
+ >>> from careamics.config import AlgorithmConfig
51
+ >>> config_dict = {
52
+ ... "algorithm": "n2v",
53
+ ... "loss": "n2v",
54
+ ... "model": {
55
+ ... "architecture": "UNet",
56
+ ... }
57
+ ... }
58
+ >>> config = AlgorithmConfig(**config_dict)
59
+
60
+ Using a custom model:
61
+ >>> from torch import nn, ones
62
+ >>> from careamics.config import AlgorithmConfig, register_model
63
+ ...
64
+ >>> @register_model(name="linear_model")
65
+ ... class LinearModel(nn.Module):
66
+ ... def __init__(self, in_features, out_features, *args, **kwargs):
67
+ ... super().__init__()
68
+ ... self.in_features = in_features
69
+ ... self.out_features = out_features
70
+ ... self.weight = nn.Parameter(ones(in_features, out_features))
71
+ ... self.bias = nn.Parameter(ones(out_features))
72
+ ... def forward(self, input):
73
+ ... return (input @ self.weight) + self.bias
74
+ ...
75
+ >>> config_dict = {
76
+ ... "algorithm": "custom",
77
+ ... "loss": "mse",
78
+ ... "model": {
79
+ ... "architecture": "Custom",
80
+ ... "name": "linear_model",
81
+ ... "in_features": 10,
82
+ ... "out_features": 5,
83
+ ... }
84
+ ... }
85
+ >>> config = AlgorithmConfig(**config_dict)
86
+ """
87
+
88
+ # Pydantic class configuration
89
+ model_config = ConfigDict(
90
+ protected_namespaces=(), # allows to use model_* as a field name
91
+ validate_assignment=True,
92
+ )
93
+
94
+ # Mandatory fields
95
+ algorithm: Literal["n2v", "care", "n2n", "custom"] # defined in SupportedAlgorithm
96
+ """Name of the algorithm, as defined in SupportedAlgorithm."""
97
+
98
+ loss: Literal["n2v", "mae", "mse"]
99
+ """Loss function to use, as defined in SupportedLoss."""
100
+
101
+ model: Union[UNetModel, VAEModel, CustomModel] = Field(discriminator="architecture")
102
+ """Model architecture to use, defined in SupportedArchitecture."""
103
+
104
+ # Optional fields
105
+ optimizer: OptimizerModel = OptimizerModel()
106
+ """Optimizer to use, defined in SupportedOptimizer."""
107
+
108
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
109
+ """Learning rate scheduler to use, defined in SupportedScheduler."""
110
+
111
+ @model_validator(mode="after")
112
+ def algorithm_cross_validation(self: Self) -> Self:
113
+ """Validate the algorithm model based on `algorithm`.
114
+
115
+ N2V:
116
+ - loss must be n2v
117
+ - model must be a `UNetModel`
118
+
119
+ Returns
120
+ -------
121
+ Self
122
+ The validated model.
123
+ """
124
+ # N2V
125
+ if self.algorithm == "n2v":
126
+ # n2v is only compatible with the n2v loss
127
+ if self.loss != "n2v":
128
+ raise ValueError(
129
+ f"Algorithm {self.algorithm} only supports loss `n2v`."
130
+ )
131
+
132
+ # n2v is only compatible with the UNet model
133
+ if not isinstance(self.model, UNetModel):
134
+ raise ValueError(
135
+ f"Model for algorithm {self.algorithm} must be a `UNetModel`."
136
+ )
137
+
138
+ # n2v requires the number of input and output channels to be the same
139
+ if self.model.in_channels != self.model.num_classes:
140
+ raise ValueError(
141
+ "N2V requires the same number of input and output channels. Make "
142
+ "sure that `in_channels` and `num_classes` are the same."
143
+ )
144
+
145
+ if self.algorithm == "care" or self.algorithm == "n2n":
146
+ if self.loss == "n2v":
147
+ raise ValueError("Supervised algorithms do not support loss `n2v`.")
148
+
149
+ if isinstance(self.model, VAEModel):
150
+ raise ValueError("VAE are currently not implemented.")
151
+
152
+ return self
153
+
154
+ def __str__(self) -> str:
155
+ """Pretty string representing the configuration.
156
+
157
+ Returns
158
+ -------
159
+ str
160
+ Pretty string.
161
+ """
162
+ return pformat(self.model_dump())
@@ -0,0 +1,17 @@
1
+ """Deep-learning model configurations."""
2
+
3
+ __all__ = [
4
+ "ArchitectureModel",
5
+ "CustomModel",
6
+ "UNetModel",
7
+ "VAEModel",
8
+ "clear_custom_models",
9
+ "get_custom_model",
10
+ "register_model",
11
+ ]
12
+
13
+ from .architecture_model import ArchitectureModel
14
+ from .custom_model import CustomModel
15
+ from .register_model import clear_custom_models, get_custom_model, register_model
16
+ from .unet_model import UNetModel
17
+ from .vae_model import VAEModel
@@ -0,0 +1,37 @@
1
+ """Base model for the various CAREamics architectures."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class ArchitectureModel(BaseModel):
9
+ """
10
+ Base Pydantic model for all model architectures.
11
+
12
+ The `model_dump` method allows removing the `architecture` key from the model.
13
+ """
14
+
15
+ architecture: str
16
+ """Name of the architecture."""
17
+
18
+ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
19
+ """
20
+ Dump the model as a dictionary, ignoring the architecture keyword.
21
+
22
+ Parameters
23
+ ----------
24
+ **kwargs : Any
25
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
26
+
27
+ Returns
28
+ -------
29
+ dict[str, Any]
30
+ Model as a dictionnary.
31
+ """
32
+ model_dict = super().model_dump(**kwargs)
33
+
34
+ # remove the architecture key
35
+ model_dict.pop("architecture")
36
+
37
+ return model_dict
@@ -0,0 +1,159 @@
1
+ """Custom architecture Pydantic model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Any, Literal
7
+
8
+ from pydantic import ConfigDict, field_validator, model_validator
9
+ from torch.nn import Module
10
+ from typing_extensions import Self
11
+
12
+ from .architecture_model import ArchitectureModel
13
+ from .register_model import get_custom_model
14
+
15
+
16
+ class CustomModel(ArchitectureModel):
17
+ """Custom model configuration.
18
+
19
+ This Pydantic model allows storing parameters for a custom model. In order for the
20
+ model to be valid, the specific model needs to be registered using the
21
+ `register_model` decorator, and its name correctly passed to this model
22
+ configuration (see Examples).
23
+
24
+ Attributes
25
+ ----------
26
+ architecture : Literal["Custom"]
27
+ Discriminator for the custom model, must be set to "Custom".
28
+ name : str
29
+ Name of the custom model.
30
+ parameters : CustomParametersModel
31
+ Parameters of the custom model.
32
+
33
+ Raises
34
+ ------
35
+ ValueError
36
+ If the custom model `name` is unknown.
37
+ ValueError
38
+ If the custom model is not a torch Module subclass.
39
+ ValueError
40
+ If the custom model parameters are not valid.
41
+
42
+ Examples
43
+ --------
44
+ >>> from torch import nn, ones
45
+ >>> from careamics.config import CustomModel, register_model
46
+ >>> # Register a custom model
47
+ >>> @register_model(name="my_linear")
48
+ ... class LinearModel(nn.Module):
49
+ ... def __init__(self, in_features, out_features, *args, **kwargs):
50
+ ... super().__init__()
51
+ ... self.in_features = in_features
52
+ ... self.out_features = out_features
53
+ ... self.weight = nn.Parameter(ones(in_features, out_features))
54
+ ... self.bias = nn.Parameter(ones(out_features))
55
+ ... def forward(self, input):
56
+ ... return (input @ self.weight) + self.bias
57
+ ...
58
+ >>> # Create a configuration
59
+ >>> config_dict = {
60
+ ... "architecture": "Custom",
61
+ ... "name": "my_linear",
62
+ ... "in_features": 10,
63
+ ... "out_features": 5,
64
+ ... }
65
+ >>> config = CustomModel(**config_dict)
66
+ """
67
+
68
+ # pydantic model config
69
+ model_config = ConfigDict(
70
+ extra="allow",
71
+ )
72
+
73
+ # discriminator used for choosing the pydantic model in Model
74
+ architecture: Literal["Custom"]
75
+ """Name of the architecture."""
76
+
77
+ # name of the custom model
78
+ name: str
79
+ """Name of the custom model."""
80
+
81
+ @field_validator("name")
82
+ @classmethod
83
+ def custom_model_is_known(cls, value: str) -> str:
84
+ """Check whether the custom model is known.
85
+
86
+ Parameters
87
+ ----------
88
+ value : str
89
+ Name of the custom model as registered using the `@register_model`
90
+ decorator.
91
+
92
+ Returns
93
+ -------
94
+ str
95
+ The custom model name.
96
+ """
97
+ # delegate error to get_custom_model
98
+ model = get_custom_model(value)
99
+
100
+ # check if it is a torch Module subclass
101
+ if not issubclass(model, Module):
102
+ raise ValueError(
103
+ f'Retrieved class {model} with name "{value}" is not a '
104
+ f"torch.nn.Module subclass."
105
+ )
106
+
107
+ return value
108
+
109
+ @model_validator(mode="after")
110
+ def check_parameters(self: Self) -> Self:
111
+ """Validate model by instantiating the model with the parameters.
112
+
113
+ Returns
114
+ -------
115
+ Self
116
+ The validated model.
117
+ """
118
+ # instantiate model
119
+ try:
120
+ get_custom_model(self.name)(**self.model_dump())
121
+ except Exception as e:
122
+ raise ValueError(
123
+ f"error while passing parameters to the model {e}. Verify that all "
124
+ f"mandatory parameters are provided, and that either the {e} accepts "
125
+ f"*args and **kwargs in its __init__() method, or that no additional"
126
+ f"parameter is provided."
127
+ ) from None
128
+
129
+ return self
130
+
131
+ def __str__(self) -> str:
132
+ """Pretty string representing the configuration.
133
+
134
+ Returns
135
+ -------
136
+ str
137
+ Pretty string.
138
+ """
139
+ return pformat(self.model_dump())
140
+
141
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
142
+ """Dump the model configuration.
143
+
144
+ Parameters
145
+ ----------
146
+ **kwargs : Any
147
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
148
+
149
+ Returns
150
+ -------
151
+ dict[str, Any]
152
+ Model configuration.
153
+ """
154
+ model_dict = super().model_dump()
155
+
156
+ # remove the name key
157
+ model_dict.pop("name")
158
+
159
+ return model_dict
@@ -0,0 +1,103 @@
1
+ """Custom model registration utilities."""
2
+
3
+ from typing import Callable
4
+
5
+ from torch.nn import Module
6
+
7
+ CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
8
+
9
+
10
+ def register_model(name: str) -> Callable:
11
+ """Decorator used to register a torch.nn.Module class with a given `name`.
12
+
13
+ Parameters
14
+ ----------
15
+ name : str
16
+ Name of the model.
17
+
18
+ Returns
19
+ -------
20
+ Callable
21
+ Function allowing to instantiate the wrapped Module class.
22
+
23
+ Raises
24
+ ------
25
+ ValueError
26
+ If a model is already registered with that name.
27
+
28
+ Examples
29
+ --------
30
+ ```python
31
+ @register_model(name="linear")
32
+ class LinearModel(nn.Module):
33
+ def __init__(self, in_features, out_features):
34
+ super().__init__()
35
+
36
+ self.weight = nn.Parameter(ones(in_features, out_features))
37
+ self.bias = nn.Parameter(ones(out_features))
38
+
39
+ def forward(self, input):
40
+ return (input @ self.weight) + self.bias
41
+ ```
42
+ """
43
+ if name is None or name == "":
44
+ raise ValueError("Model name cannot be empty.")
45
+
46
+ if name in CUSTOM_MODELS:
47
+ raise ValueError(
48
+ f"Model {name} already exists. Choose a different name or run "
49
+ f"`clear_custom_models()` to empty the registry."
50
+ )
51
+
52
+ def add_custom_model(model: Module) -> Module:
53
+ """Add a custom model to the registry and return it.
54
+
55
+ Parameters
56
+ ----------
57
+ model : Module
58
+ Module class to register.
59
+
60
+ Returns
61
+ -------
62
+ Module
63
+ The registered model.
64
+ """
65
+ # add model to the registry
66
+ CUSTOM_MODELS[name] = model
67
+
68
+ return model
69
+
70
+ return add_custom_model
71
+
72
+
73
+ def get_custom_model(name: str) -> Module:
74
+ """Get the custom model corresponding to `name` from the registry.
75
+
76
+ Parameters
77
+ ----------
78
+ name : str
79
+ Name of the model to retrieve.
80
+
81
+ Returns
82
+ -------
83
+ Module
84
+ The requested model.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If the model is not registered.
90
+ """
91
+ if name not in CUSTOM_MODELS:
92
+ raise ValueError(
93
+ f"Model {name} is unknown. Have you registered it using "
94
+ f'@register_model("{name}") as decorator?'
95
+ )
96
+
97
+ return CUSTOM_MODELS[name]
98
+
99
+
100
+ def clear_custom_models() -> None:
101
+ """Clear the custom models registry."""
102
+ # clear dictionary
103
+ CUSTOM_MODELS.clear()
@@ -0,0 +1,118 @@
1
+ """UNet Pydantic model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import ConfigDict, Field, field_validator
8
+
9
+ from .architecture_model import ArchitectureModel
10
+
11
+
12
+ # TODO tests activation <-> pydantic model, test the literals!
13
+ # TODO annotations for the json schema?
14
+ class UNetModel(ArchitectureModel):
15
+ """
16
+ Pydantic model for a N2V(2)-compatible UNet.
17
+
18
+ Attributes
19
+ ----------
20
+ depth : int
21
+ Depth of the model, between 1 and 10 (default 2).
22
+ num_channels_init : int
23
+ Number of filters of the first level of the network, should be even
24
+ and minimum 8 (default 96).
25
+ """
26
+
27
+ # pydantic model config
28
+ model_config = ConfigDict(validate_assignment=True)
29
+
30
+ # discriminator used for choosing the pydantic model in Model
31
+ architecture: Literal["UNet"]
32
+ """Name of the architecture."""
33
+
34
+ # parameters
35
+ # validate_defaults allow ignoring default values in the dump if they were not set
36
+ conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
37
+ """Dimensions (2D or 3D) of the convolutional layers."""
38
+
39
+ num_classes: int = Field(default=1, ge=1, validate_default=True)
40
+ """Number of classes or channels in the model output."""
41
+
42
+ in_channels: int = Field(default=1, ge=1, validate_default=True)
43
+ """Number of channels in the input to the model."""
44
+
45
+ depth: int = Field(default=2, ge=1, le=10, validate_default=True)
46
+ """Number of levels in the UNet."""
47
+
48
+ num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
49
+ """Number of convolutional filters in the first layer of the UNet."""
50
+
51
+ final_activation: Literal[
52
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
53
+ ] = Field(default="None", validate_default=True)
54
+ """Final activation function."""
55
+
56
+ n2v2: bool = Field(default=False, validate_default=True)
57
+ """Whether to use N2V2 architecture modifications, with blur pool layers and fewer
58
+ skip connections.
59
+ """
60
+
61
+ independent_channels: bool = Field(default=True, validate_default=True)
62
+ """Whether information is processed independently in each channel, used to train
63
+ channels independently."""
64
+
65
+ @field_validator("num_channels_init")
66
+ @classmethod
67
+ def validate_num_channels_init(cls, num_channels_init: int) -> int:
68
+ """
69
+ Validate that num_channels_init is even.
70
+
71
+ Parameters
72
+ ----------
73
+ num_channels_init : int
74
+ Number of channels.
75
+
76
+ Returns
77
+ -------
78
+ int
79
+ Validated number of channels.
80
+
81
+ Raises
82
+ ------
83
+ ValueError
84
+ If the number of channels is odd.
85
+ """
86
+ # if odd
87
+ if num_channels_init % 2 != 0:
88
+ raise ValueError(
89
+ f"Number of channels for the bottom layer must be even"
90
+ f" (got {num_channels_init})."
91
+ )
92
+
93
+ return num_channels_init
94
+
95
+ def set_3D(self, is_3D: bool) -> None:
96
+ """
97
+ Set 3D model by setting the `conv_dims` parameters.
98
+
99
+ Parameters
100
+ ----------
101
+ is_3D : bool
102
+ Whether the algorithm is 3D or not.
103
+ """
104
+ if is_3D:
105
+ self.conv_dims = 3
106
+ else:
107
+ self.conv_dims = 2
108
+
109
+ def is_3D(self) -> bool:
110
+ """
111
+ Return whether the model is 3D or not.
112
+
113
+ Returns
114
+ -------
115
+ bool
116
+ Whether the model is 3D or not.
117
+ """
118
+ return self.conv_dims == 3
@@ -0,0 +1,42 @@
1
+ """VAE Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import (
6
+ ConfigDict,
7
+ )
8
+
9
+ from .architecture_model import ArchitectureModel
10
+
11
+
12
+ class VAEModel(ArchitectureModel):
13
+ """VAE model placeholder."""
14
+
15
+ model_config = ConfigDict(
16
+ use_enum_values=True, protected_namespaces=(), validate_assignment=True
17
+ )
18
+
19
+ architecture: Literal["VAE"]
20
+ """Name of the architecture."""
21
+
22
+ def set_3D(self, is_3D: bool) -> None:
23
+ """
24
+ Set 3D model by setting the `conv_dims` parameters.
25
+
26
+ Parameters
27
+ ----------
28
+ is_3D : bool
29
+ Whether the algorithm is 3D or not.
30
+ """
31
+ raise NotImplementedError("VAE is not implemented yet.")
32
+
33
+ def is_3D(self) -> bool:
34
+ """
35
+ Return whether the model is 3D or not.
36
+
37
+ Returns
38
+ -------
39
+ bool
40
+ Whether the model is 3D or not.
41
+ """
42
+ raise NotImplementedError("VAE is not implemented yet.")