careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +17 -2
- careamics/careamist.py +4 -3
- careamics/cli/conf.py +1 -2
- careamics/cli/main.py +1 -2
- careamics/cli/utils.py +3 -3
- careamics/config/__init__.py +47 -25
- careamics/config/algorithms/__init__.py +15 -0
- careamics/config/algorithms/care_algorithm_model.py +50 -0
- careamics/config/algorithms/n2n_algorithm_model.py +42 -0
- careamics/config/algorithms/n2v_algorithm_model.py +35 -0
- careamics/config/algorithms/unet_algorithm_model.py +88 -0
- careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
- careamics/config/architectures/__init__.py +1 -11
- careamics/config/architectures/architecture_model.py +3 -3
- careamics/config/architectures/lvae_model.py +6 -1
- careamics/config/architectures/unet_model.py +1 -0
- careamics/config/care_configuration.py +100 -0
- careamics/config/configuration.py +354 -0
- careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
- careamics/config/configuration_io.py +85 -0
- careamics/config/data/__init__.py +10 -0
- careamics/config/{data_model.py → data/data_model.py} +58 -198
- careamics/config/data/n2v_data_model.py +193 -0
- careamics/config/likelihood_model.py +1 -2
- careamics/config/n2n_configuration.py +101 -0
- careamics/config/n2v_configuration.py +266 -0
- careamics/config/nm_model.py +1 -2
- careamics/config/support/__init__.py +7 -7
- careamics/config/support/supported_algorithms.py +0 -3
- careamics/config/support/supported_architectures.py +0 -4
- careamics/config/transformations/__init__.py +10 -4
- careamics/config/transformations/transform_model.py +3 -3
- careamics/config/transformations/transform_unions.py +42 -0
- careamics/config/validators/validator_utils.py +3 -3
- careamics/dataset/__init__.py +2 -2
- careamics/dataset/dataset_utils/__init__.py +3 -3
- careamics/dataset/dataset_utils/dataset_utils.py +4 -6
- careamics/dataset/dataset_utils/file_utils.py +9 -9
- careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
- careamics/dataset/in_memory_dataset.py +11 -12
- careamics/dataset/iterable_dataset.py +4 -4
- careamics/dataset/iterable_pred_dataset.py +2 -1
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
- careamics/dataset/patching/random_patching.py +11 -10
- careamics/dataset/patching/sequential_patching.py +26 -26
- careamics/dataset/patching/validate_patch_dimension.py +3 -3
- careamics/dataset/tiling/__init__.py +2 -2
- careamics/dataset/tiling/collate_tiles.py +3 -3
- careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
- careamics/dataset/tiling/tiled_patching.py +11 -10
- careamics/file_io/__init__.py +5 -5
- careamics/file_io/read/__init__.py +1 -1
- careamics/file_io/read/get_func.py +2 -2
- careamics/file_io/write/__init__.py +2 -2
- careamics/lightning/__init__.py +5 -5
- careamics/lightning/callbacks/__init__.py +1 -1
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
- careamics/lightning/callbacks/progress_bar_callback.py +2 -2
- careamics/lightning/lightning_module.py +11 -7
- careamics/lightning/train_data_module.py +26 -26
- careamics/losses/__init__.py +3 -3
- careamics/model_io/__init__.py +1 -1
- careamics/model_io/bioimage/__init__.py +1 -1
- careamics/model_io/bioimage/_readme_factory.py +1 -1
- careamics/model_io/bioimage/model_description.py +17 -17
- careamics/model_io/bmz_io.py +6 -17
- careamics/model_io/model_io_utils.py +9 -9
- careamics/models/layers.py +16 -16
- careamics/models/lvae/lvae.py +0 -3
- careamics/models/model_factory.py +2 -15
- careamics/models/unet.py +8 -8
- careamics/prediction_utils/__init__.py +1 -1
- careamics/prediction_utils/prediction_outputs.py +15 -15
- careamics/prediction_utils/stitch_prediction.py +6 -6
- careamics/transforms/__init__.py +5 -5
- careamics/transforms/compose.py +13 -13
- careamics/transforms/n2v_manipulate.py +3 -3
- careamics/transforms/pixel_manipulation.py +9 -9
- careamics/transforms/xy_random_rotate90.py +4 -4
- careamics/utils/__init__.py +5 -5
- careamics/utils/context.py +2 -1
- careamics/utils/logging.py +11 -10
- careamics/utils/torch_utils.py +7 -7
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
- careamics/config/architectures/custom_model.py +0 -162
- careamics/config/architectures/register_model.py +0 -103
- careamics/config/configuration_model.py +0 -603
- careamics/config/fcn_algorithm_model.py +0 -152
- careamics/config/references/__init__.py +0 -45
- careamics/config/references/algorithm_descriptions.py +0 -132
- careamics/config/references/references.py +0 -39
- careamics/config/transformations/transform_union.py +0 -20
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
"""Pydantic CAREamics configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from pprint import pformat
|
|
7
|
+
from typing import Any, Literal, Union
|
|
8
|
+
|
|
9
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
11
|
+
from typing_extensions import Self
|
|
12
|
+
|
|
13
|
+
from careamics.config.algorithms import UNetBasedAlgorithm, VAEBasedAlgorithm
|
|
14
|
+
from careamics.config.data import GeneralDataConfig
|
|
15
|
+
from careamics.config.training_model import TrainingConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Configuration(BaseModel):
|
|
19
|
+
"""
|
|
20
|
+
CAREamics configuration.
|
|
21
|
+
|
|
22
|
+
The configuration defines all parameters used to build and train a CAREamics model.
|
|
23
|
+
These parameters are validated to ensure that they are compatible with each other.
|
|
24
|
+
|
|
25
|
+
It contains three sub-configurations:
|
|
26
|
+
|
|
27
|
+
- AlgorithmModel: configuration for the algorithm training, which includes the
|
|
28
|
+
architecture, loss function, optimizer, and other hyperparameters.
|
|
29
|
+
- DataModel: configuration for the dataloader, which includes the type of data,
|
|
30
|
+
transformations, mean/std and other parameters.
|
|
31
|
+
- TrainingModel: configuration for the training, which includes the number of
|
|
32
|
+
epochs or the callbacks.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
experiment_name : str
|
|
37
|
+
Name of the experiment, used when saving logs and checkpoints.
|
|
38
|
+
algorithm : AlgorithmModel
|
|
39
|
+
Algorithm configuration.
|
|
40
|
+
data : DataModel
|
|
41
|
+
Data configuration.
|
|
42
|
+
training : TrainingModel
|
|
43
|
+
Training configuration.
|
|
44
|
+
|
|
45
|
+
Methods
|
|
46
|
+
-------
|
|
47
|
+
set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
|
|
48
|
+
Switch configuration between 2D and 3D.
|
|
49
|
+
model_dump(
|
|
50
|
+
exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
|
|
51
|
+
) -> Dict
|
|
52
|
+
Export configuration to a dictionary.
|
|
53
|
+
|
|
54
|
+
Raises
|
|
55
|
+
------
|
|
56
|
+
ValueError
|
|
57
|
+
Configuration parameter type validation errors.
|
|
58
|
+
ValueError
|
|
59
|
+
If the experiment name contains invalid characters or is empty.
|
|
60
|
+
ValueError
|
|
61
|
+
If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
|
|
62
|
+
with "Z" in data axes.
|
|
63
|
+
ValueError
|
|
64
|
+
Algorithm, data or training validation errors.
|
|
65
|
+
|
|
66
|
+
Notes
|
|
67
|
+
-----
|
|
68
|
+
We provide convenience methods to create standards configurations, for instance:
|
|
69
|
+
>>> from careamics.config import create_n2v_configuration
|
|
70
|
+
>>> config = create_n2v_configuration(
|
|
71
|
+
... experiment_name="n2v_experiment",
|
|
72
|
+
... data_type="array",
|
|
73
|
+
... axes="YX",
|
|
74
|
+
... patch_size=[64, 64],
|
|
75
|
+
... batch_size=32,
|
|
76
|
+
... num_epochs=100
|
|
77
|
+
... )
|
|
78
|
+
|
|
79
|
+
The configuration can be exported to a dictionary using the model_dump method:
|
|
80
|
+
>>> config_dict = config.model_dump()
|
|
81
|
+
|
|
82
|
+
Configurations can also be exported or imported from yaml files:
|
|
83
|
+
>>> from careamics.config import save_configuration, load_configuration
|
|
84
|
+
>>> path_to_config = save_configuration(config, my_path / "config.yml")
|
|
85
|
+
>>> other_config = load_configuration(path_to_config)
|
|
86
|
+
|
|
87
|
+
Examples
|
|
88
|
+
--------
|
|
89
|
+
Minimum example:
|
|
90
|
+
>>> from careamics import configuration_factory
|
|
91
|
+
>>> config_dict = {
|
|
92
|
+
... "experiment_name": "N2V_experiment",
|
|
93
|
+
... "algorithm_config": {
|
|
94
|
+
... "algorithm": "n2v",
|
|
95
|
+
... "loss": "n2v",
|
|
96
|
+
... "model": {
|
|
97
|
+
... "architecture": "UNet",
|
|
98
|
+
... },
|
|
99
|
+
... },
|
|
100
|
+
... "training_config": {
|
|
101
|
+
... "num_epochs": 200,
|
|
102
|
+
... },
|
|
103
|
+
... "data_config": {
|
|
104
|
+
... "data_type": "tiff",
|
|
105
|
+
... "patch_size": [64, 64],
|
|
106
|
+
... "axes": "SYX",
|
|
107
|
+
... },
|
|
108
|
+
... }
|
|
109
|
+
>>> config = configuration_factory(config_dict)
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
model_config = ConfigDict(
|
|
113
|
+
validate_assignment=True,
|
|
114
|
+
arbitrary_types_allowed=True,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# version
|
|
118
|
+
version: Literal["0.1.0"] = "0.1.0"
|
|
119
|
+
"""CAREamics configuration version."""
|
|
120
|
+
|
|
121
|
+
# required parameters
|
|
122
|
+
experiment_name: str
|
|
123
|
+
"""Name of the experiment, used to name logs and checkpoints."""
|
|
124
|
+
|
|
125
|
+
# Sub-configurations
|
|
126
|
+
algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm] = Field(
|
|
127
|
+
discriminator="algorithm"
|
|
128
|
+
)
|
|
129
|
+
"""Algorithm configuration, holding all parameters required to configure the
|
|
130
|
+
model."""
|
|
131
|
+
|
|
132
|
+
data_config: GeneralDataConfig
|
|
133
|
+
"""Data configuration, holding all parameters required to configure the training
|
|
134
|
+
data loader."""
|
|
135
|
+
|
|
136
|
+
training_config: TrainingConfig
|
|
137
|
+
"""Training configuration, holding all parameters required to configure the
|
|
138
|
+
training process."""
|
|
139
|
+
|
|
140
|
+
@field_validator("experiment_name")
|
|
141
|
+
@classmethod
|
|
142
|
+
def no_symbol(cls, name: str) -> str:
|
|
143
|
+
"""
|
|
144
|
+
Validate experiment name.
|
|
145
|
+
|
|
146
|
+
A valid experiment name is a non-empty string with only contains letters,
|
|
147
|
+
numbers, underscores, dashes and spaces.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
name : str
|
|
152
|
+
Name to validate.
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
str
|
|
157
|
+
Validated name.
|
|
158
|
+
|
|
159
|
+
Raises
|
|
160
|
+
------
|
|
161
|
+
ValueError
|
|
162
|
+
If the name is empty or contains invalid characters.
|
|
163
|
+
"""
|
|
164
|
+
if len(name) == 0 or name.isspace():
|
|
165
|
+
raise ValueError("Experiment name is empty.")
|
|
166
|
+
|
|
167
|
+
# Validate using a regex that it contains only letters, numbers, underscores,
|
|
168
|
+
# dashes and spaces
|
|
169
|
+
if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Experiment name contains invalid characters (got {name}). "
|
|
172
|
+
f"Only letters, numbers, underscores, dashes and spaces are allowed."
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return name
|
|
176
|
+
|
|
177
|
+
@model_validator(mode="after")
|
|
178
|
+
def validate_3D(self: Self) -> Self:
|
|
179
|
+
"""
|
|
180
|
+
Change algorithm dimensions to match data.axes.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
Self
|
|
185
|
+
Validated configuration.
|
|
186
|
+
"""
|
|
187
|
+
if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
|
|
188
|
+
# change algorithm to 3D
|
|
189
|
+
self.algorithm_config.model.set_3D(True)
|
|
190
|
+
elif "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D():
|
|
191
|
+
# change algorithm to 2D
|
|
192
|
+
self.algorithm_config.model.set_3D(False)
|
|
193
|
+
|
|
194
|
+
return self
|
|
195
|
+
|
|
196
|
+
def __str__(self) -> str:
|
|
197
|
+
"""
|
|
198
|
+
Pretty string reprensenting the configuration.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
str
|
|
203
|
+
Pretty string.
|
|
204
|
+
"""
|
|
205
|
+
return pformat(self.model_dump())
|
|
206
|
+
|
|
207
|
+
def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
|
|
208
|
+
"""
|
|
209
|
+
Set 3D flag and axes.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
is_3D : bool
|
|
214
|
+
Whether the algorithm is 3D or not.
|
|
215
|
+
axes : str
|
|
216
|
+
Axes of the data.
|
|
217
|
+
patch_size : list[int]
|
|
218
|
+
Patch size.
|
|
219
|
+
"""
|
|
220
|
+
# set the flag and axes (this will not trigger validation at the config level)
|
|
221
|
+
self.algorithm_config.model.set_3D(is_3D)
|
|
222
|
+
self.data_config.set_3D(axes, patch_size)
|
|
223
|
+
|
|
224
|
+
# cheap hack: trigger validation
|
|
225
|
+
self.algorithm_config = self.algorithm_config
|
|
226
|
+
|
|
227
|
+
def get_algorithm_friendly_name(self) -> str:
|
|
228
|
+
"""
|
|
229
|
+
Get the algorithm name.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
str
|
|
234
|
+
Algorithm name.
|
|
235
|
+
"""
|
|
236
|
+
raise ValueError("Unknown algorithm.")
|
|
237
|
+
|
|
238
|
+
def get_algorithm_description(self) -> str:
|
|
239
|
+
"""
|
|
240
|
+
Return a description of the algorithm.
|
|
241
|
+
|
|
242
|
+
This method is used to generate the README of the BioImage Model Zoo export.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
str
|
|
247
|
+
Description of the algorithm.
|
|
248
|
+
"""
|
|
249
|
+
raise ValueError("No algorithm description available.")
|
|
250
|
+
|
|
251
|
+
def get_algorithm_citations(self) -> list[CiteEntry]:
|
|
252
|
+
"""
|
|
253
|
+
Return a list of citation entries of the current algorithm.
|
|
254
|
+
|
|
255
|
+
This is used to generate the model description for the BioImage Model Zoo.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
List[CiteEntry]
|
|
260
|
+
List of citation entries.
|
|
261
|
+
"""
|
|
262
|
+
raise ValueError("No algorithm citations available.")
|
|
263
|
+
|
|
264
|
+
def get_algorithm_references(self) -> str:
|
|
265
|
+
"""
|
|
266
|
+
Get the algorithm references.
|
|
267
|
+
|
|
268
|
+
This is used to generate the README of the BioImage Model Zoo export.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
str
|
|
273
|
+
Algorithm references.
|
|
274
|
+
"""
|
|
275
|
+
raise ValueError("No algorithm references available.")
|
|
276
|
+
|
|
277
|
+
def get_algorithm_keywords(self) -> list[str]:
|
|
278
|
+
"""
|
|
279
|
+
Get algorithm keywords.
|
|
280
|
+
|
|
281
|
+
Returns
|
|
282
|
+
-------
|
|
283
|
+
list[str]
|
|
284
|
+
List of keywords.
|
|
285
|
+
"""
|
|
286
|
+
return ["CAREamics"]
|
|
287
|
+
|
|
288
|
+
def model_dump(
|
|
289
|
+
self,
|
|
290
|
+
*,
|
|
291
|
+
mode: Literal["json", "python"] | str = "python",
|
|
292
|
+
include: Any | None = None,
|
|
293
|
+
exclude: Any | None = None,
|
|
294
|
+
context: Any | None = None,
|
|
295
|
+
by_alias: bool = False,
|
|
296
|
+
exclude_unset: bool = False,
|
|
297
|
+
exclude_defaults: bool = False,
|
|
298
|
+
exclude_none: bool = True,
|
|
299
|
+
round_trip: bool = False,
|
|
300
|
+
warnings: bool | Literal["none", "warn", "error"] = True,
|
|
301
|
+
serialize_as_any: bool = False,
|
|
302
|
+
) -> dict:
|
|
303
|
+
"""
|
|
304
|
+
Override model_dump method in order to set default values.
|
|
305
|
+
|
|
306
|
+
As opposed to the parent model_dump method, this method sets exclude none by
|
|
307
|
+
default.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
mode : Literal['json', 'python'] | str, default='python'
|
|
312
|
+
The serialization format.
|
|
313
|
+
include : Any | None, default=None
|
|
314
|
+
Attributes to include.
|
|
315
|
+
exclude : Any | None, default=None
|
|
316
|
+
Attributes to exclude.
|
|
317
|
+
context : Any | None, default=None
|
|
318
|
+
Additional context to pass to the serialization functions.
|
|
319
|
+
by_alias : bool, default=False
|
|
320
|
+
Whether to use attribute aliases.
|
|
321
|
+
exclude_unset : bool, default=False
|
|
322
|
+
Whether to exclude fields that are not set.
|
|
323
|
+
exclude_defaults : bool, default=False
|
|
324
|
+
Whether to exclude fields that have default values.
|
|
325
|
+
exclude_none : bool, default=true
|
|
326
|
+
Whether to exclude fields that have None values.
|
|
327
|
+
round_trip : bool, default=False
|
|
328
|
+
Whether to dump and load the data to ensure that the output is a valid
|
|
329
|
+
representation.
|
|
330
|
+
warnings : bool | Literal['none', 'warn', 'error'], default=True
|
|
331
|
+
Whether to emit warnings.
|
|
332
|
+
serialize_as_any : bool, default=False
|
|
333
|
+
Whether to serialize all types as Any.
|
|
334
|
+
|
|
335
|
+
Returns
|
|
336
|
+
-------
|
|
337
|
+
dict
|
|
338
|
+
Dictionary containing the model parameters.
|
|
339
|
+
"""
|
|
340
|
+
dictionary = super().model_dump(
|
|
341
|
+
mode=mode,
|
|
342
|
+
include=include,
|
|
343
|
+
exclude=exclude,
|
|
344
|
+
context=context,
|
|
345
|
+
by_alias=by_alias,
|
|
346
|
+
exclude_unset=exclude_unset,
|
|
347
|
+
exclude_defaults=exclude_defaults,
|
|
348
|
+
exclude_none=exclude_none,
|
|
349
|
+
round_trip=round_trip,
|
|
350
|
+
warnings=warnings,
|
|
351
|
+
serialize_as_any=serialize_as_any,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
return dictionary
|
|
@@ -2,26 +2,93 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Literal, Optional, Union
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
from .
|
|
5
|
+
from pydantic import TypeAdapter
|
|
6
|
+
|
|
7
|
+
from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
|
|
8
|
+
from careamics.config.architectures import UNetModel
|
|
9
|
+
from careamics.config.care_configuration import CAREConfiguration
|
|
10
|
+
from careamics.config.configuration import Configuration
|
|
11
|
+
from careamics.config.data import DataConfig, N2VDataConfig
|
|
12
|
+
from careamics.config.n2n_configuration import N2NConfiguration
|
|
13
|
+
from careamics.config.n2v_configuration import N2VConfiguration
|
|
14
|
+
from careamics.config.support import (
|
|
10
15
|
SupportedArchitecture,
|
|
11
16
|
SupportedPixelManipulation,
|
|
12
17
|
SupportedTransform,
|
|
13
18
|
)
|
|
14
|
-
from .training_model import TrainingConfig
|
|
15
|
-
from .transformations import (
|
|
19
|
+
from careamics.config.training_model import TrainingConfig
|
|
20
|
+
from careamics.config.transformations import (
|
|
21
|
+
N2V_TRANSFORMS_UNION,
|
|
22
|
+
SPATIAL_TRANSFORMS_UNION,
|
|
16
23
|
N2VManipulateModel,
|
|
17
24
|
XYFlipModel,
|
|
18
25
|
XYRandomRotate90Model,
|
|
19
26
|
)
|
|
20
27
|
|
|
21
28
|
|
|
22
|
-
def
|
|
23
|
-
|
|
24
|
-
) ->
|
|
29
|
+
def configuration_factory(
|
|
30
|
+
configuration: dict[str, Any]
|
|
31
|
+
) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
|
|
32
|
+
"""
|
|
33
|
+
Create a configuration for training CAREamics.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
configuration : dict
|
|
38
|
+
Configuration dictionary.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
N2VConfiguration or N2NConfiguration or CAREConfiguration
|
|
43
|
+
Configuration for training CAREamics.
|
|
44
|
+
"""
|
|
45
|
+
adapter: TypeAdapter = TypeAdapter(
|
|
46
|
+
Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]
|
|
47
|
+
)
|
|
48
|
+
return adapter.validate_python(configuration)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def algorithm_factory(
|
|
52
|
+
algorithm: dict[str, Any]
|
|
53
|
+
) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
|
|
54
|
+
"""
|
|
55
|
+
Create an algorithm model for training CAREamics.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
algorithm : dict
|
|
60
|
+
Algorithm dictionary.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
|
|
65
|
+
Algorithm model for training CAREamics.
|
|
66
|
+
"""
|
|
67
|
+
adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm])
|
|
68
|
+
return adapter.validate_python(algorithm)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
|
|
72
|
+
"""
|
|
73
|
+
Create a data model for training CAREamics.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
data : dict
|
|
78
|
+
Data dictionary.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
DataConfig or N2VDataConfig
|
|
83
|
+
Data model for training CAREamics.
|
|
84
|
+
"""
|
|
85
|
+
adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
|
|
86
|
+
return adapter.validate_python(data)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _list_spatial_augmentations(
|
|
90
|
+
augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
|
|
91
|
+
) -> list[SPATIAL_TRANSFORMS_UNION]:
|
|
25
92
|
"""
|
|
26
93
|
List the augmentations to apply.
|
|
27
94
|
|
|
@@ -44,7 +111,7 @@ def _list_augmentations(
|
|
|
44
111
|
If there are duplicate transforms.
|
|
45
112
|
"""
|
|
46
113
|
if augmentations is None:
|
|
47
|
-
transform_list: list[
|
|
114
|
+
transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
|
|
48
115
|
XYFlipModel(),
|
|
49
116
|
XYRandomRotate90Model(),
|
|
50
117
|
]
|
|
@@ -123,7 +190,7 @@ def _create_configuration(
|
|
|
123
190
|
patch_size: list[int],
|
|
124
191
|
batch_size: int,
|
|
125
192
|
num_epochs: int,
|
|
126
|
-
augmentations: list[
|
|
193
|
+
augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
|
|
127
194
|
independent_channels: bool,
|
|
128
195
|
loss: Literal["n2v", "mae", "mse"],
|
|
129
196
|
n_channels_in: int,
|
|
@@ -188,21 +255,21 @@ def _create_configuration(
|
|
|
188
255
|
)
|
|
189
256
|
|
|
190
257
|
# algorithm model
|
|
191
|
-
algorithm_config =
|
|
192
|
-
algorithm
|
|
193
|
-
loss
|
|
194
|
-
model
|
|
195
|
-
|
|
258
|
+
algorithm_config = {
|
|
259
|
+
"algorithm": algorithm,
|
|
260
|
+
"loss": loss,
|
|
261
|
+
"model": unet_model,
|
|
262
|
+
}
|
|
196
263
|
|
|
197
264
|
# data model
|
|
198
|
-
data =
|
|
199
|
-
data_type
|
|
200
|
-
axes
|
|
201
|
-
patch_size
|
|
202
|
-
batch_size
|
|
203
|
-
transforms
|
|
204
|
-
dataloader_params
|
|
205
|
-
|
|
265
|
+
data = {
|
|
266
|
+
"data_type": data_type,
|
|
267
|
+
"axes": axes,
|
|
268
|
+
"patch_size": patch_size,
|
|
269
|
+
"batch_size": batch_size,
|
|
270
|
+
"transforms": augmentations,
|
|
271
|
+
"dataloader_params": dataloader_params,
|
|
272
|
+
}
|
|
206
273
|
|
|
207
274
|
# training model
|
|
208
275
|
training = TrainingConfig(
|
|
@@ -212,14 +279,14 @@ def _create_configuration(
|
|
|
212
279
|
)
|
|
213
280
|
|
|
214
281
|
# create configuration
|
|
215
|
-
configuration =
|
|
216
|
-
experiment_name
|
|
217
|
-
algorithm_config
|
|
218
|
-
data_config
|
|
219
|
-
training_config
|
|
220
|
-
|
|
282
|
+
configuration = {
|
|
283
|
+
"experiment_name": experiment_name,
|
|
284
|
+
"algorithm_config": algorithm_config,
|
|
285
|
+
"data_config": data,
|
|
286
|
+
"training_config": training,
|
|
287
|
+
}
|
|
221
288
|
|
|
222
|
-
return configuration
|
|
289
|
+
return configuration_factory(configuration)
|
|
223
290
|
|
|
224
291
|
|
|
225
292
|
# TODO reconsider naming once we officially support LVAE approaches
|
|
@@ -306,7 +373,7 @@ def _create_supervised_configuration(
|
|
|
306
373
|
n_channels_out = n_channels_in
|
|
307
374
|
|
|
308
375
|
# augmentations
|
|
309
|
-
|
|
376
|
+
spatial_transform_list = _list_spatial_augmentations(augmentations)
|
|
310
377
|
|
|
311
378
|
return _create_configuration(
|
|
312
379
|
algorithm=algorithm,
|
|
@@ -316,7 +383,7 @@ def _create_supervised_configuration(
|
|
|
316
383
|
patch_size=patch_size,
|
|
317
384
|
batch_size=batch_size,
|
|
318
385
|
num_epochs=num_epochs,
|
|
319
|
-
augmentations=
|
|
386
|
+
augmentations=spatial_transform_list,
|
|
320
387
|
independent_channels=independent_channels,
|
|
321
388
|
loss=loss,
|
|
322
389
|
n_channels_in=n_channels_in,
|
|
@@ -853,7 +920,7 @@ def create_n2v_configuration(
|
|
|
853
920
|
n_channels = 1
|
|
854
921
|
|
|
855
922
|
# augmentations
|
|
856
|
-
|
|
923
|
+
spatial_transforms = _list_spatial_augmentations(augmentations)
|
|
857
924
|
|
|
858
925
|
# create the N2VManipulate transform using the supplied parameters
|
|
859
926
|
n2v_transform = N2VManipulateModel(
|
|
@@ -868,7 +935,7 @@ def create_n2v_configuration(
|
|
|
868
935
|
struct_mask_axis=struct_n2v_axis,
|
|
869
936
|
struct_mask_span=struct_n2v_span,
|
|
870
937
|
)
|
|
871
|
-
transform_list
|
|
938
|
+
transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
|
|
872
939
|
|
|
873
940
|
return _create_configuration(
|
|
874
941
|
algorithm="n2v",
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""I/O functions for Configuration objects."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from careamics.config import Configuration, configuration_factory
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
12
|
+
"""
|
|
13
|
+
Load configuration from a yaml file.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
path : str or Path
|
|
18
|
+
Path to the configuration.
|
|
19
|
+
|
|
20
|
+
Returns
|
|
21
|
+
-------
|
|
22
|
+
Configuration
|
|
23
|
+
Configuration.
|
|
24
|
+
|
|
25
|
+
Raises
|
|
26
|
+
------
|
|
27
|
+
FileNotFoundError
|
|
28
|
+
If the configuration file does not exist.
|
|
29
|
+
"""
|
|
30
|
+
# load dictionary from yaml
|
|
31
|
+
if not Path(path).exists():
|
|
32
|
+
raise FileNotFoundError(
|
|
33
|
+
f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
|
|
37
|
+
|
|
38
|
+
return configuration_factory(dictionary)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
42
|
+
"""
|
|
43
|
+
Save configuration to path.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
config : Configuration
|
|
48
|
+
Configuration to save.
|
|
49
|
+
path : str or Path
|
|
50
|
+
Path to a existing folder in which to save the configuration, or to a valid
|
|
51
|
+
configuration file path (uses a .yml or .yaml extension).
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
Path
|
|
56
|
+
Path object representing the configuration.
|
|
57
|
+
|
|
58
|
+
Raises
|
|
59
|
+
------
|
|
60
|
+
ValueError
|
|
61
|
+
If the path does not point to an existing directory or .yml file.
|
|
62
|
+
"""
|
|
63
|
+
# make sure path is a Path object
|
|
64
|
+
config_path = Path(path)
|
|
65
|
+
|
|
66
|
+
# check if path is pointing to an existing directory or .yml file
|
|
67
|
+
if config_path.exists():
|
|
68
|
+
if config_path.is_dir():
|
|
69
|
+
config_path = Path(config_path, "config.yml")
|
|
70
|
+
elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# save configuration as dictionary to yaml
|
|
81
|
+
with open(config_path, "w") as f:
|
|
82
|
+
# dump configuration
|
|
83
|
+
yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
|
|
84
|
+
|
|
85
|
+
return config_path
|