careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,39 @@
1
+ """Configuration module."""
2
+
3
+ __all__ = [
4
+ "FCNAlgorithmConfig",
5
+ "VAEAlgorithmConfig",
6
+ "DataConfig",
7
+ "Configuration",
8
+ "CheckpointModel",
9
+ "InferenceConfig",
10
+ "load_configuration",
11
+ "save_configuration",
12
+ "TrainingConfig",
13
+ "create_n2v_configuration",
14
+ "create_n2n_configuration",
15
+ "create_care_configuration",
16
+ "register_model",
17
+ "CustomModel",
18
+ "clear_custom_models",
19
+ "GaussianMixtureNMConfig",
20
+ "MultiChannelNMConfig",
21
+ ]
22
+ from .architectures import CustomModel, clear_custom_models, register_model
23
+ from .callback_model import CheckpointModel
24
+ from .configuration_factory import (
25
+ create_care_configuration,
26
+ create_n2n_configuration,
27
+ create_n2v_configuration,
28
+ )
29
+ from .configuration_model import (
30
+ Configuration,
31
+ load_configuration,
32
+ save_configuration,
33
+ )
34
+ from .data_model import DataConfig
35
+ from .fcn_algorithm_model import FCNAlgorithmConfig
36
+ from .inference_model import InferenceConfig
37
+ from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
38
+ from .training_model import TrainingConfig
39
+ from .vae_algorithm_model import VAEAlgorithmConfig
@@ -0,0 +1,17 @@
1
+ """Deep-learning model configurations."""
2
+
3
+ __all__ = [
4
+ "ArchitectureModel",
5
+ "CustomModel",
6
+ "UNetModel",
7
+ "LVAEModel",
8
+ "clear_custom_models",
9
+ "get_custom_model",
10
+ "register_model",
11
+ ]
12
+
13
+ from .architecture_model import ArchitectureModel
14
+ from .custom_model import CustomModel
15
+ from .lvae_model import LVAEModel
16
+ from .register_model import clear_custom_models, get_custom_model, register_model
17
+ from .unet_model import UNetModel
@@ -0,0 +1,37 @@
1
+ """Base model for the various CAREamics architectures."""
2
+
3
+ from typing import Any, Dict
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class ArchitectureModel(BaseModel):
9
+ """
10
+ Base Pydantic model for all model architectures.
11
+
12
+ The `model_dump` method allows removing the `architecture` key from the model.
13
+ """
14
+
15
+ architecture: str
16
+ """Name of the architecture."""
17
+
18
+ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
19
+ """
20
+ Dump the model as a dictionary, ignoring the architecture keyword.
21
+
22
+ Parameters
23
+ ----------
24
+ **kwargs : Any
25
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
26
+
27
+ Returns
28
+ -------
29
+ dict[str, Any]
30
+ Model as a dictionary.
31
+ """
32
+ model_dict = super().model_dump(**kwargs)
33
+
34
+ # remove the architecture key
35
+ model_dict.pop("architecture")
36
+
37
+ return model_dict
@@ -0,0 +1,162 @@
1
+ """Custom architecture Pydantic model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ from pprint import pformat
7
+ from typing import Any, Literal
8
+
9
+ from pydantic import ConfigDict, field_validator, model_validator
10
+ from torch.nn import Module
11
+ from typing_extensions import Self
12
+
13
+ from .architecture_model import ArchitectureModel
14
+ from .register_model import get_custom_model
15
+
16
+
17
+ class CustomModel(ArchitectureModel):
18
+ """Custom model configuration.
19
+
20
+ This Pydantic model allows storing parameters for a custom model. In order for the
21
+ model to be valid, the specific model needs to be registered using the
22
+ `register_model` decorator, and its name correctly passed to this model
23
+ configuration (see Examples).
24
+
25
+ Attributes
26
+ ----------
27
+ architecture : Literal["custom"]
28
+ Discriminator for the custom model, must be set to "custom".
29
+ name : str
30
+ Name of the custom model.
31
+ parameters : CustomParametersModel
32
+ All parameters, required for the initialization of the torch module have to be
33
+ passed here.
34
+
35
+ Raises
36
+ ------
37
+ ValueError
38
+ If the custom model `name` is unknown.
39
+ ValueError
40
+ If the custom model is not a torch Module subclass.
41
+ ValueError
42
+ If the custom model parameters are not valid.
43
+
44
+ Examples
45
+ --------
46
+ >>> from torch import nn, ones
47
+ >>> from careamics.config import CustomModel, register_model
48
+ >>> # Register a custom model
49
+ >>> @register_model(name="my_linear")
50
+ ... class LinearModel(nn.Module):
51
+ ... def __init__(self, in_features, out_features, *args, **kwargs):
52
+ ... super().__init__()
53
+ ... self.in_features = in_features
54
+ ... self.out_features = out_features
55
+ ... self.weight = nn.Parameter(ones(in_features, out_features))
56
+ ... self.bias = nn.Parameter(ones(out_features))
57
+ ... def forward(self, input):
58
+ ... return (input @ self.weight) + self.bias
59
+ ...
60
+ >>> # Create a configuration
61
+ >>> config_dict = {
62
+ ... "architecture": "custom",
63
+ ... "name": "my_linear",
64
+ ... "in_features": 10,
65
+ ... "out_features": 5,
66
+ ... }
67
+ >>> config = CustomModel(**config_dict)
68
+ """
69
+
70
+ # pydantic model config
71
+ model_config = ConfigDict(
72
+ extra="allow",
73
+ )
74
+
75
+ # discriminator used for choosing the pydantic model in Model
76
+ architecture: Literal["custom"]
77
+ """Name of the architecture."""
78
+
79
+ name: str
80
+ """Name of the custom model."""
81
+
82
+ @field_validator("name")
83
+ @classmethod
84
+ def custom_model_is_known(cls, value: str) -> str:
85
+ """Check whether the custom model is known.
86
+
87
+ Parameters
88
+ ----------
89
+ value : str
90
+ Name of the custom model as registered using the `@register_model`
91
+ decorator.
92
+
93
+ Returns
94
+ -------
95
+ str
96
+ The custom model name.
97
+ """
98
+ # delegate error to get_custom_model
99
+ model = get_custom_model(value)
100
+
101
+ # check if it is a torch Module subclass
102
+ if not issubclass(model, Module):
103
+ raise ValueError(
104
+ f'Retrieved class {model} with name "{value}" is not a '
105
+ f"torch.nn.Module subclass."
106
+ )
107
+
108
+ return value
109
+
110
+ @model_validator(mode="after")
111
+ def check_parameters(self: Self) -> Self:
112
+ """Validate model by instantiating the model with the parameters.
113
+
114
+ Returns
115
+ -------
116
+ Self
117
+ The validated model.
118
+ """
119
+ # instantiate model
120
+ try:
121
+ get_custom_model(self.name)(**self.model_dump())
122
+ except Exception as e:
123
+ raise ValueError(
124
+ f"while passing parameters to the model {e}. Verify that all "
125
+ f"mandatory parameters are provided, and that either the {e} accepts "
126
+ f"*args and **kwargs in its __init__() method, or that no additional"
127
+ f"parameter is provided. Trace: "
128
+ f"filename: {inspect.trace()[-1].filename}, function: "
129
+ f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
130
+ ) from None
131
+
132
+ return self
133
+
134
+ def __str__(self) -> str:
135
+ """Pretty string representing the configuration.
136
+
137
+ Returns
138
+ -------
139
+ str
140
+ Pretty string.
141
+ """
142
+ return pformat(self.model_dump())
143
+
144
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
145
+ """Dump the model configuration.
146
+
147
+ Parameters
148
+ ----------
149
+ **kwargs : Any
150
+ Additional keyword arguments from Pydantic BaseModel model_dump method.
151
+
152
+ Returns
153
+ -------
154
+ dict[str, Any]
155
+ Model configuration.
156
+ """
157
+ model_dict = super().model_dump()
158
+
159
+ # remove the name key
160
+ model_dict.pop("name")
161
+
162
+ return model_dict
@@ -0,0 +1,174 @@
1
+ """LVAE Pydantic model."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import ConfigDict, Field, field_validator, model_validator
6
+ from typing_extensions import Self
7
+
8
+ from .architecture_model import ArchitectureModel
9
+
10
+
11
+ # TODO: it is quite confusing to call this LVAEModel, as it is basically a config
12
+ class LVAEModel(ArchitectureModel):
13
+ """LVAE model."""
14
+
15
+ model_config = ConfigDict(validate_assignment=True, validate_default=True)
16
+
17
+ architecture: Literal["LVAE"]
18
+ input_shape: int = Field(default=64, ge=8, le=1024)
19
+ multiscale_count: int = Field(default=5) # TODO clarify
20
+ # 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
21
+ z_dims: list = Field(default=[128, 128, 128, 128])
22
+ output_channels: int = Field(default=1, ge=1)
23
+ encoder_n_filters: int = Field(default=64, ge=8, le=1024)
24
+ decoder_n_filters: int = Field(default=64, ge=8, le=1024)
25
+ encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
26
+ decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
27
+ nonlinearity: Literal[
28
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
29
+ ] = Field(
30
+ default="ELU",
31
+ )
32
+
33
+ predict_logvar: Literal[None, "pixelwise"] = None
34
+
35
+ # TODO this parameter is exessive -> Remove & refactor
36
+ enable_noise_model: bool = Field(
37
+ default=True,
38
+ )
39
+ analytical_kl: bool = Field(
40
+ default=False,
41
+ )
42
+
43
+ @field_validator("encoder_n_filters")
44
+ @classmethod
45
+ def validate_encoder_even(cls, encoder_n_filters: int) -> int:
46
+ """
47
+ Validate that num_channels_init is even.
48
+
49
+ Parameters
50
+ ----------
51
+ encoder_n_filters : int
52
+ Number of channels.
53
+
54
+ Returns
55
+ -------
56
+ int
57
+ Validated number of channels.
58
+
59
+ Raises
60
+ ------
61
+ ValueError
62
+ If the number of channels is odd.
63
+ """
64
+ # if odd
65
+ if encoder_n_filters % 2 != 0:
66
+ raise ValueError(
67
+ f"Number of channels for the bottom layer must be even"
68
+ f" (got {encoder_n_filters})."
69
+ )
70
+
71
+ return encoder_n_filters
72
+
73
+ @field_validator("decoder_n_filters")
74
+ @classmethod
75
+ def validate_decoder_even(cls, decoder_n_filters: int) -> int:
76
+ """
77
+ Validate that num_channels_init is even.
78
+
79
+ Parameters
80
+ ----------
81
+ decoder_n_filters : int
82
+ Number of channels.
83
+
84
+ Returns
85
+ -------
86
+ int
87
+ Validated number of channels.
88
+
89
+ Raises
90
+ ------
91
+ ValueError
92
+ If the number of channels is odd.
93
+ """
94
+ # if odd
95
+ if decoder_n_filters % 2 != 0:
96
+ raise ValueError(
97
+ f"Number of channels for the bottom layer must be even"
98
+ f" (got {decoder_n_filters})."
99
+ )
100
+
101
+ return decoder_n_filters
102
+
103
+ @field_validator("z_dims")
104
+ def validate_z_dims(cls, z_dims: tuple) -> tuple:
105
+ """
106
+ Validate the z_dims.
107
+
108
+ Parameters
109
+ ----------
110
+ z_dims : tuple
111
+ Tuple of z dimensions.
112
+
113
+ Returns
114
+ -------
115
+ tuple
116
+ Validated z dimensions.
117
+
118
+ Raises
119
+ ------
120
+ ValueError
121
+ If the number of z dimensions is not 4.
122
+ """
123
+ if len(z_dims) < 2:
124
+ raise ValueError(
125
+ f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
126
+ )
127
+
128
+ return z_dims
129
+
130
+ @model_validator(mode="after")
131
+ def validate_multiscale_count(cls, self: Self) -> Self:
132
+ """
133
+ Validate the multiscale count.
134
+
135
+ Parameters
136
+ ----------
137
+ self : Self
138
+ The model.
139
+
140
+ Returns
141
+ -------
142
+ Self
143
+ The validated model.
144
+ """
145
+ # if self.multiscale_count != 0:
146
+ # if self.multiscale_count != len(self.z_dims) - 1:
147
+ # raise ValueError(
148
+ # f"Multiscale count must be 0 or equal to the number of Z "
149
+ # f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
150
+ # )
151
+
152
+ return self
153
+
154
+ def set_3D(self, is_3D: bool) -> None:
155
+ """
156
+ Set 3D model by setting the `conv_dims` parameters.
157
+
158
+ Parameters
159
+ ----------
160
+ is_3D : bool
161
+ Whether the algorithm is 3D or not.
162
+ """
163
+ raise NotImplementedError("VAE is not implemented yet.")
164
+
165
+ def is_3D(self) -> bool:
166
+ """
167
+ Return whether the model is 3D or not.
168
+
169
+ Returns
170
+ -------
171
+ bool
172
+ Whether the model is 3D or not.
173
+ """
174
+ raise NotImplementedError("VAE is not implemented yet.")
@@ -0,0 +1,103 @@
1
+ """Custom model registration utilities."""
2
+
3
+ from typing import Callable
4
+
5
+ from torch.nn import Module
6
+
7
+ CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__}
8
+
9
+
10
+ def register_model(name: str) -> Callable:
11
+ """Decorator used to register a torch.nn.Module class with a given `name`.
12
+
13
+ Parameters
14
+ ----------
15
+ name : str
16
+ Name of the model.
17
+
18
+ Returns
19
+ -------
20
+ Callable
21
+ Function allowing to instantiate the wrapped Module class.
22
+
23
+ Raises
24
+ ------
25
+ ValueError
26
+ If a model is already registered with that name.
27
+
28
+ Examples
29
+ --------
30
+ ```python
31
+ @register_model(name="linear")
32
+ class LinearModel(nn.Module):
33
+ def __init__(self, in_features, out_features):
34
+ super().__init__()
35
+
36
+ self.weight = nn.Parameter(ones(in_features, out_features))
37
+ self.bias = nn.Parameter(ones(out_features))
38
+
39
+ def forward(self, input):
40
+ return (input @ self.weight) + self.bias
41
+ ```
42
+ """
43
+ if name is None or name == "":
44
+ raise ValueError("Model name cannot be empty.")
45
+
46
+ if name in CUSTOM_MODELS:
47
+ raise ValueError(
48
+ f"Model {name} already exists. Choose a different name or run "
49
+ f"`clear_custom_models()` to empty the registry."
50
+ )
51
+
52
+ def add_custom_model(model: Module) -> Module:
53
+ """Add a custom model to the registry and return it.
54
+
55
+ Parameters
56
+ ----------
57
+ model : Module
58
+ Module class to register.
59
+
60
+ Returns
61
+ -------
62
+ Module
63
+ The registered model.
64
+ """
65
+ # add model to the registry
66
+ CUSTOM_MODELS[name] = model
67
+
68
+ return model
69
+
70
+ return add_custom_model
71
+
72
+
73
+ def get_custom_model(name: str) -> Module:
74
+ """Get the custom model corresponding to `name` from the registry.
75
+
76
+ Parameters
77
+ ----------
78
+ name : str
79
+ Name of the model to retrieve.
80
+
81
+ Returns
82
+ -------
83
+ Module
84
+ The requested model.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If the model is not registered.
90
+ """
91
+ if name not in CUSTOM_MODELS:
92
+ raise ValueError(
93
+ f"Model {name} is unknown. Have you registered it using "
94
+ f'@register_model("{name}") as decorator?'
95
+ )
96
+
97
+ return CUSTOM_MODELS[name]
98
+
99
+
100
+ def clear_custom_models() -> None:
101
+ """Clear the custom models registry."""
102
+ # clear dictionary
103
+ CUSTOM_MODELS.clear()
@@ -0,0 +1,118 @@
1
+ """UNet Pydantic model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import ConfigDict, Field, field_validator
8
+
9
+ from .architecture_model import ArchitectureModel
10
+
11
+
12
+ # TODO tests activation <-> pydantic model, test the literals!
13
+ # TODO annotations for the json schema?
14
+ class UNetModel(ArchitectureModel):
15
+ """
16
+ Pydantic model for a N2V(2)-compatible UNet.
17
+
18
+ Attributes
19
+ ----------
20
+ depth : int
21
+ Depth of the model, between 1 and 10 (default 2).
22
+ num_channels_init : int
23
+ Number of filters of the first level of the network, should be even
24
+ and minimum 8 (default 96).
25
+ """
26
+
27
+ # pydantic model config
28
+ model_config = ConfigDict(validate_assignment=True)
29
+
30
+ # discriminator used for choosing the pydantic model in Model
31
+ architecture: Literal["UNet"]
32
+ """Name of the architecture."""
33
+
34
+ # parameters
35
+ # validate_defaults allow ignoring default values in the dump if they were not set
36
+ conv_dims: Literal[2, 3] = Field(default=2, validate_default=True)
37
+ """Dimensions (2D or 3D) of the convolutional layers."""
38
+
39
+ num_classes: int = Field(default=1, ge=1, validate_default=True)
40
+ """Number of classes or channels in the model output."""
41
+
42
+ in_channels: int = Field(default=1, ge=1, validate_default=True)
43
+ """Number of channels in the input to the model."""
44
+
45
+ depth: int = Field(default=2, ge=1, le=10, validate_default=True)
46
+ """Number of levels in the UNet."""
47
+
48
+ num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True)
49
+ """Number of convolutional filters in the first layer of the UNet."""
50
+
51
+ final_activation: Literal[
52
+ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU"
53
+ ] = Field(default="None", validate_default=True)
54
+ """Final activation function."""
55
+
56
+ n2v2: bool = Field(default=False, validate_default=True)
57
+ """Whether to use N2V2 architecture modifications, with blur pool layers and fewer
58
+ skip connections.
59
+ """
60
+
61
+ independent_channels: bool = Field(default=True, validate_default=True)
62
+ """Whether information is processed independently in each channel, used to train
63
+ channels independently."""
64
+
65
+ @field_validator("num_channels_init")
66
+ @classmethod
67
+ def validate_num_channels_init(cls, num_channels_init: int) -> int:
68
+ """
69
+ Validate that num_channels_init is even.
70
+
71
+ Parameters
72
+ ----------
73
+ num_channels_init : int
74
+ Number of channels.
75
+
76
+ Returns
77
+ -------
78
+ int
79
+ Validated number of channels.
80
+
81
+ Raises
82
+ ------
83
+ ValueError
84
+ If the number of channels is odd.
85
+ """
86
+ # if odd
87
+ if num_channels_init % 2 != 0:
88
+ raise ValueError(
89
+ f"Number of channels for the bottom layer must be even"
90
+ f" (got {num_channels_init})."
91
+ )
92
+
93
+ return num_channels_init
94
+
95
+ def set_3D(self, is_3D: bool) -> None:
96
+ """
97
+ Set 3D model by setting the `conv_dims` parameters.
98
+
99
+ Parameters
100
+ ----------
101
+ is_3D : bool
102
+ Whether the algorithm is 3D or not.
103
+ """
104
+ if is_3D:
105
+ self.conv_dims = 3
106
+ else:
107
+ self.conv_dims = 2
108
+
109
+ def is_3D(self) -> bool:
110
+ """
111
+ Return whether the model is 3D or not.
112
+
113
+ Returns
114
+ -------
115
+ bool
116
+ Whether the model is 3D or not.
117
+ """
118
+ return self.conv_dims == 3