careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
"""Data configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pprint import pformat
|
|
6
|
+
from typing import Any, Literal, Optional, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
from pydantic import (
|
|
11
|
+
BaseModel,
|
|
12
|
+
ConfigDict,
|
|
13
|
+
Discriminator,
|
|
14
|
+
Field,
|
|
15
|
+
PlainSerializer,
|
|
16
|
+
field_validator,
|
|
17
|
+
model_validator,
|
|
18
|
+
)
|
|
19
|
+
from typing_extensions import Annotated, Self
|
|
20
|
+
|
|
21
|
+
from .support import SupportedTransform
|
|
22
|
+
from .transformations.n2v_manipulate_model import N2VManipulateModel
|
|
23
|
+
from .transformations.xy_flip_model import XYFlipModel
|
|
24
|
+
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
25
|
+
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def np_float_to_scientific_str(x: float) -> str:
|
|
29
|
+
"""Return a string scientific representation of a float.
|
|
30
|
+
|
|
31
|
+
In particular, this method is used to serialize floats to strings, allowing
|
|
32
|
+
numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
x : float
|
|
37
|
+
Input value.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
str
|
|
42
|
+
Scientific string representation of the input value.
|
|
43
|
+
"""
|
|
44
|
+
return np.format_float_scientific(x, precision=7)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
|
|
48
|
+
"""Annotated float type, used to serialize floats to strings."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
TRANSFORMS_UNION = Annotated[
|
|
52
|
+
Union[
|
|
53
|
+
XYFlipModel,
|
|
54
|
+
XYRandomRotate90Model,
|
|
55
|
+
N2VManipulateModel,
|
|
56
|
+
],
|
|
57
|
+
Discriminator("name"), # used to tell the different transform models apart
|
|
58
|
+
]
|
|
59
|
+
"""Available transforms in CAREamics."""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DataConfig(BaseModel):
|
|
63
|
+
"""
|
|
64
|
+
Data configuration.
|
|
65
|
+
|
|
66
|
+
If std is specified, mean must be specified as well. Note that setting the std first
|
|
67
|
+
and then the mean (if they were both `None` before) will raise a validation error.
|
|
68
|
+
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
69
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
70
|
+
of the target could be different from the input data.
|
|
71
|
+
|
|
72
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
73
|
+
|
|
74
|
+
Examples
|
|
75
|
+
--------
|
|
76
|
+
Minimum example:
|
|
77
|
+
|
|
78
|
+
>>> data = DataConfig(
|
|
79
|
+
... data_type="array", # defined in SupportedData
|
|
80
|
+
... patch_size=[128, 128],
|
|
81
|
+
... batch_size=4,
|
|
82
|
+
... axes="YX"
|
|
83
|
+
... )
|
|
84
|
+
|
|
85
|
+
To change the image_means and image_stds of the data:
|
|
86
|
+
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
87
|
+
|
|
88
|
+
One can pass also a list of transformations, by keyword, using the
|
|
89
|
+
SupportedTransform value:
|
|
90
|
+
>>> from careamics.config.support import SupportedTransform
|
|
91
|
+
>>> data = DataConfig(
|
|
92
|
+
... data_type="tiff",
|
|
93
|
+
... patch_size=[128, 128],
|
|
94
|
+
... batch_size=4,
|
|
95
|
+
... axes="YX",
|
|
96
|
+
... transforms=[
|
|
97
|
+
... {
|
|
98
|
+
... "name": "XYFlip",
|
|
99
|
+
... }
|
|
100
|
+
... ]
|
|
101
|
+
... )
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
# Pydantic class configuration
|
|
105
|
+
model_config = ConfigDict(
|
|
106
|
+
validate_assignment=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Dataset configuration
|
|
110
|
+
data_type: Literal["array", "tiff", "custom"]
|
|
111
|
+
"""Type of input data, numpy.ndarray (array) or paths (tiff and custom), as defined
|
|
112
|
+
in SupportedData."""
|
|
113
|
+
|
|
114
|
+
axes: str
|
|
115
|
+
"""Axes of the data, as defined in SupportedAxes."""
|
|
116
|
+
|
|
117
|
+
patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
|
|
118
|
+
"""Patch size, as used during training."""
|
|
119
|
+
|
|
120
|
+
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
121
|
+
"""Batch size for training."""
|
|
122
|
+
|
|
123
|
+
# Optional fields
|
|
124
|
+
image_means: Optional[list[Float]] = Field(
|
|
125
|
+
default=None, min_length=0, max_length=32
|
|
126
|
+
)
|
|
127
|
+
"""Means of the data across channels, used for normalization."""
|
|
128
|
+
|
|
129
|
+
image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
|
|
130
|
+
"""Standard deviations of the data across channels, used for normalization."""
|
|
131
|
+
|
|
132
|
+
target_means: Optional[list[Float]] = Field(
|
|
133
|
+
default=None, min_length=0, max_length=32
|
|
134
|
+
)
|
|
135
|
+
"""Means of the target data across channels, used for normalization."""
|
|
136
|
+
|
|
137
|
+
target_stds: Optional[list[Float]] = Field(
|
|
138
|
+
default=None, min_length=0, max_length=32
|
|
139
|
+
)
|
|
140
|
+
"""Standard deviations of the target data across channels, used for
|
|
141
|
+
normalization."""
|
|
142
|
+
|
|
143
|
+
transforms: list[TRANSFORMS_UNION] = Field(
|
|
144
|
+
default=[
|
|
145
|
+
{
|
|
146
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
147
|
+
},
|
|
148
|
+
{
|
|
149
|
+
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
150
|
+
},
|
|
151
|
+
{
|
|
152
|
+
"name": SupportedTransform.N2V_MANIPULATE.value,
|
|
153
|
+
},
|
|
154
|
+
],
|
|
155
|
+
validate_default=True,
|
|
156
|
+
)
|
|
157
|
+
"""List of transformations to apply to the data, available transforms are defined
|
|
158
|
+
in SupportedTransform. The default values are set for Noise2Void."""
|
|
159
|
+
|
|
160
|
+
dataloader_params: Optional[dict] = None
|
|
161
|
+
"""Dictionary of PyTorch dataloader parameters."""
|
|
162
|
+
|
|
163
|
+
@field_validator("patch_size")
|
|
164
|
+
@classmethod
|
|
165
|
+
def all_elements_power_of_2_minimum_8(
|
|
166
|
+
cls, patch_list: Union[list[int]]
|
|
167
|
+
) -> Union[list[int]]:
|
|
168
|
+
"""
|
|
169
|
+
Validate patch size.
|
|
170
|
+
|
|
171
|
+
Patch size must be powers of 2 and minimum 8.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
patch_list : list of int
|
|
176
|
+
Patch size.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
list of int
|
|
181
|
+
Validated patch size.
|
|
182
|
+
|
|
183
|
+
Raises
|
|
184
|
+
------
|
|
185
|
+
ValueError
|
|
186
|
+
If the patch size is smaller than 8.
|
|
187
|
+
ValueError
|
|
188
|
+
If the patch size is not a power of 2.
|
|
189
|
+
"""
|
|
190
|
+
patch_size_ge_than_8_power_of_2(patch_list)
|
|
191
|
+
|
|
192
|
+
return patch_list
|
|
193
|
+
|
|
194
|
+
@field_validator("axes")
|
|
195
|
+
@classmethod
|
|
196
|
+
def axes_valid(cls, axes: str) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Validate axes.
|
|
199
|
+
|
|
200
|
+
Axes must:
|
|
201
|
+
- be a combination of 'STCZYX'
|
|
202
|
+
- not contain duplicates
|
|
203
|
+
- contain at least 2 contiguous axes: X and Y
|
|
204
|
+
- contain at most 4 axes
|
|
205
|
+
- not contain both S and T axes
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
axes : str
|
|
210
|
+
Axes to validate.
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
str
|
|
215
|
+
Validated axes.
|
|
216
|
+
|
|
217
|
+
Raises
|
|
218
|
+
------
|
|
219
|
+
ValueError
|
|
220
|
+
If axes are not valid.
|
|
221
|
+
"""
|
|
222
|
+
# Validate axes
|
|
223
|
+
check_axes_validity(axes)
|
|
224
|
+
|
|
225
|
+
return axes
|
|
226
|
+
|
|
227
|
+
@field_validator("transforms")
|
|
228
|
+
@classmethod
|
|
229
|
+
def validate_prediction_transforms(
|
|
230
|
+
cls, transforms: list[TRANSFORMS_UNION]
|
|
231
|
+
) -> list[TRANSFORMS_UNION]:
|
|
232
|
+
"""
|
|
233
|
+
Validate N2VManipulate transform position in the transform list.
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
transforms : list[Transformations_Union]
|
|
238
|
+
Transforms.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
list of transforms
|
|
243
|
+
Validated transforms.
|
|
244
|
+
|
|
245
|
+
Raises
|
|
246
|
+
------
|
|
247
|
+
ValueError
|
|
248
|
+
If multiple instances of N2VManipulate are found.
|
|
249
|
+
"""
|
|
250
|
+
transform_list = [t.name for t in transforms]
|
|
251
|
+
|
|
252
|
+
if SupportedTransform.N2V_MANIPULATE in transform_list:
|
|
253
|
+
# multiple N2V_MANIPULATE
|
|
254
|
+
if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Multiple instances of "
|
|
257
|
+
f"{SupportedTransform.N2V_MANIPULATE} transforms "
|
|
258
|
+
f"are not allowed."
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# N2V_MANIPULATE not the last transform
|
|
262
|
+
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
|
|
263
|
+
index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
|
|
264
|
+
transform = transforms.pop(index)
|
|
265
|
+
transforms.append(transform)
|
|
266
|
+
|
|
267
|
+
return transforms
|
|
268
|
+
|
|
269
|
+
@model_validator(mode="after")
|
|
270
|
+
def std_only_with_mean(self: Self) -> Self:
|
|
271
|
+
"""
|
|
272
|
+
Check that mean and std are either both None, or both specified.
|
|
273
|
+
|
|
274
|
+
Returns
|
|
275
|
+
-------
|
|
276
|
+
Self
|
|
277
|
+
Validated data model.
|
|
278
|
+
|
|
279
|
+
Raises
|
|
280
|
+
------
|
|
281
|
+
ValueError
|
|
282
|
+
If std is not None and mean is None.
|
|
283
|
+
"""
|
|
284
|
+
# check that mean and std are either both None, or both specified
|
|
285
|
+
if (self.image_means and not self.image_stds) or (
|
|
286
|
+
self.image_stds and not self.image_means
|
|
287
|
+
):
|
|
288
|
+
raise ValueError(
|
|
289
|
+
"Mean and std must be either both None, or both specified."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
293
|
+
len(self.image_means) != len(self.image_stds)
|
|
294
|
+
):
|
|
295
|
+
raise ValueError("Mean and std must be specified for each input channel.")
|
|
296
|
+
|
|
297
|
+
if (self.target_means and not self.target_stds) or (
|
|
298
|
+
self.target_stds and not self.target_means
|
|
299
|
+
):
|
|
300
|
+
raise ValueError(
|
|
301
|
+
"Mean and std must be either both None, or both specified "
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
elif self.target_means is not None and self.target_stds is not None:
|
|
305
|
+
if len(self.target_means) != len(self.target_stds):
|
|
306
|
+
raise ValueError(
|
|
307
|
+
"Mean and std must be either both None, or both specified for each "
|
|
308
|
+
"target channel."
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
return self
|
|
312
|
+
|
|
313
|
+
@model_validator(mode="after")
|
|
314
|
+
def validate_dimensions(self: Self) -> Self:
|
|
315
|
+
"""
|
|
316
|
+
Validate 2D/3D dimensions between axes, patch size and transforms.
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
-------
|
|
320
|
+
Self
|
|
321
|
+
Validated data model.
|
|
322
|
+
|
|
323
|
+
Raises
|
|
324
|
+
------
|
|
325
|
+
ValueError
|
|
326
|
+
If the transforms are not valid.
|
|
327
|
+
"""
|
|
328
|
+
if "Z" in self.axes:
|
|
329
|
+
if len(self.patch_size) != 3:
|
|
330
|
+
raise ValueError(
|
|
331
|
+
f"Patch size must have 3 dimensions if the data is 3D "
|
|
332
|
+
f"({self.axes})."
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
else:
|
|
336
|
+
if len(self.patch_size) != 2:
|
|
337
|
+
raise ValueError(
|
|
338
|
+
f"Patch size must have 3 dimensions if the data is 3D "
|
|
339
|
+
f"({self.axes})."
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
return self
|
|
343
|
+
|
|
344
|
+
def __str__(self) -> str:
|
|
345
|
+
"""
|
|
346
|
+
Pretty string reprensenting the configuration.
|
|
347
|
+
|
|
348
|
+
Returns
|
|
349
|
+
-------
|
|
350
|
+
str
|
|
351
|
+
Pretty string.
|
|
352
|
+
"""
|
|
353
|
+
return pformat(self.model_dump())
|
|
354
|
+
|
|
355
|
+
def _update(self, **kwargs: Any) -> None:
|
|
356
|
+
"""
|
|
357
|
+
Update multiple arguments at once.
|
|
358
|
+
|
|
359
|
+
Parameters
|
|
360
|
+
----------
|
|
361
|
+
**kwargs : Any
|
|
362
|
+
Keyword arguments to update.
|
|
363
|
+
"""
|
|
364
|
+
self.__dict__.update(kwargs)
|
|
365
|
+
self.__class__.model_validate(self.__dict__)
|
|
366
|
+
|
|
367
|
+
def has_n2v_manipulate(self) -> bool:
|
|
368
|
+
"""
|
|
369
|
+
Check if the transforms contain N2VManipulate.
|
|
370
|
+
|
|
371
|
+
Returns
|
|
372
|
+
-------
|
|
373
|
+
bool
|
|
374
|
+
True if the transforms contain N2VManipulate, False otherwise.
|
|
375
|
+
"""
|
|
376
|
+
return any(
|
|
377
|
+
transform.name == SupportedTransform.N2V_MANIPULATE.value
|
|
378
|
+
for transform in self.transforms
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
def add_n2v_manipulate(self) -> None:
|
|
382
|
+
"""Add N2VManipulate to the transforms."""
|
|
383
|
+
if not self.has_n2v_manipulate():
|
|
384
|
+
self.transforms.append(
|
|
385
|
+
N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def remove_n2v_manipulate(self) -> None:
|
|
389
|
+
"""Remove N2VManipulate from the transforms."""
|
|
390
|
+
if self.has_n2v_manipulate():
|
|
391
|
+
self.transforms.pop(-1)
|
|
392
|
+
|
|
393
|
+
def set_means_and_stds(
|
|
394
|
+
self,
|
|
395
|
+
image_means: Union[NDArray, tuple, list, None],
|
|
396
|
+
image_stds: Union[NDArray, tuple, list, None],
|
|
397
|
+
target_means: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
398
|
+
target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
399
|
+
) -> None:
|
|
400
|
+
"""
|
|
401
|
+
Set mean and standard deviation of the data across channels.
|
|
402
|
+
|
|
403
|
+
This method should be used instead setting the fields directly, as it would
|
|
404
|
+
otherwise trigger a validation error.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
image_means : numpy.ndarray, tuple or list
|
|
409
|
+
Mean values for normalization.
|
|
410
|
+
image_stds : numpy.ndarray, tuple or list
|
|
411
|
+
Standard deviation values for normalization.
|
|
412
|
+
target_means : numpy.ndarray, tuple or list, optional
|
|
413
|
+
Target mean values for normalization, by default ().
|
|
414
|
+
target_stds : numpy.ndarray, tuple or list, optional
|
|
415
|
+
Target standard deviation values for normalization, by default ().
|
|
416
|
+
"""
|
|
417
|
+
# make sure we pass a list
|
|
418
|
+
if image_means is not None:
|
|
419
|
+
image_means = list(image_means)
|
|
420
|
+
if image_stds is not None:
|
|
421
|
+
image_stds = list(image_stds)
|
|
422
|
+
if target_means is not None:
|
|
423
|
+
target_means = list(target_means)
|
|
424
|
+
if target_stds is not None:
|
|
425
|
+
target_stds = list(target_stds)
|
|
426
|
+
|
|
427
|
+
self._update(
|
|
428
|
+
image_means=image_means,
|
|
429
|
+
image_stds=image_stds,
|
|
430
|
+
target_means=target_means,
|
|
431
|
+
target_stds=target_stds,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
def set_3D(self, axes: str, patch_size: list[int]) -> None:
|
|
435
|
+
"""
|
|
436
|
+
Set 3D parameters.
|
|
437
|
+
|
|
438
|
+
Parameters
|
|
439
|
+
----------
|
|
440
|
+
axes : str
|
|
441
|
+
Axes.
|
|
442
|
+
patch_size : list of int
|
|
443
|
+
Patch size.
|
|
444
|
+
"""
|
|
445
|
+
self._update(axes=axes, patch_size=patch_size)
|
|
446
|
+
|
|
447
|
+
def set_N2V2(self, use_n2v2: bool) -> None:
|
|
448
|
+
"""
|
|
449
|
+
Set N2V2.
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
use_n2v2 : bool
|
|
454
|
+
Whether to use N2V2.
|
|
455
|
+
|
|
456
|
+
Raises
|
|
457
|
+
------
|
|
458
|
+
ValueError
|
|
459
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
460
|
+
"""
|
|
461
|
+
if use_n2v2:
|
|
462
|
+
self.set_N2V2_strategy("median")
|
|
463
|
+
else:
|
|
464
|
+
self.set_N2V2_strategy("uniform")
|
|
465
|
+
|
|
466
|
+
def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None:
|
|
467
|
+
"""
|
|
468
|
+
Set N2V2 strategy.
|
|
469
|
+
|
|
470
|
+
Parameters
|
|
471
|
+
----------
|
|
472
|
+
strategy : Literal["uniform", "median"]
|
|
473
|
+
Strategy to use for N2V2.
|
|
474
|
+
|
|
475
|
+
Raises
|
|
476
|
+
------
|
|
477
|
+
ValueError
|
|
478
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
479
|
+
"""
|
|
480
|
+
found_n2v = False
|
|
481
|
+
|
|
482
|
+
for transform in self.transforms:
|
|
483
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
484
|
+
transform.strategy = strategy
|
|
485
|
+
found_n2v = True
|
|
486
|
+
|
|
487
|
+
if not found_n2v:
|
|
488
|
+
transforms = [t.name for t in self.transforms]
|
|
489
|
+
raise ValueError(
|
|
490
|
+
f"N2V_Manipulate transform not found in the transforms list "
|
|
491
|
+
f"({transforms})."
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
def set_structN2V_mask(
|
|
495
|
+
self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
|
|
496
|
+
) -> None:
|
|
497
|
+
"""
|
|
498
|
+
Set structN2V mask parameters.
|
|
499
|
+
|
|
500
|
+
Setting `mask_axis` to `none` will disable structN2V.
|
|
501
|
+
|
|
502
|
+
Parameters
|
|
503
|
+
----------
|
|
504
|
+
mask_axis : Literal["horizontal", "vertical", "none"]
|
|
505
|
+
Axis along which to apply the mask. `none` will disable structN2V.
|
|
506
|
+
mask_span : int
|
|
507
|
+
Total span of the mask in pixels.
|
|
508
|
+
|
|
509
|
+
Raises
|
|
510
|
+
------
|
|
511
|
+
ValueError
|
|
512
|
+
If the N2V pixel manipulate transform is not found in the transforms.
|
|
513
|
+
"""
|
|
514
|
+
found_n2v = False
|
|
515
|
+
|
|
516
|
+
for transform in self.transforms:
|
|
517
|
+
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
518
|
+
transform.struct_mask_axis = mask_axis
|
|
519
|
+
transform.struct_mask_span = mask_span
|
|
520
|
+
found_n2v = True
|
|
521
|
+
|
|
522
|
+
if not found_n2v:
|
|
523
|
+
transforms = [t.name for t in self.transforms]
|
|
524
|
+
raise ValueError(
|
|
525
|
+
f"N2V pixel manipulate transform not found in the transforms "
|
|
526
|
+
f"({transforms})."
|
|
527
|
+
)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Module containing `FCNAlgorithmConfig` class."""
|
|
2
|
+
|
|
3
|
+
from pprint import pformat
|
|
4
|
+
from typing import Literal, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from careamics.config.architectures import CustomModel, UNetModel
|
|
10
|
+
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FCNAlgorithmConfig(BaseModel):
|
|
14
|
+
"""Algorithm configuration.
|
|
15
|
+
|
|
16
|
+
This Pydantic model validates the parameters governing the components of the
|
|
17
|
+
training algorithm: which algorithm, loss function, model architecture, optimizer,
|
|
18
|
+
and learning rate scheduler to use.
|
|
19
|
+
|
|
20
|
+
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
|
|
21
|
+
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
|
|
22
|
+
allows you to register your own architecture and select it using its name as
|
|
23
|
+
`name` in the custom pydantic model.
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
algorithm : Literal["n2v", "custom"]
|
|
28
|
+
Algorithm to use.
|
|
29
|
+
loss : Literal["n2v", "mae", "mse"]
|
|
30
|
+
Loss function to use.
|
|
31
|
+
model : Union[UNetModel, LVAEModel, CustomModel]
|
|
32
|
+
Model architecture to use.
|
|
33
|
+
optimizer : OptimizerModel, optional
|
|
34
|
+
Optimizer to use.
|
|
35
|
+
lr_scheduler : LrSchedulerModel, optional
|
|
36
|
+
Learning rate scheduler to use.
|
|
37
|
+
|
|
38
|
+
Raises
|
|
39
|
+
------
|
|
40
|
+
ValueError
|
|
41
|
+
Algorithm parameter type validation errors.
|
|
42
|
+
ValueError
|
|
43
|
+
If the algorithm, loss and model are not compatible.
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
Minimum example:
|
|
48
|
+
>>> from careamics.config import FCNAlgorithmConfig
|
|
49
|
+
>>> config_dict = {
|
|
50
|
+
... "algorithm": "n2v",
|
|
51
|
+
... "algorithm_type": "fcn",
|
|
52
|
+
... "loss": "n2v",
|
|
53
|
+
... "model": {
|
|
54
|
+
... "architecture": "UNet",
|
|
55
|
+
... }
|
|
56
|
+
... }
|
|
57
|
+
>>> config = FCNAlgorithmConfig(**config_dict)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
# Pydantic class configuration
|
|
61
|
+
model_config = ConfigDict(
|
|
62
|
+
protected_namespaces=(), # allows to use model_* as a field name
|
|
63
|
+
validate_assignment=True,
|
|
64
|
+
extra="allow",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Mandatory fields
|
|
68
|
+
# defined in SupportedAlgorithm
|
|
69
|
+
algorithm_type: Literal["fcn"]
|
|
70
|
+
"""Algorithm type must be `fcn` (fully convolutional network) to differentiate this
|
|
71
|
+
configuration from LVAE."""
|
|
72
|
+
|
|
73
|
+
algorithm: Literal["n2v", "care", "n2n", "custom"]
|
|
74
|
+
"""Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
|
|
75
|
+
model architecture."""
|
|
76
|
+
|
|
77
|
+
loss: Literal["n2v", "mae", "mse"]
|
|
78
|
+
"""Loss function to use, as defined in SupportedLoss."""
|
|
79
|
+
|
|
80
|
+
model: Union[UNetModel, CustomModel] = Field(discriminator="architecture")
|
|
81
|
+
"""Model architecture to use, along with its parameters. Compatible architectures
|
|
82
|
+
are defined in SupportedArchitecture, and their Pydantic models in
|
|
83
|
+
`careamics.config.architectures`."""
|
|
84
|
+
# TODO supported architectures are now all the architectures but does not warn users
|
|
85
|
+
# of the compatibility with the algorithm
|
|
86
|
+
|
|
87
|
+
# Optional fields
|
|
88
|
+
optimizer: OptimizerModel = OptimizerModel()
|
|
89
|
+
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
90
|
+
|
|
91
|
+
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
92
|
+
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""
|
|
93
|
+
|
|
94
|
+
@model_validator(mode="after")
|
|
95
|
+
def algorithm_cross_validation(self: Self) -> Self:
|
|
96
|
+
"""Validate the algorithm model based on `algorithm`.
|
|
97
|
+
|
|
98
|
+
N2V:
|
|
99
|
+
- loss must be n2v
|
|
100
|
+
- model must be a `UNetModel`
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
Self
|
|
105
|
+
The validated model.
|
|
106
|
+
"""
|
|
107
|
+
# N2V
|
|
108
|
+
if self.algorithm == "n2v":
|
|
109
|
+
# n2v is only compatible with the n2v loss
|
|
110
|
+
if self.loss != "n2v":
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Algorithm {self.algorithm} only supports loss `n2v`."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# n2v is only compatible with the UNet model
|
|
116
|
+
if not isinstance(self.model, UNetModel):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"Model for algorithm {self.algorithm} must be a `UNetModel`."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# n2v requires the number of input and output channels to be the same
|
|
122
|
+
if self.model.in_channels != self.model.num_classes:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
"N2V requires the same number of input and output channels. Make "
|
|
125
|
+
"sure that `in_channels` and `num_classes` are the same."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if self.algorithm == "care" or self.algorithm == "n2n":
|
|
129
|
+
if self.loss == "n2v":
|
|
130
|
+
raise ValueError("Supervised algorithms do not support loss `n2v`.")
|
|
131
|
+
|
|
132
|
+
if (self.algorithm == "custom") != (self.model.architecture == "custom"):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Algorithm and model architecture must be both `custom` or not."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return self
|
|
138
|
+
|
|
139
|
+
def __str__(self) -> str:
|
|
140
|
+
"""Pretty string representing the configuration.
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
str
|
|
145
|
+
Pretty string.
|
|
146
|
+
"""
|
|
147
|
+
return pformat(self.model_dump())
|