careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
careamics/config/data_model.py
CHANGED
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Literal, Optional, Union
|
|
7
7
|
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
from pydantic import (
|
|
9
10
|
BaseModel,
|
|
10
11
|
ConfigDict,
|
|
@@ -17,7 +18,6 @@ from typing_extensions import Annotated, Self
|
|
|
17
18
|
|
|
18
19
|
from .support import SupportedTransform
|
|
19
20
|
from .transformations.n2v_manipulate_model import N2VManipulateModel
|
|
20
|
-
from .transformations.normalize_model import NormalizeModel
|
|
21
21
|
from .transformations.xy_flip_model import XYFlipModel
|
|
22
22
|
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
23
23
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
@@ -26,7 +26,6 @@ TRANSFORMS_UNION = Annotated[
|
|
|
26
26
|
Union[
|
|
27
27
|
XYFlipModel,
|
|
28
28
|
XYRandomRotate90Model,
|
|
29
|
-
NormalizeModel,
|
|
30
29
|
N2VManipulateModel,
|
|
31
30
|
],
|
|
32
31
|
Discriminator("name"), # used to tell the different transform models apart
|
|
@@ -39,7 +38,9 @@ class DataConfig(BaseModel):
|
|
|
39
38
|
|
|
40
39
|
If std is specified, mean must be specified as well. Note that setting the std first
|
|
41
40
|
and then the mean (if they were both `None` before) will raise a validation error.
|
|
42
|
-
Prefer instead `set_mean_and_std` to set both at once.
|
|
41
|
+
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
42
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
43
|
+
of the target could be different from the input data.
|
|
43
44
|
|
|
44
45
|
All supported transforms are defined in the SupportedTransform enum.
|
|
45
46
|
|
|
@@ -54,8 +55,8 @@ class DataConfig(BaseModel):
|
|
|
54
55
|
... axes="YX"
|
|
55
56
|
... )
|
|
56
57
|
|
|
57
|
-
To change the
|
|
58
|
-
>>> data.
|
|
58
|
+
To change the image_means and image_stds of the data:
|
|
59
|
+
>>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
|
|
59
60
|
|
|
60
61
|
One can pass also a list of transformations, by keyword, using the
|
|
61
62
|
SupportedTransform value:
|
|
@@ -67,11 +68,6 @@ class DataConfig(BaseModel):
|
|
|
67
68
|
... axes="YX",
|
|
68
69
|
... transforms=[
|
|
69
70
|
... {
|
|
70
|
-
... "name": SupportedTransform.NORMALIZE.value,
|
|
71
|
-
... "mean": 167.6,
|
|
72
|
-
... "std": 47.2,
|
|
73
|
-
... },
|
|
74
|
-
... {
|
|
75
71
|
... "name": "XYFlip",
|
|
76
72
|
... }
|
|
77
73
|
... ]
|
|
@@ -84,20 +80,41 @@ class DataConfig(BaseModel):
|
|
|
84
80
|
)
|
|
85
81
|
|
|
86
82
|
# Dataset configuration
|
|
87
|
-
data_type: Literal["array", "tiff", "custom"]
|
|
88
|
-
|
|
89
|
-
|
|
83
|
+
data_type: Literal["array", "tiff", "custom"]
|
|
84
|
+
"""Type of input data, numpy.ndarray (array) or paths (tiff and custom), as defined
|
|
85
|
+
in SupportedData."""
|
|
86
|
+
|
|
90
87
|
axes: str
|
|
88
|
+
"""Axes of the data, as defined in SupportedAxes."""
|
|
89
|
+
|
|
90
|
+
patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
|
|
91
|
+
"""Patch size, as used during training."""
|
|
92
|
+
|
|
93
|
+
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
94
|
+
"""Batch size for training."""
|
|
91
95
|
|
|
92
96
|
# Optional fields
|
|
93
|
-
|
|
94
|
-
|
|
97
|
+
image_means: Optional[list[float]] = Field(
|
|
98
|
+
default=None, min_length=0, max_length=32
|
|
99
|
+
)
|
|
100
|
+
"""Means of the data across channels, used for normalization."""
|
|
101
|
+
|
|
102
|
+
image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
|
|
103
|
+
"""Standard deviations of the data across channels, used for normalization."""
|
|
95
104
|
|
|
96
|
-
|
|
105
|
+
target_means: Optional[list[float]] = Field(
|
|
106
|
+
default=None, min_length=0, max_length=32
|
|
107
|
+
)
|
|
108
|
+
"""Means of the target data across channels, used for normalization."""
|
|
109
|
+
|
|
110
|
+
target_stds: Optional[list[float]] = Field(
|
|
111
|
+
default=None, min_length=0, max_length=32
|
|
112
|
+
)
|
|
113
|
+
"""Standard deviations of the target data across channels, used for
|
|
114
|
+
normalization."""
|
|
115
|
+
|
|
116
|
+
transforms: list[TRANSFORMS_UNION] = Field(
|
|
97
117
|
default=[
|
|
98
|
-
{
|
|
99
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
100
|
-
},
|
|
101
118
|
{
|
|
102
119
|
"name": SupportedTransform.XY_FLIP.value,
|
|
103
120
|
},
|
|
@@ -110,14 +127,17 @@ class DataConfig(BaseModel):
|
|
|
110
127
|
],
|
|
111
128
|
validate_default=True,
|
|
112
129
|
)
|
|
130
|
+
"""List of transformations to apply to the data, available transforms are defined
|
|
131
|
+
in SupportedTransform. The default values are set for Noise2Void."""
|
|
113
132
|
|
|
114
133
|
dataloader_params: Optional[dict] = None
|
|
134
|
+
"""Dictionary of PyTorch dataloader parameters."""
|
|
115
135
|
|
|
116
136
|
@field_validator("patch_size")
|
|
117
137
|
@classmethod
|
|
118
138
|
def all_elements_power_of_2_minimum_8(
|
|
119
|
-
cls, patch_list: Union[
|
|
120
|
-
) -> Union[
|
|
139
|
+
cls, patch_list: Union[list[int]]
|
|
140
|
+
) -> Union[list[int]]:
|
|
121
141
|
"""
|
|
122
142
|
Validate patch size.
|
|
123
143
|
|
|
@@ -125,12 +145,12 @@ class DataConfig(BaseModel):
|
|
|
125
145
|
|
|
126
146
|
Parameters
|
|
127
147
|
----------
|
|
128
|
-
patch_list :
|
|
148
|
+
patch_list : list of int
|
|
129
149
|
Patch size.
|
|
130
150
|
|
|
131
151
|
Returns
|
|
132
152
|
-------
|
|
133
|
-
|
|
153
|
+
list of int
|
|
134
154
|
Validated patch size.
|
|
135
155
|
|
|
136
156
|
Raises
|
|
@@ -180,19 +200,19 @@ class DataConfig(BaseModel):
|
|
|
180
200
|
@field_validator("transforms")
|
|
181
201
|
@classmethod
|
|
182
202
|
def validate_prediction_transforms(
|
|
183
|
-
cls, transforms:
|
|
184
|
-
) ->
|
|
203
|
+
cls, transforms: list[TRANSFORMS_UNION]
|
|
204
|
+
) -> list[TRANSFORMS_UNION]:
|
|
185
205
|
"""
|
|
186
206
|
Validate N2VManipulate transform position in the transform list.
|
|
187
207
|
|
|
188
208
|
Parameters
|
|
189
209
|
----------
|
|
190
|
-
transforms :
|
|
210
|
+
transforms : list[Transformations_Union]
|
|
191
211
|
Transforms.
|
|
192
212
|
|
|
193
213
|
Returns
|
|
194
214
|
-------
|
|
195
|
-
|
|
215
|
+
list of transforms
|
|
196
216
|
Validated transforms.
|
|
197
217
|
|
|
198
218
|
Raises
|
|
@@ -235,29 +255,33 @@ class DataConfig(BaseModel):
|
|
|
235
255
|
If std is not None and mean is None.
|
|
236
256
|
"""
|
|
237
257
|
# check that mean and std are either both None, or both specified
|
|
238
|
-
if (self.
|
|
258
|
+
if (self.image_means and not self.image_stds) or (
|
|
259
|
+
self.image_stds and not self.image_means
|
|
260
|
+
):
|
|
239
261
|
raise ValueError(
|
|
240
262
|
"Mean and std must be either both None, or both specified."
|
|
241
263
|
)
|
|
242
264
|
|
|
243
|
-
|
|
265
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
266
|
+
len(self.image_means) != len(self.image_stds)
|
|
267
|
+
):
|
|
268
|
+
raise ValueError(
|
|
269
|
+
"Mean and std must be specified for each " "input channel."
|
|
270
|
+
)
|
|
244
271
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
272
|
+
if (self.target_means and not self.target_stds) or (
|
|
273
|
+
self.target_stds and not self.target_means
|
|
274
|
+
):
|
|
275
|
+
raise ValueError(
|
|
276
|
+
"Mean and std must be either both None, or both specified "
|
|
277
|
+
)
|
|
249
278
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
# search in the transforms for Normalize and update parameters
|
|
257
|
-
for transform in self.transforms:
|
|
258
|
-
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
259
|
-
transform.mean = self.mean
|
|
260
|
-
transform.std = self.std
|
|
279
|
+
elif self.target_means is not None and self.target_stds is not None:
|
|
280
|
+
if len(self.target_means) != len(self.target_stds):
|
|
281
|
+
raise ValueError(
|
|
282
|
+
"Mean and std must be either both None, or both specified for each "
|
|
283
|
+
"target channel."
|
|
284
|
+
)
|
|
261
285
|
|
|
262
286
|
return self
|
|
263
287
|
|
|
@@ -341,23 +365,48 @@ class DataConfig(BaseModel):
|
|
|
341
365
|
if self.has_n2v_manipulate():
|
|
342
366
|
self.transforms.pop(-1)
|
|
343
367
|
|
|
344
|
-
def
|
|
368
|
+
def set_means_and_stds(
|
|
369
|
+
self,
|
|
370
|
+
image_means: Union[NDArray, tuple, list, None],
|
|
371
|
+
image_stds: Union[NDArray, tuple, list, None],
|
|
372
|
+
target_means: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
373
|
+
target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
374
|
+
) -> None:
|
|
345
375
|
"""
|
|
346
|
-
Set mean and standard deviation of the data.
|
|
376
|
+
Set mean and standard deviation of the data across channels.
|
|
347
377
|
|
|
348
378
|
This method should be used instead setting the fields directly, as it would
|
|
349
379
|
otherwise trigger a validation error.
|
|
350
380
|
|
|
351
381
|
Parameters
|
|
352
382
|
----------
|
|
353
|
-
|
|
354
|
-
Mean
|
|
355
|
-
|
|
356
|
-
Standard deviation
|
|
383
|
+
image_means : numpy.ndarray ,tuple or list
|
|
384
|
+
Mean values for normalization.
|
|
385
|
+
image_stds : numpy.ndarray, tuple or list
|
|
386
|
+
Standard deviation values for normalization.
|
|
387
|
+
target_means : numpy.ndarray, tuple or list, optional
|
|
388
|
+
Target mean values for normalization, by default ().
|
|
389
|
+
target_stds : numpy.ndarray, tuple or list, optional
|
|
390
|
+
Target standard deviation values for normalization, by default ().
|
|
357
391
|
"""
|
|
358
|
-
|
|
392
|
+
# make sure we pass a list
|
|
393
|
+
if image_means is not None:
|
|
394
|
+
image_means = list(image_means)
|
|
395
|
+
if image_stds is not None:
|
|
396
|
+
image_stds = list(image_stds)
|
|
397
|
+
if target_means is not None:
|
|
398
|
+
target_means = list(target_means)
|
|
399
|
+
if target_stds is not None:
|
|
400
|
+
target_stds = list(target_stds)
|
|
401
|
+
|
|
402
|
+
self._update(
|
|
403
|
+
image_means=image_means,
|
|
404
|
+
image_stds=image_stds,
|
|
405
|
+
target_means=target_means,
|
|
406
|
+
target_stds=target_stds,
|
|
407
|
+
)
|
|
359
408
|
|
|
360
|
-
def set_3D(self, axes: str, patch_size:
|
|
409
|
+
def set_3D(self, axes: str, patch_size: list[int]) -> None:
|
|
361
410
|
"""
|
|
362
411
|
Set 3D parameters.
|
|
363
412
|
|
|
@@ -365,7 +414,7 @@ class DataConfig(BaseModel):
|
|
|
365
414
|
----------
|
|
366
415
|
axes : str
|
|
367
416
|
Axes.
|
|
368
|
-
patch_size :
|
|
417
|
+
patch_size : list of int
|
|
369
418
|
Patch size.
|
|
370
419
|
"""
|
|
371
420
|
self._update(axes=axes, patch_size=patch_size)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
@@ -15,31 +15,41 @@ class InferenceConfig(BaseModel):
|
|
|
15
15
|
|
|
16
16
|
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
17
17
|
|
|
18
|
-
# Mandatory fields
|
|
19
18
|
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
20
|
-
|
|
19
|
+
"""Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
|
|
20
|
+
|
|
21
|
+
tile_size: Optional[Union[list[int]]] = Field(
|
|
21
22
|
default=None, min_length=2, max_length=3
|
|
22
23
|
)
|
|
23
|
-
tile_overlap
|
|
24
|
+
"""Tile size of prediction, only effective if `tile_overlap` is specified."""
|
|
25
|
+
|
|
26
|
+
tile_overlap: Optional[Union[list[int]]] = Field(
|
|
24
27
|
default=None, min_length=2, max_length=3
|
|
25
28
|
)
|
|
29
|
+
"""Overlap between tiles, only effective if `tile_size` is specified."""
|
|
26
30
|
|
|
27
31
|
axes: str
|
|
32
|
+
"""Data axes (TSCZYX) in the order of the input data."""
|
|
28
33
|
|
|
29
|
-
|
|
30
|
-
|
|
34
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
35
|
+
"""Mean values for each input channel."""
|
|
31
36
|
|
|
32
|
-
|
|
37
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
38
|
+
"""Standard deviation values for each input channel."""
|
|
39
|
+
|
|
40
|
+
# TODO only default TTAs are supported for now
|
|
33
41
|
tta_transforms: bool = Field(default=True)
|
|
42
|
+
"""Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
|
|
34
43
|
|
|
35
44
|
# Dataloader parameters
|
|
36
45
|
batch_size: int = Field(default=1, ge=1)
|
|
46
|
+
"""Batch size for prediction."""
|
|
37
47
|
|
|
38
48
|
@field_validator("tile_overlap")
|
|
39
49
|
@classmethod
|
|
40
50
|
def all_elements_non_zero_even(
|
|
41
|
-
cls, tile_overlap: Optional[
|
|
42
|
-
) -> Optional[
|
|
51
|
+
cls, tile_overlap: Optional[list[int]]
|
|
52
|
+
) -> Optional[list[int]]:
|
|
43
53
|
"""
|
|
44
54
|
Validate tile overlap.
|
|
45
55
|
|
|
@@ -47,12 +57,12 @@ class InferenceConfig(BaseModel):
|
|
|
47
57
|
|
|
48
58
|
Parameters
|
|
49
59
|
----------
|
|
50
|
-
tile_overlap :
|
|
60
|
+
tile_overlap : list[int] or None
|
|
51
61
|
Patch size.
|
|
52
62
|
|
|
53
63
|
Returns
|
|
54
64
|
-------
|
|
55
|
-
|
|
65
|
+
list[int] or None
|
|
56
66
|
Validated tile overlap.
|
|
57
67
|
|
|
58
68
|
Raises
|
|
@@ -77,19 +87,19 @@ class InferenceConfig(BaseModel):
|
|
|
77
87
|
@field_validator("tile_size")
|
|
78
88
|
@classmethod
|
|
79
89
|
def tile_min_8_power_of_2(
|
|
80
|
-
cls, tile_list: Optional[
|
|
81
|
-
) -> Optional[
|
|
90
|
+
cls, tile_list: Optional[list[int]]
|
|
91
|
+
) -> Optional[list[int]]:
|
|
82
92
|
"""
|
|
83
93
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
84
94
|
|
|
85
95
|
Parameters
|
|
86
96
|
----------
|
|
87
|
-
tile_list :
|
|
97
|
+
tile_list : list of int
|
|
88
98
|
Patch size.
|
|
89
99
|
|
|
90
100
|
Returns
|
|
91
101
|
-------
|
|
92
|
-
|
|
102
|
+
list of int
|
|
93
103
|
Validated patch size.
|
|
94
104
|
|
|
95
105
|
Raises
|
|
@@ -182,11 +192,23 @@ class InferenceConfig(BaseModel):
|
|
|
182
192
|
If std is not None and mean is None.
|
|
183
193
|
"""
|
|
184
194
|
# check that mean and std are either both None, or both specified
|
|
185
|
-
if
|
|
195
|
+
if not self.image_means and not self.image_stds:
|
|
196
|
+
raise ValueError("Mean and std must be specified during inference.")
|
|
197
|
+
|
|
198
|
+
if (self.image_means and not self.image_stds) or (
|
|
199
|
+
self.image_stds and not self.image_means
|
|
200
|
+
):
|
|
186
201
|
raise ValueError(
|
|
187
202
|
"Mean and std must be either both None, or both specified."
|
|
188
203
|
)
|
|
189
204
|
|
|
205
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
206
|
+
len(self.image_means) != len(self.image_stds)
|
|
207
|
+
):
|
|
208
|
+
raise ValueError(
|
|
209
|
+
"Mean and std must be specified for each " "input channel."
|
|
210
|
+
)
|
|
211
|
+
|
|
190
212
|
return self
|
|
191
213
|
|
|
192
214
|
def _update(self, **kwargs: Any) -> None:
|
|
@@ -201,7 +223,7 @@ class InferenceConfig(BaseModel):
|
|
|
201
223
|
self.__dict__.update(kwargs)
|
|
202
224
|
self.__class__.model_validate(self.__dict__)
|
|
203
225
|
|
|
204
|
-
def set_3D(self, axes: str, tile_size:
|
|
226
|
+
def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
|
|
205
227
|
"""
|
|
206
228
|
Set 3D parameters.
|
|
207
229
|
|
|
@@ -209,9 +231,9 @@ class InferenceConfig(BaseModel):
|
|
|
209
231
|
----------
|
|
210
232
|
axes : str
|
|
211
233
|
Axes.
|
|
212
|
-
tile_size :
|
|
234
|
+
tile_size : list of int
|
|
213
235
|
Tile size.
|
|
214
|
-
tile_overlap :
|
|
236
|
+
tile_overlap : list of int
|
|
215
237
|
Tile overlap.
|
|
216
238
|
"""
|
|
217
239
|
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Literal
|
|
6
6
|
|
|
7
7
|
from pydantic import (
|
|
8
8
|
BaseModel,
|
|
@@ -32,7 +32,7 @@ class OptimizerModel(BaseModel):
|
|
|
32
32
|
|
|
33
33
|
Attributes
|
|
34
34
|
----------
|
|
35
|
-
name :
|
|
35
|
+
name : {"Adam", "SGD"}
|
|
36
36
|
Name of the optimizer.
|
|
37
37
|
parameters : dict
|
|
38
38
|
Parameters of the optimizer (see torch documentation).
|
|
@@ -45,6 +45,7 @@ class OptimizerModel(BaseModel):
|
|
|
45
45
|
|
|
46
46
|
# Mandatory field
|
|
47
47
|
name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
|
|
48
|
+
"""Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
|
|
48
49
|
|
|
49
50
|
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
50
51
|
parameters: dict = Field(
|
|
@@ -53,10 +54,11 @@ class OptimizerModel(BaseModel):
|
|
|
53
54
|
},
|
|
54
55
|
validate_default=True,
|
|
55
56
|
)
|
|
57
|
+
"""Parameters of the optimizer, see PyTorch documentation for more details."""
|
|
56
58
|
|
|
57
59
|
@field_validator("parameters")
|
|
58
60
|
@classmethod
|
|
59
|
-
def filter_parameters(cls, user_params: dict, values: ValidationInfo) ->
|
|
61
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
60
62
|
"""
|
|
61
63
|
Validate optimizer parameters.
|
|
62
64
|
|
|
@@ -71,7 +73,7 @@ class OptimizerModel(BaseModel):
|
|
|
71
73
|
|
|
72
74
|
Returns
|
|
73
75
|
-------
|
|
74
|
-
|
|
76
|
+
dict
|
|
75
77
|
Filtered optimizer parameters.
|
|
76
78
|
|
|
77
79
|
Raises
|
|
@@ -127,7 +129,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
127
129
|
|
|
128
130
|
Attributes
|
|
129
131
|
----------
|
|
130
|
-
name :
|
|
132
|
+
name : {"ReduceLROnPlateau", "StepLR"}
|
|
131
133
|
Name of the learning rate scheduler.
|
|
132
134
|
parameters : dict
|
|
133
135
|
Parameters of the learning rate scheduler (see torch documentation).
|
|
@@ -140,13 +142,17 @@ class LrSchedulerModel(BaseModel):
|
|
|
140
142
|
|
|
141
143
|
# Mandatory field
|
|
142
144
|
name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
|
|
145
|
+
"""Name of the learning rate scheduler, supported schedulers are defined in
|
|
146
|
+
SupportedScheduler."""
|
|
143
147
|
|
|
144
148
|
# Optional parameters
|
|
145
149
|
parameters: dict = Field(default={}, validate_default=True)
|
|
150
|
+
"""Parameters of the learning rate scheduler, see PyTorch documentation for more
|
|
151
|
+
details."""
|
|
146
152
|
|
|
147
153
|
@field_validator("parameters")
|
|
148
154
|
@classmethod
|
|
149
|
-
def filter_parameters(cls, user_params: dict, values: ValidationInfo) ->
|
|
155
|
+
def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
|
|
150
156
|
"""Filter parameters based on the learning rate scheduler's signature.
|
|
151
157
|
|
|
152
158
|
Parameters
|
|
@@ -158,7 +164,7 @@ class LrSchedulerModel(BaseModel):
|
|
|
158
164
|
|
|
159
165
|
Returns
|
|
160
166
|
-------
|
|
161
|
-
|
|
167
|
+
dict
|
|
162
168
|
Filtered scheduler parameters.
|
|
163
169
|
|
|
164
170
|
Raises
|
|
@@ -60,9 +60,9 @@ class SupportedData(str, BaseEnum):
|
|
|
60
60
|
return super()._missing_(value)
|
|
61
61
|
|
|
62
62
|
@classmethod
|
|
63
|
-
def
|
|
63
|
+
def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str:
|
|
64
64
|
"""
|
|
65
|
-
Path.rglob and fnmatch compatible extension.
|
|
65
|
+
Get Path.rglob and fnmatch compatible extension.
|
|
66
66
|
|
|
67
67
|
Parameters
|
|
68
68
|
----------
|
|
@@ -72,13 +72,38 @@ class SupportedData(str, BaseEnum):
|
|
|
72
72
|
Returns
|
|
73
73
|
-------
|
|
74
74
|
str
|
|
75
|
-
Corresponding extension.
|
|
75
|
+
Corresponding extension pattern.
|
|
76
76
|
"""
|
|
77
77
|
if data_type == cls.ARRAY:
|
|
78
|
-
raise NotImplementedError(f"Data {data_type}
|
|
78
|
+
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
79
79
|
elif data_type == cls.TIFF:
|
|
80
80
|
return "*.tif*"
|
|
81
81
|
elif data_type == cls.CUSTOM:
|
|
82
82
|
return "*.*"
|
|
83
83
|
else:
|
|
84
84
|
raise ValueError(f"Data type {data_type} is not supported.")
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
|
|
88
|
+
"""
|
|
89
|
+
Get file extension of corresponding data type.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
data_type : str or SupportedData
|
|
94
|
+
Data type.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
str
|
|
99
|
+
Corresponding extension.
|
|
100
|
+
"""
|
|
101
|
+
if data_type == cls.ARRAY:
|
|
102
|
+
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
|
|
103
|
+
elif data_type == cls.TIFF:
|
|
104
|
+
return ".tiff"
|
|
105
|
+
elif data_type == cls.CUSTOM:
|
|
106
|
+
# TODO: improve this message
|
|
107
|
+
raise NotImplementedError("Custom extensions have to be passed elsewhere.")
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(f"Data type {data_type} is not supported.")
|