careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (133) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +323 -134
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -14
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -221
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -12
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +112 -75
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -182
  111. careamics/bioimage/rdf.py +0 -105
  112. careamics/config/algorithm.py +0 -231
  113. careamics/config/config.py +0 -297
  114. careamics/config/config_filter.py +0 -44
  115. careamics/config/data.py +0 -194
  116. careamics/config/torch_optim.py +0 -118
  117. careamics/config/training.py +0 -534
  118. careamics/dataset/dataset_utils.py +0 -111
  119. careamics/dataset/patching.py +0 -492
  120. careamics/dataset/prepare_dataset.py +0 -175
  121. careamics/dataset/tiff_dataset.py +0 -212
  122. careamics/engine.py +0 -1014
  123. careamics/manipulation/__init__.py +0 -4
  124. careamics/manipulation/pixel_manipulation.py +0 -158
  125. careamics/prediction/prediction_utils.py +0 -106
  126. careamics/utils/ascii_logo.txt +0 -9
  127. careamics/utils/augment.py +0 -65
  128. careamics/utils/normalization.py +0 -55
  129. careamics/utils/validators.py +0 -170
  130. careamics/utils/wandb.py +0 -121
  131. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  132. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  133. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,283 @@
1
+ """Pydantic model representing CAREamics prediction configuration."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, List, Literal, Optional, Union
5
+
6
+ from albumentations import Compose
7
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
+ from typing_extensions import Self
9
+
10
+ from .support import SupportedTransform
11
+ from .transformations.normalize_model import NormalizeModel
12
+ from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
13
+
14
+ TRANSFORMS_UNION = Union[NormalizeModel]
15
+
16
+
17
+ class InferenceModel(BaseModel):
18
+ """Configuration class for the prediction model."""
19
+
20
+ model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
21
+
22
+ # Mandatory fields
23
+ data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
24
+ tile_size: Optional[Union[List[int]]] = Field(
25
+ default=None, min_length=2, max_length=3
26
+ )
27
+ tile_overlap: Optional[Union[List[int]]] = Field(
28
+ default=None, min_length=2, max_length=3
29
+ )
30
+
31
+ axes: str
32
+
33
+ mean: float
34
+ std: float = Field(..., ge=0.0)
35
+
36
+ transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
37
+ default=[
38
+ {
39
+ "name": SupportedTransform.NORMALIZE.value,
40
+ },
41
+ ],
42
+ validate_default=True,
43
+ )
44
+
45
+ # only default TTAs are supported for now
46
+ tta_transforms: bool = Field(default=True)
47
+
48
+ # Dataloader parameters
49
+ batch_size: int = Field(default=1, ge=1)
50
+
51
+ @field_validator("tile_overlap")
52
+ @classmethod
53
+ def all_elements_non_zero_even(
54
+ cls, patch_list: Optional[Union[List[int]]]
55
+ ) -> Optional[Union[List[int]]]:
56
+ """
57
+ Validate patch size.
58
+
59
+ Patch size must be non-zero, positive and even.
60
+
61
+ Parameters
62
+ ----------
63
+ patch_list : Optional[Union[List[int]]]
64
+ Patch size.
65
+
66
+ Returns
67
+ -------
68
+ Optional[Union[List[int]]]
69
+ Validated patch size.
70
+
71
+ Raises
72
+ ------
73
+ ValueError
74
+ If the patch size is 0.
75
+ ValueError
76
+ If the patch size is not even.
77
+ """
78
+ if patch_list is not None:
79
+ for dim in patch_list:
80
+ if dim < 1:
81
+ raise ValueError(
82
+ f"Patch size must be non-zero positive (got {dim})."
83
+ )
84
+
85
+ if dim % 2 != 0:
86
+ raise ValueError(f"Patch size must be even (got {dim}).")
87
+
88
+ return patch_list
89
+
90
+ @field_validator("tile_size")
91
+ @classmethod
92
+ def tile_min_8_power_of_2(
93
+ cls, tile_list: Optional[Union[List[int]]]
94
+ ) -> Optional[Union[List[int]]]:
95
+ """
96
+ Validate that each entry is greater or equal than 8 and a power of 2.
97
+
98
+ Parameters
99
+ ----------
100
+ tile_list : List[int]
101
+ Patch size.
102
+
103
+ Returns
104
+ -------
105
+ List[int]
106
+ Validated patch size.
107
+
108
+ Raises
109
+ ------
110
+ ValueError
111
+ If the patch size if smaller than 8.
112
+ ValueError
113
+ If the patch size is not a power of 2.
114
+ """
115
+ patch_size_ge_than_8_power_of_2(tile_list)
116
+
117
+ return tile_list
118
+
119
+ @field_validator("axes")
120
+ @classmethod
121
+ def axes_valid(cls, axes: str) -> str:
122
+ """
123
+ Validate axes.
124
+
125
+ Axes must:
126
+ - be a combination of 'STCZYX'
127
+ - not contain duplicates
128
+ - contain at least 2 contiguous axes: X and Y
129
+ - contain at most 4 axes
130
+ - not contain both S and T axes
131
+
132
+ Parameters
133
+ ----------
134
+ axes : str
135
+ Axes to validate.
136
+
137
+ Returns
138
+ -------
139
+ str
140
+ Validated axes.
141
+
142
+ Raises
143
+ ------
144
+ ValueError
145
+ If axes are not valid.
146
+ """
147
+ # Validate axes
148
+ check_axes_validity(axes)
149
+
150
+ return axes
151
+
152
+ @field_validator("transforms")
153
+ @classmethod
154
+ def validate_transforms(
155
+ cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
156
+ ) -> Union[List[TRANSFORMS_UNION], Compose]:
157
+ """
158
+ Validate that transforms do not have N2V pixel manipulate transforms.
159
+
160
+ Parameters
161
+ ----------
162
+ transforms : Union[List[TransformModel], Compose]
163
+ Transforms.
164
+
165
+ Returns
166
+ -------
167
+ Union[List[Transformations_Union], Compose]
168
+ Validated transforms.
169
+
170
+ Raises
171
+ ------
172
+ ValueError
173
+ If transforms contain N2V pixel manipulate transforms.
174
+ """
175
+ if not isinstance(transforms, Compose) and transforms is not None:
176
+ for transform in transforms:
177
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
178
+ raise ValueError(
179
+ "N2V_Manipulate transform is not allowed in "
180
+ "prediction transforms."
181
+ )
182
+
183
+ return transforms
184
+
185
+ @model_validator(mode="after")
186
+ def validate_dimensions(self: Self) -> Self:
187
+ """
188
+ Validate 2D/3D dimensions between axes and tile size.
189
+
190
+ Returns
191
+ -------
192
+ Self
193
+ Validated prediction model.
194
+ """
195
+ expected_len = 3 if "Z" in self.axes else 2
196
+
197
+ if self.tile_size is not None and self.tile_overlap is not None:
198
+ if len(self.tile_size) != expected_len:
199
+ raise ValueError(
200
+ f"Tile size must have {expected_len} dimensions given axes "
201
+ f"{self.axes} (got {self.tile_size})."
202
+ )
203
+
204
+ if len(self.tile_overlap) != expected_len:
205
+ raise ValueError(
206
+ f"Tile overlap must have {expected_len} dimensions given axes "
207
+ f"{self.axes} (got {self.tile_overlap})."
208
+ )
209
+
210
+ if any((i >= j) for i, j in zip(self.tile_overlap, self.tile_size)):
211
+ raise ValueError("Tile overlap must be smaller than tile size.")
212
+
213
+ return self
214
+
215
+ @model_validator(mode="after")
216
+ def std_only_with_mean(self: Self) -> Self:
217
+ """
218
+ Check that mean and std are either both None, or both specified.
219
+
220
+ Returns
221
+ -------
222
+ Self
223
+ Validated prediction model.
224
+
225
+ Raises
226
+ ------
227
+ ValueError
228
+ If std is not None and mean is None.
229
+ """
230
+ # check that mean and std are either both None, or both specified
231
+ if (self.mean is None) != (self.std is None):
232
+ raise ValueError(
233
+ "Mean and std must be either both None, or both specified."
234
+ )
235
+
236
+ return self
237
+
238
+ @model_validator(mode="after")
239
+ def add_std_and_mean_to_normalize(self: Self) -> Self:
240
+ """
241
+ Add mean and std to the Normalize transform if it is present.
242
+
243
+ Returns
244
+ -------
245
+ Self
246
+ Inference model with mean and std added to the Normalize transform.
247
+ """
248
+ if self.mean is not None or self.std is not None:
249
+ # search in the transforms for Normalize and update parameters
250
+ if not isinstance(self.transforms, Compose):
251
+ for transform in self.transforms:
252
+ if transform.name == SupportedTransform.NORMALIZE.value:
253
+ transform.mean = self.mean
254
+ transform.std = self.std
255
+
256
+ return self
257
+
258
+ def _update(self, **kwargs: Any) -> None:
259
+ """
260
+ Update multiple arguments at once.
261
+
262
+ Parameters
263
+ ----------
264
+ **kwargs : Any
265
+ Key-value pairs of arguments to update.
266
+ """
267
+ self.__dict__.update(kwargs)
268
+ self.__class__.model_validate(self.__dict__)
269
+
270
+ def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None:
271
+ """
272
+ Set 3D parameters.
273
+
274
+ Parameters
275
+ ----------
276
+ axes : str
277
+ Axes.
278
+ tile_size : List[int]
279
+ Tile size.
280
+ tile_overlap : List[int]
281
+ Tile overlap.
282
+ """
283
+ self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
@@ -0,0 +1,162 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Dict, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
7
+
8
+
9
+ class NoiseModelType(str, Enum):
10
+ """
11
+ Available noise models.
12
+
13
+ Currently supported noise models:
14
+
15
+ - hist: Histogram noise model.
16
+ - gmm: Gaussian mixture model noise model.F
17
+ """
18
+
19
+ NONE = "none"
20
+ HIST = "hist"
21
+ GMM = "gmm"
22
+
23
+ # TODO add validator decorator
24
+ @classmethod
25
+ def validate_noise_model_type(
26
+ cls, noise_model: Union[str, NoiseModel], parameters: dict
27
+ ) -> None:
28
+ """_summary_.
29
+
30
+ Parameters
31
+ ----------
32
+ noise_model : Union[str, NoiseModel]
33
+ _description_
34
+ parameters : dict
35
+ _description_
36
+
37
+ Returns
38
+ -------
39
+ BaseModel
40
+ _description_
41
+ """
42
+ if noise_model == NoiseModelType.HIST.value:
43
+ HistogramNoiseModel(**parameters)
44
+ return HistogramNoiseModel().model_dump() if not parameters else parameters
45
+
46
+ elif noise_model == NoiseModelType.GMM.value:
47
+ GaussianMixtureNoiseModel(**parameters)
48
+ return (
49
+ GaussianMixtureNoiseModel().model_dump()
50
+ if not parameters
51
+ else parameters
52
+ )
53
+
54
+
55
+ class NoiseModel(BaseModel):
56
+ """_summary_.
57
+
58
+ Parameters
59
+ ----------
60
+ BaseModel : _type_
61
+ _description_
62
+
63
+ Returns
64
+ -------
65
+ _type_
66
+ _description_
67
+
68
+ Raises
69
+ ------
70
+ ValueError
71
+ _description_
72
+ """
73
+
74
+ model_config = ConfigDict(
75
+ use_enum_values=True,
76
+ protected_namespaces=(), # allows to use model_* as a field name
77
+ validate_assignment=True,
78
+ )
79
+
80
+ model_type: NoiseModelType
81
+ parameters: Dict = Field(default_factory=dict, validate_default=True)
82
+
83
+ @field_validator("parameters")
84
+ @classmethod
85
+ def validate_parameters(cls, data, values) -> Dict:
86
+ """_summary_.
87
+
88
+ Parameters
89
+ ----------
90
+ parameters : Dict
91
+ _description_
92
+
93
+ Returns
94
+ -------
95
+ Dict
96
+ _description_
97
+ """
98
+ if values.data["model_type"] not in [NoiseModelType.GMM, NoiseModelType.HIST]:
99
+ raise ValueError(
100
+ f"Incorrect noise model {values.data['model_type']}."
101
+ f"Please refer to the documentation" # TODO add link to documentation
102
+ )
103
+
104
+ parameters = NoiseModelType.validate_noise_model_type(
105
+ values.data["model_type"], data
106
+ )
107
+ return parameters
108
+
109
+
110
+ class HistogramNoiseModel(BaseModel):
111
+ """
112
+ Histogram noise model.
113
+
114
+ Attributes
115
+ ----------
116
+ min_value : float
117
+ Minimum value in the input.
118
+ max_value : float
119
+ Maximum value in the input.
120
+ bins : int
121
+ Number of bins of the histogram.
122
+ """
123
+
124
+ min_value: float = Field(default=350.0, ge=0.0, le=65535.0)
125
+ max_value: float = Field(default=6500.0, ge=0.0, le=65535.0)
126
+ bins: int = Field(default=256, ge=1)
127
+
128
+
129
+ class GaussianMixtureNoiseModel(BaseModel):
130
+ """
131
+ Gaussian mixture model noise model.
132
+
133
+ Attributes
134
+ ----------
135
+ min_signal : float
136
+ Minimum signal intensity expected in the image.
137
+ max_signal : float
138
+ Maximum signal intensity expected in the image.
139
+ weight : array
140
+ A [3*n_gaussian, n_coeff] sized array containing the values of the weights
141
+ describing the noise model.
142
+ Each gaussian contributes three parameters (mean, standard deviation and weight),
143
+ hence the number of rows in `weight` are 3*n_gaussian.
144
+ If `weight = None`, the weight array is initialized using the `min_signal` and
145
+ `max_signal` parameters.
146
+ n_gaussian: int
147
+ Number of gaussians.
148
+ n_coeff: int
149
+ Number of coefficients to describe the functional relationship between gaussian
150
+ parameters and the signal.
151
+ 2 implies a linear relationship, 3 implies a quadratic relationship and so on.
152
+ device: device
153
+ GPU device.
154
+ min_sigma: int
155
+ """
156
+
157
+ num_components: int = Field(default=3, ge=1)
158
+ min_value: float = Field(default=350.0, ge=0.0, le=65535.0)
159
+ max_value: float = Field(default=6500.0, ge=0.0, le=65535.0)
160
+ n_gaussian: int = Field(default=3, ge=1)
161
+ n_coeff: int = Field(default=2, ge=1)
162
+ min_sigma: int = Field(default=50, ge=1)
@@ -0,0 +1,181 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Literal
4
+
5
+ from pydantic import (
6
+ BaseModel,
7
+ ConfigDict,
8
+ Field,
9
+ ValidationInfo,
10
+ field_validator,
11
+ model_validator,
12
+ )
13
+ from torch import optim
14
+ from typing_extensions import Self
15
+
16
+ from careamics.utils.torch_utils import filter_parameters
17
+
18
+ from .support import SupportedOptimizer
19
+
20
+
21
+ class OptimizerModel(BaseModel):
22
+ """
23
+ Torch optimizer.
24
+
25
+ Only parameters supported by the corresponding torch optimizer will be taken
26
+ into account. For more details, check:
27
+ https://pytorch.org/docs/stable/optim.html#algorithms
28
+
29
+ Note that mandatory parameters (see the specific Optimizer signature in the
30
+ link above) must be provided. For example, SGD requires `lr`.
31
+
32
+ Attributes
33
+ ----------
34
+ name : TorchOptimizer
35
+ Name of the optimizer.
36
+ parameters : dict
37
+ Parameters of the optimizer (see torch documentation).
38
+ """
39
+
40
+ # Pydantic class configuration
41
+ model_config = ConfigDict(
42
+ validate_assignment=True,
43
+ )
44
+
45
+ # Mandatory field
46
+ name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
47
+
48
+ # Optional parameters, empty dict default value to allow filtering dictionary
49
+ parameters: dict = Field(
50
+ default={
51
+ "lr": 1e-4,
52
+ },
53
+ validate_default=True,
54
+ )
55
+
56
+ @field_validator("parameters")
57
+ @classmethod
58
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
59
+ """
60
+ Validate optimizer parameters.
61
+
62
+ This method filters out unknown parameters, given the optimizer name.
63
+
64
+ Parameters
65
+ ----------
66
+ user_params : dict
67
+ Parameters passed on to the torch optimizer.
68
+ values : ValidationInfo
69
+ Pydantic field validation info, used to get the optimizer name.
70
+
71
+ Returns
72
+ -------
73
+ Dict
74
+ Filtered optimizer parameters.
75
+
76
+ Raises
77
+ ------
78
+ ValueError
79
+ If the optimizer name is not specified.
80
+ """
81
+ optimizer_name = values.data["name"]
82
+
83
+ # retrieve the corresponding optimizer class
84
+ optimizer_class = getattr(optim, optimizer_name)
85
+
86
+ # filter the user parameters according to the optimizer's signature
87
+ parameters = filter_parameters(optimizer_class, user_params)
88
+
89
+ return parameters
90
+
91
+ @model_validator(mode="after")
92
+ def sgd_lr_parameter(self) -> Self:
93
+ """
94
+ Check that SGD optimizer has the mandatory `lr` parameter specified.
95
+
96
+ This is specific for PyTorch < 2.2.
97
+
98
+ Returns
99
+ -------
100
+ Self
101
+ Validated optimizer.
102
+
103
+ Raises
104
+ ------
105
+ ValueError
106
+ If the optimizer is SGD and the lr parameter is not specified.
107
+ """
108
+ if self.name == SupportedOptimizer.SGD and "lr" not in self.parameters:
109
+ raise ValueError(
110
+ "SGD optimizer requires `lr` parameter, check that it has correctly "
111
+ "been specified in `parameters`."
112
+ )
113
+
114
+ return self
115
+
116
+
117
+ class LrSchedulerModel(BaseModel):
118
+ """
119
+ Torch learning rate scheduler.
120
+
121
+ Only parameters supported by the corresponding torch lr scheduler will be taken
122
+ into account. For more details, check:
123
+ https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
124
+
125
+ Note that mandatory parameters (see the specific LrScheduler signature in the
126
+ link above) must be provided. For example, StepLR requires `step_size`.
127
+
128
+ Attributes
129
+ ----------
130
+ name : TorchLRScheduler
131
+ Name of the learning rate scheduler.
132
+ parameters : dict
133
+ Parameters of the learning rate scheduler (see torch documentation).
134
+ """
135
+
136
+ # Pydantic class configuration
137
+ model_config = ConfigDict(
138
+ validate_assignment=True,
139
+ )
140
+
141
+ # Mandatory field
142
+ name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
143
+
144
+ # Optional parameters
145
+ parameters: dict = Field(default={}, validate_default=True)
146
+
147
+ @field_validator("parameters")
148
+ @classmethod
149
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
150
+ """Filter parameters based on the learning rate scheduler's signature.
151
+
152
+ Parameters
153
+ ----------
154
+ user_params : dict
155
+ User parameters.
156
+ values : ValidationInfo
157
+ Pydantic field validation info, used to get the scheduler name.
158
+
159
+ Returns
160
+ -------
161
+ Dict
162
+ Filtered scheduler parameters.
163
+
164
+ Raises
165
+ ------
166
+ ValueError
167
+ If the scheduler is StepLR and the step_size parameter is not specified.
168
+ """
169
+ # retrieve the corresponding scheduler class
170
+ scheduler_class = getattr(optim.lr_scheduler, values.data["name"])
171
+
172
+ # filter the user parameters according to the scheduler's signature
173
+ parameters = filter_parameters(scheduler_class, user_params)
174
+
175
+ if values.data["name"] == "StepLR" and "step_size" not in parameters:
176
+ raise ValueError(
177
+ "StepLR scheduler requires `step_size` parameter, check that it has "
178
+ "correctly been specified in `parameters`."
179
+ )
180
+
181
+ return parameters
@@ -0,0 +1,45 @@
1
+ """Module containing references to the algorithm used in CAREamics."""
2
+
3
+ __all__ = [
4
+ "N2V2Ref",
5
+ "N2VRef",
6
+ "StructN2VRef",
7
+ "N2VDescription",
8
+ "N2V2Description",
9
+ "StructN2VDescription",
10
+ "StructN2V2Description",
11
+ "N2V",
12
+ "N2V2",
13
+ "STRUCT_N2V",
14
+ "STRUCT_N2V2",
15
+ "CUSTOM",
16
+ "N2N",
17
+ "CARE",
18
+ "CAREDescription",
19
+ "N2NDescription",
20
+ "CARERef",
21
+ "N2NRef",
22
+ ]
23
+
24
+ from .algorithm_descriptions import (
25
+ CARE,
26
+ CUSTOM,
27
+ N2N,
28
+ N2V,
29
+ N2V2,
30
+ STRUCT_N2V,
31
+ STRUCT_N2V2,
32
+ CAREDescription,
33
+ N2NDescription,
34
+ N2V2Description,
35
+ N2VDescription,
36
+ StructN2V2Description,
37
+ StructN2VDescription,
38
+ )
39
+ from .references import (
40
+ CARERef,
41
+ N2NRef,
42
+ N2V2Ref,
43
+ N2VRef,
44
+ StructN2VRef,
45
+ )