careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Pydantic model representing CAREamics prediction configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InferenceConfig(BaseModel):
|
|
14
|
+
"""Configuration class for the prediction model."""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
17
|
+
|
|
18
|
+
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
19
|
+
"""Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
|
|
20
|
+
|
|
21
|
+
tile_size: Optional[Union[list[int]]] = Field(
|
|
22
|
+
default=None, min_length=2, max_length=3
|
|
23
|
+
)
|
|
24
|
+
"""Tile size of prediction, only effective if `tile_overlap` is specified."""
|
|
25
|
+
|
|
26
|
+
tile_overlap: Optional[Union[list[int]]] = Field(
|
|
27
|
+
default=None, min_length=2, max_length=3
|
|
28
|
+
)
|
|
29
|
+
"""Overlap between tiles, only effective if `tile_size` is specified."""
|
|
30
|
+
|
|
31
|
+
axes: str
|
|
32
|
+
"""Data axes (TSCZYX) in the order of the input data."""
|
|
33
|
+
|
|
34
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
35
|
+
"""Mean values for each input channel."""
|
|
36
|
+
|
|
37
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
38
|
+
"""Standard deviation values for each input channel."""
|
|
39
|
+
|
|
40
|
+
# TODO only default TTAs are supported for now
|
|
41
|
+
tta_transforms: bool = Field(default=True)
|
|
42
|
+
"""Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
|
|
43
|
+
|
|
44
|
+
# Dataloader parameters
|
|
45
|
+
batch_size: int = Field(default=1, ge=1)
|
|
46
|
+
"""Batch size for prediction."""
|
|
47
|
+
|
|
48
|
+
@field_validator("tile_overlap")
|
|
49
|
+
@classmethod
|
|
50
|
+
def all_elements_non_zero_even(
|
|
51
|
+
cls, tile_overlap: Optional[list[int]]
|
|
52
|
+
) -> Optional[list[int]]:
|
|
53
|
+
"""
|
|
54
|
+
Validate tile overlap.
|
|
55
|
+
|
|
56
|
+
Overlaps must be non-zero, positive and even.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
tile_overlap : list[int] or None
|
|
61
|
+
Patch size.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
list[int] or None
|
|
66
|
+
Validated tile overlap.
|
|
67
|
+
|
|
68
|
+
Raises
|
|
69
|
+
------
|
|
70
|
+
ValueError
|
|
71
|
+
If the patch size is 0.
|
|
72
|
+
ValueError
|
|
73
|
+
If the patch size is not even.
|
|
74
|
+
"""
|
|
75
|
+
if tile_overlap is not None:
|
|
76
|
+
for dim in tile_overlap:
|
|
77
|
+
if dim < 1:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Patch size must be non-zero positive (got {dim})."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if dim % 2 != 0:
|
|
83
|
+
raise ValueError(f"Patch size must be even (got {dim}).")
|
|
84
|
+
|
|
85
|
+
return tile_overlap
|
|
86
|
+
|
|
87
|
+
@field_validator("tile_size")
|
|
88
|
+
@classmethod
|
|
89
|
+
def tile_min_8_power_of_2(
|
|
90
|
+
cls, tile_list: Optional[list[int]]
|
|
91
|
+
) -> Optional[list[int]]:
|
|
92
|
+
"""
|
|
93
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
tile_list : list of int
|
|
98
|
+
Patch size.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
list of int
|
|
103
|
+
Validated patch size.
|
|
104
|
+
|
|
105
|
+
Raises
|
|
106
|
+
------
|
|
107
|
+
ValueError
|
|
108
|
+
If the patch size if smaller than 8.
|
|
109
|
+
ValueError
|
|
110
|
+
If the patch size is not a power of 2.
|
|
111
|
+
"""
|
|
112
|
+
patch_size_ge_than_8_power_of_2(tile_list)
|
|
113
|
+
|
|
114
|
+
return tile_list
|
|
115
|
+
|
|
116
|
+
@field_validator("axes")
|
|
117
|
+
@classmethod
|
|
118
|
+
def axes_valid(cls, axes: str) -> str:
|
|
119
|
+
"""
|
|
120
|
+
Validate axes.
|
|
121
|
+
|
|
122
|
+
Axes must:
|
|
123
|
+
- be a combination of 'STCZYX'
|
|
124
|
+
- not contain duplicates
|
|
125
|
+
- contain at least 2 contiguous axes: X and Y
|
|
126
|
+
- contain at most 4 axes
|
|
127
|
+
- not contain both S and T axes
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
axes : str
|
|
132
|
+
Axes to validate.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
str
|
|
137
|
+
Validated axes.
|
|
138
|
+
|
|
139
|
+
Raises
|
|
140
|
+
------
|
|
141
|
+
ValueError
|
|
142
|
+
If axes are not valid.
|
|
143
|
+
"""
|
|
144
|
+
# Validate axes
|
|
145
|
+
check_axes_validity(axes)
|
|
146
|
+
|
|
147
|
+
return axes
|
|
148
|
+
|
|
149
|
+
@model_validator(mode="after")
|
|
150
|
+
def validate_dimensions(self: Self) -> Self:
|
|
151
|
+
"""
|
|
152
|
+
Validate 2D/3D dimensions between axes and tile size.
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
Self
|
|
157
|
+
Validated prediction model.
|
|
158
|
+
"""
|
|
159
|
+
expected_len = 3 if "Z" in self.axes else 2
|
|
160
|
+
|
|
161
|
+
if self.tile_size is not None and self.tile_overlap is not None:
|
|
162
|
+
if len(self.tile_size) != expected_len:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Tile size must have {expected_len} dimensions given axes "
|
|
165
|
+
f"{self.axes} (got {self.tile_size})."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if len(self.tile_overlap) != expected_len:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Tile overlap must have {expected_len} dimensions given axes "
|
|
171
|
+
f"{self.axes} (got {self.tile_overlap})."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if any((i >= j) for i, j in zip(self.tile_overlap, self.tile_size)):
|
|
175
|
+
raise ValueError("Tile overlap must be smaller than tile size.")
|
|
176
|
+
|
|
177
|
+
return self
|
|
178
|
+
|
|
179
|
+
@model_validator(mode="after")
|
|
180
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
181
|
+
"""
|
|
182
|
+
Check that mean and std are either both None, or both specified.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
Self
|
|
187
|
+
Validated prediction model.
|
|
188
|
+
|
|
189
|
+
Raises
|
|
190
|
+
------
|
|
191
|
+
ValueError
|
|
192
|
+
If std is not None and mean is None.
|
|
193
|
+
"""
|
|
194
|
+
# check that mean and std are either both None, or both specified
|
|
195
|
+
if not self.image_means and not self.image_stds:
|
|
196
|
+
raise ValueError("Mean and std must be specified during inference.")
|
|
197
|
+
|
|
198
|
+
if (self.image_means and not self.image_stds) or (
|
|
199
|
+
self.image_stds and not self.image_means
|
|
200
|
+
):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"Mean and std must be either both None, or both specified."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
206
|
+
len(self.image_means) != len(self.image_stds)
|
|
207
|
+
):
|
|
208
|
+
raise ValueError(
|
|
209
|
+
"Mean and std must be specified for each " "input channel."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
def _update(self, **kwargs: Any) -> None:
|
|
215
|
+
"""
|
|
216
|
+
Update multiple arguments at once.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
**kwargs : Any
|
|
221
|
+
Key-value pairs of arguments to update.
|
|
222
|
+
"""
|
|
223
|
+
self.__dict__.update(kwargs)
|
|
224
|
+
self.__class__.model_validate(self.__dict__)
|
|
225
|
+
|
|
226
|
+
def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Set 3D parameters.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
axes : str
|
|
233
|
+
Axes.
|
|
234
|
+
tile_size : list of int
|
|
235
|
+
Tile size.
|
|
236
|
+
tile_overlap : list of int
|
|
237
|
+
Tile overlap.
|
|
238
|
+
"""
|
|
239
|
+
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Likelihood model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
from careamics.models.lvae.noise_models import (
|
|
9
|
+
GaussianMixtureNoiseModel,
|
|
10
|
+
MultiChannelNoiseModel,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GaussianLikelihoodConfig(BaseModel):
|
|
17
|
+
"""Gaussian likelihood configuration."""
|
|
18
|
+
|
|
19
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
20
|
+
|
|
21
|
+
predict_logvar: Optional[Literal["pixelwise"]] = None
|
|
22
|
+
"""If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
23
|
+
is not computed."""
|
|
24
|
+
|
|
25
|
+
logvar_lowerbound: Union[float, None] = None
|
|
26
|
+
"""The lowerbound value for log-variance."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NMLikelihoodConfig(BaseModel):
|
|
30
|
+
"""Noise model likelihood configuration."""
|
|
31
|
+
|
|
32
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
33
|
+
|
|
34
|
+
data_mean: Union[torch.Tensor] = torch.zeros(1)
|
|
35
|
+
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
36
|
+
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
37
|
+
|
|
38
|
+
data_std: Union[torch.Tensor] = torch.ones(1)
|
|
39
|
+
"""The standard deviation of the data, used to unnormalize data for noise
|
|
40
|
+
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
41
|
+
|
|
42
|
+
noise_model: Union[NoiseModel, None] = None
|
|
43
|
+
"""The noise model instance used to compute the likelihood."""
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Noise models config."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
# TODO: add histogram-based noise model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GaussianMixtureNMConfig(BaseModel):
|
|
14
|
+
"""Gaussian mixture noise model."""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(
|
|
17
|
+
protected_namespaces=(),
|
|
18
|
+
validate_assignment=True,
|
|
19
|
+
arbitrary_types_allowed=True,
|
|
20
|
+
extra="allow",
|
|
21
|
+
)
|
|
22
|
+
# model type
|
|
23
|
+
model_type: Literal["GaussianMixtureNoiseModel"]
|
|
24
|
+
|
|
25
|
+
path: Optional[Union[Path, str]] = None
|
|
26
|
+
"""Path to the directory where the trained noise model (*.npz) is saved in the
|
|
27
|
+
`train` method."""
|
|
28
|
+
|
|
29
|
+
signal: Optional[Union[str, Path, np.ndarray]] = None
|
|
30
|
+
"""Path to the file containing signal or respective numpy array."""
|
|
31
|
+
|
|
32
|
+
observation: Optional[Union[str, Path, np.ndarray]] = None
|
|
33
|
+
"""Path to the file containing observation or respective numpy array."""
|
|
34
|
+
|
|
35
|
+
weight: Optional[np.ndarray] = None
|
|
36
|
+
"""A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
37
|
+
describing the GMM noise model, with each row corresponding to one
|
|
38
|
+
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
39
|
+
Specifically, rows are organized as follows:
|
|
40
|
+
- first n_gaussian rows correspond to the means
|
|
41
|
+
- next n_gaussian rows correspond to the weights
|
|
42
|
+
- last n_gaussian rows correspond to the standard deviations
|
|
43
|
+
If `weight=None`, the weight array is initialized using the `min_signal`
|
|
44
|
+
and `max_signal` parameters."""
|
|
45
|
+
|
|
46
|
+
n_gaussian: int = Field(default=1, ge=1)
|
|
47
|
+
"""Number of gaussians used for the GMM."""
|
|
48
|
+
|
|
49
|
+
n_coeff: int = Field(default=2, ge=2)
|
|
50
|
+
"""Number of coefficients to describe the functional relationship between gaussian
|
|
51
|
+
parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
|
|
52
|
+
relationship and so on."""
|
|
53
|
+
|
|
54
|
+
min_signal: float = Field(default=0.0, ge=0.0)
|
|
55
|
+
"""Minimum signal intensity expected in the image."""
|
|
56
|
+
|
|
57
|
+
max_signal: float = Field(default=1.0, ge=0.0)
|
|
58
|
+
"""Maximum signal intensity expected in the image."""
|
|
59
|
+
|
|
60
|
+
min_sigma: float = Field(default=200.0, ge=0.0) # TODO took from nb in pn2v
|
|
61
|
+
"""Minimum value of `standard deviation` allowed in the GMM.
|
|
62
|
+
All values of `standard deviation` below this are clamped to this value."""
|
|
63
|
+
|
|
64
|
+
tol: float = Field(default=1e-10)
|
|
65
|
+
"""Tolerance used in the computation of the noise model likelihood."""
|
|
66
|
+
|
|
67
|
+
@model_validator(mode="after")
|
|
68
|
+
def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
69
|
+
"""Validate paths provided in the config.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Self
|
|
74
|
+
Returns itself.
|
|
75
|
+
"""
|
|
76
|
+
if self.path and (self.signal is not None or self.observation is not None):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Either only 'path' to pre-trained noise model should be"
|
|
79
|
+
"provided or only signal and observation in form of paths"
|
|
80
|
+
"or numpy arrays."
|
|
81
|
+
)
|
|
82
|
+
if not self.path and (self.signal is None or self.observation is None):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"Either only 'path' to pre-trained noise model should be"
|
|
85
|
+
"provided or only signal and observation in form of paths"
|
|
86
|
+
"or numpy arrays."
|
|
87
|
+
)
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# The noise model is given by a set of GMMs, one for each target
|
|
92
|
+
# e.g., 2 target channels, 2 noise models
|
|
93
|
+
class MultiChannelNMConfig(BaseModel):
|
|
94
|
+
"""Noise Model config aggregating noise models for single output channels."""
|
|
95
|
+
|
|
96
|
+
# TODO: check that this model config is OK
|
|
97
|
+
model_config = ConfigDict(
|
|
98
|
+
validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
|
|
99
|
+
)
|
|
100
|
+
noise_models: list[GaussianMixtureNMConfig]
|
|
101
|
+
"""List of noise models, one for each target channel."""
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Optimizers and schedulers Pydantic models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import (
|
|
8
|
+
BaseModel,
|
|
9
|
+
ConfigDict,
|
|
10
|
+
Field,
|
|
11
|
+
ValidationInfo,
|
|
12
|
+
field_validator,
|
|
13
|
+
model_validator,
|
|
14
|
+
)
|
|
15
|
+
from torch import optim
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
from careamics.utils.torch_utils import filter_parameters
|
|
19
|
+
|
|
20
|
+
from .support import SupportedOptimizer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OptimizerModel(BaseModel):
|
|
24
|
+
"""Torch optimizer Pydantic model.
|
|
25
|
+
|
|
26
|
+
Only parameters supported by the corresponding torch optimizer will be taken
|
|
27
|
+
into account. For more details, check:
|
|
28
|
+
https://pytorch.org/docs/stable/optim.html#algorithms
|
|
29
|
+
|
|
30
|
+
Note that mandatory parameters (see the specific Optimizer signature in the
|
|
31
|
+
link above) must be provided. For example, SGD requires `lr`.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
name : {"Adam", "SGD"}
|
|
36
|
+
Name of the optimizer.
|
|
37
|
+
parameters : dict
|
|
38
|
+
Parameters of the optimizer (see torch documentation).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# Pydantic class configuration
|
|
42
|
+
model_config = ConfigDict(
|
|
43
|
+
validate_assignment=True,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Mandatory field
|
|
47
|
+
name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
|
|
48
|
+
"""Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
|
|
49
|
+
|
|
50
|
+
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
51
|
+
parameters: dict = Field(
|
|
52
|
+
default={
|
|
53
|
+
"lr": 1e-4,
|
|
54
|
+
},
|
|
55
|
+
validate_default=True,
|
|
56
|
+
)
|
|
57
|
+
"""Parameters of the optimizer, see PyTorch documentation for more details."""
|
|
58
|
+
|
|
59
|
+
@field_validator("parameters")
|
|
60
|
+
@classmethod
|
|
61
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
62
|
+
"""
|
|
63
|
+
Validate optimizer parameters.
|
|
64
|
+
|
|
65
|
+
This method filters out unknown parameters, given the optimizer name.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
user_params : dict
|
|
70
|
+
Parameters passed on to the torch optimizer.
|
|
71
|
+
values : ValidationInfo
|
|
72
|
+
Pydantic field validation info, used to get the optimizer name.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
dict
|
|
77
|
+
Filtered optimizer parameters.
|
|
78
|
+
|
|
79
|
+
Raises
|
|
80
|
+
------
|
|
81
|
+
ValueError
|
|
82
|
+
If the optimizer name is not specified.
|
|
83
|
+
"""
|
|
84
|
+
optimizer_name = values.data["name"]
|
|
85
|
+
|
|
86
|
+
# retrieve the corresponding optimizer class
|
|
87
|
+
optimizer_class = getattr(optim, optimizer_name)
|
|
88
|
+
|
|
89
|
+
# filter the user parameters according to the optimizer's signature
|
|
90
|
+
parameters = filter_parameters(optimizer_class, user_params)
|
|
91
|
+
|
|
92
|
+
return parameters
|
|
93
|
+
|
|
94
|
+
@model_validator(mode="after")
|
|
95
|
+
def sgd_lr_parameter(self) -> Self:
|
|
96
|
+
"""
|
|
97
|
+
Check that SGD optimizer has the mandatory `lr` parameter specified.
|
|
98
|
+
|
|
99
|
+
This is specific for PyTorch < 2.2.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
Self
|
|
104
|
+
Validated optimizer.
|
|
105
|
+
|
|
106
|
+
Raises
|
|
107
|
+
------
|
|
108
|
+
ValueError
|
|
109
|
+
If the optimizer is SGD and the lr parameter is not specified.
|
|
110
|
+
"""
|
|
111
|
+
if self.name == SupportedOptimizer.SGD and "lr" not in self.parameters:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"SGD optimizer requires `lr` parameter, check that it has correctly "
|
|
114
|
+
"been specified in `parameters`."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class LrSchedulerModel(BaseModel):
|
|
121
|
+
"""Torch learning rate scheduler Pydantic model.
|
|
122
|
+
|
|
123
|
+
Only parameters supported by the corresponding torch lr scheduler will be taken
|
|
124
|
+
into account. For more details, check:
|
|
125
|
+
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
|
126
|
+
|
|
127
|
+
Note that mandatory parameters (see the specific LrScheduler signature in the
|
|
128
|
+
link above) must be provided. For example, StepLR requires `step_size`.
|
|
129
|
+
|
|
130
|
+
Attributes
|
|
131
|
+
----------
|
|
132
|
+
name : {"ReduceLROnPlateau", "StepLR"}
|
|
133
|
+
Name of the learning rate scheduler.
|
|
134
|
+
parameters : dict
|
|
135
|
+
Parameters of the learning rate scheduler (see torch documentation).
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
# Pydantic class configuration
|
|
139
|
+
model_config = ConfigDict(
|
|
140
|
+
validate_assignment=True,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Mandatory field
|
|
144
|
+
name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
|
|
145
|
+
"""Name of the learning rate scheduler, supported schedulers are defined in
|
|
146
|
+
SupportedScheduler."""
|
|
147
|
+
|
|
148
|
+
# Optional parameters
|
|
149
|
+
parameters: dict = Field(default={}, validate_default=True)
|
|
150
|
+
"""Parameters of the learning rate scheduler, see PyTorch documentation for more
|
|
151
|
+
details."""
|
|
152
|
+
|
|
153
|
+
@field_validator("parameters")
|
|
154
|
+
@classmethod
|
|
155
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
156
|
+
"""Filter parameters based on the learning rate scheduler's signature.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
user_params : dict
|
|
161
|
+
User parameters.
|
|
162
|
+
values : ValidationInfo
|
|
163
|
+
Pydantic field validation info, used to get the scheduler name.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
dict
|
|
168
|
+
Filtered scheduler parameters.
|
|
169
|
+
|
|
170
|
+
Raises
|
|
171
|
+
------
|
|
172
|
+
ValueError
|
|
173
|
+
If the scheduler is StepLR and the step_size parameter is not specified.
|
|
174
|
+
"""
|
|
175
|
+
# retrieve the corresponding scheduler class
|
|
176
|
+
scheduler_class = getattr(optim.lr_scheduler, values.data["name"])
|
|
177
|
+
|
|
178
|
+
# filter the user parameters according to the scheduler's signature
|
|
179
|
+
parameters = filter_parameters(scheduler_class, user_params)
|
|
180
|
+
|
|
181
|
+
if values.data["name"] == "StepLR" and "step_size" not in parameters:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
"StepLR scheduler requires `step_size` parameter, check that it has "
|
|
184
|
+
"correctly been specified in `parameters`."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return parameters
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Module containing references to the algorithm used in CAREamics."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"N2V2Ref",
|
|
5
|
+
"N2VRef",
|
|
6
|
+
"StructN2VRef",
|
|
7
|
+
"N2VDescription",
|
|
8
|
+
"N2V2Description",
|
|
9
|
+
"StructN2VDescription",
|
|
10
|
+
"StructN2V2Description",
|
|
11
|
+
"N2V",
|
|
12
|
+
"N2V2",
|
|
13
|
+
"STRUCT_N2V",
|
|
14
|
+
"STRUCT_N2V2",
|
|
15
|
+
"CUSTOM",
|
|
16
|
+
"N2N",
|
|
17
|
+
"CARE",
|
|
18
|
+
"CAREDescription",
|
|
19
|
+
"N2NDescription",
|
|
20
|
+
"CARERef",
|
|
21
|
+
"N2NRef",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
from .algorithm_descriptions import (
|
|
25
|
+
CARE,
|
|
26
|
+
CUSTOM,
|
|
27
|
+
N2N,
|
|
28
|
+
N2V,
|
|
29
|
+
N2V2,
|
|
30
|
+
STRUCT_N2V,
|
|
31
|
+
STRUCT_N2V2,
|
|
32
|
+
CAREDescription,
|
|
33
|
+
N2NDescription,
|
|
34
|
+
N2V2Description,
|
|
35
|
+
N2VDescription,
|
|
36
|
+
StructN2V2Description,
|
|
37
|
+
StructN2VDescription,
|
|
38
|
+
)
|
|
39
|
+
from .references import (
|
|
40
|
+
CARERef,
|
|
41
|
+
N2NRef,
|
|
42
|
+
N2V2Ref,
|
|
43
|
+
N2VRef,
|
|
44
|
+
StructN2VRef,
|
|
45
|
+
)
|