careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +726 -0
- careamics/config/__init__.py +35 -0
- careamics/config/algorithm_model.py +162 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +159 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/architectures/vae_model.py +42 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +575 -0
- careamics/config/configuration_model.py +600 -0
- careamics/config/data_model.py +502 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +26 -0
- careamics/config/support/supported_algorithms.py +20 -0
- careamics/config/support/supported_architectures.py +20 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +27 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +17 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +276 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +5 -0
- careamics/losses/loss_factory.py +49 -0
- careamics/losses/losses.py +98 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +233 -0
- careamics/model_io/model_io_utils.py +83 -0
- careamics/models/__init__.py +7 -0
- careamics/models/activation.py +37 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +52 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +98 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +115 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.2.dist-info/METADATA +78 -0
- careamics-0.0.2.dist-info/RECORD +140 -0
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Pydantic model representing CAREamics prediction configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
|
+
from typing_extensions import Self
|
|
9
|
+
|
|
10
|
+
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InferenceConfig(BaseModel):
|
|
14
|
+
"""Configuration class for the prediction model."""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
17
|
+
|
|
18
|
+
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
19
|
+
"""Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
|
|
20
|
+
|
|
21
|
+
tile_size: Optional[Union[list[int]]] = Field(
|
|
22
|
+
default=None, min_length=2, max_length=3
|
|
23
|
+
)
|
|
24
|
+
"""Tile size of prediction, only effective if `tile_overlap` is specified."""
|
|
25
|
+
|
|
26
|
+
tile_overlap: Optional[Union[list[int]]] = Field(
|
|
27
|
+
default=None, min_length=2, max_length=3
|
|
28
|
+
)
|
|
29
|
+
"""Overlap between tiles, only effective if `tile_size` is specified."""
|
|
30
|
+
|
|
31
|
+
axes: str
|
|
32
|
+
"""Data axes (TSCZYX) in the order of the input data."""
|
|
33
|
+
|
|
34
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
35
|
+
"""Mean values for each input channel."""
|
|
36
|
+
|
|
37
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
38
|
+
"""Standard deviation values for each input channel."""
|
|
39
|
+
|
|
40
|
+
# TODO only default TTAs are supported for now
|
|
41
|
+
tta_transforms: bool = Field(default=True)
|
|
42
|
+
"""Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
|
|
43
|
+
|
|
44
|
+
# Dataloader parameters
|
|
45
|
+
batch_size: int = Field(default=1, ge=1)
|
|
46
|
+
"""Batch size for prediction."""
|
|
47
|
+
|
|
48
|
+
@field_validator("tile_overlap")
|
|
49
|
+
@classmethod
|
|
50
|
+
def all_elements_non_zero_even(
|
|
51
|
+
cls, tile_overlap: Optional[list[int]]
|
|
52
|
+
) -> Optional[list[int]]:
|
|
53
|
+
"""
|
|
54
|
+
Validate tile overlap.
|
|
55
|
+
|
|
56
|
+
Overlaps must be non-zero, positive and even.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
tile_overlap : list[int] or None
|
|
61
|
+
Patch size.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
list[int] or None
|
|
66
|
+
Validated tile overlap.
|
|
67
|
+
|
|
68
|
+
Raises
|
|
69
|
+
------
|
|
70
|
+
ValueError
|
|
71
|
+
If the patch size is 0.
|
|
72
|
+
ValueError
|
|
73
|
+
If the patch size is not even.
|
|
74
|
+
"""
|
|
75
|
+
if tile_overlap is not None:
|
|
76
|
+
for dim in tile_overlap:
|
|
77
|
+
if dim < 1:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Patch size must be non-zero positive (got {dim})."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if dim % 2 != 0:
|
|
83
|
+
raise ValueError(f"Patch size must be even (got {dim}).")
|
|
84
|
+
|
|
85
|
+
return tile_overlap
|
|
86
|
+
|
|
87
|
+
@field_validator("tile_size")
|
|
88
|
+
@classmethod
|
|
89
|
+
def tile_min_8_power_of_2(
|
|
90
|
+
cls, tile_list: Optional[list[int]]
|
|
91
|
+
) -> Optional[list[int]]:
|
|
92
|
+
"""
|
|
93
|
+
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
tile_list : list of int
|
|
98
|
+
Patch size.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
list of int
|
|
103
|
+
Validated patch size.
|
|
104
|
+
|
|
105
|
+
Raises
|
|
106
|
+
------
|
|
107
|
+
ValueError
|
|
108
|
+
If the patch size if smaller than 8.
|
|
109
|
+
ValueError
|
|
110
|
+
If the patch size is not a power of 2.
|
|
111
|
+
"""
|
|
112
|
+
patch_size_ge_than_8_power_of_2(tile_list)
|
|
113
|
+
|
|
114
|
+
return tile_list
|
|
115
|
+
|
|
116
|
+
@field_validator("axes")
|
|
117
|
+
@classmethod
|
|
118
|
+
def axes_valid(cls, axes: str) -> str:
|
|
119
|
+
"""
|
|
120
|
+
Validate axes.
|
|
121
|
+
|
|
122
|
+
Axes must:
|
|
123
|
+
- be a combination of 'STCZYX'
|
|
124
|
+
- not contain duplicates
|
|
125
|
+
- contain at least 2 contiguous axes: X and Y
|
|
126
|
+
- contain at most 4 axes
|
|
127
|
+
- not contain both S and T axes
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
axes : str
|
|
132
|
+
Axes to validate.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
str
|
|
137
|
+
Validated axes.
|
|
138
|
+
|
|
139
|
+
Raises
|
|
140
|
+
------
|
|
141
|
+
ValueError
|
|
142
|
+
If axes are not valid.
|
|
143
|
+
"""
|
|
144
|
+
# Validate axes
|
|
145
|
+
check_axes_validity(axes)
|
|
146
|
+
|
|
147
|
+
return axes
|
|
148
|
+
|
|
149
|
+
@model_validator(mode="after")
|
|
150
|
+
def validate_dimensions(self: Self) -> Self:
|
|
151
|
+
"""
|
|
152
|
+
Validate 2D/3D dimensions between axes and tile size.
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
Self
|
|
157
|
+
Validated prediction model.
|
|
158
|
+
"""
|
|
159
|
+
expected_len = 3 if "Z" in self.axes else 2
|
|
160
|
+
|
|
161
|
+
if self.tile_size is not None and self.tile_overlap is not None:
|
|
162
|
+
if len(self.tile_size) != expected_len:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Tile size must have {expected_len} dimensions given axes "
|
|
165
|
+
f"{self.axes} (got {self.tile_size})."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if len(self.tile_overlap) != expected_len:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Tile overlap must have {expected_len} dimensions given axes "
|
|
171
|
+
f"{self.axes} (got {self.tile_overlap})."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if any((i >= j) for i, j in zip(self.tile_overlap, self.tile_size)):
|
|
175
|
+
raise ValueError("Tile overlap must be smaller than tile size.")
|
|
176
|
+
|
|
177
|
+
return self
|
|
178
|
+
|
|
179
|
+
@model_validator(mode="after")
|
|
180
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
181
|
+
"""
|
|
182
|
+
Check that mean and std are either both None, or both specified.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
Self
|
|
187
|
+
Validated prediction model.
|
|
188
|
+
|
|
189
|
+
Raises
|
|
190
|
+
------
|
|
191
|
+
ValueError
|
|
192
|
+
If std is not None and mean is None.
|
|
193
|
+
"""
|
|
194
|
+
# check that mean and std are either both None, or both specified
|
|
195
|
+
if not self.image_means and not self.image_stds:
|
|
196
|
+
raise ValueError("Mean and std must be specified during inference.")
|
|
197
|
+
|
|
198
|
+
if (self.image_means and not self.image_stds) or (
|
|
199
|
+
self.image_stds and not self.image_means
|
|
200
|
+
):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"Mean and std must be either both None, or both specified."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
206
|
+
len(self.image_means) != len(self.image_stds)
|
|
207
|
+
):
|
|
208
|
+
raise ValueError(
|
|
209
|
+
"Mean and std must be specified for each " "input channel."
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
def _update(self, **kwargs: Any) -> None:
|
|
215
|
+
"""
|
|
216
|
+
Update multiple arguments at once.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
**kwargs : Any
|
|
221
|
+
Key-value pairs of arguments to update.
|
|
222
|
+
"""
|
|
223
|
+
self.__dict__.update(kwargs)
|
|
224
|
+
self.__class__.model_validate(self.__dict__)
|
|
225
|
+
|
|
226
|
+
def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
|
|
227
|
+
"""
|
|
228
|
+
Set 3D parameters.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
axes : str
|
|
233
|
+
Axes.
|
|
234
|
+
tile_size : list of int
|
|
235
|
+
Tile size.
|
|
236
|
+
tile_overlap : list of int
|
|
237
|
+
Tile overlap.
|
|
238
|
+
"""
|
|
239
|
+
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""Optimizers and schedulers Pydantic models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import (
|
|
8
|
+
BaseModel,
|
|
9
|
+
ConfigDict,
|
|
10
|
+
Field,
|
|
11
|
+
ValidationInfo,
|
|
12
|
+
field_validator,
|
|
13
|
+
model_validator,
|
|
14
|
+
)
|
|
15
|
+
from torch import optim
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
from careamics.utils.torch_utils import filter_parameters
|
|
19
|
+
|
|
20
|
+
from .support import SupportedOptimizer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OptimizerModel(BaseModel):
|
|
24
|
+
"""Torch optimizer Pydantic model.
|
|
25
|
+
|
|
26
|
+
Only parameters supported by the corresponding torch optimizer will be taken
|
|
27
|
+
into account. For more details, check:
|
|
28
|
+
https://pytorch.org/docs/stable/optim.html#algorithms
|
|
29
|
+
|
|
30
|
+
Note that mandatory parameters (see the specific Optimizer signature in the
|
|
31
|
+
link above) must be provided. For example, SGD requires `lr`.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
name : {"Adam", "SGD"}
|
|
36
|
+
Name of the optimizer.
|
|
37
|
+
parameters : dict
|
|
38
|
+
Parameters of the optimizer (see torch documentation).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# Pydantic class configuration
|
|
42
|
+
model_config = ConfigDict(
|
|
43
|
+
validate_assignment=True,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Mandatory field
|
|
47
|
+
name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
|
|
48
|
+
"""Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
|
|
49
|
+
|
|
50
|
+
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
51
|
+
parameters: dict = Field(
|
|
52
|
+
default={
|
|
53
|
+
"lr": 1e-4,
|
|
54
|
+
},
|
|
55
|
+
validate_default=True,
|
|
56
|
+
)
|
|
57
|
+
"""Parameters of the optimizer, see PyTorch documentation for more details."""
|
|
58
|
+
|
|
59
|
+
@field_validator("parameters")
|
|
60
|
+
@classmethod
|
|
61
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
62
|
+
"""
|
|
63
|
+
Validate optimizer parameters.
|
|
64
|
+
|
|
65
|
+
This method filters out unknown parameters, given the optimizer name.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
user_params : dict
|
|
70
|
+
Parameters passed on to the torch optimizer.
|
|
71
|
+
values : ValidationInfo
|
|
72
|
+
Pydantic field validation info, used to get the optimizer name.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
dict
|
|
77
|
+
Filtered optimizer parameters.
|
|
78
|
+
|
|
79
|
+
Raises
|
|
80
|
+
------
|
|
81
|
+
ValueError
|
|
82
|
+
If the optimizer name is not specified.
|
|
83
|
+
"""
|
|
84
|
+
optimizer_name = values.data["name"]
|
|
85
|
+
|
|
86
|
+
# retrieve the corresponding optimizer class
|
|
87
|
+
optimizer_class = getattr(optim, optimizer_name)
|
|
88
|
+
|
|
89
|
+
# filter the user parameters according to the optimizer's signature
|
|
90
|
+
parameters = filter_parameters(optimizer_class, user_params)
|
|
91
|
+
|
|
92
|
+
return parameters
|
|
93
|
+
|
|
94
|
+
@model_validator(mode="after")
|
|
95
|
+
def sgd_lr_parameter(self) -> Self:
|
|
96
|
+
"""
|
|
97
|
+
Check that SGD optimizer has the mandatory `lr` parameter specified.
|
|
98
|
+
|
|
99
|
+
This is specific for PyTorch < 2.2.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
Self
|
|
104
|
+
Validated optimizer.
|
|
105
|
+
|
|
106
|
+
Raises
|
|
107
|
+
------
|
|
108
|
+
ValueError
|
|
109
|
+
If the optimizer is SGD and the lr parameter is not specified.
|
|
110
|
+
"""
|
|
111
|
+
if self.name == SupportedOptimizer.SGD and "lr" not in self.parameters:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"SGD optimizer requires `lr` parameter, check that it has correctly "
|
|
114
|
+
"been specified in `parameters`."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return self
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class LrSchedulerModel(BaseModel):
|
|
121
|
+
"""Torch learning rate scheduler Pydantic model.
|
|
122
|
+
|
|
123
|
+
Only parameters supported by the corresponding torch lr scheduler will be taken
|
|
124
|
+
into account. For more details, check:
|
|
125
|
+
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
|
126
|
+
|
|
127
|
+
Note that mandatory parameters (see the specific LrScheduler signature in the
|
|
128
|
+
link above) must be provided. For example, StepLR requires `step_size`.
|
|
129
|
+
|
|
130
|
+
Attributes
|
|
131
|
+
----------
|
|
132
|
+
name : {"ReduceLROnPlateau", "StepLR"}
|
|
133
|
+
Name of the learning rate scheduler.
|
|
134
|
+
parameters : dict
|
|
135
|
+
Parameters of the learning rate scheduler (see torch documentation).
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
# Pydantic class configuration
|
|
139
|
+
model_config = ConfigDict(
|
|
140
|
+
validate_assignment=True,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Mandatory field
|
|
144
|
+
name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
|
|
145
|
+
"""Name of the learning rate scheduler, supported schedulers are defined in
|
|
146
|
+
SupportedScheduler."""
|
|
147
|
+
|
|
148
|
+
# Optional parameters
|
|
149
|
+
parameters: dict = Field(default={}, validate_default=True)
|
|
150
|
+
"""Parameters of the learning rate scheduler, see PyTorch documentation for more
|
|
151
|
+
details."""
|
|
152
|
+
|
|
153
|
+
@field_validator("parameters")
|
|
154
|
+
@classmethod
|
|
155
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
156
|
+
"""Filter parameters based on the learning rate scheduler's signature.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
user_params : dict
|
|
161
|
+
User parameters.
|
|
162
|
+
values : ValidationInfo
|
|
163
|
+
Pydantic field validation info, used to get the scheduler name.
|
|
164
|
+
|
|
165
|
+
Returns
|
|
166
|
+
-------
|
|
167
|
+
dict
|
|
168
|
+
Filtered scheduler parameters.
|
|
169
|
+
|
|
170
|
+
Raises
|
|
171
|
+
------
|
|
172
|
+
ValueError
|
|
173
|
+
If the scheduler is StepLR and the step_size parameter is not specified.
|
|
174
|
+
"""
|
|
175
|
+
# retrieve the corresponding scheduler class
|
|
176
|
+
scheduler_class = getattr(optim.lr_scheduler, values.data["name"])
|
|
177
|
+
|
|
178
|
+
# filter the user parameters according to the scheduler's signature
|
|
179
|
+
parameters = filter_parameters(scheduler_class, user_params)
|
|
180
|
+
|
|
181
|
+
if values.data["name"] == "StepLR" and "step_size" not in parameters:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
"StepLR scheduler requires `step_size` parameter, check that it has "
|
|
184
|
+
"correctly been specified in `parameters`."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return parameters
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Module containing references to the algorithm used in CAREamics."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"N2V2Ref",
|
|
5
|
+
"N2VRef",
|
|
6
|
+
"StructN2VRef",
|
|
7
|
+
"N2VDescription",
|
|
8
|
+
"N2V2Description",
|
|
9
|
+
"StructN2VDescription",
|
|
10
|
+
"StructN2V2Description",
|
|
11
|
+
"N2V",
|
|
12
|
+
"N2V2",
|
|
13
|
+
"STRUCT_N2V",
|
|
14
|
+
"STRUCT_N2V2",
|
|
15
|
+
"CUSTOM",
|
|
16
|
+
"N2N",
|
|
17
|
+
"CARE",
|
|
18
|
+
"CAREDescription",
|
|
19
|
+
"N2NDescription",
|
|
20
|
+
"CARERef",
|
|
21
|
+
"N2NRef",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
from .algorithm_descriptions import (
|
|
25
|
+
CARE,
|
|
26
|
+
CUSTOM,
|
|
27
|
+
N2N,
|
|
28
|
+
N2V,
|
|
29
|
+
N2V2,
|
|
30
|
+
STRUCT_N2V,
|
|
31
|
+
STRUCT_N2V2,
|
|
32
|
+
CAREDescription,
|
|
33
|
+
N2NDescription,
|
|
34
|
+
N2V2Description,
|
|
35
|
+
N2VDescription,
|
|
36
|
+
StructN2V2Description,
|
|
37
|
+
StructN2VDescription,
|
|
38
|
+
)
|
|
39
|
+
from .references import (
|
|
40
|
+
CARERef,
|
|
41
|
+
N2NRef,
|
|
42
|
+
N2V2Ref,
|
|
43
|
+
N2VRef,
|
|
44
|
+
StructN2VRef,
|
|
45
|
+
)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Descriptions of the algorithms used in CAREmics."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
CUSTOM = "Custom"
|
|
6
|
+
N2V = "Noise2Void"
|
|
7
|
+
N2V2 = "N2V2"
|
|
8
|
+
STRUCT_N2V = "StructN2V"
|
|
9
|
+
STRUCT_N2V2 = "StructN2V2"
|
|
10
|
+
N2N = "Noise2Noise"
|
|
11
|
+
CARE = "CARE"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
N2V_DESCRIPTION = (
|
|
15
|
+
"Noise2Void is a UNet-based self-supervised algorithm that "
|
|
16
|
+
"uses blind-spot training to denoise images. In short, in every "
|
|
17
|
+
"patches during training, random pixels are selected and their "
|
|
18
|
+
"value replaced by a neighboring pixel value. The network is then "
|
|
19
|
+
"trained to predict the original pixel value. The algorithm "
|
|
20
|
+
"relies on the continuity of the signal (neighboring pixels have "
|
|
21
|
+
"similar values) and the pixel-wise independence of the noise "
|
|
22
|
+
"(the noise in a pixel is not correlated with the noise in "
|
|
23
|
+
"neighboring pixels)."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AlgorithmDescription(BaseModel):
|
|
28
|
+
"""Description of an algorithm.
|
|
29
|
+
|
|
30
|
+
Attributes
|
|
31
|
+
----------
|
|
32
|
+
description : str
|
|
33
|
+
Description of the algorithm.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
description: str
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class N2VDescription(AlgorithmDescription):
|
|
40
|
+
"""Description of Noise2Void.
|
|
41
|
+
|
|
42
|
+
Attributes
|
|
43
|
+
----------
|
|
44
|
+
description : str
|
|
45
|
+
Description of Noise2Void.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
description: str = N2V_DESCRIPTION
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class N2V2Description(AlgorithmDescription):
|
|
52
|
+
"""Description of N2V2.
|
|
53
|
+
|
|
54
|
+
Attributes
|
|
55
|
+
----------
|
|
56
|
+
description : str
|
|
57
|
+
Description of N2V2.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
description: str = (
|
|
61
|
+
"N2V2 is a variant of Noise2Void. "
|
|
62
|
+
+ N2V_DESCRIPTION
|
|
63
|
+
+ "\nN2V2 introduces blur-pool layers and removed skip "
|
|
64
|
+
"connections in the UNet architecture to remove checkboard "
|
|
65
|
+
"artefacts, a common artefacts ocurring in Noise2Void."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class StructN2VDescription(AlgorithmDescription):
|
|
70
|
+
"""Description of StructN2V.
|
|
71
|
+
|
|
72
|
+
Attributes
|
|
73
|
+
----------
|
|
74
|
+
description : str
|
|
75
|
+
Description of StructN2V.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
description: str = (
|
|
79
|
+
"StructN2V is a variant of Noise2Void. "
|
|
80
|
+
+ N2V_DESCRIPTION
|
|
81
|
+
+ "\nStructN2V uses a linear mask (horizontal or vertical) to replace "
|
|
82
|
+
"the pixel values of neighbors of the masked pixels by a random "
|
|
83
|
+
"value. Such masking allows removing 1D structured noise from the "
|
|
84
|
+
"the images, the main failure case of the original N2V."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class StructN2V2Description(AlgorithmDescription):
|
|
89
|
+
"""Description of StructN2V2.
|
|
90
|
+
|
|
91
|
+
Attributes
|
|
92
|
+
----------
|
|
93
|
+
description : str
|
|
94
|
+
Description of StructN2V2.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
description: str = (
|
|
98
|
+
"StructN2V2 is a a variant of Noise2Void that uses both "
|
|
99
|
+
"structN2V and N2V2. "
|
|
100
|
+
+ N2V_DESCRIPTION
|
|
101
|
+
+ "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace "
|
|
102
|
+
"the pixel values of neighbors of the masked pixels by a random "
|
|
103
|
+
"value. Such masking allows removing 1D structured noise from the "
|
|
104
|
+
"the images, the main failure case of the original N2V."
|
|
105
|
+
"\nN2V2 introduces blur-pool layers and removed skip connections in "
|
|
106
|
+
"the UNet architecture to remove checkboard artefacts, a common "
|
|
107
|
+
"artefacts ocurring in Noise2Void."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class N2NDescription(AlgorithmDescription):
|
|
112
|
+
"""Description of Noise2Noise.
|
|
113
|
+
|
|
114
|
+
Attributes
|
|
115
|
+
----------
|
|
116
|
+
description : str
|
|
117
|
+
Description of Noise2Noise.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
description: str = "Noise2Noise" # TODO
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class CAREDescription(AlgorithmDescription):
|
|
124
|
+
"""Description of CARE.
|
|
125
|
+
|
|
126
|
+
Attributes
|
|
127
|
+
----------
|
|
128
|
+
description : str
|
|
129
|
+
Description of CARE.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
description: str = "CARE" # TODO
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""References for the CAREamics algorithms."""
|
|
2
|
+
|
|
3
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
4
|
+
|
|
5
|
+
N2VRef = CiteEntry(
|
|
6
|
+
text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning '
|
|
7
|
+
'denoising from single noisy images". In Proceedings of the IEEE/CVF '
|
|
8
|
+
"conference on computer vision and pattern recognition (pp. 2129-2137).",
|
|
9
|
+
doi="10.1109/cvpr.2019.00223",
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
N2V2Ref = CiteEntry(
|
|
13
|
+
text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., "
|
|
14
|
+
'2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified '
|
|
15
|
+
'sampling strategies and a tweaked network architecture". In European '
|
|
16
|
+
"Conference on Computer Vision (pp. 503-518).",
|
|
17
|
+
doi="10.1007/978-3-031-25069-9_33",
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
StructN2VRef = CiteEntry(
|
|
21
|
+
text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020."
|
|
22
|
+
'"Removing structured noise with self-supervised blind-spot '
|
|
23
|
+
'networks". In 2020 IEEE 17th International Symposium on Biomedical '
|
|
24
|
+
"Imaging (ISBI) (pp. 159-163).",
|
|
25
|
+
doi="10.1109/isbi45749.2020.9098336",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
N2NRef = CiteEntry(
|
|
29
|
+
text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., "
|
|
30
|
+
'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration '
|
|
31
|
+
'without clean data". arXiv preprint arXiv:1803.04189.',
|
|
32
|
+
doi="10.48550/arXiv.1803.04189",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
CARERef = CiteEntry(
|
|
36
|
+
text='Weigert, Martin, et al. "Content-aware image restoration: pushing the '
|
|
37
|
+
'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.',
|
|
38
|
+
doi="10.1038/s41592-018-0216-7",
|
|
39
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Supported configuration options.
|
|
2
|
+
|
|
3
|
+
Used throughout the code to ensure consistency. These should be kept in sync with the
|
|
4
|
+
corresponding configuration options in the Pydantic models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"SupportedArchitecture",
|
|
9
|
+
"SupportedActivation",
|
|
10
|
+
"SupportedOptimizer",
|
|
11
|
+
"SupportedScheduler",
|
|
12
|
+
"SupportedLoss",
|
|
13
|
+
"SupportedAlgorithm",
|
|
14
|
+
"SupportedPixelManipulation",
|
|
15
|
+
"SupportedTransform",
|
|
16
|
+
"SupportedData",
|
|
17
|
+
"SupportedStructAxis",
|
|
18
|
+
"SupportedLogger",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
from .supported_activations import SupportedActivation
|
|
23
|
+
from .supported_algorithms import SupportedAlgorithm
|
|
24
|
+
from .supported_architectures import SupportedArchitecture
|
|
25
|
+
from .supported_data import SupportedData
|
|
26
|
+
from .supported_loggers import SupportedLogger
|
|
27
|
+
from .supported_losses import SupportedLoss
|
|
28
|
+
from .supported_optimizers import SupportedOptimizer, SupportedScheduler
|
|
29
|
+
from .supported_pixel_manipulations import SupportedPixelManipulation
|
|
30
|
+
from .supported_struct_axis import SupportedStructAxis
|
|
31
|
+
from .supported_transforms import SupportedTransform
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Activations supported by CAREamics."""
|
|
2
|
+
|
|
3
|
+
from careamics.utils import BaseEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportedActivation(str, BaseEnum):
|
|
7
|
+
"""Supported activation functions.
|
|
8
|
+
|
|
9
|
+
- None, no activation will be used.
|
|
10
|
+
- Sigmoid
|
|
11
|
+
- Softmax
|
|
12
|
+
- Tanh
|
|
13
|
+
- ReLU
|
|
14
|
+
- LeakyReLU
|
|
15
|
+
|
|
16
|
+
All activations are defined in PyTorch.
|
|
17
|
+
|
|
18
|
+
See: https://pytorch.org/docs/stable/nn.html#loss-functions
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
NONE = "None"
|
|
22
|
+
SIGMOID = "Sigmoid"
|
|
23
|
+
SOFTMAX = "Softmax"
|
|
24
|
+
TANH = "Tanh"
|
|
25
|
+
RELU = "ReLU"
|
|
26
|
+
LEAKYRELU = "LeakyReLU"
|