careamics 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +14 -11
- careamics/config/__init__.py +7 -3
- careamics/config/architectures/__init__.py +2 -2
- careamics/config/architectures/architecture_model.py +1 -1
- careamics/config/architectures/custom_model.py +11 -8
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/configuration_factory.py +11 -3
- careamics/config/configuration_model.py +7 -3
- careamics/config/data_model.py +33 -8
- careamics/config/{algorithm_model.py → fcn_algorithm_model.py} +28 -43
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/support/supported_activations.py +1 -0
- careamics/config/support/supported_algorithms.py +17 -4
- careamics/config/support/supported_architectures.py +8 -11
- careamics/config/support/supported_losses.py +3 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/file_io/read/tiff.py +1 -1
- careamics/lightning/__init__.py +3 -2
- careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
- careamics/lightning/lightning_module.py +365 -9
- careamics/lightning/predict_data_module.py +2 -2
- careamics/lightning/train_data_module.py +2 -2
- careamics/losses/__init__.py +11 -1
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/{losses.py → fcn/losses.py} +1 -1
- careamics/losses/loss_factory.py +112 -6
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/{data_utils.py → dataset/data_utils.py} +277 -194
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/{data_modules.py → dataset/vae_dataset.py} +306 -472
- careamics/lvae_training/get_config.py +1 -1
- careamics/lvae_training/train_lvae.py +6 -3
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +2 -2
- careamics/model_io/bmz_io.py +19 -6
- careamics/model_io/model_io_utils.py +16 -4
- careamics/models/__init__.py +1 -3
- careamics/models/activation.py +2 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +21 -21
- careamics/models/lvae/likelihoods.py +180 -128
- careamics/models/lvae/lvae.py +52 -136
- careamics/models/lvae/noise_models.py +318 -186
- careamics/models/lvae/utils.py +2 -2
- careamics/models/model_factory.py +22 -7
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/stitch_prediction.py +16 -2
- careamics/transforms/pixel_manipulation.py +1 -1
- careamics/utils/metrics.py +74 -1
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/METADATA +2 -2
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/RECORD +63 -49
- careamics/config/architectures/vae_model.py +0 -42
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/WHEEL +0 -0
- {careamics-0.0.2.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +0 -0
careamics/careamist.py
CHANGED
|
@@ -13,10 +13,7 @@ from pytorch_lightning.callbacks import (
|
|
|
13
13
|
)
|
|
14
14
|
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
|
|
15
15
|
|
|
16
|
-
from careamics.config import
|
|
17
|
-
Configuration,
|
|
18
|
-
load_configuration,
|
|
19
|
-
)
|
|
16
|
+
from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
|
|
20
17
|
from careamics.config.support import (
|
|
21
18
|
SupportedAlgorithm,
|
|
22
19
|
SupportedArchitecture,
|
|
@@ -25,7 +22,7 @@ from careamics.config.support import (
|
|
|
25
22
|
)
|
|
26
23
|
from careamics.dataset.dataset_utils import reshape_array
|
|
27
24
|
from careamics.lightning import (
|
|
28
|
-
|
|
25
|
+
FCNModule,
|
|
29
26
|
HyperParametersCallback,
|
|
30
27
|
PredictDataModule,
|
|
31
28
|
ProgressBarCallback,
|
|
@@ -148,9 +145,12 @@ class CAREamist:
|
|
|
148
145
|
self.cfg = source
|
|
149
146
|
|
|
150
147
|
# instantiate model
|
|
151
|
-
self.
|
|
152
|
-
|
|
153
|
-
|
|
148
|
+
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
|
|
149
|
+
self.model = FCNModule(
|
|
150
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
raise NotImplementedError("Architecture not supported.")
|
|
154
154
|
|
|
155
155
|
# path to configuration file or model
|
|
156
156
|
else:
|
|
@@ -164,9 +164,12 @@ class CAREamist:
|
|
|
164
164
|
self.cfg = load_configuration(source)
|
|
165
165
|
|
|
166
166
|
# instantiate model
|
|
167
|
-
self.
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
|
|
168
|
+
self.model = FCNModule(
|
|
169
|
+
algorithm_config=self.cfg.algorithm_config,
|
|
170
|
+
) # type: ignore
|
|
171
|
+
else:
|
|
172
|
+
raise NotImplementedError("Architecture not supported.")
|
|
170
173
|
|
|
171
174
|
# attempt loading a pre-trained model
|
|
172
175
|
else:
|
careamics/config/__init__.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""Configuration module."""
|
|
2
2
|
|
|
3
3
|
__all__ = [
|
|
4
|
-
"
|
|
4
|
+
"FCNAlgorithmConfig",
|
|
5
|
+
"VAEAlgorithmConfig",
|
|
5
6
|
"DataConfig",
|
|
6
7
|
"Configuration",
|
|
7
8
|
"CheckpointModel",
|
|
@@ -15,9 +16,9 @@ __all__ = [
|
|
|
15
16
|
"register_model",
|
|
16
17
|
"CustomModel",
|
|
17
18
|
"clear_custom_models",
|
|
19
|
+
"GaussianMixtureNMConfig",
|
|
20
|
+
"MultiChannelNMConfig",
|
|
18
21
|
]
|
|
19
|
-
|
|
20
|
-
from .algorithm_model import AlgorithmConfig
|
|
21
22
|
from .architectures import CustomModel, clear_custom_models, register_model
|
|
22
23
|
from .callback_model import CheckpointModel
|
|
23
24
|
from .configuration_factory import (
|
|
@@ -31,5 +32,8 @@ from .configuration_model import (
|
|
|
31
32
|
save_configuration,
|
|
32
33
|
)
|
|
33
34
|
from .data_model import DataConfig
|
|
35
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
34
36
|
from .inference_model import InferenceConfig
|
|
37
|
+
from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
|
|
35
38
|
from .training_model import TrainingConfig
|
|
39
|
+
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
@@ -4,7 +4,7 @@ __all__ = [
|
|
|
4
4
|
"ArchitectureModel",
|
|
5
5
|
"CustomModel",
|
|
6
6
|
"UNetModel",
|
|
7
|
-
"
|
|
7
|
+
"LVAEModel",
|
|
8
8
|
"clear_custom_models",
|
|
9
9
|
"get_custom_model",
|
|
10
10
|
"register_model",
|
|
@@ -12,6 +12,6 @@ __all__ = [
|
|
|
12
12
|
|
|
13
13
|
from .architecture_model import ArchitectureModel
|
|
14
14
|
from .custom_model import CustomModel
|
|
15
|
+
from .lvae_model import LVAEModel
|
|
15
16
|
from .register_model import clear_custom_models, get_custom_model, register_model
|
|
16
17
|
from .unet_model import UNetModel
|
|
17
|
-
from .vae_model import VAEModel
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import inspect
|
|
5
6
|
from pprint import pformat
|
|
6
7
|
from typing import Any, Literal
|
|
7
8
|
|
|
@@ -23,12 +24,13 @@ class CustomModel(ArchitectureModel):
|
|
|
23
24
|
|
|
24
25
|
Attributes
|
|
25
26
|
----------
|
|
26
|
-
architecture : Literal["
|
|
27
|
-
Discriminator for the custom model, must be set to "
|
|
27
|
+
architecture : Literal["custom"]
|
|
28
|
+
Discriminator for the custom model, must be set to "custom".
|
|
28
29
|
name : str
|
|
29
30
|
Name of the custom model.
|
|
30
31
|
parameters : CustomParametersModel
|
|
31
|
-
|
|
32
|
+
All parameters, required for the initialization of the torch module have to be
|
|
33
|
+
passed here.
|
|
32
34
|
|
|
33
35
|
Raises
|
|
34
36
|
------
|
|
@@ -57,7 +59,7 @@ class CustomModel(ArchitectureModel):
|
|
|
57
59
|
...
|
|
58
60
|
>>> # Create a configuration
|
|
59
61
|
>>> config_dict = {
|
|
60
|
-
... "architecture": "
|
|
62
|
+
... "architecture": "custom",
|
|
61
63
|
... "name": "my_linear",
|
|
62
64
|
... "in_features": 10,
|
|
63
65
|
... "out_features": 5,
|
|
@@ -71,10 +73,9 @@ class CustomModel(ArchitectureModel):
|
|
|
71
73
|
)
|
|
72
74
|
|
|
73
75
|
# discriminator used for choosing the pydantic model in Model
|
|
74
|
-
architecture: Literal["
|
|
76
|
+
architecture: Literal["custom"]
|
|
75
77
|
"""Name of the architecture."""
|
|
76
78
|
|
|
77
|
-
# name of the custom model
|
|
78
79
|
name: str
|
|
79
80
|
"""Name of the custom model."""
|
|
80
81
|
|
|
@@ -120,10 +121,12 @@ class CustomModel(ArchitectureModel):
|
|
|
120
121
|
get_custom_model(self.name)(**self.model_dump())
|
|
121
122
|
except Exception as e:
|
|
122
123
|
raise ValueError(
|
|
123
|
-
f"
|
|
124
|
+
f"while passing parameters to the model {e}. Verify that all "
|
|
124
125
|
f"mandatory parameters are provided, and that either the {e} accepts "
|
|
125
126
|
f"*args and **kwargs in its __init__() method, or that no additional"
|
|
126
|
-
f"parameter is provided."
|
|
127
|
+
f"parameter is provided. Trace: "
|
|
128
|
+
f"filename: {inspect.trace()[-1].filename}, function: "
|
|
129
|
+
f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
|
|
127
130
|
) from None
|
|
128
131
|
|
|
129
132
|
return self
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""LVAE Pydantic model."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field, field_validator, model_validator
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from .architecture_model import ArchitectureModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TODO: it is quite confusing to call this LVAEModel, as it is basically a config
|
|
12
|
+
class LVAEModel(ArchitectureModel):
|
|
13
|
+
"""LVAE model."""
|
|
14
|
+
|
|
15
|
+
model_config = ConfigDict(validate_assignment=True, validate_default=True)
|
|
16
|
+
|
|
17
|
+
architecture: Literal["LVAE"]
|
|
18
|
+
input_shape: int = Field(default=64, ge=8, le=1024)
|
|
19
|
+
multiscale_count: int = Field(default=5) # TODO clarify
|
|
20
|
+
# 0 - off, len(z_dims) + 1 # TODO can/should be le to z_dims len + 1
|
|
21
|
+
z_dims: list = Field(default=[128, 128, 128, 128])
|
|
22
|
+
output_channels: int = Field(default=1, ge=1)
|
|
23
|
+
encoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
24
|
+
decoder_n_filters: int = Field(default=64, ge=8, le=1024)
|
|
25
|
+
encoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
26
|
+
decoder_dropout: float = Field(default=0.1, ge=0.0, le=0.9)
|
|
27
|
+
nonlinearity: Literal[
|
|
28
|
+
"None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU", "ELU"
|
|
29
|
+
] = Field(
|
|
30
|
+
default="ELU",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
predict_logvar: Literal[None, "pixelwise"] = None
|
|
34
|
+
|
|
35
|
+
# TODO this parameter is exessive -> Remove & refactor
|
|
36
|
+
enable_noise_model: bool = Field(
|
|
37
|
+
default=True,
|
|
38
|
+
)
|
|
39
|
+
analytical_kl: bool = Field(
|
|
40
|
+
default=False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@field_validator("encoder_n_filters")
|
|
44
|
+
@classmethod
|
|
45
|
+
def validate_encoder_even(cls, encoder_n_filters: int) -> int:
|
|
46
|
+
"""
|
|
47
|
+
Validate that num_channels_init is even.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
encoder_n_filters : int
|
|
52
|
+
Number of channels.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
int
|
|
57
|
+
Validated number of channels.
|
|
58
|
+
|
|
59
|
+
Raises
|
|
60
|
+
------
|
|
61
|
+
ValueError
|
|
62
|
+
If the number of channels is odd.
|
|
63
|
+
"""
|
|
64
|
+
# if odd
|
|
65
|
+
if encoder_n_filters % 2 != 0:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Number of channels for the bottom layer must be even"
|
|
68
|
+
f" (got {encoder_n_filters})."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return encoder_n_filters
|
|
72
|
+
|
|
73
|
+
@field_validator("decoder_n_filters")
|
|
74
|
+
@classmethod
|
|
75
|
+
def validate_decoder_even(cls, decoder_n_filters: int) -> int:
|
|
76
|
+
"""
|
|
77
|
+
Validate that num_channels_init is even.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
decoder_n_filters : int
|
|
82
|
+
Number of channels.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
int
|
|
87
|
+
Validated number of channels.
|
|
88
|
+
|
|
89
|
+
Raises
|
|
90
|
+
------
|
|
91
|
+
ValueError
|
|
92
|
+
If the number of channels is odd.
|
|
93
|
+
"""
|
|
94
|
+
# if odd
|
|
95
|
+
if decoder_n_filters % 2 != 0:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Number of channels for the bottom layer must be even"
|
|
98
|
+
f" (got {decoder_n_filters})."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return decoder_n_filters
|
|
102
|
+
|
|
103
|
+
@field_validator("z_dims")
|
|
104
|
+
def validate_z_dims(cls, z_dims: tuple) -> tuple:
|
|
105
|
+
"""
|
|
106
|
+
Validate the z_dims.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
z_dims : tuple
|
|
111
|
+
Tuple of z dimensions.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
tuple
|
|
116
|
+
Validated z dimensions.
|
|
117
|
+
|
|
118
|
+
Raises
|
|
119
|
+
------
|
|
120
|
+
ValueError
|
|
121
|
+
If the number of z dimensions is not 4.
|
|
122
|
+
"""
|
|
123
|
+
if len(z_dims) < 2:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"Number of z dimensions must be at least 2 (got {len(z_dims)})."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return z_dims
|
|
129
|
+
|
|
130
|
+
@model_validator(mode="after")
|
|
131
|
+
def validate_multiscale_count(cls, self: Self) -> Self:
|
|
132
|
+
"""
|
|
133
|
+
Validate the multiscale count.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
self : Self
|
|
138
|
+
The model.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
Self
|
|
143
|
+
The validated model.
|
|
144
|
+
"""
|
|
145
|
+
# if self.multiscale_count != 0:
|
|
146
|
+
# if self.multiscale_count != len(self.z_dims) - 1:
|
|
147
|
+
# raise ValueError(
|
|
148
|
+
# f"Multiscale count must be 0 or equal to the number of Z "
|
|
149
|
+
# f"dims - 1 (got {self.multiscale_count} and {len(self.z_dims)})."
|
|
150
|
+
# )
|
|
151
|
+
|
|
152
|
+
return self
|
|
153
|
+
|
|
154
|
+
def set_3D(self, is_3D: bool) -> None:
|
|
155
|
+
"""
|
|
156
|
+
Set 3D model by setting the `conv_dims` parameters.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
is_3D : bool
|
|
161
|
+
Whether the algorithm is 3D or not.
|
|
162
|
+
"""
|
|
163
|
+
raise NotImplementedError("VAE is not implemented yet.")
|
|
164
|
+
|
|
165
|
+
def is_3D(self) -> bool:
|
|
166
|
+
"""
|
|
167
|
+
Return whether the model is 3D or not.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
bool
|
|
172
|
+
Whether the model is 3D or not.
|
|
173
|
+
"""
|
|
174
|
+
raise NotImplementedError("VAE is not implemented yet.")
|
|
@@ -2,10 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, List, Literal, Optional
|
|
4
4
|
|
|
5
|
-
from .algorithm_model import AlgorithmConfig
|
|
6
5
|
from .architectures import UNetModel
|
|
7
6
|
from .configuration_model import Configuration
|
|
8
7
|
from .data_model import DataConfig
|
|
8
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
9
9
|
from .support import (
|
|
10
10
|
SupportedAlgorithm,
|
|
11
11
|
SupportedArchitecture,
|
|
@@ -16,7 +16,9 @@ from .support import (
|
|
|
16
16
|
from .training_model import TrainingConfig
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
# TODO rename ?
|
|
19
20
|
def _create_supervised_configuration(
|
|
21
|
+
algorithm_type: Literal["fcn"],
|
|
20
22
|
algorithm: Literal["care", "n2n"],
|
|
21
23
|
experiment_name: str,
|
|
22
24
|
data_type: Literal["array", "tiff", "custom"],
|
|
@@ -37,6 +39,8 @@ def _create_supervised_configuration(
|
|
|
37
39
|
|
|
38
40
|
Parameters
|
|
39
41
|
----------
|
|
42
|
+
algorithm_type : Literal["fcn"]
|
|
43
|
+
Type of the algorithm.
|
|
40
44
|
algorithm : Literal["care", "n2n"]
|
|
41
45
|
Algorithm to use.
|
|
42
46
|
experiment_name : str
|
|
@@ -97,7 +101,8 @@ def _create_supervised_configuration(
|
|
|
97
101
|
)
|
|
98
102
|
|
|
99
103
|
# algorithm model
|
|
100
|
-
algorithm =
|
|
104
|
+
algorithm = FCNAlgorithmConfig(
|
|
105
|
+
algorithm_type=algorithm_type,
|
|
101
106
|
algorithm=algorithm,
|
|
102
107
|
loss=loss,
|
|
103
108
|
model=unet_model,
|
|
@@ -215,6 +220,7 @@ def create_care_configuration(
|
|
|
215
220
|
n_channels_out = n_channels_in
|
|
216
221
|
|
|
217
222
|
return _create_supervised_configuration(
|
|
223
|
+
algorithm_type="fcn",
|
|
218
224
|
algorithm="care",
|
|
219
225
|
experiment_name=experiment_name,
|
|
220
226
|
data_type=data_type,
|
|
@@ -304,6 +310,7 @@ def create_n2n_configuration(
|
|
|
304
310
|
n_channels_out = n_channels_in
|
|
305
311
|
|
|
306
312
|
return _create_supervised_configuration(
|
|
313
|
+
algorithm_type="fcn",
|
|
307
314
|
algorithm="n2n",
|
|
308
315
|
experiment_name=experiment_name,
|
|
309
316
|
data_type=data_type,
|
|
@@ -514,7 +521,8 @@ def create_n2v_configuration(
|
|
|
514
521
|
)
|
|
515
522
|
|
|
516
523
|
# algorithm model
|
|
517
|
-
algorithm =
|
|
524
|
+
algorithm = FCNAlgorithmConfig(
|
|
525
|
+
algorithm_type="fcn",
|
|
518
526
|
algorithm=SupportedAlgorithm.N2V.value,
|
|
519
527
|
loss=SupportedLoss.N2V.value,
|
|
520
528
|
model=unet_model,
|
|
@@ -9,11 +9,11 @@ from typing import Literal, Union
|
|
|
9
9
|
|
|
10
10
|
import yaml
|
|
11
11
|
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
12
|
-
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
13
13
|
from typing_extensions import Self
|
|
14
14
|
|
|
15
|
-
from .algorithm_model import AlgorithmConfig
|
|
16
15
|
from .data_model import DataConfig
|
|
16
|
+
from .fcn_algorithm_model import FCNAlgorithmConfig
|
|
17
17
|
from .references import (
|
|
18
18
|
CARE,
|
|
19
19
|
CUSTOM,
|
|
@@ -39,6 +39,7 @@ from .training_model import TrainingConfig
|
|
|
39
39
|
from .transformations.n2v_manipulate_model import (
|
|
40
40
|
N2VManipulateModel,
|
|
41
41
|
)
|
|
42
|
+
from .vae_algorithm_model import VAEAlgorithmConfig
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
class Configuration(BaseModel):
|
|
@@ -123,6 +124,7 @@ class Configuration(BaseModel):
|
|
|
123
124
|
>>> config_dict = {
|
|
124
125
|
... "experiment_name": "N2V_experiment",
|
|
125
126
|
... "algorithm_config": {
|
|
127
|
+
... "algorithm_type": "fcn",
|
|
126
128
|
... "algorithm": "n2v",
|
|
127
129
|
... "loss": "n2v",
|
|
128
130
|
... "model": {
|
|
@@ -155,7 +157,9 @@ class Configuration(BaseModel):
|
|
|
155
157
|
"""Name of the experiment, used to name logs and checkpoints."""
|
|
156
158
|
|
|
157
159
|
# Sub-configurations
|
|
158
|
-
algorithm_config:
|
|
160
|
+
algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
|
|
161
|
+
discriminator="algorithm_type"
|
|
162
|
+
)
|
|
159
163
|
"""Algorithm configuration, holding all parameters required to configure the
|
|
160
164
|
model."""
|
|
161
165
|
|
careamics/config/data_model.py
CHANGED
|
@@ -5,12 +5,14 @@ from __future__ import annotations
|
|
|
5
5
|
from pprint import pformat
|
|
6
6
|
from typing import Any, Literal, Optional, Union
|
|
7
7
|
|
|
8
|
+
import numpy as np
|
|
8
9
|
from numpy.typing import NDArray
|
|
9
10
|
from pydantic import (
|
|
10
11
|
BaseModel,
|
|
11
12
|
ConfigDict,
|
|
12
13
|
Discriminator,
|
|
13
14
|
Field,
|
|
15
|
+
PlainSerializer,
|
|
14
16
|
field_validator,
|
|
15
17
|
model_validator,
|
|
16
18
|
)
|
|
@@ -22,6 +24,30 @@ from .transformations.xy_flip_model import XYFlipModel
|
|
|
22
24
|
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
|
|
23
25
|
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
|
|
24
26
|
|
|
27
|
+
|
|
28
|
+
def np_float_to_scientific_str(x: float) -> str:
|
|
29
|
+
"""Return a string scientific representation of a float.
|
|
30
|
+
|
|
31
|
+
In particular, this method is used to serialize floats to strings, allowing
|
|
32
|
+
numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
x : float
|
|
37
|
+
Input value.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
str
|
|
42
|
+
Scientific string representation of the input value.
|
|
43
|
+
"""
|
|
44
|
+
return np.format_float_scientific(x, precision=7)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
|
|
48
|
+
"""Annotated float type, used to serialize floats to strings."""
|
|
49
|
+
|
|
50
|
+
|
|
25
51
|
TRANSFORMS_UNION = Annotated[
|
|
26
52
|
Union[
|
|
27
53
|
XYFlipModel,
|
|
@@ -30,6 +56,7 @@ TRANSFORMS_UNION = Annotated[
|
|
|
30
56
|
],
|
|
31
57
|
Discriminator("name"), # used to tell the different transform models apart
|
|
32
58
|
]
|
|
59
|
+
"""Available transforms in CAREamics."""
|
|
33
60
|
|
|
34
61
|
|
|
35
62
|
class DataConfig(BaseModel):
|
|
@@ -94,20 +121,20 @@ class DataConfig(BaseModel):
|
|
|
94
121
|
"""Batch size for training."""
|
|
95
122
|
|
|
96
123
|
# Optional fields
|
|
97
|
-
image_means: Optional[list[
|
|
124
|
+
image_means: Optional[list[Float]] = Field(
|
|
98
125
|
default=None, min_length=0, max_length=32
|
|
99
126
|
)
|
|
100
127
|
"""Means of the data across channels, used for normalization."""
|
|
101
128
|
|
|
102
|
-
image_stds: Optional[list[
|
|
129
|
+
image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
|
|
103
130
|
"""Standard deviations of the data across channels, used for normalization."""
|
|
104
131
|
|
|
105
|
-
target_means: Optional[list[
|
|
132
|
+
target_means: Optional[list[Float]] = Field(
|
|
106
133
|
default=None, min_length=0, max_length=32
|
|
107
134
|
)
|
|
108
135
|
"""Means of the target data across channels, used for normalization."""
|
|
109
136
|
|
|
110
|
-
target_stds: Optional[list[
|
|
137
|
+
target_stds: Optional[list[Float]] = Field(
|
|
111
138
|
default=None, min_length=0, max_length=32
|
|
112
139
|
)
|
|
113
140
|
"""Standard deviations of the target data across channels, used for
|
|
@@ -265,9 +292,7 @@ class DataConfig(BaseModel):
|
|
|
265
292
|
elif (self.image_means is not None and self.image_stds is not None) and (
|
|
266
293
|
len(self.image_means) != len(self.image_stds)
|
|
267
294
|
):
|
|
268
|
-
raise ValueError(
|
|
269
|
-
"Mean and std must be specified for each " "input channel."
|
|
270
|
-
)
|
|
295
|
+
raise ValueError("Mean and std must be specified for each input channel.")
|
|
271
296
|
|
|
272
297
|
if (self.target_means and not self.target_stds) or (
|
|
273
298
|
self.target_stds and not self.target_means
|
|
@@ -380,7 +405,7 @@ class DataConfig(BaseModel):
|
|
|
380
405
|
|
|
381
406
|
Parameters
|
|
382
407
|
----------
|
|
383
|
-
image_means : numpy.ndarray
|
|
408
|
+
image_means : numpy.ndarray, tuple or list
|
|
384
409
|
Mean values for normalization.
|
|
385
410
|
image_stds : numpy.ndarray, tuple or list
|
|
386
411
|
Standard deviation values for normalization.
|
|
@@ -1,6 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
1
|
+
"""Module containing `FCNAlgorithmConfig` class."""
|
|
4
2
|
|
|
5
3
|
from pprint import pformat
|
|
6
4
|
from typing import Literal, Union
|
|
@@ -8,11 +6,11 @@ from typing import Literal, Union
|
|
|
8
6
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
9
7
|
from typing_extensions import Self
|
|
10
8
|
|
|
11
|
-
from .architectures import CustomModel, UNetModel
|
|
12
|
-
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
9
|
+
from careamics.config.architectures import CustomModel, UNetModel
|
|
10
|
+
from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
class
|
|
13
|
+
class FCNAlgorithmConfig(BaseModel):
|
|
16
14
|
"""Algorithm configuration.
|
|
17
15
|
|
|
18
16
|
This Pydantic model validates the parameters governing the components of the
|
|
@@ -30,7 +28,7 @@ class AlgorithmConfig(BaseModel):
|
|
|
30
28
|
Algorithm to use.
|
|
31
29
|
loss : Literal["n2v", "mae", "mse"]
|
|
32
30
|
Loss function to use.
|
|
33
|
-
model : Union[UNetModel,
|
|
31
|
+
model : Union[UNetModel, LVAEModel, CustomModel]
|
|
34
32
|
Model architecture to use.
|
|
35
33
|
optimizer : OptimizerModel, optional
|
|
36
34
|
Optimizer to use.
|
|
@@ -47,66 +45,51 @@ class AlgorithmConfig(BaseModel):
|
|
|
47
45
|
Examples
|
|
48
46
|
--------
|
|
49
47
|
Minimum example:
|
|
50
|
-
>>> from careamics.config import
|
|
48
|
+
>>> from careamics.config import FCNAlgorithmConfig
|
|
51
49
|
>>> config_dict = {
|
|
52
50
|
... "algorithm": "n2v",
|
|
51
|
+
... "algorithm_type": "fcn",
|
|
53
52
|
... "loss": "n2v",
|
|
54
53
|
... "model": {
|
|
55
54
|
... "architecture": "UNet",
|
|
56
55
|
... }
|
|
57
56
|
... }
|
|
58
|
-
>>> config =
|
|
59
|
-
|
|
60
|
-
Using a custom model:
|
|
61
|
-
>>> from torch import nn, ones
|
|
62
|
-
>>> from careamics.config import AlgorithmConfig, register_model
|
|
63
|
-
...
|
|
64
|
-
>>> @register_model(name="linear_model")
|
|
65
|
-
... class LinearModel(nn.Module):
|
|
66
|
-
... def __init__(self, in_features, out_features, *args, **kwargs):
|
|
67
|
-
... super().__init__()
|
|
68
|
-
... self.in_features = in_features
|
|
69
|
-
... self.out_features = out_features
|
|
70
|
-
... self.weight = nn.Parameter(ones(in_features, out_features))
|
|
71
|
-
... self.bias = nn.Parameter(ones(out_features))
|
|
72
|
-
... def forward(self, input):
|
|
73
|
-
... return (input @ self.weight) + self.bias
|
|
74
|
-
...
|
|
75
|
-
>>> config_dict = {
|
|
76
|
-
... "algorithm": "custom",
|
|
77
|
-
... "loss": "mse",
|
|
78
|
-
... "model": {
|
|
79
|
-
... "architecture": "Custom",
|
|
80
|
-
... "name": "linear_model",
|
|
81
|
-
... "in_features": 10,
|
|
82
|
-
... "out_features": 5,
|
|
83
|
-
... }
|
|
84
|
-
... }
|
|
85
|
-
>>> config = AlgorithmConfig(**config_dict)
|
|
57
|
+
>>> config = FCNAlgorithmConfig(**config_dict)
|
|
86
58
|
"""
|
|
87
59
|
|
|
88
60
|
# Pydantic class configuration
|
|
89
61
|
model_config = ConfigDict(
|
|
90
62
|
protected_namespaces=(), # allows to use model_* as a field name
|
|
91
63
|
validate_assignment=True,
|
|
64
|
+
extra="allow",
|
|
92
65
|
)
|
|
93
66
|
|
|
94
67
|
# Mandatory fields
|
|
95
|
-
|
|
96
|
-
|
|
68
|
+
# defined in SupportedAlgorithm
|
|
69
|
+
algorithm_type: Literal["fcn"]
|
|
70
|
+
"""Algorithm type must be `fcn` (fully convolutional network) to differentiate this
|
|
71
|
+
configuration from LVAE."""
|
|
72
|
+
|
|
73
|
+
algorithm: Literal["n2v", "care", "n2n", "custom"]
|
|
74
|
+
"""Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
|
|
75
|
+
model architecture."""
|
|
97
76
|
|
|
98
77
|
loss: Literal["n2v", "mae", "mse"]
|
|
99
78
|
"""Loss function to use, as defined in SupportedLoss."""
|
|
100
79
|
|
|
101
|
-
model: Union[UNetModel,
|
|
102
|
-
"""Model architecture to use,
|
|
80
|
+
model: Union[UNetModel, CustomModel] = Field(discriminator="architecture")
|
|
81
|
+
"""Model architecture to use, along with its parameters. Compatible architectures
|
|
82
|
+
are defined in SupportedArchitecture, and their Pydantic models in
|
|
83
|
+
`careamics.config.architectures`."""
|
|
84
|
+
# TODO supported architectures are now all the architectures but does not warn users
|
|
85
|
+
# of the compatibility with the algorithm
|
|
103
86
|
|
|
104
87
|
# Optional fields
|
|
105
88
|
optimizer: OptimizerModel = OptimizerModel()
|
|
106
89
|
"""Optimizer to use, defined in SupportedOptimizer."""
|
|
107
90
|
|
|
108
91
|
lr_scheduler: LrSchedulerModel = LrSchedulerModel()
|
|
109
|
-
"""Learning rate scheduler to use, defined in
|
|
92
|
+
"""Learning rate scheduler to use, defined in SupportedLrScheduler."""
|
|
110
93
|
|
|
111
94
|
@model_validator(mode="after")
|
|
112
95
|
def algorithm_cross_validation(self: Self) -> Self:
|
|
@@ -146,8 +129,10 @@ class AlgorithmConfig(BaseModel):
|
|
|
146
129
|
if self.loss == "n2v":
|
|
147
130
|
raise ValueError("Supervised algorithms do not support loss `n2v`.")
|
|
148
131
|
|
|
149
|
-
if
|
|
150
|
-
raise ValueError(
|
|
132
|
+
if (self.algorithm == "custom") != (self.model.architecture == "custom"):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Algorithm and model architecture must be both `custom` or not."
|
|
135
|
+
)
|
|
151
136
|
|
|
152
137
|
return self
|
|
153
138
|
|