careamics 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +39 -28
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +391 -0
- careamics/cli/main.py +134 -0
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +170 -0
- careamics/config/configuration_factory.py +481 -170
- careamics/config/configuration_model.py +6 -3
- careamics/config/data_model.py +31 -20
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +35 -45
- careamics/config/likelihood_model.py +60 -0
- careamics/config/nm_model.py +127 -0
- careamics/config/optimizer_models.py +3 -1
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/support/supported_optimizers.py +1 -1
- careamics/config/support/supported_transforms.py +1 -0
- careamics/config/training_model.py +35 -6
- careamics/config/transformations/__init__.py +4 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/transform_union.py +20 -0
- careamics/config/vae_algorithm_model.py +137 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +364 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +367 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +4 -4
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +15 -0
- careamics/lvae_training/dataset/config.py +123 -0
- careamics/lvae_training/dataset/lc_dataset.py +267 -0
- careamics/lvae_training/{data_modules.py → dataset/multich_dataset.py} +375 -501
- careamics/lvae_training/dataset/multifile_dataset.py +334 -0
- careamics/lvae_training/dataset/types.py +43 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +232 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +109 -64
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +20 -7
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +190 -129
- careamics/models/lvae/lvae.py +60 -148
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/compose.py +90 -15
- careamics/transforms/n2v_manipulate.py +6 -2
- careamics/transforms/normalize.py +14 -3
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/transforms/xy_flip.py +16 -6
- careamics/transforms/xy_random_rotate90.py +16 -7
- careamics/utils/metrics.py +277 -24
- careamics/utils/serializers.py +60 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/METADATA +5 -4
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/RECORD +85 -60
- careamics-0.0.4.dist-info/entry_points.txt +2 -0
- careamics/config/architectures/vae_model.py +0 -42
- careamics/lvae_training/data_utils.py +0 -618
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,11 +9,11 @@ from typing import Literal, Union
|
|
|
9
9
|
|
|
10
10
|
import yaml
|
|
11
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
12
|
-
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
13
13
|
from typing_extensions import Self
|
|
14
14
|
|
|
15
|
-
from .algorithm_model import AlgorithmConfig
|
|
16
15
|
from .data_model import DataConfig
|
|
16
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
17
17
|
from .references import (
|
|
18
18
|
CARE,
|
|
19
19
|
CUSTOM,
|
|
@@ -39,6 +39,7 @@ from .training_model import TrainingConfig
|
|
|
39
39
|
from .transformations.n2v_manipulate_model import (
|
|
40
40
|
N2VManipulateModel,
|
|
41
41
|
)
|
|
42
|
+
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
class Configuration(BaseModel):
|
|
@@ -155,7 +156,9 @@ class Configuration(BaseModel):
|
|
|
155
156
|
"""Name of the experiment, used to name logs and checkpoints."""
|
|
156
157
|
|
|
157
158
|
# Sub-configurations
|
|
158
|
-
algorithm_config:
|
|
159
|
+
algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
|
|
160
|
+
discriminator="algorithm"
|
|
161
|
+
)
|
|
159
162
|
"""Algorithm configuration, holding all parameters required to configure the
|
|
160
163
|
model."""
|
|
161
164
|
|
careamics/config/data_model.py
CHANGED
|
@@ -5,31 +5,44 @@ from __future__ import annotations
|
|
|
5
5
|
from pprint import pformat
|
|
6
6
|
from typing import Any, Literal, Optional, Union
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
10
|
from pydantic import (
|
|
10
11
|
BaseModel,
|
|
11
12
|
ConfigDict,
|
|
12
|
-
Discriminator,
|
|
13
13
|
Field,
|
|
14
|
+
PlainSerializer,
|
|
14
15
|
field_validator,
|
|
15
16
|
model_validator,
|
|
16
17
|
)
|
|
17
18
|
from typing_extensions import Annotated, Self
|
|
18
19
|
|
|
19
20
|
from .support import SupportedTransform
|
|
20
|
-
from .transformations
|
|
21
|
-
from .transformations.xy_flip_model import XYFlipModel
|
|
22
|
-
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
21
|
+
from .transformations import TRANSFORMS_UNION, N2VManipulateModel
|
|
23
22
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
24
23
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
24
|
+
|
|
25
|
+
def np_float_to_scientific_str(x: float) -> str:
|
|
26
|
+
"""Return a string scientific representation of a float.
|
|
27
|
+
|
|
28
|
+
In particular, this method is used to serialize floats to strings, allowing
|
|
29
|
+
numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
x : float
|
|
34
|
+
Input value.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
str
|
|
39
|
+
Scientific string representation of the input value.
|
|
40
|
+
"""
|
|
41
|
+
return np.format_float_scientific(x, precision=7)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
|
|
45
|
+
"""Annotated float type, used to serialize floats to strings."""
|
|
33
46
|
|
|
34
47
|
|
|
35
48
|
class DataConfig(BaseModel):
|
|
@@ -94,20 +107,20 @@ class DataConfig(BaseModel):
|
|
|
94
107
|
"""Batch size for training."""
|
|
95
108
|
|
|
96
109
|
# Optional fields
|
|
97
|
-
image_means: Optional[list[
|
|
110
|
+
image_means: Optional[list[Float]] = Field(
|
|
98
111
|
default=None, min_length=0, max_length=32
|
|
99
112
|
)
|
|
100
113
|
"""Means of the data across channels, used for normalization."""
|
|
101
114
|
|
|
102
|
-
image_stds: Optional[list[
|
|
115
|
+
image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
|
|
103
116
|
"""Standard deviations of the data across channels, used for normalization."""
|
|
104
117
|
|
|
105
|
-
target_means: Optional[list[
|
|
118
|
+
target_means: Optional[list[Float]] = Field(
|
|
106
119
|
default=None, min_length=0, max_length=32
|
|
107
120
|
)
|
|
108
121
|
"""Means of the target data across channels, used for normalization."""
|
|
109
122
|
|
|
110
|
-
target_stds: Optional[list[
|
|
123
|
+
target_stds: Optional[list[Float]] = Field(
|
|
111
124
|
default=None, min_length=0, max_length=32
|
|
112
125
|
)
|
|
113
126
|
"""Standard deviations of the target data across channels, used for
|
|
@@ -265,9 +278,7 @@ class DataConfig(BaseModel):
|
|
|
265
278
|
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
266
279
|
len(self.image_means) != len(self.image_stds)
|
|
267
280
|
):
|
|
268
|
-
raise ValueError(
|
|
269
|
-
"Mean and std must be specified for each " "input channel."
|
|
270
|
-
)
|
|
281
|
+
raise ValueError("Mean and std must be specified for each input channel.")
|
|
271
282
|
|
|
272
283
|
if (self.target_means and not self.target_stds) or (
|
|
273
284
|
self.target_stds and not self.target_means
|
|
@@ -380,7 +391,7 @@ class DataConfig(BaseModel):
|
|
|
380
391
|
|
|
381
392
|
Parameters
|
|
382
393
|
----------
|
|
383
|
-
image_means : numpy.ndarray
|
|
394
|
+
image_means : numpy.ndarray, tuple or list
|
|
384
395
|
Mean values for normalization.
|
|
385
396
|
image_stds : numpy.ndarray, tuple or list
|
|
386
397
|
Standard deviation values for normalization.
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
1
|
+
"""Module containing `FCNAlgorithmConfig` class."""
|
|
4
2
|
|
|
5
3
|
from pprint import pformat
|
|
6
4
|
from typing import Literal, Union
|
|
@@ -8,11 +6,11 @@ from typing import Literal, Union
|
|
|
8
6
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
9
7
|
from typing_extensions import Self
|
|
10
8
|
|
|
11
|
-
from .architectures import CustomModel, UNetModel
|
|
12
|
-
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
9
|
+
from careamics.config.architectures import CustomModel, UNetModel
|
|
10
|
+
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
class
|
|
13
|
+
class FCNAlgorithmConfig(BaseModel):
|
|
16
14
|
"""Algorithm configuration.
|
|
17
15
|
|
|
18
16
|
This Pydantic model validates the parameters governing the components of the
|
|
@@ -26,11 +24,11 @@ class AlgorithmConfig(BaseModel):
|
|
|
26
24
|
|
|
27
25
|
Attributes
|
|
28
26
|
----------
|
|
29
|
-
algorithm :
|
|
27
|
+
algorithm : {"n2v", "care", "n2n", "custom"}
|
|
30
28
|
Algorithm to use.
|
|
31
|
-
loss :
|
|
29
|
+
loss : {"n2v", "mae", "mse"}
|
|
32
30
|
Loss function to use.
|
|
33
|
-
model :
|
|
31
|
+
model : UNetModel or CustomModel
|
|
34
32
|
Model architecture to use.
|
|
35
33
|
optimizer : OptimizerModel, optional
|
|
36
34
|
Optimizer to use.
|
|
@@ -47,7 +45,7 @@ class AlgorithmConfig(BaseModel):
|
|
|
47
45
|
Examples
|
|
48
46
|
--------
|
|
49
47
|
Minimum example:
|
|
50
|
-
>>> from careamics.config import
|
|
48
|
+
>>> from careamics.config import FCNAlgorithmConfig
|
|
51
49
|
>>> config_dict = {
|
|
52
50
|
... "algorithm": "n2v",
|
|
53
51
|
... "loss": "n2v",
|
|
@@ -55,58 +53,37 @@ class AlgorithmConfig(BaseModel):
|
|
|
55
53
|
... "architecture": "UNet",
|
|
56
54
|
... }
|
|
57
55
|
... }
|
|
58
|
-
>>> config =
|
|
59
|
-
|
|
60
|
-
Using a custom model:
|
|
61
|
-
>>> from torch import nn, ones
|
|
62
|
-
>>> from careamics.config import AlgorithmConfig, register_model
|
|
63
|
-
...
|
|
64
|
-
>>> @register_model(name="linear_model")
|
|
65
|
-
... class LinearModel(nn.Module):
|
|
66
|
-
... def __init__(self, in_features, out_features, *args, **kwargs):
|
|
67
|
-
... super().__init__()
|
|
68
|
-
... self.in_features = in_features
|
|
69
|
-
... self.out_features = out_features
|
|
70
|
-
... self.weight = nn.Parameter(ones(in_features, out_features))
|
|
71
|
-
... self.bias = nn.Parameter(ones(out_features))
|
|
72
|
-
... def forward(self, input):
|
|
73
|
-
... return (input @ self.weight) + self.bias
|
|
74
|
-
...
|
|
75
|
-
>>> config_dict = {
|
|
76
|
-
... "algorithm": "custom",
|
|
77
|
-
... "loss": "mse",
|
|
78
|
-
... "model": {
|
|
79
|
-
... "architecture": "Custom",
|
|
80
|
-
... "name": "linear_model",
|
|
81
|
-
... "in_features": 10,
|
|
82
|
-
... "out_features": 5,
|
|
83
|
-
... }
|
|
84
|
-
... }
|
|
85
|
-
>>> config = AlgorithmConfig(**config_dict)
|
|
56
|
+
>>> config = FCNAlgorithmConfig(**config_dict)
|
|
86
57
|
"""
|
|
87
58
|
|
|
88
59
|
# Pydantic class configuration
|
|
89
60
|
model_config = ConfigDict(
|
|
90
61
|
protected_namespaces=(), # allows to use model_* as a field name
|
|
91
62
|
validate_assignment=True,
|
|
63
|
+
extra="allow",
|
|
92
64
|
)
|
|
93
65
|
|
|
94
66
|
# Mandatory fields
|
|
95
|
-
algorithm: Literal["n2v", "care", "n2n", "custom"]
|
|
96
|
-
"""Name of the algorithm, as defined in SupportedAlgorithm.
|
|
67
|
+
algorithm: Literal["n2v", "care", "n2n", "custom"]
|
|
68
|
+
"""Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
|
|
69
|
+
model architecture."""
|
|
97
70
|
|
|
98
71
|
loss: Literal["n2v", "mae", "mse"]
|
|
99
72
|
"""Loss function to use, as defined in SupportedLoss."""
|
|
100
73
|
|
|
101
|
-
model: Union[UNetModel,
|
|
102
|
-
"""Model architecture to use,
|
|
74
|
+
model: Union[UNetModel, CustomModel] = Field(discriminator="architecture")
|
|
75
|
+
"""Model architecture to use, along with its parameters. Compatible architectures
|
|
76
|
+
are defined in SupportedArchitecture, and their Pydantic models in
|
|
77
|
+
`careamics.config.architectures`."""
|
|
78
|
+
# TODO supported architectures are now all the architectures but does not warn users
|
|
79
|
+
# of the compatibility with the algorithm
|
|
103
80
|
|
|
104
81
|
# Optional fields
|
|
105
82
|
optimizer: OptimizerModel = OptimizerModel()
|
|
106
83
|
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
107
84
|
|
|
108
85
|
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
109
|
-
"""Learning rate scheduler to use, defined in
|
|
86
|
+
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""
|
|
110
87
|
|
|
111
88
|
@model_validator(mode="after")
|
|
112
89
|
def algorithm_cross_validation(self: Self) -> Self:
|
|
@@ -146,8 +123,10 @@ class AlgorithmConfig(BaseModel):
|
|
|
146
123
|
if self.loss == "n2v":
|
|
147
124
|
raise ValueError("Supervised algorithms do not support loss `n2v`.")
|
|
148
125
|
|
|
149
|
-
if
|
|
150
|
-
raise ValueError(
|
|
126
|
+
if (self.algorithm == "custom") != (self.model.architecture == "custom"):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"Algorithm and model architecture must be both `custom` or not."
|
|
129
|
+
)
|
|
151
130
|
|
|
152
131
|
return self
|
|
153
132
|
|
|
@@ -160,3 +139,14 @@ class AlgorithmConfig(BaseModel):
|
|
|
160
139
|
Pretty string.
|
|
161
140
|
"""
|
|
162
141
|
return pformat(self.model_dump())
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def get_compatible_algorithms(cls) -> list[str]:
|
|
145
|
+
"""Get the list of compatible algorithms.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
list of str
|
|
150
|
+
List of compatible algorithms.
|
|
151
|
+
"""
|
|
152
|
+
return ["n2v", "care", "n2n"]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Likelihood model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer, PlainValidator
|
|
8
|
+
from typing_extensions import Annotated
|
|
9
|
+
|
|
10
|
+
from careamics.models.lvae.noise_models import (
|
|
11
|
+
GaussianMixtureNoiseModel,
|
|
12
|
+
MultiChannelNoiseModel,
|
|
13
|
+
)
|
|
14
|
+
from careamics.utils.serializers import _array_to_json, _to_torch
|
|
15
|
+
|
|
16
|
+
NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel]
|
|
17
|
+
|
|
18
|
+
# TODO: this is a temporary solution to serialize and deserialize tensor fields
|
|
19
|
+
# in pydantic models. Specifically, the aim is to enable saving and loading configs
|
|
20
|
+
# with such tensors to/from JSON files during, resp., training and evaluation.
|
|
21
|
+
Tensor = Annotated[
|
|
22
|
+
Union[np.ndarray, torch.Tensor],
|
|
23
|
+
PlainSerializer(_array_to_json, return_type=str),
|
|
24
|
+
PlainValidator(_to_torch),
|
|
25
|
+
]
|
|
26
|
+
"""Annotated tensor type, used to serialize arrays or tensors to JSON strings
|
|
27
|
+
and deserialize them back to tensors."""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GaussianLikelihoodConfig(BaseModel):
|
|
31
|
+
"""Gaussian likelihood configuration."""
|
|
32
|
+
|
|
33
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
34
|
+
|
|
35
|
+
predict_logvar: Optional[Literal["pixelwise"]] = None
|
|
36
|
+
"""If `pixelwise`, log-variance is computed for each pixel, else log-variance
|
|
37
|
+
is not computed."""
|
|
38
|
+
|
|
39
|
+
logvar_lowerbound: Union[float, None] = None
|
|
40
|
+
"""The lowerbound value for log-variance."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class NMLikelihoodConfig(BaseModel):
|
|
44
|
+
"""Noise model likelihood configuration."""
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
|
|
47
|
+
|
|
48
|
+
# TODO remove and use as parameters to the likelihood functions?
|
|
49
|
+
data_mean: Tensor = torch.zeros(1)
|
|
50
|
+
"""The mean of the data, used to unnormalize data for noise model evaluation.
|
|
51
|
+
Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
52
|
+
|
|
53
|
+
# TODO remove and use as parameters to the likelihood functions?
|
|
54
|
+
data_std: Tensor = torch.ones(1)
|
|
55
|
+
"""The standard deviation of the data, used to unnormalize data for noise
|
|
56
|
+
model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
|
|
57
|
+
|
|
58
|
+
# TODO: serialization/deserialization for this
|
|
59
|
+
noise_model: Optional[NoiseModel] = Field(default=None, exclude=True)
|
|
60
|
+
"""The noise model instance used to compute the likelihood."""
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Noise models config."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from pydantic import (
|
|
9
|
+
BaseModel,
|
|
10
|
+
ConfigDict,
|
|
11
|
+
Field,
|
|
12
|
+
PlainSerializer,
|
|
13
|
+
PlainValidator,
|
|
14
|
+
model_validator,
|
|
15
|
+
)
|
|
16
|
+
from typing_extensions import Annotated, Self
|
|
17
|
+
|
|
18
|
+
from careamics.utils.serializers import _array_to_json, _to_numpy
|
|
19
|
+
|
|
20
|
+
# TODO: this is a temporary solution to serialize and deserialize array fields
|
|
21
|
+
# in pydantic models. Specifically, the aim is to enable saving and loading configs
|
|
22
|
+
# with such arrays to/from JSON files during, resp., training and evaluation.
|
|
23
|
+
Array = Annotated[
|
|
24
|
+
Union[np.ndarray, torch.Tensor],
|
|
25
|
+
PlainSerializer(_array_to_json, return_type=str),
|
|
26
|
+
PlainValidator(_to_numpy),
|
|
27
|
+
]
|
|
28
|
+
"""Annotated array type, used to serialize arrays or tensors to JSON strings
|
|
29
|
+
and deserialize them back to arrays."""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# TODO: add histogram-based noise model
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GaussianMixtureNMConfig(BaseModel):
|
|
36
|
+
"""Gaussian mixture noise model."""
|
|
37
|
+
|
|
38
|
+
model_config = ConfigDict(
|
|
39
|
+
protected_namespaces=(),
|
|
40
|
+
validate_assignment=True,
|
|
41
|
+
arbitrary_types_allowed=True,
|
|
42
|
+
extra="allow",
|
|
43
|
+
)
|
|
44
|
+
# model type
|
|
45
|
+
model_type: Literal["GaussianMixtureNoiseModel"]
|
|
46
|
+
|
|
47
|
+
path: Optional[Union[Path, str]] = None
|
|
48
|
+
"""Path to the directory where the trained noise model (*.npz) is saved in the
|
|
49
|
+
`train` method."""
|
|
50
|
+
|
|
51
|
+
# TODO remove and use as parameters to the NM functions?
|
|
52
|
+
signal: Optional[Union[str, Path, np.ndarray]] = Field(default=None, exclude=True)
|
|
53
|
+
"""Path to the file containing signal or respective numpy array."""
|
|
54
|
+
|
|
55
|
+
# TODO remove and use as parameters to the NM functions?
|
|
56
|
+
observation: Optional[Union[str, Path, np.ndarray]] = Field(
|
|
57
|
+
default=None, exclude=True
|
|
58
|
+
)
|
|
59
|
+
"""Path to the file containing observation or respective numpy array."""
|
|
60
|
+
|
|
61
|
+
weight: Optional[Array] = None
|
|
62
|
+
"""A [3*n_gaussian, n_coeff] sized array containing the values of the weights
|
|
63
|
+
describing the GMM noise model, with each row corresponding to one
|
|
64
|
+
parameter of each gaussian, namely [mean, standard deviation and weight].
|
|
65
|
+
Specifically, rows are organized as follows:
|
|
66
|
+
- first n_gaussian rows correspond to the means
|
|
67
|
+
- next n_gaussian rows correspond to the weights
|
|
68
|
+
- last n_gaussian rows correspond to the standard deviations
|
|
69
|
+
If `weight=None`, the weight array is initialized using the `min_signal`
|
|
70
|
+
and `max_signal` parameters."""
|
|
71
|
+
|
|
72
|
+
n_gaussian: int = Field(default=1, ge=1)
|
|
73
|
+
"""Number of gaussians used for the GMM."""
|
|
74
|
+
|
|
75
|
+
n_coeff: int = Field(default=2, ge=2)
|
|
76
|
+
"""Number of coefficients to describe the functional relationship between gaussian
|
|
77
|
+
parameters and the signal. 2 implies a linear relationship, 3 implies a quadratic
|
|
78
|
+
relationship and so on."""
|
|
79
|
+
|
|
80
|
+
min_signal: float = Field(default=0.0, ge=0.0)
|
|
81
|
+
"""Minimum signal intensity expected in the image."""
|
|
82
|
+
|
|
83
|
+
max_signal: float = Field(default=1.0, ge=0.0)
|
|
84
|
+
"""Maximum signal intensity expected in the image."""
|
|
85
|
+
|
|
86
|
+
min_sigma: float = Field(default=200.0, ge=0.0) # TODO took from nb in pn2v
|
|
87
|
+
"""Minimum value of `standard deviation` allowed in the GMM.
|
|
88
|
+
All values of `standard deviation` below this are clamped to this value."""
|
|
89
|
+
|
|
90
|
+
tol: float = Field(default=1e-10)
|
|
91
|
+
"""Tolerance used in the computation of the noise model likelihood."""
|
|
92
|
+
|
|
93
|
+
@model_validator(mode="after")
|
|
94
|
+
def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
|
|
95
|
+
"""Validate paths provided in the config.
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
Self
|
|
100
|
+
Returns itself.
|
|
101
|
+
"""
|
|
102
|
+
if self.path and (self.signal is not None or self.observation is not None):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"Either only 'path' to pre-trained noise model should be"
|
|
105
|
+
"provided or only signal and observation in form of paths"
|
|
106
|
+
"or numpy arrays."
|
|
107
|
+
)
|
|
108
|
+
if not self.path and (self.signal is None or self.observation is None):
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Either only 'path' to pre-trained noise model should be"
|
|
111
|
+
"provided or only signal and observation in form of paths"
|
|
112
|
+
"or numpy arrays."
|
|
113
|
+
)
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# The noise model is given by a set of GMMs, one for each target
|
|
118
|
+
# e.g., 2 target channels, 2 noise models
|
|
119
|
+
class MultiChannelNMConfig(BaseModel):
|
|
120
|
+
"""Noise Model config aggregating noise models for single output channels."""
|
|
121
|
+
|
|
122
|
+
# TODO: check that this model config is OK
|
|
123
|
+
model_config = ConfigDict(
|
|
124
|
+
validate_assignment=True, arbitrary_types_allowed=True, extra="allow"
|
|
125
|
+
)
|
|
126
|
+
noise_models: list[GaussianMixtureNMConfig]
|
|
127
|
+
"""List of noise models, one for each target channel."""
|
|
@@ -44,7 +44,9 @@ class OptimizerModel(BaseModel):
|
|
|
44
44
|
)
|
|
45
45
|
|
|
46
46
|
# Mandatory field
|
|
47
|
-
name: Literal["Adam", "SGD"] = Field(
|
|
47
|
+
name: Literal["Adam", "SGD", "Adamax"] = Field(
|
|
48
|
+
default="Adam", validate_default=True
|
|
49
|
+
)
|
|
48
50
|
"""Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
|
|
49
51
|
|
|
50
52
|
# Optional parameters, empty dict default value to allow filtering dictionary
|
|
@@ -6,15 +6,28 @@ from careamics.utils import BaseEnum
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SupportedAlgorithm(str, BaseEnum):
|
|
9
|
-
"""Algorithms available in CAREamics.
|
|
10
|
-
|
|
11
|
-
# TODO
|
|
12
|
-
"""
|
|
9
|
+
"""Algorithms available in CAREamics."""
|
|
13
10
|
|
|
14
11
|
N2V = "n2v"
|
|
12
|
+
"""Noise2Void algorithm, a self-supervised approach based on blind denoising."""
|
|
13
|
+
|
|
15
14
|
CARE = "care"
|
|
15
|
+
"""Content-aware image restoration, a supervised algorithm used for a variety
|
|
16
|
+
of tasks."""
|
|
17
|
+
|
|
16
18
|
N2N = "n2n"
|
|
19
|
+
"""Noise2Noise algorithm, a self-supervised denoising scheme based on comparing
|
|
20
|
+
noisy images of the same sample."""
|
|
21
|
+
|
|
22
|
+
MUSPLIT = "musplit"
|
|
23
|
+
"""An image splitting approach based on ladder VAE architectures."""
|
|
24
|
+
|
|
25
|
+
DENOISPLIT = "denoisplit"
|
|
26
|
+
"""An image splitting and denoising approach based on ladder VAE architectures."""
|
|
27
|
+
|
|
17
28
|
CUSTOM = "custom"
|
|
29
|
+
"""Custom algorithm, used for cases where a custom architecture is provided."""
|
|
30
|
+
|
|
18
31
|
# PN2V = "pn2v"
|
|
19
32
|
# HDN = "hdn"
|
|
20
33
|
# SEG = "segmentation"
|
|
@@ -4,17 +4,14 @@ from careamics.utils import BaseEnum
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class SupportedArchitecture(str, BaseEnum):
|
|
7
|
-
"""Supported architectures.
|
|
7
|
+
"""Supported architectures."""
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
UNET = "UNet"
|
|
10
|
+
"""UNet architecture used with N2V, CARE and Noise2Noise."""
|
|
10
11
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
- Custom: custom model registered with `@register_model` decorator
|
|
14
|
-
"""
|
|
12
|
+
LVAE = "LVAE"
|
|
13
|
+
"""Ladder Variational Autoencoder used for muSplit and denoiSplit."""
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
"Custom" # TODO all the others tags are small letters, except the architect
|
|
20
|
-
)
|
|
15
|
+
CUSTOM = "custom"
|
|
16
|
+
"""Keyword used for custom architectures provided by users and only compatible
|
|
17
|
+
with `FCNAlgorithmConfig` configuration."""
|
|
@@ -22,6 +22,8 @@ class SupportedLoss(str, BaseEnum):
|
|
|
22
22
|
N2V = "n2v"
|
|
23
23
|
# PN2V = "pn2v"
|
|
24
24
|
# HDN = "hdn"
|
|
25
|
+
MUSPLIT = "musplit"
|
|
26
|
+
DENOISPLIT = "denoisplit"
|
|
27
|
+
DENOISPLIT_MUSPLIT = "denoisplit_musplit"
|
|
25
28
|
# CE = "ce"
|
|
26
29
|
# DICE = "dice"
|
|
27
|
-
# CUSTOM = "custom" # TODO create mechanism for that
|
|
@@ -3,13 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from pprint import pformat
|
|
6
|
-
from typing import Literal, Optional
|
|
6
|
+
from typing import Literal, Optional, Union
|
|
7
7
|
|
|
8
|
-
from pydantic import
|
|
9
|
-
BaseModel,
|
|
10
|
-
ConfigDict,
|
|
11
|
-
Field,
|
|
12
|
-
)
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
13
9
|
|
|
14
10
|
from .callback_model import CheckpointModel, EarlyStoppingModel
|
|
15
11
|
|
|
@@ -37,6 +33,20 @@ class TrainingConfig(BaseModel):
|
|
|
37
33
|
num_epochs: int = Field(default=20, ge=1)
|
|
38
34
|
"""Number of epochs, greater than 0."""
|
|
39
35
|
|
|
36
|
+
precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
|
|
37
|
+
"""Numerical precision"""
|
|
38
|
+
max_steps: int = Field(default=-1, ge=-1)
|
|
39
|
+
"""Maximum number of steps to train for. -1 means no limit."""
|
|
40
|
+
check_val_every_n_epoch: int = Field(default=1, ge=1)
|
|
41
|
+
"""Validation step frequency."""
|
|
42
|
+
enable_progress_bar: bool = Field(default=True)
|
|
43
|
+
"""Whether to enable the progress bar."""
|
|
44
|
+
accumulate_grad_batches: int = Field(default=1, ge=1)
|
|
45
|
+
"""Number of batches to accumulate gradients over before stepping the optimizer."""
|
|
46
|
+
gradient_clip_val: Optional[Union[int, float]] = None
|
|
47
|
+
"""The value to which to clip the gradient"""
|
|
48
|
+
gradient_clip_algorithm: Literal["value", "norm"] = "norm"
|
|
49
|
+
"""The algorithm to use for gradient clipping (see lightning `Trainer`)."""
|
|
40
50
|
logger: Optional[Literal["wandb", "tensorboard"]] = None
|
|
41
51
|
"""Logger to use during training. If None, no logger will be used. Available
|
|
42
52
|
loggers are defined in SupportedLogger."""
|
|
@@ -70,3 +80,22 @@ class TrainingConfig(BaseModel):
|
|
|
70
80
|
Whether the logger is defined or not.
|
|
71
81
|
"""
|
|
72
82
|
return self.logger is not None
|
|
83
|
+
|
|
84
|
+
@field_validator("max_steps")
|
|
85
|
+
@classmethod
|
|
86
|
+
def validate_max_steps(cls, max_steps: int) -> int:
|
|
87
|
+
"""Validate the max_steps parameter.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
max_steps : int
|
|
92
|
+
Maximum number of steps to train for. -1 means no limit.
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
int
|
|
97
|
+
Validated max_steps.
|
|
98
|
+
"""
|
|
99
|
+
if max_steps == 0:
|
|
100
|
+
raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
|
|
101
|
+
return max_steps
|
|
@@ -5,11 +5,14 @@ __all__ = [
|
|
|
5
5
|
"XYFlipModel",
|
|
6
6
|
"NormalizeModel",
|
|
7
7
|
"XYRandomRotate90Model",
|
|
8
|
-
"
|
|
8
|
+
"TransformModel",
|
|
9
|
+
"TRANSFORMS_UNION",
|
|
9
10
|
]
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
from .n2v_manipulate_model import N2VManipulateModel
|
|
13
14
|
from .normalize_model import NormalizeModel
|
|
15
|
+
from .transform_model import TransformModel
|
|
16
|
+
from .transform_union import TRANSFORMS_UNION
|
|
14
17
|
from .xy_flip_model import XYFlipModel
|
|
15
18
|
from .xy_random_rotate90_model import XYRandomRotate90Model
|
|
@@ -33,7 +33,7 @@ class N2VManipulateModel(TransformModel):
|
|
|
33
33
|
|
|
34
34
|
name: Literal["N2VManipulate"] = "N2VManipulate"
|
|
35
35
|
roi_size: int = Field(default=11, ge=3, le=21)
|
|
36
|
-
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=
|
|
36
|
+
masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=10.0)
|
|
37
37
|
strategy: Literal["uniform", "median"] = Field(default="uniform")
|
|
38
38
|
struct_mask_axis: Literal["horizontal", "vertical", "none"] = Field(default="none")
|
|
39
39
|
struct_mask_span: int = Field(default=5, ge=3, le=15)
|