careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +321 -131
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -13
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -202
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -13
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +89 -56
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -271
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -296
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -115
- careamics/dataset/patching.py +0 -493
- careamics/dataset/prepare_dataset.py +0 -174
- careamics/dataset/tiff_dataset.py +0 -211
- careamics/engine.py +0 -954
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -102
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -156
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc1.dist-info/METADATA +0 -80
- careamics-0.1.0rc1.dist-info/RECORD +0 -46
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""Pydantic model representing CAREamics prediction configuration."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, List, Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
from albumentations import Compose
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from .support import SupportedTransform
|
|
11
|
+
from .transformations.normalize_model import NormalizeModel
|
|
12
|
+
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
13
|
+
|
|
14
|
+
TRANSFORMS_UNION = Union[NormalizeModel]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InferenceModel(BaseModel):
|
|
18
|
+
"""Configuration class for the prediction model."""
|
|
19
|
+
|
|
20
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
21
|
+
|
|
22
|
+
# Mandatory fields
|
|
23
|
+
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
24
|
+
tile_size: Optional[Union[List[int]]] = Field(
|
|
25
|
+
default=None, min_length=2, max_length=3
|
|
26
|
+
)
|
|
27
|
+
tile_overlap: Optional[Union[List[int]]] = Field(
|
|
28
|
+
default=None, min_length=2, max_length=3
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
axes: str
|
|
32
|
+
|
|
33
|
+
mean: float
|
|
34
|
+
std: float = Field(..., ge=0.0)
|
|
35
|
+
|
|
36
|
+
transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
|
|
37
|
+
default=[
|
|
38
|
+
{
|
|
39
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
40
|
+
},
|
|
41
|
+
],
|
|
42
|
+
validate_default=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# only default TTAs are supported for now
|
|
46
|
+
tta_transforms: bool = Field(default=True)
|
|
47
|
+
|
|
48
|
+
# Dataloader parameters
|
|
49
|
+
batch_size: int = Field(default=1, ge=1)
|
|
50
|
+
|
|
51
|
+
@field_validator("tile_overlap")
|
|
52
|
+
@classmethod
|
|
53
|
+
def all_elements_non_zero_even(
|
|
54
|
+
cls, patch_list: Optional[Union[List[int]]]
|
|
55
|
+
) -> Optional[Union[List[int]]]:
|
|
56
|
+
"""
|
|
57
|
+
Validate patch size.
|
|
58
|
+
|
|
59
|
+
Patch size must be non-zero, positive and even.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
patch_list : Optional[Union[List[int]]]
|
|
64
|
+
Patch size.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
Optional[Union[List[int]]]
|
|
69
|
+
Validated patch size.
|
|
70
|
+
|
|
71
|
+
Raises
|
|
72
|
+
------
|
|
73
|
+
ValueError
|
|
74
|
+
If the patch size is 0.
|
|
75
|
+
ValueError
|
|
76
|
+
If the patch size is not even.
|
|
77
|
+
"""
|
|
78
|
+
if patch_list is not None:
|
|
79
|
+
for dim in patch_list:
|
|
80
|
+
if dim < 1:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Patch size must be non-zero positive (got {dim})."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if dim % 2 != 0:
|
|
86
|
+
raise ValueError(f"Patch size must be even (got {dim}).")
|
|
87
|
+
|
|
88
|
+
return patch_list
|
|
89
|
+
|
|
90
|
+
@field_validator("tile_size")
|
|
91
|
+
@classmethod
|
|
92
|
+
def tile_min_8_power_of_2(
|
|
93
|
+
cls, tile_list: Optional[Union[List[int]]]
|
|
94
|
+
) -> Optional[Union[List[int]]]:
|
|
95
|
+
"""
|
|
96
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
tile_list : List[int]
|
|
101
|
+
Patch size.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
List[int]
|
|
106
|
+
Validated patch size.
|
|
107
|
+
|
|
108
|
+
Raises
|
|
109
|
+
------
|
|
110
|
+
ValueError
|
|
111
|
+
If the patch size if smaller than 8.
|
|
112
|
+
ValueError
|
|
113
|
+
If the patch size is not a power of 2.
|
|
114
|
+
"""
|
|
115
|
+
patch_size_ge_than_8_power_of_2(tile_list)
|
|
116
|
+
|
|
117
|
+
return tile_list
|
|
118
|
+
|
|
119
|
+
@field_validator("axes")
|
|
120
|
+
@classmethod
|
|
121
|
+
def axes_valid(cls, axes: str) -> str:
|
|
122
|
+
"""
|
|
123
|
+
Validate axes.
|
|
124
|
+
|
|
125
|
+
Axes must:
|
|
126
|
+
- be a combination of 'STCZYX'
|
|
127
|
+
- not contain duplicates
|
|
128
|
+
- contain at least 2 contiguous axes: X and Y
|
|
129
|
+
- contain at most 4 axes
|
|
130
|
+
- not contain both S and T axes
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
axes : str
|
|
135
|
+
Axes to validate.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
str
|
|
140
|
+
Validated axes.
|
|
141
|
+
|
|
142
|
+
Raises
|
|
143
|
+
------
|
|
144
|
+
ValueError
|
|
145
|
+
If axes are not valid.
|
|
146
|
+
"""
|
|
147
|
+
# Validate axes
|
|
148
|
+
check_axes_validity(axes)
|
|
149
|
+
|
|
150
|
+
return axes
|
|
151
|
+
|
|
152
|
+
@field_validator("transforms")
|
|
153
|
+
@classmethod
|
|
154
|
+
def validate_transforms(
|
|
155
|
+
cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
|
|
156
|
+
) -> Union[List[TRANSFORMS_UNION], Compose]:
|
|
157
|
+
"""
|
|
158
|
+
Validate that transforms do not have N2V pixel manipulate transforms.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
transforms : Union[List[TransformModel], Compose]
|
|
163
|
+
Transforms.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
Union[List[Transformations_Union], Compose]
|
|
168
|
+
Validated transforms.
|
|
169
|
+
|
|
170
|
+
Raises
|
|
171
|
+
------
|
|
172
|
+
ValueError
|
|
173
|
+
If transforms contain N2V pixel manipulate transforms.
|
|
174
|
+
"""
|
|
175
|
+
if not isinstance(transforms, Compose) and transforms is not None:
|
|
176
|
+
for transform in transforms:
|
|
177
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"N2V_Manipulate transform is not allowed in "
|
|
180
|
+
"prediction transforms."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return transforms
|
|
184
|
+
|
|
185
|
+
@model_validator(mode="after")
|
|
186
|
+
def validate_dimensions(self: Self) -> Self:
|
|
187
|
+
"""
|
|
188
|
+
Validate 2D/3D dimensions between axes and tile size.
|
|
189
|
+
|
|
190
|
+
Returns
|
|
191
|
+
-------
|
|
192
|
+
Self
|
|
193
|
+
Validated prediction model.
|
|
194
|
+
"""
|
|
195
|
+
expected_len = 3 if "Z" in self.axes else 2
|
|
196
|
+
|
|
197
|
+
if self.tile_size is not None and self.tile_overlap is not None:
|
|
198
|
+
if len(self.tile_size) != expected_len:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Tile size must have {expected_len} dimensions given axes "
|
|
201
|
+
f"{self.axes} (got {self.tile_size})."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
if len(self.tile_overlap) != expected_len:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"Tile overlap must have {expected_len} dimensions given axes "
|
|
207
|
+
f"{self.axes} (got {self.tile_overlap})."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if any((i >= j) for i, j in zip(self.tile_overlap, self.tile_size)):
|
|
211
|
+
raise ValueError("Tile overlap must be smaller than tile size.")
|
|
212
|
+
|
|
213
|
+
return self
|
|
214
|
+
|
|
215
|
+
@model_validator(mode="after")
|
|
216
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
217
|
+
"""
|
|
218
|
+
Check that mean and std are either both None, or both specified.
|
|
219
|
+
|
|
220
|
+
Returns
|
|
221
|
+
-------
|
|
222
|
+
Self
|
|
223
|
+
Validated prediction model.
|
|
224
|
+
|
|
225
|
+
Raises
|
|
226
|
+
------
|
|
227
|
+
ValueError
|
|
228
|
+
If std is not None and mean is None.
|
|
229
|
+
"""
|
|
230
|
+
# check that mean and std are either both None, or both specified
|
|
231
|
+
if (self.mean is None) != (self.std is None):
|
|
232
|
+
raise ValueError(
|
|
233
|
+
"Mean and std must be either both None, or both specified."
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return self
|
|
237
|
+
|
|
238
|
+
@model_validator(mode="after")
|
|
239
|
+
def add_std_and_mean_to_normalize(self: Self) -> Self:
|
|
240
|
+
"""
|
|
241
|
+
Add mean and std to the Normalize transform if it is present.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
Self
|
|
246
|
+
Inference model with mean and std added to the Normalize transform.
|
|
247
|
+
"""
|
|
248
|
+
if self.mean is not None or self.std is not None:
|
|
249
|
+
# search in the transforms for Normalize and update parameters
|
|
250
|
+
if not isinstance(self.transforms, Compose):
|
|
251
|
+
for transform in self.transforms:
|
|
252
|
+
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
253
|
+
transform.mean = self.mean
|
|
254
|
+
transform.std = self.std
|
|
255
|
+
|
|
256
|
+
return self
|
|
257
|
+
|
|
258
|
+
def _update(self, **kwargs: Any) -> None:
|
|
259
|
+
"""
|
|
260
|
+
Update multiple arguments at once.
|
|
261
|
+
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
**kwargs : Any
|
|
265
|
+
Key-value pairs of arguments to update.
|
|
266
|
+
"""
|
|
267
|
+
self.__dict__.update(kwargs)
|
|
268
|
+
self.__class__.model_validate(self.__dict__)
|
|
269
|
+
|
|
270
|
+
def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None:
|
|
271
|
+
"""
|
|
272
|
+
Set 3D parameters.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
axes : str
|
|
277
|
+
Axes.
|
|
278
|
+
tile_size : List[int]
|
|
279
|
+
Tile size.
|
|
280
|
+
tile_overlap : List[int]
|
|
281
|
+
Tile overlap.
|
|
282
|
+
"""
|
|
283
|
+
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Dict, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NoiseModelType(str, Enum):
|
|
10
|
+
"""
|
|
11
|
+
Available noise models.
|
|
12
|
+
|
|
13
|
+
Currently supported noise models:
|
|
14
|
+
|
|
15
|
+
- hist: Histogram noise model.
|
|
16
|
+
- gmm: Gaussian mixture model noise model.F
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
NONE = "none"
|
|
20
|
+
HIST = "hist"
|
|
21
|
+
GMM = "gmm"
|
|
22
|
+
|
|
23
|
+
# TODO add validator decorator
|
|
24
|
+
@classmethod
|
|
25
|
+
def validate_noise_model_type(
|
|
26
|
+
cls, noise_model: Union[str, NoiseModel], parameters: dict
|
|
27
|
+
) -> None:
|
|
28
|
+
"""_summary_.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
noise_model : Union[str, NoiseModel]
|
|
33
|
+
_description_
|
|
34
|
+
parameters : dict
|
|
35
|
+
_description_
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
BaseModel
|
|
40
|
+
_description_
|
|
41
|
+
"""
|
|
42
|
+
if noise_model == NoiseModelType.HIST.value:
|
|
43
|
+
HistogramNoiseModel(**parameters)
|
|
44
|
+
return HistogramNoiseModel().model_dump() if not parameters else parameters
|
|
45
|
+
|
|
46
|
+
elif noise_model == NoiseModelType.GMM.value:
|
|
47
|
+
GaussianMixtureNoiseModel(**parameters)
|
|
48
|
+
return (
|
|
49
|
+
GaussianMixtureNoiseModel().model_dump()
|
|
50
|
+
if not parameters
|
|
51
|
+
else parameters
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class NoiseModel(BaseModel):
|
|
56
|
+
"""_summary_.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
BaseModel : _type_
|
|
61
|
+
_description_
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
_type_
|
|
66
|
+
_description_
|
|
67
|
+
|
|
68
|
+
Raises
|
|
69
|
+
------
|
|
70
|
+
ValueError
|
|
71
|
+
_description_
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
model_config = ConfigDict(
|
|
75
|
+
use_enum_values=True,
|
|
76
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
77
|
+
validate_assignment=True,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
model_type: NoiseModelType
|
|
81
|
+
parameters: Dict = Field(default_factory=dict, validate_default=True)
|
|
82
|
+
|
|
83
|
+
@field_validator("parameters")
|
|
84
|
+
@classmethod
|
|
85
|
+
def validate_parameters(cls, data, values) -> Dict:
|
|
86
|
+
"""_summary_.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
parameters : Dict
|
|
91
|
+
_description_
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Dict
|
|
96
|
+
_description_
|
|
97
|
+
"""
|
|
98
|
+
if values.data["model_type"] not in [NoiseModelType.GMM, NoiseModelType.HIST]:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Incorrect noise model {values.data['model_type']}."
|
|
101
|
+
f"Please refer to the documentation" # TODO add link to documentation
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
parameters = NoiseModelType.validate_noise_model_type(
|
|
105
|
+
values.data["model_type"], data
|
|
106
|
+
)
|
|
107
|
+
return parameters
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class HistogramNoiseModel(BaseModel):
|
|
111
|
+
"""
|
|
112
|
+
Histogram noise model.
|
|
113
|
+
|
|
114
|
+
Attributes
|
|
115
|
+
----------
|
|
116
|
+
min_value : float
|
|
117
|
+
Minimum value in the input.
|
|
118
|
+
max_value : float
|
|
119
|
+
Maximum value in the input.
|
|
120
|
+
bins : int
|
|
121
|
+
Number of bins of the histogram.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
min_value: float = Field(default=350.0, ge=0.0, le=65535.0)
|
|
125
|
+
max_value: float = Field(default=6500.0, ge=0.0, le=65535.0)
|
|
126
|
+
bins: int = Field(default=256, ge=1)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class GaussianMixtureNoiseModel(BaseModel):
|
|
130
|
+
"""
|
|
131
|
+
Gaussian mixture model noise model.
|
|
132
|
+
|
|
133
|
+
Attributes
|
|
134
|
+
----------
|
|
135
|
+
min_signal : float
|
|
136
|
+
Minimum signal intensity expected in the image.
|
|
137
|
+
max_signal : float
|
|
138
|
+
Maximum signal intensity expected in the image.
|
|
139
|
+
weight : array
|
|
140
|
+
A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
141
|
+
describing the noise model.
|
|
142
|
+
Each gaussian contributes three parameters (mean, standard deviation and weight),
|
|
143
|
+
hence the number of rows in `weight` are 3*n_gaussian.
|
|
144
|
+
If `weight = None`, the weight array is initialized using the `min_signal` and
|
|
145
|
+
`max_signal` parameters.
|
|
146
|
+
n_gaussian: int
|
|
147
|
+
Number of gaussians.
|
|
148
|
+
n_coeff: int
|
|
149
|
+
Number of coefficients to describe the functional relationship between gaussian
|
|
150
|
+
parameters and the signal.
|
|
151
|
+
2 implies a linear relationship, 3 implies a quadratic relationship and so on.
|
|
152
|
+
device: device
|
|
153
|
+
GPU device.
|
|
154
|
+
min_sigma: int
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
num_components: int = Field(default=3, ge=1)
|
|
158
|
+
min_value: float = Field(default=350.0, ge=0.0, le=65535.0)
|
|
159
|
+
max_value: float = Field(default=6500.0, ge=0.0, le=65535.0)
|
|
160
|
+
n_gaussian: int = Field(default=3, ge=1)
|
|
161
|
+
n_coeff: int = Field(default=2, ge=1)
|
|
162
|
+
min_sigma: int = Field(default=50, ge=1)
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import (
|
|
6
|
+
BaseModel,
|
|
7
|
+
ConfigDict,
|
|
8
|
+
Field,
|
|
9
|
+
ValidationInfo,
|
|
10
|
+
field_validator,
|
|
11
|
+
model_validator,
|
|
12
|
+
)
|
|
13
|
+
from torch import optim
|
|
14
|
+
from typing_extensions import Self
|
|
15
|
+
|
|
16
|
+
from careamics.utils.torch_utils import filter_parameters
|
|
17
|
+
|
|
18
|
+
from .support import SupportedOptimizer
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OptimizerModel(BaseModel):
|
|
22
|
+
"""
|
|
23
|
+
Torch optimizer.
|
|
24
|
+
|
|
25
|
+
Only parameters supported by the corresponding torch optimizer will be taken
|
|
26
|
+
into account. For more details, check:
|
|
27
|
+
https://pytorch.org/docs/stable/optim.html#algorithms
|
|
28
|
+
|
|
29
|
+
Note that mandatory parameters (see the specific Optimizer signature in the
|
|
30
|
+
link above) must be provided. For example, SGD requires `lr`.
|
|
31
|
+
|
|
32
|
+
Attributes
|
|
33
|
+
----------
|
|
34
|
+
name : TorchOptimizer
|
|
35
|
+
Name of the optimizer.
|
|
36
|
+
parameters : dict
|
|
37
|
+
Parameters of the optimizer (see torch documentation).
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# Pydantic class configuration
|
|
41
|
+
model_config = ConfigDict(
|
|
42
|
+
validate_assignment=True,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Mandatory field
|
|
46
|
+
name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
|
|
47
|
+
|
|
48
|
+
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
49
|
+
parameters: dict = Field(
|
|
50
|
+
default={
|
|
51
|
+
"lr": 1e-4,
|
|
52
|
+
},
|
|
53
|
+
validate_default=True,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@field_validator("parameters")
|
|
57
|
+
@classmethod
|
|
58
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
|
|
59
|
+
"""
|
|
60
|
+
Validate optimizer parameters.
|
|
61
|
+
|
|
62
|
+
This method filters out unknown parameters, given the optimizer name.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
user_params : dict
|
|
67
|
+
Parameters passed on to the torch optimizer.
|
|
68
|
+
values : ValidationInfo
|
|
69
|
+
Pydantic field validation info, used to get the optimizer name.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Dict
|
|
74
|
+
Filtered optimizer parameters.
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
ValueError
|
|
79
|
+
If the optimizer name is not specified.
|
|
80
|
+
"""
|
|
81
|
+
optimizer_name = values.data["name"]
|
|
82
|
+
|
|
83
|
+
# retrieve the corresponding optimizer class
|
|
84
|
+
optimizer_class = getattr(optim, optimizer_name)
|
|
85
|
+
|
|
86
|
+
# filter the user parameters according to the optimizer's signature
|
|
87
|
+
parameters = filter_parameters(optimizer_class, user_params)
|
|
88
|
+
|
|
89
|
+
return parameters
|
|
90
|
+
|
|
91
|
+
@model_validator(mode="after")
|
|
92
|
+
def sgd_lr_parameter(self) -> Self:
|
|
93
|
+
"""
|
|
94
|
+
Check that SGD optimizer has the mandatory `lr` parameter specified.
|
|
95
|
+
|
|
96
|
+
This is specific for PyTorch < 2.2.
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
Self
|
|
101
|
+
Validated optimizer.
|
|
102
|
+
|
|
103
|
+
Raises
|
|
104
|
+
------
|
|
105
|
+
ValueError
|
|
106
|
+
If the optimizer is SGD and the lr parameter is not specified.
|
|
107
|
+
"""
|
|
108
|
+
if self.name == SupportedOptimizer.SGD and "lr" not in self.parameters:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"SGD optimizer requires `lr` parameter, check that it has correctly "
|
|
111
|
+
"been specified in `parameters`."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class LrSchedulerModel(BaseModel):
|
|
118
|
+
"""
|
|
119
|
+
Torch learning rate scheduler.
|
|
120
|
+
|
|
121
|
+
Only parameters supported by the corresponding torch lr scheduler will be taken
|
|
122
|
+
into account. For more details, check:
|
|
123
|
+
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
|
124
|
+
|
|
125
|
+
Note that mandatory parameters (see the specific LrScheduler signature in the
|
|
126
|
+
link above) must be provided. For example, StepLR requires `step_size`.
|
|
127
|
+
|
|
128
|
+
Attributes
|
|
129
|
+
----------
|
|
130
|
+
name : TorchLRScheduler
|
|
131
|
+
Name of the learning rate scheduler.
|
|
132
|
+
parameters : dict
|
|
133
|
+
Parameters of the learning rate scheduler (see torch documentation).
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
# Pydantic class configuration
|
|
137
|
+
model_config = ConfigDict(
|
|
138
|
+
validate_assignment=True,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Mandatory field
|
|
142
|
+
name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
|
|
143
|
+
|
|
144
|
+
# Optional parameters
|
|
145
|
+
parameters: dict = Field(default={}, validate_default=True)
|
|
146
|
+
|
|
147
|
+
@field_validator("parameters")
|
|
148
|
+
@classmethod
|
|
149
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
|
|
150
|
+
"""Filter parameters based on the learning rate scheduler's signature.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
user_params : dict
|
|
155
|
+
User parameters.
|
|
156
|
+
values : ValidationInfo
|
|
157
|
+
Pydantic field validation info, used to get the scheduler name.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
Dict
|
|
162
|
+
Filtered scheduler parameters.
|
|
163
|
+
|
|
164
|
+
Raises
|
|
165
|
+
------
|
|
166
|
+
ValueError
|
|
167
|
+
If the scheduler is StepLR and the step_size parameter is not specified.
|
|
168
|
+
"""
|
|
169
|
+
# retrieve the corresponding scheduler class
|
|
170
|
+
scheduler_class = getattr(optim.lr_scheduler, values.data["name"])
|
|
171
|
+
|
|
172
|
+
# filter the user parameters according to the scheduler's signature
|
|
173
|
+
parameters = filter_parameters(scheduler_class, user_params)
|
|
174
|
+
|
|
175
|
+
if values.data["name"] == "StepLR" and "step_size" not in parameters:
|
|
176
|
+
raise ValueError(
|
|
177
|
+
"StepLR scheduler requires `step_size` parameter, check that it has "
|
|
178
|
+
"correctly been specified in `parameters`."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
return parameters
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Module containing references to the algorithm used in CAREamics."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"N2V2Ref",
|
|
5
|
+
"N2VRef",
|
|
6
|
+
"StructN2VRef",
|
|
7
|
+
"N2VDescription",
|
|
8
|
+
"N2V2Description",
|
|
9
|
+
"StructN2VDescription",
|
|
10
|
+
"StructN2V2Description",
|
|
11
|
+
"N2V",
|
|
12
|
+
"N2V2",
|
|
13
|
+
"STRUCT_N2V",
|
|
14
|
+
"STRUCT_N2V2",
|
|
15
|
+
"CUSTOM",
|
|
16
|
+
"N2N",
|
|
17
|
+
"CARE",
|
|
18
|
+
"CAREDescription",
|
|
19
|
+
"N2NDescription",
|
|
20
|
+
"CARERef",
|
|
21
|
+
"N2NRef",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
from .algorithm_descriptions import (
|
|
25
|
+
CARE,
|
|
26
|
+
CUSTOM,
|
|
27
|
+
N2N,
|
|
28
|
+
N2V,
|
|
29
|
+
N2V2,
|
|
30
|
+
STRUCT_N2V,
|
|
31
|
+
STRUCT_N2V2,
|
|
32
|
+
CAREDescription,
|
|
33
|
+
N2NDescription,
|
|
34
|
+
N2V2Description,
|
|
35
|
+
N2VDescription,
|
|
36
|
+
StructN2V2Description,
|
|
37
|
+
StructN2VDescription,
|
|
38
|
+
)
|
|
39
|
+
from .references import (
|
|
40
|
+
CARERef,
|
|
41
|
+
N2NRef,
|
|
42
|
+
N2V2Ref,
|
|
43
|
+
N2VRef,
|
|
44
|
+
StructN2VRef,
|
|
45
|
+
)
|