careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
careamics/config/data_model.py
CHANGED
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Literal, Optional, Union
|
|
7
7
|
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
from pydantic import (
|
|
9
10
|
BaseModel,
|
|
10
11
|
ConfigDict,
|
|
@@ -17,16 +18,14 @@ from typing_extensions import Annotated, Self
|
|
|
17
18
|
|
|
18
19
|
from .support import SupportedTransform
|
|
19
20
|
from .transformations.n2v_manipulate_model import N2VManipulateModel
|
|
20
|
-
from .transformations.
|
|
21
|
-
from .transformations.normalize_model import NormalizeModel
|
|
21
|
+
from .transformations.xy_flip_model import XYFlipModel
|
|
22
22
|
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
23
23
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
24
24
|
|
|
25
25
|
TRANSFORMS_UNION = Annotated[
|
|
26
26
|
Union[
|
|
27
|
-
|
|
27
|
+
XYFlipModel,
|
|
28
28
|
XYRandomRotate90Model,
|
|
29
|
-
NormalizeModel,
|
|
30
29
|
N2VManipulateModel,
|
|
31
30
|
],
|
|
32
31
|
Discriminator("name"), # used to tell the different transform models apart
|
|
@@ -39,7 +38,11 @@ class DataConfig(BaseModel):
|
|
|
39
38
|
|
|
40
39
|
If std is specified, mean must be specified as well. Note that setting the std first
|
|
41
40
|
and then the mean (if they were both `None` before) will raise a validation error.
|
|
42
|
-
Prefer instead `set_mean_and_std` to set both at once.
|
|
41
|
+
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
42
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
43
|
+
of the target could be different from the input data.
|
|
44
|
+
|
|
45
|
+
All supported transforms are defined in the SupportedTransform enum.
|
|
43
46
|
|
|
44
47
|
Examples
|
|
45
48
|
--------
|
|
@@ -53,10 +56,10 @@ class DataConfig(BaseModel):
|
|
|
53
56
|
... )
|
|
54
57
|
|
|
55
58
|
To change the mean and std of the data:
|
|
56
|
-
>>> data.set_mean_and_std(
|
|
59
|
+
>>> data.set_mean_and_std(image_means=[214.3], image_stds=[84.5])
|
|
57
60
|
|
|
58
61
|
One can pass also a list of transformations, by keyword, using the
|
|
59
|
-
SupportedTransform
|
|
62
|
+
SupportedTransform value:
|
|
60
63
|
>>> from careamics.config.support import SupportedTransform
|
|
61
64
|
>>> data = DataConfig(
|
|
62
65
|
... data_type="tiff",
|
|
@@ -65,12 +68,7 @@ class DataConfig(BaseModel):
|
|
|
65
68
|
... axes="YX",
|
|
66
69
|
... transforms=[
|
|
67
70
|
... {
|
|
68
|
-
... "name":
|
|
69
|
-
... "mean": 167.6,
|
|
70
|
-
... "std": 47.2,
|
|
71
|
-
... },
|
|
72
|
-
... {
|
|
73
|
-
... "name": "NDFlip",
|
|
71
|
+
... "name": "XYFlip",
|
|
74
72
|
... }
|
|
75
73
|
... ]
|
|
76
74
|
... )
|
|
@@ -83,21 +81,26 @@ class DataConfig(BaseModel):
|
|
|
83
81
|
|
|
84
82
|
# Dataset configuration
|
|
85
83
|
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
86
|
-
patch_size: Union[
|
|
84
|
+
patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
|
|
87
85
|
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
88
86
|
axes: str
|
|
89
87
|
|
|
90
88
|
# Optional fields
|
|
91
|
-
|
|
92
|
-
|
|
89
|
+
image_means: Optional[list[float]] = Field(
|
|
90
|
+
default=None, min_length=0, max_length=32
|
|
91
|
+
)
|
|
92
|
+
image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
|
|
93
|
+
target_means: Optional[list[float]] = Field(
|
|
94
|
+
default=None, min_length=0, max_length=32
|
|
95
|
+
)
|
|
96
|
+
target_stds: Optional[list[float]] = Field(
|
|
97
|
+
default=None, min_length=0, max_length=32
|
|
98
|
+
)
|
|
93
99
|
|
|
94
|
-
transforms:
|
|
100
|
+
transforms: list[TRANSFORMS_UNION] = Field(
|
|
95
101
|
default=[
|
|
96
102
|
{
|
|
97
|
-
"name": SupportedTransform.
|
|
98
|
-
},
|
|
99
|
-
{
|
|
100
|
-
"name": SupportedTransform.NDFLIP.value,
|
|
103
|
+
"name": SupportedTransform.XY_FLIP.value,
|
|
101
104
|
},
|
|
102
105
|
{
|
|
103
106
|
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
@@ -114,8 +117,8 @@ class DataConfig(BaseModel):
|
|
|
114
117
|
@field_validator("patch_size")
|
|
115
118
|
@classmethod
|
|
116
119
|
def all_elements_power_of_2_minimum_8(
|
|
117
|
-
cls, patch_list: Union[
|
|
118
|
-
) -> Union[
|
|
120
|
+
cls, patch_list: Union[list[int]]
|
|
121
|
+
) -> Union[list[int]]:
|
|
119
122
|
"""
|
|
120
123
|
Validate patch size.
|
|
121
124
|
|
|
@@ -123,12 +126,12 @@ class DataConfig(BaseModel):
|
|
|
123
126
|
|
|
124
127
|
Parameters
|
|
125
128
|
----------
|
|
126
|
-
patch_list :
|
|
129
|
+
patch_list : list of int
|
|
127
130
|
Patch size.
|
|
128
131
|
|
|
129
132
|
Returns
|
|
130
133
|
-------
|
|
131
|
-
|
|
134
|
+
list of int
|
|
132
135
|
Validated patch size.
|
|
133
136
|
|
|
134
137
|
Raises
|
|
@@ -178,19 +181,19 @@ class DataConfig(BaseModel):
|
|
|
178
181
|
@field_validator("transforms")
|
|
179
182
|
@classmethod
|
|
180
183
|
def validate_prediction_transforms(
|
|
181
|
-
cls, transforms:
|
|
182
|
-
) ->
|
|
184
|
+
cls, transforms: list[TRANSFORMS_UNION]
|
|
185
|
+
) -> list[TRANSFORMS_UNION]:
|
|
183
186
|
"""
|
|
184
187
|
Validate N2VManipulate transform position in the transform list.
|
|
185
188
|
|
|
186
189
|
Parameters
|
|
187
190
|
----------
|
|
188
|
-
transforms :
|
|
191
|
+
transforms : list[Transformations_Union]
|
|
189
192
|
Transforms.
|
|
190
193
|
|
|
191
194
|
Returns
|
|
192
195
|
-------
|
|
193
|
-
|
|
196
|
+
list of transforms
|
|
194
197
|
Validated transforms.
|
|
195
198
|
|
|
196
199
|
Raises
|
|
@@ -202,7 +205,7 @@ class DataConfig(BaseModel):
|
|
|
202
205
|
|
|
203
206
|
if SupportedTransform.N2V_MANIPULATE in transform_list:
|
|
204
207
|
# multiple N2V_MANIPULATE
|
|
205
|
-
if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
|
|
208
|
+
if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
|
|
206
209
|
raise ValueError(
|
|
207
210
|
f"Multiple instances of "
|
|
208
211
|
f"{SupportedTransform.N2V_MANIPULATE} transforms "
|
|
@@ -211,7 +214,7 @@ class DataConfig(BaseModel):
|
|
|
211
214
|
|
|
212
215
|
# N2V_MANIPULATE not the last transform
|
|
213
216
|
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
|
|
214
|
-
index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
|
|
217
|
+
index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
|
|
215
218
|
transform = transforms.pop(index)
|
|
216
219
|
transforms.append(transform)
|
|
217
220
|
|
|
@@ -233,29 +236,33 @@ class DataConfig(BaseModel):
|
|
|
233
236
|
If std is not None and mean is None.
|
|
234
237
|
"""
|
|
235
238
|
# check that mean and std are either both None, or both specified
|
|
236
|
-
if (self.
|
|
239
|
+
if (self.image_means and not self.image_stds) or (
|
|
240
|
+
self.image_stds and not self.image_means
|
|
241
|
+
):
|
|
237
242
|
raise ValueError(
|
|
238
243
|
"Mean and std must be either both None, or both specified."
|
|
239
244
|
)
|
|
240
245
|
|
|
241
|
-
|
|
246
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
247
|
+
len(self.image_means) != len(self.image_stds)
|
|
248
|
+
):
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"Mean and std must be specified for each " "input channel."
|
|
251
|
+
)
|
|
242
252
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
253
|
+
if (self.target_means and not self.target_stds) or (
|
|
254
|
+
self.target_stds and not self.target_means
|
|
255
|
+
):
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"Mean and std must be either both None, or both specified "
|
|
258
|
+
)
|
|
247
259
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
# search in the transforms for Normalize and update parameters
|
|
255
|
-
for transform in self.transforms:
|
|
256
|
-
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
257
|
-
transform.mean = self.mean
|
|
258
|
-
transform.std = self.std
|
|
260
|
+
elif self.target_means is not None and self.target_stds is not None:
|
|
261
|
+
if len(self.target_means) != len(self.target_stds):
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Mean and std must be either both None, or both specified for each "
|
|
264
|
+
"target channel."
|
|
265
|
+
)
|
|
259
266
|
|
|
260
267
|
return self
|
|
261
268
|
|
|
@@ -339,7 +346,13 @@ class DataConfig(BaseModel):
|
|
|
339
346
|
if self.has_n2v_manipulate():
|
|
340
347
|
self.transforms.pop(-1)
|
|
341
348
|
|
|
342
|
-
def set_mean_and_std(
|
|
349
|
+
def set_mean_and_std(
|
|
350
|
+
self,
|
|
351
|
+
image_means: Union[NDArray, tuple, list, None],
|
|
352
|
+
image_stds: Union[NDArray, tuple, list, None],
|
|
353
|
+
target_means: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
354
|
+
target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
355
|
+
) -> None:
|
|
343
356
|
"""
|
|
344
357
|
Set mean and standard deviation of the data.
|
|
345
358
|
|
|
@@ -348,20 +361,33 @@ class DataConfig(BaseModel):
|
|
|
348
361
|
|
|
349
362
|
Parameters
|
|
350
363
|
----------
|
|
351
|
-
|
|
352
|
-
Mean
|
|
353
|
-
|
|
354
|
-
Standard deviation
|
|
364
|
+
image_means : NDArray or tuple or list
|
|
365
|
+
Mean values for normalization.
|
|
366
|
+
image_stds : NDArray or tuple or list
|
|
367
|
+
Standard deviation values for normalization.
|
|
368
|
+
target_means : NDArray or tuple or list, optional
|
|
369
|
+
Target mean values for normalization, by default ().
|
|
370
|
+
target_stds : NDArray or tuple or list, optional
|
|
371
|
+
Target standard deviation values for normalization, by default ().
|
|
355
372
|
"""
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
373
|
+
# make sure we pass a list
|
|
374
|
+
if image_means is not None:
|
|
375
|
+
image_means = list(image_means)
|
|
376
|
+
if image_stds is not None:
|
|
377
|
+
image_stds = list(image_stds)
|
|
378
|
+
if target_means is not None:
|
|
379
|
+
target_means = list(target_means)
|
|
380
|
+
if target_stds is not None:
|
|
381
|
+
target_stds = list(target_stds)
|
|
382
|
+
|
|
383
|
+
self._update(
|
|
384
|
+
image_means=image_means,
|
|
385
|
+
image_stds=image_stds,
|
|
386
|
+
target_means=target_means,
|
|
387
|
+
target_stds=target_stds,
|
|
388
|
+
)
|
|
363
389
|
|
|
364
|
-
def set_3D(self, axes: str, patch_size:
|
|
390
|
+
def set_3D(self, axes: str, patch_size: list[int]) -> None:
|
|
365
391
|
"""
|
|
366
392
|
Set 3D parameters.
|
|
367
393
|
|
|
@@ -369,7 +395,7 @@ class DataConfig(BaseModel):
|
|
|
369
395
|
----------
|
|
370
396
|
axes : str
|
|
371
397
|
Axes.
|
|
372
|
-
patch_size :
|
|
398
|
+
patch_size : list of int
|
|
373
399
|
Patch size.
|
|
374
400
|
"""
|
|
375
401
|
self._update(axes=axes, patch_size=patch_size)
|
|
@@ -2,17 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
9
9
|
|
|
10
|
-
from .support import SupportedTransform
|
|
11
|
-
from .transformations.normalize_model import NormalizeModel
|
|
12
10
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
13
11
|
|
|
14
|
-
TRANSFORMS_UNION = Union[NormalizeModel]
|
|
15
|
-
|
|
16
12
|
|
|
17
13
|
class InferenceConfig(BaseModel):
|
|
18
14
|
"""Configuration class for the prediction model."""
|
|
@@ -21,26 +17,17 @@ class InferenceConfig(BaseModel):
|
|
|
21
17
|
|
|
22
18
|
# Mandatory fields
|
|
23
19
|
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
24
|
-
tile_size: Optional[Union[
|
|
20
|
+
tile_size: Optional[Union[list[int]]] = Field(
|
|
25
21
|
default=None, min_length=2, max_length=3
|
|
26
22
|
)
|
|
27
|
-
tile_overlap: Optional[Union[
|
|
23
|
+
tile_overlap: Optional[Union[list[int]]] = Field(
|
|
28
24
|
default=None, min_length=2, max_length=3
|
|
29
25
|
)
|
|
30
26
|
|
|
31
27
|
axes: str
|
|
32
28
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
transforms: List[TRANSFORMS_UNION] = Field(
|
|
37
|
-
default=[
|
|
38
|
-
{
|
|
39
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
40
|
-
},
|
|
41
|
-
],
|
|
42
|
-
validate_default=True,
|
|
43
|
-
)
|
|
29
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
30
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
44
31
|
|
|
45
32
|
# only default TTAs are supported for now
|
|
46
33
|
tta_transforms: bool = Field(default=True)
|
|
@@ -51,8 +38,8 @@ class InferenceConfig(BaseModel):
|
|
|
51
38
|
@field_validator("tile_overlap")
|
|
52
39
|
@classmethod
|
|
53
40
|
def all_elements_non_zero_even(
|
|
54
|
-
cls, tile_overlap: Optional[
|
|
55
|
-
) -> Optional[
|
|
41
|
+
cls, tile_overlap: Optional[list[int]]
|
|
42
|
+
) -> Optional[list[int]]:
|
|
56
43
|
"""
|
|
57
44
|
Validate tile overlap.
|
|
58
45
|
|
|
@@ -60,12 +47,12 @@ class InferenceConfig(BaseModel):
|
|
|
60
47
|
|
|
61
48
|
Parameters
|
|
62
49
|
----------
|
|
63
|
-
tile_overlap :
|
|
50
|
+
tile_overlap : list[int] or None
|
|
64
51
|
Patch size.
|
|
65
52
|
|
|
66
53
|
Returns
|
|
67
54
|
-------
|
|
68
|
-
|
|
55
|
+
list[int] or None
|
|
69
56
|
Validated tile overlap.
|
|
70
57
|
|
|
71
58
|
Raises
|
|
@@ -90,19 +77,19 @@ class InferenceConfig(BaseModel):
|
|
|
90
77
|
@field_validator("tile_size")
|
|
91
78
|
@classmethod
|
|
92
79
|
def tile_min_8_power_of_2(
|
|
93
|
-
cls, tile_list: Optional[
|
|
94
|
-
) -> Optional[
|
|
80
|
+
cls, tile_list: Optional[list[int]]
|
|
81
|
+
) -> Optional[list[int]]:
|
|
95
82
|
"""
|
|
96
83
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
97
84
|
|
|
98
85
|
Parameters
|
|
99
86
|
----------
|
|
100
|
-
tile_list :
|
|
87
|
+
tile_list : list of int
|
|
101
88
|
Patch size.
|
|
102
89
|
|
|
103
90
|
Returns
|
|
104
91
|
-------
|
|
105
|
-
|
|
92
|
+
list of int
|
|
106
93
|
Validated patch size.
|
|
107
94
|
|
|
108
95
|
Raises
|
|
@@ -149,39 +136,6 @@ class InferenceConfig(BaseModel):
|
|
|
149
136
|
|
|
150
137
|
return axes
|
|
151
138
|
|
|
152
|
-
@field_validator("transforms")
|
|
153
|
-
@classmethod
|
|
154
|
-
def validate_transforms(
|
|
155
|
-
cls, transforms: List[TRANSFORMS_UNION]
|
|
156
|
-
) -> List[TRANSFORMS_UNION]:
|
|
157
|
-
"""
|
|
158
|
-
Validate that transforms do not have N2V pixel manipulate transforms.
|
|
159
|
-
|
|
160
|
-
Parameters
|
|
161
|
-
----------
|
|
162
|
-
transforms : List[TRANSFORMS_UNION]
|
|
163
|
-
Transforms.
|
|
164
|
-
|
|
165
|
-
Returns
|
|
166
|
-
-------
|
|
167
|
-
List[TRANSFORMS_UNION]
|
|
168
|
-
Validated transforms.
|
|
169
|
-
|
|
170
|
-
Raises
|
|
171
|
-
------
|
|
172
|
-
ValueError
|
|
173
|
-
If transforms contain N2V pixel manipulate transforms.
|
|
174
|
-
"""
|
|
175
|
-
if transforms is not None:
|
|
176
|
-
for transform in transforms:
|
|
177
|
-
if transform.name == SupportedTransform.N2V_MANIPULATE.value:
|
|
178
|
-
raise ValueError(
|
|
179
|
-
"N2V_Manipulate transform is not allowed in "
|
|
180
|
-
"prediction transforms."
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
return transforms
|
|
184
|
-
|
|
185
139
|
@model_validator(mode="after")
|
|
186
140
|
def validate_dimensions(self: Self) -> Self:
|
|
187
141
|
"""
|
|
@@ -228,29 +182,22 @@ class InferenceConfig(BaseModel):
|
|
|
228
182
|
If std is not None and mean is None.
|
|
229
183
|
"""
|
|
230
184
|
# check that mean and std are either both None, or both specified
|
|
231
|
-
if
|
|
185
|
+
if not self.image_means and not self.image_stds:
|
|
186
|
+
raise ValueError("Mean and std must be specified during inference.")
|
|
187
|
+
|
|
188
|
+
if (self.image_means and not self.image_stds) or (
|
|
189
|
+
self.image_stds and not self.image_means
|
|
190
|
+
):
|
|
232
191
|
raise ValueError(
|
|
233
192
|
"Mean and std must be either both None, or both specified."
|
|
234
193
|
)
|
|
235
194
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
Returns
|
|
244
|
-
-------
|
|
245
|
-
Self
|
|
246
|
-
Inference model with mean and std added to the Normalize transform.
|
|
247
|
-
"""
|
|
248
|
-
if self.mean is not None or self.std is not None:
|
|
249
|
-
# search in the transforms for Normalize and update parameters
|
|
250
|
-
for transform in self.transforms:
|
|
251
|
-
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
252
|
-
transform.mean = self.mean
|
|
253
|
-
transform.std = self.std
|
|
195
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
196
|
+
len(self.image_means) != len(self.image_stds)
|
|
197
|
+
):
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"Mean and std must be specified for each " "input channel."
|
|
200
|
+
)
|
|
254
201
|
|
|
255
202
|
return self
|
|
256
203
|
|
|
@@ -266,7 +213,7 @@ class InferenceConfig(BaseModel):
|
|
|
266
213
|
self.__dict__.update(kwargs)
|
|
267
214
|
self.__class__.model_validate(self.__dict__)
|
|
268
215
|
|
|
269
|
-
def set_3D(self, axes: str, tile_size:
|
|
216
|
+
def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
|
|
270
217
|
"""
|
|
271
218
|
Set 3D parameters.
|
|
272
219
|
|
|
@@ -274,9 +221,9 @@ class InferenceConfig(BaseModel):
|
|
|
274
221
|
----------
|
|
275
222
|
axes : str
|
|
276
223
|
Axes.
|
|
277
|
-
tile_size :
|
|
224
|
+
tile_size : list of int
|
|
278
225
|
Tile size.
|
|
279
|
-
tile_overlap :
|
|
226
|
+
tile_overlap : list of int
|
|
280
227
|
Tile overlap.
|
|
281
228
|
"""
|
|
282
229
|
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
"""Optimizers and schedulers Pydantic models."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
|
-
from typing import
|
|
5
|
+
from typing import Literal
|
|
4
6
|
|
|
5
7
|
from pydantic import (
|
|
6
8
|
BaseModel,
|
|
@@ -19,8 +21,7 @@ from .support import SupportedOptimizer
|
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class OptimizerModel(BaseModel):
|
|
22
|
-
"""
|
|
23
|
-
Torch optimizer.
|
|
24
|
+
"""Torch optimizer Pydantic model.
|
|
24
25
|
|
|
25
26
|
Only parameters supported by the corresponding torch optimizer will be taken
|
|
26
27
|
into account. For more details, check:
|
|
@@ -31,7 +32,7 @@ class OptimizerModel(BaseModel):
|
|
|
31
32
|
|
|
32
33
|
Attributes
|
|
33
34
|
----------
|
|
34
|
-
name :
|
|
35
|
+
name : {"Adam", "SGD"}
|
|
35
36
|
Name of the optimizer.
|
|
36
37
|
parameters : dict
|
|
37
38
|
Parameters of the optimizer (see torch documentation).
|
|
@@ -55,7 +56,7 @@ class OptimizerModel(BaseModel):
|
|
|
55
56
|
|
|
56
57
|
@field_validator("parameters")
|
|
57
58
|
@classmethod
|
|
58
|
-
def filter_parameters(cls, user_params: dict, values: ValidationInfo) ->
|
|
59
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
59
60
|
"""
|
|
60
61
|
Validate optimizer parameters.
|
|
61
62
|
|
|
@@ -70,7 +71,7 @@ class OptimizerModel(BaseModel):
|
|
|
70
71
|
|
|
71
72
|
Returns
|
|
72
73
|
-------
|
|
73
|
-
|
|
74
|
+
dict
|
|
74
75
|
Filtered optimizer parameters.
|
|
75
76
|
|
|
76
77
|
Raises
|
|
@@ -115,8 +116,7 @@ class OptimizerModel(BaseModel):
|
|
|
115
116
|
|
|
116
117
|
|
|
117
118
|
class LrSchedulerModel(BaseModel):
|
|
118
|
-
"""
|
|
119
|
-
Torch learning rate scheduler.
|
|
119
|
+
"""Torch learning rate scheduler Pydantic model.
|
|
120
120
|
|
|
121
121
|
Only parameters supported by the corresponding torch lr scheduler will be taken
|
|
122
122
|
into account. For more details, check:
|
|
@@ -127,7 +127,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
127
127
|
|
|
128
128
|
Attributes
|
|
129
129
|
----------
|
|
130
|
-
name :
|
|
130
|
+
name : {"ReduceLROnPlateau", "StepLR"}
|
|
131
131
|
Name of the learning rate scheduler.
|
|
132
132
|
parameters : dict
|
|
133
133
|
Parameters of the learning rate scheduler (see torch documentation).
|
|
@@ -146,7 +146,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
146
146
|
|
|
147
147
|
@field_validator("parameters")
|
|
148
148
|
@classmethod
|
|
149
|
-
def filter_parameters(cls, user_params: dict, values: ValidationInfo) ->
|
|
149
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
150
150
|
"""Filter parameters based on the learning rate scheduler's signature.
|
|
151
151
|
|
|
152
152
|
Parameters
|
|
@@ -158,7 +158,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
158
158
|
|
|
159
159
|
Returns
|
|
160
160
|
-------
|
|
161
|
-
|
|
161
|
+
dict
|
|
162
162
|
Filtered scheduler parameters.
|
|
163
163
|
|
|
164
164
|
Raises
|
|
@@ -14,7 +14,6 @@ __all__ = [
|
|
|
14
14
|
"SupportedPixelManipulation",
|
|
15
15
|
"SupportedTransform",
|
|
16
16
|
"SupportedData",
|
|
17
|
-
"SupportedExtractionStrategy",
|
|
18
17
|
"SupportedStructAxis",
|
|
19
18
|
"SupportedLogger",
|
|
20
19
|
]
|
|
@@ -24,7 +23,6 @@ from .supported_activations import SupportedActivation
|
|
|
24
23
|
from .supported_algorithms import SupportedAlgorithm
|
|
25
24
|
from .supported_architectures import SupportedArchitecture
|
|
26
25
|
from .supported_data import SupportedData
|
|
27
|
-
from .supported_extraction_strategies import SupportedExtractionStrategy
|
|
28
26
|
from .supported_loggers import SupportedLogger
|
|
29
27
|
from .supported_losses import SupportedLoss
|
|
30
28
|
from .supported_optimizers import SupportedOptimizer, SupportedScheduler
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Algorithms supported by CAREamics."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
from careamics.utils import BaseEnum
|
|
@@ -10,9 +12,9 @@ class SupportedAlgorithm(str, BaseEnum):
|
|
|
10
12
|
"""
|
|
11
13
|
|
|
12
14
|
N2V = "n2v"
|
|
13
|
-
CUSTOM = "custom"
|
|
14
15
|
CARE = "care"
|
|
15
16
|
N2N = "n2n"
|
|
17
|
+
CUSTOM = "custom"
|
|
16
18
|
# PN2V = "pn2v"
|
|
17
19
|
# HDN = "hdn"
|
|
18
20
|
# SEG = "segmentation"
|
|
@@ -1,15 +1,15 @@
|
|
|
1
|
+
"""Pixel manipulation methods supported by CAREamics."""
|
|
2
|
+
|
|
1
3
|
from careamics.utils import BaseEnum
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class SupportedPixelManipulation(str, BaseEnum):
|
|
5
|
-
"""
|
|
7
|
+
"""Supported Noise2Void pixel manipulations.
|
|
6
8
|
|
|
7
9
|
- Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor
|
|
8
10
|
pixel value.
|
|
9
11
|
- Median: Replace masked pixel value by the mean of the neighborhood.
|
|
10
12
|
"""
|
|
11
13
|
|
|
12
|
-
# TODO docs
|
|
13
|
-
|
|
14
14
|
UNIFORM = "uniform"
|
|
15
15
|
MEDIAN = "median"
|