careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -134,21 +134,6 @@ class AlgorithmConfig(BaseModel):
|
|
|
134
134
|
"sure that `in_channels` and `num_classes` are the same."
|
|
135
135
|
)
|
|
136
136
|
|
|
137
|
-
# N2N
|
|
138
|
-
if self.algorithm == "n2n":
|
|
139
|
-
# n2n is only compatible with the UNet model
|
|
140
|
-
if not isinstance(self.model, UNetModel):
|
|
141
|
-
raise ValueError(
|
|
142
|
-
f"Model for algorithm {self.algorithm} must be a `UNetModel`."
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
# n2n requires the number of input and output channels to be the same
|
|
146
|
-
if self.model.in_channels != self.model.num_classes:
|
|
147
|
-
raise ValueError(
|
|
148
|
-
"N2N requires the same number of input and output channels. Make "
|
|
149
|
-
"sure that `in_channels` and `num_classes` are the same."
|
|
150
|
-
)
|
|
151
|
-
|
|
152
137
|
if self.algorithm == "care" or self.algorithm == "n2n":
|
|
153
138
|
if self.loss == "n2v":
|
|
154
139
|
raise ValueError("Supervised algorithms do not support loss `n2v`.")
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Literal
|
|
7
7
|
|
|
8
8
|
from pydantic import ConfigDict, field_validator, model_validator
|
|
9
9
|
from torch.nn import Module
|
|
@@ -136,7 +136,7 @@ class CustomModel(ArchitectureModel):
|
|
|
136
136
|
"""
|
|
137
137
|
return pformat(self.model_dump())
|
|
138
138
|
|
|
139
|
-
def model_dump(self, **kwargs: Any) ->
|
|
139
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
140
140
|
"""Dump the model configuration.
|
|
141
141
|
|
|
142
142
|
Parameters
|
|
@@ -146,7 +146,7 @@ class CustomModel(ArchitectureModel):
|
|
|
146
146
|
|
|
147
147
|
Returns
|
|
148
148
|
-------
|
|
149
|
-
|
|
149
|
+
dict[str, Any]
|
|
150
150
|
Model configuration.
|
|
151
151
|
"""
|
|
152
152
|
model_dict = super().model_dump()
|
|
@@ -107,9 +107,6 @@ def _create_supervised_configuration(
|
|
|
107
107
|
# augmentations
|
|
108
108
|
if use_augmentations:
|
|
109
109
|
transforms: List[Dict[str, Any]] = [
|
|
110
|
-
{
|
|
111
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
112
|
-
},
|
|
113
110
|
{
|
|
114
111
|
"name": SupportedTransform.XY_FLIP.value,
|
|
115
112
|
},
|
|
@@ -118,11 +115,7 @@ def _create_supervised_configuration(
|
|
|
118
115
|
},
|
|
119
116
|
]
|
|
120
117
|
else:
|
|
121
|
-
transforms = [
|
|
122
|
-
{
|
|
123
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
124
|
-
},
|
|
125
|
-
]
|
|
118
|
+
transforms = []
|
|
126
119
|
|
|
127
120
|
# data model
|
|
128
121
|
data = DataConfig(
|
|
@@ -250,7 +243,8 @@ def create_n2n_configuration(
|
|
|
250
243
|
use_augmentations: bool = True,
|
|
251
244
|
independent_channels: bool = False,
|
|
252
245
|
loss: Literal["mae", "mse"] = "mae",
|
|
253
|
-
|
|
246
|
+
n_channels_in: int = 1,
|
|
247
|
+
n_channels_out: int = -1,
|
|
254
248
|
logger: Literal["wandb", "tensorboard", "none"] = "none",
|
|
255
249
|
model_kwargs: Optional[dict] = None,
|
|
256
250
|
) -> Configuration:
|
|
@@ -260,10 +254,13 @@ def create_n2n_configuration(
|
|
|
260
254
|
If "Z" is present in `axes`, then `path_size` must be a list of length 3, otherwise
|
|
261
255
|
2.
|
|
262
256
|
|
|
263
|
-
If "C" is present in `axes`, then you need to set `
|
|
257
|
+
If "C" is present in `axes`, then you need to set `n_channels_in` to the number of
|
|
264
258
|
channels. Likewise, if you set the number of channels, then "C" must be present in
|
|
265
259
|
`axes`.
|
|
266
260
|
|
|
261
|
+
To set the number of output channels, use the `n_channels_out` parameter. If it is
|
|
262
|
+
not specified, it will be assumed to be equal to `n_channels_in`.
|
|
263
|
+
|
|
267
264
|
By default, all channels are trained together. To train all channels independently,
|
|
268
265
|
set `independent_channels` to True.
|
|
269
266
|
|
|
@@ -290,8 +287,10 @@ def create_n2n_configuration(
|
|
|
290
287
|
Whether to train all channels independently, by default False.
|
|
291
288
|
loss : Literal["mae", "mse"], optional
|
|
292
289
|
Loss function to use, by default "mae".
|
|
293
|
-
|
|
294
|
-
Number of channels
|
|
290
|
+
n_channels_in : int, optional
|
|
291
|
+
Number of channels in, by default 1.
|
|
292
|
+
n_channels_out : int, optional
|
|
293
|
+
Number of channels out, by default -1.
|
|
295
294
|
logger : Literal["wandb", "tensorboard", "none"], optional
|
|
296
295
|
Logger to use, by default "none".
|
|
297
296
|
model_kwargs : dict, optional
|
|
@@ -302,6 +301,9 @@ def create_n2n_configuration(
|
|
|
302
301
|
Configuration
|
|
303
302
|
Configuration for training Noise2Noise.
|
|
304
303
|
"""
|
|
304
|
+
if n_channels_out == -1:
|
|
305
|
+
n_channels_out = n_channels_in
|
|
306
|
+
|
|
305
307
|
return _create_supervised_configuration(
|
|
306
308
|
algorithm="n2n",
|
|
307
309
|
experiment_name=experiment_name,
|
|
@@ -313,8 +315,8 @@ def create_n2n_configuration(
|
|
|
313
315
|
use_augmentations=use_augmentations,
|
|
314
316
|
independent_channels=independent_channels,
|
|
315
317
|
loss=loss,
|
|
316
|
-
n_channels_in=
|
|
317
|
-
n_channels_out=
|
|
318
|
+
n_channels_in=n_channels_in,
|
|
319
|
+
n_channels_out=n_channels_out,
|
|
318
320
|
logger=logger,
|
|
319
321
|
model_kwargs=model_kwargs,
|
|
320
322
|
)
|
|
@@ -522,9 +524,6 @@ def create_n2v_configuration(
|
|
|
522
524
|
# augmentations
|
|
523
525
|
if use_augmentations:
|
|
524
526
|
transforms: List[Dict[str, Any]] = [
|
|
525
|
-
{
|
|
526
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
527
|
-
},
|
|
528
527
|
{
|
|
529
528
|
"name": SupportedTransform.XY_FLIP.value,
|
|
530
529
|
},
|
|
@@ -533,11 +532,7 @@ def create_n2v_configuration(
|
|
|
533
532
|
},
|
|
534
533
|
]
|
|
535
534
|
else:
|
|
536
|
-
transforms = [
|
|
537
|
-
{
|
|
538
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
539
|
-
},
|
|
540
|
-
]
|
|
535
|
+
transforms = []
|
|
541
536
|
|
|
542
537
|
# n2v2 and structn2v
|
|
543
538
|
nv2_transform = {
|
|
@@ -618,7 +613,10 @@ def create_inference_configuration(
|
|
|
618
613
|
InferenceConfiguration
|
|
619
614
|
Configuration used to configure CAREamicsPredictData.
|
|
620
615
|
"""
|
|
621
|
-
if
|
|
616
|
+
if (
|
|
617
|
+
configuration.data_config.image_means is None
|
|
618
|
+
or configuration.data_config.image_stds is None
|
|
619
|
+
):
|
|
622
620
|
raise ValueError("Mean and std must be provided in the configuration.")
|
|
623
621
|
|
|
624
622
|
# tile size for UNets
|
|
@@ -648,8 +646,8 @@ def create_inference_configuration(
|
|
|
648
646
|
tile_size=tile_size,
|
|
649
647
|
tile_overlap=tile_overlap,
|
|
650
648
|
axes=axes or configuration.data_config.axes,
|
|
651
|
-
|
|
652
|
-
|
|
649
|
+
image_means=configuration.data_config.image_means,
|
|
650
|
+
image_stds=configuration.data_config.image_stds,
|
|
653
651
|
tta_transforms=tta_transforms,
|
|
654
652
|
batch_size=batch_size,
|
|
655
653
|
)
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import re
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from pprint import pformat
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Literal, Union
|
|
9
9
|
|
|
10
10
|
import yaml
|
|
11
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
@@ -269,7 +269,7 @@ class Configuration(BaseModel):
|
|
|
269
269
|
"""
|
|
270
270
|
return pformat(self.model_dump())
|
|
271
271
|
|
|
272
|
-
def set_3D(self, is_3D: bool, axes: str, patch_size:
|
|
272
|
+
def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
|
|
273
273
|
"""
|
|
274
274
|
Set 3D flag and axes.
|
|
275
275
|
|
|
@@ -279,7 +279,7 @@ class Configuration(BaseModel):
|
|
|
279
279
|
Whether the algorithm is 3D or not.
|
|
280
280
|
axes : str
|
|
281
281
|
Axes of the data.
|
|
282
|
-
patch_size :
|
|
282
|
+
patch_size : list[int]
|
|
283
283
|
Patch size.
|
|
284
284
|
"""
|
|
285
285
|
# set the flag and axes (this will not trigger validation at the config level)
|
|
@@ -389,7 +389,7 @@ class Configuration(BaseModel):
|
|
|
389
389
|
|
|
390
390
|
return ""
|
|
391
391
|
|
|
392
|
-
def get_algorithm_citations(self) ->
|
|
392
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
393
393
|
"""
|
|
394
394
|
Return a list of citation entries of the current algorithm.
|
|
395
395
|
|
|
@@ -455,13 +455,13 @@ class Configuration(BaseModel):
|
|
|
455
455
|
|
|
456
456
|
return ""
|
|
457
457
|
|
|
458
|
-
def get_algorithm_keywords(self) ->
|
|
458
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
459
459
|
"""
|
|
460
460
|
Get algorithm keywords.
|
|
461
461
|
|
|
462
462
|
Returns
|
|
463
463
|
-------
|
|
464
|
-
|
|
464
|
+
list[str]
|
|
465
465
|
List of keywords.
|
|
466
466
|
"""
|
|
467
467
|
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
@@ -491,8 +491,8 @@ class Configuration(BaseModel):
|
|
|
491
491
|
self,
|
|
492
492
|
exclude_defaults: bool = False,
|
|
493
493
|
exclude_none: bool = True,
|
|
494
|
-
**kwargs:
|
|
495
|
-
) ->
|
|
494
|
+
**kwargs: dict,
|
|
495
|
+
) -> dict:
|
|
496
496
|
"""
|
|
497
497
|
Override model_dump method in order to set default values.
|
|
498
498
|
|
|
@@ -503,7 +503,7 @@ class Configuration(BaseModel):
|
|
|
503
503
|
True.
|
|
504
504
|
exclude_none : bool, optional
|
|
505
505
|
Whether to exclude fields with None values or not, by default True.
|
|
506
|
-
**kwargs :
|
|
506
|
+
**kwargs : dict
|
|
507
507
|
Keyword arguments.
|
|
508
508
|
|
|
509
509
|
Returns
|
|
@@ -524,7 +524,7 @@ def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
|
524
524
|
|
|
525
525
|
Parameters
|
|
526
526
|
----------
|
|
527
|
-
path :
|
|
527
|
+
path : str or Path
|
|
528
528
|
Path to the configuration.
|
|
529
529
|
|
|
530
530
|
Returns
|
|
@@ -556,7 +556,7 @@ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
|
556
556
|
----------
|
|
557
557
|
config : Configuration
|
|
558
558
|
Configuration to save.
|
|
559
|
-
path :
|
|
559
|
+
path : str or Path
|
|
560
560
|
Path to a existing folder in which to save the configuration or to an existing
|
|
561
561
|
configuration file.
|
|
562
562
|
|
careamics/config/data_model.py
CHANGED
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Any,
|
|
6
|
+
from typing import Any, Literal, Optional, Union
|
|
7
7
|
|
|
8
|
+
from numpy.typing import NDArray
|
|
8
9
|
from pydantic import (
|
|
9
10
|
BaseModel,
|
|
10
11
|
ConfigDict,
|
|
@@ -17,7 +18,6 @@ from typing_extensions import Annotated, Self
|
|
|
17
18
|
|
|
18
19
|
from .support import SupportedTransform
|
|
19
20
|
from .transformations.n2v_manipulate_model import N2VManipulateModel
|
|
20
|
-
from .transformations.normalize_model import NormalizeModel
|
|
21
21
|
from .transformations.xy_flip_model import XYFlipModel
|
|
22
22
|
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
23
23
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
@@ -26,7 +26,6 @@ TRANSFORMS_UNION = Annotated[
|
|
|
26
26
|
Union[
|
|
27
27
|
XYFlipModel,
|
|
28
28
|
XYRandomRotate90Model,
|
|
29
|
-
NormalizeModel,
|
|
30
29
|
N2VManipulateModel,
|
|
31
30
|
],
|
|
32
31
|
Discriminator("name"), # used to tell the different transform models apart
|
|
@@ -39,7 +38,9 @@ class DataConfig(BaseModel):
|
|
|
39
38
|
|
|
40
39
|
If std is specified, mean must be specified as well. Note that setting the std first
|
|
41
40
|
and then the mean (if they were both `None` before) will raise a validation error.
|
|
42
|
-
Prefer instead `set_mean_and_std` to set both at once.
|
|
41
|
+
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
|
|
42
|
+
to be lists of floats, one for each channel. For supervised tasks, the mean and std
|
|
43
|
+
of the target could be different from the input data.
|
|
43
44
|
|
|
44
45
|
All supported transforms are defined in the SupportedTransform enum.
|
|
45
46
|
|
|
@@ -55,7 +56,7 @@ class DataConfig(BaseModel):
|
|
|
55
56
|
... )
|
|
56
57
|
|
|
57
58
|
To change the mean and std of the data:
|
|
58
|
-
>>> data.set_mean_and_std(
|
|
59
|
+
>>> data.set_mean_and_std(image_means=[214.3], image_stds=[84.5])
|
|
59
60
|
|
|
60
61
|
One can pass also a list of transformations, by keyword, using the
|
|
61
62
|
SupportedTransform value:
|
|
@@ -67,11 +68,6 @@ class DataConfig(BaseModel):
|
|
|
67
68
|
... axes="YX",
|
|
68
69
|
... transforms=[
|
|
69
70
|
... {
|
|
70
|
-
... "name": SupportedTransform.NORMALIZE.value,
|
|
71
|
-
... "mean": 167.6,
|
|
72
|
-
... "std": 47.2,
|
|
73
|
-
... },
|
|
74
|
-
... {
|
|
75
71
|
... "name": "XYFlip",
|
|
76
72
|
... }
|
|
77
73
|
... ]
|
|
@@ -85,19 +81,24 @@ class DataConfig(BaseModel):
|
|
|
85
81
|
|
|
86
82
|
# Dataset configuration
|
|
87
83
|
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
88
|
-
patch_size: Union[
|
|
84
|
+
patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
|
|
89
85
|
batch_size: int = Field(default=1, ge=1, validate_default=True)
|
|
90
86
|
axes: str
|
|
91
87
|
|
|
92
88
|
# Optional fields
|
|
93
|
-
|
|
94
|
-
|
|
89
|
+
image_means: Optional[list[float]] = Field(
|
|
90
|
+
default=None, min_length=0, max_length=32
|
|
91
|
+
)
|
|
92
|
+
image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
|
|
93
|
+
target_means: Optional[list[float]] = Field(
|
|
94
|
+
default=None, min_length=0, max_length=32
|
|
95
|
+
)
|
|
96
|
+
target_stds: Optional[list[float]] = Field(
|
|
97
|
+
default=None, min_length=0, max_length=32
|
|
98
|
+
)
|
|
95
99
|
|
|
96
|
-
transforms:
|
|
100
|
+
transforms: list[TRANSFORMS_UNION] = Field(
|
|
97
101
|
default=[
|
|
98
|
-
{
|
|
99
|
-
"name": SupportedTransform.NORMALIZE.value,
|
|
100
|
-
},
|
|
101
102
|
{
|
|
102
103
|
"name": SupportedTransform.XY_FLIP.value,
|
|
103
104
|
},
|
|
@@ -116,8 +117,8 @@ class DataConfig(BaseModel):
|
|
|
116
117
|
@field_validator("patch_size")
|
|
117
118
|
@classmethod
|
|
118
119
|
def all_elements_power_of_2_minimum_8(
|
|
119
|
-
cls, patch_list: Union[
|
|
120
|
-
) -> Union[
|
|
120
|
+
cls, patch_list: Union[list[int]]
|
|
121
|
+
) -> Union[list[int]]:
|
|
121
122
|
"""
|
|
122
123
|
Validate patch size.
|
|
123
124
|
|
|
@@ -125,12 +126,12 @@ class DataConfig(BaseModel):
|
|
|
125
126
|
|
|
126
127
|
Parameters
|
|
127
128
|
----------
|
|
128
|
-
patch_list :
|
|
129
|
+
patch_list : list of int
|
|
129
130
|
Patch size.
|
|
130
131
|
|
|
131
132
|
Returns
|
|
132
133
|
-------
|
|
133
|
-
|
|
134
|
+
list of int
|
|
134
135
|
Validated patch size.
|
|
135
136
|
|
|
136
137
|
Raises
|
|
@@ -180,19 +181,19 @@ class DataConfig(BaseModel):
|
|
|
180
181
|
@field_validator("transforms")
|
|
181
182
|
@classmethod
|
|
182
183
|
def validate_prediction_transforms(
|
|
183
|
-
cls, transforms:
|
|
184
|
-
) ->
|
|
184
|
+
cls, transforms: list[TRANSFORMS_UNION]
|
|
185
|
+
) -> list[TRANSFORMS_UNION]:
|
|
185
186
|
"""
|
|
186
187
|
Validate N2VManipulate transform position in the transform list.
|
|
187
188
|
|
|
188
189
|
Parameters
|
|
189
190
|
----------
|
|
190
|
-
transforms :
|
|
191
|
+
transforms : list[Transformations_Union]
|
|
191
192
|
Transforms.
|
|
192
193
|
|
|
193
194
|
Returns
|
|
194
195
|
-------
|
|
195
|
-
|
|
196
|
+
list of transforms
|
|
196
197
|
Validated transforms.
|
|
197
198
|
|
|
198
199
|
Raises
|
|
@@ -235,29 +236,33 @@ class DataConfig(BaseModel):
|
|
|
235
236
|
If std is not None and mean is None.
|
|
236
237
|
"""
|
|
237
238
|
# check that mean and std are either both None, or both specified
|
|
238
|
-
if (self.
|
|
239
|
+
if (self.image_means and not self.image_stds) or (
|
|
240
|
+
self.image_stds and not self.image_means
|
|
241
|
+
):
|
|
239
242
|
raise ValueError(
|
|
240
243
|
"Mean and std must be either both None, or both specified."
|
|
241
244
|
)
|
|
242
245
|
|
|
243
|
-
|
|
246
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
247
|
+
len(self.image_means) != len(self.image_stds)
|
|
248
|
+
):
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"Mean and std must be specified for each " "input channel."
|
|
251
|
+
)
|
|
244
252
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
253
|
+
if (self.target_means and not self.target_stds) or (
|
|
254
|
+
self.target_stds and not self.target_means
|
|
255
|
+
):
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"Mean and std must be either both None, or both specified "
|
|
258
|
+
)
|
|
249
259
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
# search in the transforms for Normalize and update parameters
|
|
257
|
-
for transform in self.transforms:
|
|
258
|
-
if transform.name == SupportedTransform.NORMALIZE.value:
|
|
259
|
-
transform.mean = self.mean
|
|
260
|
-
transform.std = self.std
|
|
260
|
+
elif self.target_means is not None and self.target_stds is not None:
|
|
261
|
+
if len(self.target_means) != len(self.target_stds):
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Mean and std must be either both None, or both specified for each "
|
|
264
|
+
"target channel."
|
|
265
|
+
)
|
|
261
266
|
|
|
262
267
|
return self
|
|
263
268
|
|
|
@@ -341,7 +346,13 @@ class DataConfig(BaseModel):
|
|
|
341
346
|
if self.has_n2v_manipulate():
|
|
342
347
|
self.transforms.pop(-1)
|
|
343
348
|
|
|
344
|
-
def set_mean_and_std(
|
|
349
|
+
def set_mean_and_std(
|
|
350
|
+
self,
|
|
351
|
+
image_means: Union[NDArray, tuple, list, None],
|
|
352
|
+
image_stds: Union[NDArray, tuple, list, None],
|
|
353
|
+
target_means: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
354
|
+
target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
|
|
355
|
+
) -> None:
|
|
345
356
|
"""
|
|
346
357
|
Set mean and standard deviation of the data.
|
|
347
358
|
|
|
@@ -350,14 +361,33 @@ class DataConfig(BaseModel):
|
|
|
350
361
|
|
|
351
362
|
Parameters
|
|
352
363
|
----------
|
|
353
|
-
|
|
354
|
-
Mean
|
|
355
|
-
|
|
356
|
-
Standard deviation
|
|
364
|
+
image_means : NDArray or tuple or list
|
|
365
|
+
Mean values for normalization.
|
|
366
|
+
image_stds : NDArray or tuple or list
|
|
367
|
+
Standard deviation values for normalization.
|
|
368
|
+
target_means : NDArray or tuple or list, optional
|
|
369
|
+
Target mean values for normalization, by default ().
|
|
370
|
+
target_stds : NDArray or tuple or list, optional
|
|
371
|
+
Target standard deviation values for normalization, by default ().
|
|
357
372
|
"""
|
|
358
|
-
|
|
373
|
+
# make sure we pass a list
|
|
374
|
+
if image_means is not None:
|
|
375
|
+
image_means = list(image_means)
|
|
376
|
+
if image_stds is not None:
|
|
377
|
+
image_stds = list(image_stds)
|
|
378
|
+
if target_means is not None:
|
|
379
|
+
target_means = list(target_means)
|
|
380
|
+
if target_stds is not None:
|
|
381
|
+
target_stds = list(target_stds)
|
|
382
|
+
|
|
383
|
+
self._update(
|
|
384
|
+
image_means=image_means,
|
|
385
|
+
image_stds=image_stds,
|
|
386
|
+
target_means=target_means,
|
|
387
|
+
target_stds=target_stds,
|
|
388
|
+
)
|
|
359
389
|
|
|
360
|
-
def set_3D(self, axes: str, patch_size:
|
|
390
|
+
def set_3D(self, axes: str, patch_size: list[int]) -> None:
|
|
361
391
|
"""
|
|
362
392
|
Set 3D parameters.
|
|
363
393
|
|
|
@@ -365,7 +395,7 @@ class DataConfig(BaseModel):
|
|
|
365
395
|
----------
|
|
366
396
|
axes : str
|
|
367
397
|
Axes.
|
|
368
|
-
patch_size :
|
|
398
|
+
patch_size : list of int
|
|
369
399
|
Patch size.
|
|
370
400
|
"""
|
|
371
401
|
self._update(axes=axes, patch_size=patch_size)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
8
8
|
from typing_extensions import Self
|
|
@@ -17,17 +17,17 @@ class InferenceConfig(BaseModel):
|
|
|
17
17
|
|
|
18
18
|
# Mandatory fields
|
|
19
19
|
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
|
|
20
|
-
tile_size: Optional[Union[
|
|
20
|
+
tile_size: Optional[Union[list[int]]] = Field(
|
|
21
21
|
default=None, min_length=2, max_length=3
|
|
22
22
|
)
|
|
23
|
-
tile_overlap: Optional[Union[
|
|
23
|
+
tile_overlap: Optional[Union[list[int]]] = Field(
|
|
24
24
|
default=None, min_length=2, max_length=3
|
|
25
25
|
)
|
|
26
26
|
|
|
27
27
|
axes: str
|
|
28
28
|
|
|
29
|
-
|
|
30
|
-
|
|
29
|
+
image_means: list = Field(..., min_length=0, max_length=32)
|
|
30
|
+
image_stds: list = Field(..., min_length=0, max_length=32)
|
|
31
31
|
|
|
32
32
|
# only default TTAs are supported for now
|
|
33
33
|
tta_transforms: bool = Field(default=True)
|
|
@@ -38,8 +38,8 @@ class InferenceConfig(BaseModel):
|
|
|
38
38
|
@field_validator("tile_overlap")
|
|
39
39
|
@classmethod
|
|
40
40
|
def all_elements_non_zero_even(
|
|
41
|
-
cls, tile_overlap: Optional[
|
|
42
|
-
) -> Optional[
|
|
41
|
+
cls, tile_overlap: Optional[list[int]]
|
|
42
|
+
) -> Optional[list[int]]:
|
|
43
43
|
"""
|
|
44
44
|
Validate tile overlap.
|
|
45
45
|
|
|
@@ -47,12 +47,12 @@ class InferenceConfig(BaseModel):
|
|
|
47
47
|
|
|
48
48
|
Parameters
|
|
49
49
|
----------
|
|
50
|
-
tile_overlap :
|
|
50
|
+
tile_overlap : list[int] or None
|
|
51
51
|
Patch size.
|
|
52
52
|
|
|
53
53
|
Returns
|
|
54
54
|
-------
|
|
55
|
-
|
|
55
|
+
list[int] or None
|
|
56
56
|
Validated tile overlap.
|
|
57
57
|
|
|
58
58
|
Raises
|
|
@@ -77,19 +77,19 @@ class InferenceConfig(BaseModel):
|
|
|
77
77
|
@field_validator("tile_size")
|
|
78
78
|
@classmethod
|
|
79
79
|
def tile_min_8_power_of_2(
|
|
80
|
-
cls, tile_list: Optional[
|
|
81
|
-
) -> Optional[
|
|
80
|
+
cls, tile_list: Optional[list[int]]
|
|
81
|
+
) -> Optional[list[int]]:
|
|
82
82
|
"""
|
|
83
83
|
Validate that each entry is greater or equal than 8 and a power of 2.
|
|
84
84
|
|
|
85
85
|
Parameters
|
|
86
86
|
----------
|
|
87
|
-
tile_list :
|
|
87
|
+
tile_list : list of int
|
|
88
88
|
Patch size.
|
|
89
89
|
|
|
90
90
|
Returns
|
|
91
91
|
-------
|
|
92
|
-
|
|
92
|
+
list of int
|
|
93
93
|
Validated patch size.
|
|
94
94
|
|
|
95
95
|
Raises
|
|
@@ -182,11 +182,23 @@ class InferenceConfig(BaseModel):
|
|
|
182
182
|
If std is not None and mean is None.
|
|
183
183
|
"""
|
|
184
184
|
# check that mean and std are either both None, or both specified
|
|
185
|
-
if
|
|
185
|
+
if not self.image_means and not self.image_stds:
|
|
186
|
+
raise ValueError("Mean and std must be specified during inference.")
|
|
187
|
+
|
|
188
|
+
if (self.image_means and not self.image_stds) or (
|
|
189
|
+
self.image_stds and not self.image_means
|
|
190
|
+
):
|
|
186
191
|
raise ValueError(
|
|
187
192
|
"Mean and std must be either both None, or both specified."
|
|
188
193
|
)
|
|
189
194
|
|
|
195
|
+
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
196
|
+
len(self.image_means) != len(self.image_stds)
|
|
197
|
+
):
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"Mean and std must be specified for each " "input channel."
|
|
200
|
+
)
|
|
201
|
+
|
|
190
202
|
return self
|
|
191
203
|
|
|
192
204
|
def _update(self, **kwargs: Any) -> None:
|
|
@@ -201,7 +213,7 @@ class InferenceConfig(BaseModel):
|
|
|
201
213
|
self.__dict__.update(kwargs)
|
|
202
214
|
self.__class__.model_validate(self.__dict__)
|
|
203
215
|
|
|
204
|
-
def set_3D(self, axes: str, tile_size:
|
|
216
|
+
def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
|
|
205
217
|
"""
|
|
206
218
|
Set 3D parameters.
|
|
207
219
|
|
|
@@ -209,9 +221,9 @@ class InferenceConfig(BaseModel):
|
|
|
209
221
|
----------
|
|
210
222
|
axes : str
|
|
211
223
|
Axes.
|
|
212
|
-
tile_size :
|
|
224
|
+
tile_size : list of int
|
|
213
225
|
Tile size.
|
|
214
|
-
tile_overlap :
|
|
226
|
+
tile_overlap : list of int
|
|
215
227
|
Tile overlap.
|
|
216
228
|
"""
|
|
217
229
|
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
|