careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,239 @@
1
+ """Pydantic model representing CAREamics prediction configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Literal, Optional, Union
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
+ from typing_extensions import Self
9
+
10
+ from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
11
+
12
+
13
+ class InferenceConfig(BaseModel):
14
+ """Configuration class for the prediction model."""
15
+
16
+ model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
17
+
18
+ data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
19
+ """Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
20
+
21
+ tile_size: Optional[Union[list[int]]] = Field(
22
+ default=None, min_length=2, max_length=3
23
+ )
24
+ """Tile size of prediction, only effective if `tile_overlap` is specified."""
25
+
26
+ tile_overlap: Optional[Union[list[int]]] = Field(
27
+ default=None, min_length=2, max_length=3
28
+ )
29
+ """Overlap between tiles, only effective if `tile_size` is specified."""
30
+
31
+ axes: str
32
+ """Data axes (TSCZYX) in the order of the input data."""
33
+
34
+ image_means: list = Field(..., min_length=0, max_length=32)
35
+ """Mean values for each input channel."""
36
+
37
+ image_stds: list = Field(..., min_length=0, max_length=32)
38
+ """Standard deviation values for each input channel."""
39
+
40
+ # TODO only default TTAs are supported for now
41
+ tta_transforms: bool = Field(default=True)
42
+ """Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
43
+
44
+ # Dataloader parameters
45
+ batch_size: int = Field(default=1, ge=1)
46
+ """Batch size for prediction."""
47
+
48
+ @field_validator("tile_overlap")
49
+ @classmethod
50
+ def all_elements_non_zero_even(
51
+ cls, tile_overlap: Optional[list[int]]
52
+ ) -> Optional[list[int]]:
53
+ """
54
+ Validate tile overlap.
55
+
56
+ Overlaps must be non-zero, positive and even.
57
+
58
+ Parameters
59
+ ----------
60
+ tile_overlap : list[int] or None
61
+ Patch size.
62
+
63
+ Returns
64
+ -------
65
+ list[int] or None
66
+ Validated tile overlap.
67
+
68
+ Raises
69
+ ------
70
+ ValueError
71
+ If the patch size is 0.
72
+ ValueError
73
+ If the patch size is not even.
74
+ """
75
+ if tile_overlap is not None:
76
+ for dim in tile_overlap:
77
+ if dim < 1:
78
+ raise ValueError(
79
+ f"Patch size must be non-zero positive (got {dim})."
80
+ )
81
+
82
+ if dim % 2 != 0:
83
+ raise ValueError(f"Patch size must be even (got {dim}).")
84
+
85
+ return tile_overlap
86
+
87
+ @field_validator("tile_size")
88
+ @classmethod
89
+ def tile_min_8_power_of_2(
90
+ cls, tile_list: Optional[list[int]]
91
+ ) -> Optional[list[int]]:
92
+ """
93
+ Validate that each entry is greater or equal than 8 and a power of 2.
94
+
95
+ Parameters
96
+ ----------
97
+ tile_list : list of int
98
+ Patch size.
99
+
100
+ Returns
101
+ -------
102
+ list of int
103
+ Validated patch size.
104
+
105
+ Raises
106
+ ------
107
+ ValueError
108
+ If the patch size if smaller than 8.
109
+ ValueError
110
+ If the patch size is not a power of 2.
111
+ """
112
+ patch_size_ge_than_8_power_of_2(tile_list)
113
+
114
+ return tile_list
115
+
116
+ @field_validator("axes")
117
+ @classmethod
118
+ def axes_valid(cls, axes: str) -> str:
119
+ """
120
+ Validate axes.
121
+
122
+ Axes must:
123
+ - be a combination of 'STCZYX'
124
+ - not contain duplicates
125
+ - contain at least 2 contiguous axes: X and Y
126
+ - contain at most 4 axes
127
+ - not contain both S and T axes
128
+
129
+ Parameters
130
+ ----------
131
+ axes : str
132
+ Axes to validate.
133
+
134
+ Returns
135
+ -------
136
+ str
137
+ Validated axes.
138
+
139
+ Raises
140
+ ------
141
+ ValueError
142
+ If axes are not valid.
143
+ """
144
+ # Validate axes
145
+ check_axes_validity(axes)
146
+
147
+ return axes
148
+
149
+ @model_validator(mode="after")
150
+ def validate_dimensions(self: Self) -> Self:
151
+ """
152
+ Validate 2D/3D dimensions between axes and tile size.
153
+
154
+ Returns
155
+ -------
156
+ Self
157
+ Validated prediction model.
158
+ """
159
+ expected_len = 3 if "Z" in self.axes else 2
160
+
161
+ if self.tile_size is not None and self.tile_overlap is not None:
162
+ if len(self.tile_size) != expected_len:
163
+ raise ValueError(
164
+ f"Tile size must have {expected_len} dimensions given axes "
165
+ f"{self.axes} (got {self.tile_size})."
166
+ )
167
+
168
+ if len(self.tile_overlap) != expected_len:
169
+ raise ValueError(
170
+ f"Tile overlap must have {expected_len} dimensions given axes "
171
+ f"{self.axes} (got {self.tile_overlap})."
172
+ )
173
+
174
+ if any((i >= j) for i, j in zip(self.tile_overlap, self.tile_size)):
175
+ raise ValueError("Tile overlap must be smaller than tile size.")
176
+
177
+ return self
178
+
179
+ @model_validator(mode="after")
180
+ def std_only_with_mean(self: Self) -> Self:
181
+ """
182
+ Check that mean and std are either both None, or both specified.
183
+
184
+ Returns
185
+ -------
186
+ Self
187
+ Validated prediction model.
188
+
189
+ Raises
190
+ ------
191
+ ValueError
192
+ If std is not None and mean is None.
193
+ """
194
+ # check that mean and std are either both None, or both specified
195
+ if not self.image_means and not self.image_stds:
196
+ raise ValueError("Mean and std must be specified during inference.")
197
+
198
+ if (self.image_means and not self.image_stds) or (
199
+ self.image_stds and not self.image_means
200
+ ):
201
+ raise ValueError(
202
+ "Mean and std must be either both None, or both specified."
203
+ )
204
+
205
+ elif (self.image_means is not None and self.image_stds is not None) and (
206
+ len(self.image_means) != len(self.image_stds)
207
+ ):
208
+ raise ValueError(
209
+ "Mean and std must be specified for each " "input channel."
210
+ )
211
+
212
+ return self
213
+
214
+ def _update(self, **kwargs: Any) -> None:
215
+ """
216
+ Update multiple arguments at once.
217
+
218
+ Parameters
219
+ ----------
220
+ **kwargs : Any
221
+ Key-value pairs of arguments to update.
222
+ """
223
+ self.__dict__.update(kwargs)
224
+ self.__class__.model_validate(self.__dict__)
225
+
226
+ def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
227
+ """
228
+ Set 3D parameters.
229
+
230
+ Parameters
231
+ ----------
232
+ axes : str
233
+ Axes.
234
+ tile_size : list of int
235
+ Tile size.
236
+ tile_overlap : list of int
237
+ Tile overlap.
238
+ """
239
+ self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
@@ -0,0 +1,43 @@
1
+ """Likelihood model."""
2
+
3
+ from typing import Literal, Optional, Union
4
+
5
+ import torch
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+ from careamics.models.lvae.noise_models import (
9
+ GaussianMixtureNoiseModel,
10
+ MultiChannelNoiseModel,
11
+ )
12
+
13
+ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
14
+
15
+
16
+ class GaussianLikelihoodConfig(BaseModel):
17
+ """Gaussian likelihood configuration."""
18
+
19
+ model_config = ConfigDict(validate_assignment=True)
20
+
21
+ predict_logvar: Optional[Literal["pixelwise"]] = None
22
+ """If `pixelwise`, log-variance is computed for each pixel, else log-variance
23
+ is not computed."""
24
+
25
+ logvar_lowerbound: Union[float, None] = None
26
+ """The lowerbound value for log-variance."""
27
+
28
+
29
+ class NMLikelihoodConfig(BaseModel):
30
+ """Noise model likelihood configuration."""
31
+
32
+ model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
33
+
34
+ data_mean: Union[torch.Tensor] = torch.zeros(1)
35
+ """The mean of the data, used to unnormalize data for noise model evaluation.
36
+ Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
37
+
38
+ data_std: Union[torch.Tensor] = torch.ones(1)
39
+ """The standard deviation of the data, used to unnormalize data for noise
40
+ model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
41
+
42
+ noise_model: Union[NoiseModel, None] = None
43
+ """The noise model instance used to compute the likelihood."""
@@ -0,0 +1,101 @@
1
+ """Noise models config."""
2
+
3
+ from pathlib import Path
4
+ from typing import Literal, Optional, Union
5
+
6
+ import numpy as np
7
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
8
+ from typing_extensions import Self
9
+
10
+ # TODO: add histogram-based noise model
11
+
12
+
13
+ class GaussianMixtureNMConfig(BaseModel):
14
+ """Gaussian mixture noise model."""
15
+
16
+ model_config = ConfigDict(
17
+ protected_namespaces=(),
18
+ validate_assignment=True,
19
+ arbitrary_types_allowed=True,
20
+ extra="allow",
21
+ )
22
+ # model type
23
+ model_type: Literal["GaussianMixtureNoiseModel"]
24
+
25
+ path: Optional[Union[Path, str]] = None
26
+ """Path to the directory where the trained noise model (*.npz) is saved in the
27
+ `train` method."""
28
+
29
+ signal: Optional[Union[str, Path, np.ndarray]] = None
30
+ """Path to the file containing signal or respective numpy array."""
31
+
32
+ observation: Optional[Union[str, Path, np.ndarray]] = None
33
+ """Path to the file containing observation or respective numpy array."""
34
+
35
+ weight: Optional[np.ndarray] = None
36
+ """A [3*n_gaussian, n_coeff] sized array containing the values of the weights
37
+ describing the GMM noise model, with each row corresponding to one
38
+ parameter of each gaussian, namely [mean, standard deviation and weight].
39
+ Specifically, rows are organized as follows:
40
+ - first n_gaussian rows correspond to the means
41
+ - next n_gaussian rows correspond to the weights
42
+ - last n_gaussian rows correspond to the standard deviations
43
+ If `weight=None`, the weight array is initialized using the `min_signal`
44
+ and `max_signal` parameters."""
45
+
46
+ n_gaussian: int = Field(default=1, ge=1)
47
+ """Number of gaussians used for the GMM."""
48
+
49
+ n_coeff: int = Field(default=2, ge=2)
50
+ """Number of coefficients to describe the functional relationship between gaussian
51
+ parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
52
+ relationship and so on."""
53
+
54
+ min_signal: float = Field(default=0.0, ge=0.0)
55
+ """Minimum signal intensity expected in the image."""
56
+
57
+ max_signal: float = Field(default=1.0, ge=0.0)
58
+ """Maximum signal intensity expected in the image."""
59
+
60
+ min_sigma: float = Field(default=200.0, ge=0.0) # TODO took from nb in pn2v
61
+ """Minimum value of `standard deviation` allowed in the GMM.
62
+ All values of `standard deviation` below this are clamped to this value."""
63
+
64
+ tol: float = Field(default=1e-10)
65
+ """Tolerance used in the computation of the noise model likelihood."""
66
+
67
+ @model_validator(mode="after")
68
+ def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
69
+ """Validate paths provided in the config.
70
+
71
+ Returns
72
+ -------
73
+ Self
74
+ Returns itself.
75
+ """
76
+ if self.path and (self.signal is not None or self.observation is not None):
77
+ raise ValueError(
78
+ "Either only 'path' to pre-trained noise model should be"
79
+ "provided or only signal and observation in form of paths"
80
+ "or numpy arrays."
81
+ )
82
+ if not self.path and (self.signal is None or self.observation is None):
83
+ raise ValueError(
84
+ "Either only 'path' to pre-trained noise model should be"
85
+ "provided or only signal and observation in form of paths"
86
+ "or numpy arrays."
87
+ )
88
+ return self
89
+
90
+
91
+ # The noise model is given by a set of GMMs, one for each target
92
+ # e.g., 2 target channels, 2 noise models
93
+ class MultiChannelNMConfig(BaseModel):
94
+ """Noise Model config aggregating noise models for single output channels."""
95
+
96
+ # TODO: check that this model config is OK
97
+ model_config = ConfigDict(
98
+ validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
99
+ )
100
+ noise_models: list[GaussianMixtureNMConfig]
101
+ """List of noise models, one for each target channel."""
@@ -0,0 +1,187 @@
1
+ """Optimizers and schedulers Pydantic models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal
6
+
7
+ from pydantic import (
8
+ BaseModel,
9
+ ConfigDict,
10
+ Field,
11
+ ValidationInfo,
12
+ field_validator,
13
+ model_validator,
14
+ )
15
+ from torch import optim
16
+ from typing_extensions import Self
17
+
18
+ from careamics.utils.torch_utils import filter_parameters
19
+
20
+ from .support import SupportedOptimizer
21
+
22
+
23
+ class OptimizerModel(BaseModel):
24
+ """Torch optimizer Pydantic model.
25
+
26
+ Only parameters supported by the corresponding torch optimizer will be taken
27
+ into account. For more details, check:
28
+ https://pytorch.org/docs/stable/optim.html#algorithms
29
+
30
+ Note that mandatory parameters (see the specific Optimizer signature in the
31
+ link above) must be provided. For example, SGD requires `lr`.
32
+
33
+ Attributes
34
+ ----------
35
+ name : {"Adam", "SGD"}
36
+ Name of the optimizer.
37
+ parameters : dict
38
+ Parameters of the optimizer (see torch documentation).
39
+ """
40
+
41
+ # Pydantic class configuration
42
+ model_config = ConfigDict(
43
+ validate_assignment=True,
44
+ )
45
+
46
+ # Mandatory field
47
+ name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
48
+ """Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
49
+
50
+ # Optional parameters, empty dict default value to allow filtering dictionary
51
+ parameters: dict = Field(
52
+ default={
53
+ "lr": 1e-4,
54
+ },
55
+ validate_default=True,
56
+ )
57
+ """Parameters of the optimizer, see PyTorch documentation for more details."""
58
+
59
+ @field_validator("parameters")
60
+ @classmethod
61
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
62
+ """
63
+ Validate optimizer parameters.
64
+
65
+ This method filters out unknown parameters, given the optimizer name.
66
+
67
+ Parameters
68
+ ----------
69
+ user_params : dict
70
+ Parameters passed on to the torch optimizer.
71
+ values : ValidationInfo
72
+ Pydantic field validation info, used to get the optimizer name.
73
+
74
+ Returns
75
+ -------
76
+ dict
77
+ Filtered optimizer parameters.
78
+
79
+ Raises
80
+ ------
81
+ ValueError
82
+ If the optimizer name is not specified.
83
+ """
84
+ optimizer_name = values.data["name"]
85
+
86
+ # retrieve the corresponding optimizer class
87
+ optimizer_class = getattr(optim, optimizer_name)
88
+
89
+ # filter the user parameters according to the optimizer's signature
90
+ parameters = filter_parameters(optimizer_class, user_params)
91
+
92
+ return parameters
93
+
94
+ @model_validator(mode="after")
95
+ def sgd_lr_parameter(self) -> Self:
96
+ """
97
+ Check that SGD optimizer has the mandatory `lr` parameter specified.
98
+
99
+ This is specific for PyTorch < 2.2.
100
+
101
+ Returns
102
+ -------
103
+ Self
104
+ Validated optimizer.
105
+
106
+ Raises
107
+ ------
108
+ ValueError
109
+ If the optimizer is SGD and the lr parameter is not specified.
110
+ """
111
+ if self.name == SupportedOptimizer.SGD and "lr" not in self.parameters:
112
+ raise ValueError(
113
+ "SGD optimizer requires `lr` parameter, check that it has correctly "
114
+ "been specified in `parameters`."
115
+ )
116
+
117
+ return self
118
+
119
+
120
+ class LrSchedulerModel(BaseModel):
121
+ """Torch learning rate scheduler Pydantic model.
122
+
123
+ Only parameters supported by the corresponding torch lr scheduler will be taken
124
+ into account. For more details, check:
125
+ https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
126
+
127
+ Note that mandatory parameters (see the specific LrScheduler signature in the
128
+ link above) must be provided. For example, StepLR requires `step_size`.
129
+
130
+ Attributes
131
+ ----------
132
+ name : {"ReduceLROnPlateau", "StepLR"}
133
+ Name of the learning rate scheduler.
134
+ parameters : dict
135
+ Parameters of the learning rate scheduler (see torch documentation).
136
+ """
137
+
138
+ # Pydantic class configuration
139
+ model_config = ConfigDict(
140
+ validate_assignment=True,
141
+ )
142
+
143
+ # Mandatory field
144
+ name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
145
+ """Name of the learning rate scheduler, supported schedulers are defined in
146
+ SupportedScheduler."""
147
+
148
+ # Optional parameters
149
+ parameters: dict = Field(default={}, validate_default=True)
150
+ """Parameters of the learning rate scheduler, see PyTorch documentation for more
151
+ details."""
152
+
153
+ @field_validator("parameters")
154
+ @classmethod
155
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
156
+ """Filter parameters based on the learning rate scheduler's signature.
157
+
158
+ Parameters
159
+ ----------
160
+ user_params : dict
161
+ User parameters.
162
+ values : ValidationInfo
163
+ Pydantic field validation info, used to get the scheduler name.
164
+
165
+ Returns
166
+ -------
167
+ dict
168
+ Filtered scheduler parameters.
169
+
170
+ Raises
171
+ ------
172
+ ValueError
173
+ If the scheduler is StepLR and the step_size parameter is not specified.
174
+ """
175
+ # retrieve the corresponding scheduler class
176
+ scheduler_class = getattr(optim.lr_scheduler, values.data["name"])
177
+
178
+ # filter the user parameters according to the scheduler's signature
179
+ parameters = filter_parameters(scheduler_class, user_params)
180
+
181
+ if values.data["name"] == "StepLR" and "step_size" not in parameters:
182
+ raise ValueError(
183
+ "StepLR scheduler requires `step_size` parameter, check that it has "
184
+ "correctly been specified in `parameters`."
185
+ )
186
+
187
+ return parameters
@@ -0,0 +1,45 @@
1
+ """Module containing references to the algorithm used in CAREamics."""
2
+
3
+ __all__ = [
4
+ "N2V2Ref",
5
+ "N2VRef",
6
+ "StructN2VRef",
7
+ "N2VDescription",
8
+ "N2V2Description",
9
+ "StructN2VDescription",
10
+ "StructN2V2Description",
11
+ "N2V",
12
+ "N2V2",
13
+ "STRUCT_N2V",
14
+ "STRUCT_N2V2",
15
+ "CUSTOM",
16
+ "N2N",
17
+ "CARE",
18
+ "CAREDescription",
19
+ "N2NDescription",
20
+ "CARERef",
21
+ "N2NRef",
22
+ ]
23
+
24
+ from .algorithm_descriptions import (
25
+ CARE,
26
+ CUSTOM,
27
+ N2N,
28
+ N2V,
29
+ N2V2,
30
+ STRUCT_N2V,
31
+ STRUCT_N2V2,
32
+ CAREDescription,
33
+ N2NDescription,
34
+ N2V2Description,
35
+ N2VDescription,
36
+ StructN2V2Description,
37
+ StructN2VDescription,
38
+ )
39
+ from .references import (
40
+ CARERef,
41
+ N2NRef,
42
+ N2V2Ref,
43
+ N2VRef,
44
+ StructN2VRef,
45
+ )