careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
"""Pydantic CAREamics configuration."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from pprint import pformat
|
|
7
|
+
from typing import Dict, List, Literal, Union
|
|
8
|
+
|
|
9
|
+
import yaml
|
|
10
|
+
from bioimageio.spec.generic.v0_3 import CiteEntry
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
12
|
+
from typing_extensions import Self
|
|
13
|
+
|
|
14
|
+
from .algorithm_model import AlgorithmConfig
|
|
15
|
+
from .data_model import DataConfig
|
|
16
|
+
from .references import (
|
|
17
|
+
CARE,
|
|
18
|
+
CUSTOM,
|
|
19
|
+
N2N,
|
|
20
|
+
N2V,
|
|
21
|
+
N2V2,
|
|
22
|
+
STRUCT_N2V,
|
|
23
|
+
STRUCT_N2V2,
|
|
24
|
+
CAREDescription,
|
|
25
|
+
CARERef,
|
|
26
|
+
N2NDescription,
|
|
27
|
+
N2NRef,
|
|
28
|
+
N2V2Description,
|
|
29
|
+
N2V2Ref,
|
|
30
|
+
N2VDescription,
|
|
31
|
+
N2VRef,
|
|
32
|
+
StructN2V2Description,
|
|
33
|
+
StructN2VDescription,
|
|
34
|
+
StructN2VRef,
|
|
35
|
+
)
|
|
36
|
+
from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
|
|
37
|
+
from .training_model import TrainingConfig
|
|
38
|
+
from .transformations.n2v_manipulate_model import (
|
|
39
|
+
N2VManipulateModel,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Configuration(BaseModel):
|
|
44
|
+
"""
|
|
45
|
+
CAREamics configuration.
|
|
46
|
+
|
|
47
|
+
The configuration defines all parameters used to build and train a CAREamics model.
|
|
48
|
+
These parameters are validated to ensure that they are compatible with each other.
|
|
49
|
+
|
|
50
|
+
It contains three sub-configurations:
|
|
51
|
+
|
|
52
|
+
- AlgorithmModel: configuration for the algorithm training, which includes the
|
|
53
|
+
architecture, loss function, optimizer, and other hyperparameters.
|
|
54
|
+
- DataModel: configuration for the dataloader, which includes the type of data,
|
|
55
|
+
transformations, mean/std and other parameters.
|
|
56
|
+
- TrainingModel: configuration for the training, which includes the number of
|
|
57
|
+
epochs or the callbacks.
|
|
58
|
+
|
|
59
|
+
Attributes
|
|
60
|
+
----------
|
|
61
|
+
experiment_name : str
|
|
62
|
+
Name of the experiment, used when saving logs and checkpoints.
|
|
63
|
+
algorithm : AlgorithmModel
|
|
64
|
+
Algorithm configuration.
|
|
65
|
+
data : DataModel
|
|
66
|
+
Data configuration.
|
|
67
|
+
training : TrainingModel
|
|
68
|
+
Training configuration.
|
|
69
|
+
|
|
70
|
+
Methods
|
|
71
|
+
-------
|
|
72
|
+
set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
|
|
73
|
+
Switch configuration between 2D and 3D.
|
|
74
|
+
set_N2V2(use_n2v2: bool) -> None
|
|
75
|
+
Switch N2V algorithm between N2V and N2V2.
|
|
76
|
+
set_structN2V(
|
|
77
|
+
mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int) -> None
|
|
78
|
+
Set StructN2V parameters.
|
|
79
|
+
model_dump(
|
|
80
|
+
exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
|
|
81
|
+
) -> Dict
|
|
82
|
+
Export configuration to a dictionary.
|
|
83
|
+
|
|
84
|
+
Raises
|
|
85
|
+
------
|
|
86
|
+
ValueError
|
|
87
|
+
Configuration parameter type validation errors.
|
|
88
|
+
ValueError
|
|
89
|
+
If the experiment name contains invalid characters or is empty.
|
|
90
|
+
ValueError
|
|
91
|
+
If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
|
|
92
|
+
with "Z" in data axes.
|
|
93
|
+
ValueError
|
|
94
|
+
Algorithm, data or training validation errors.
|
|
95
|
+
|
|
96
|
+
Notes
|
|
97
|
+
-----
|
|
98
|
+
We provide convenience methods to create standards configurations, for instance
|
|
99
|
+
for N2V, in the `careamics.config.configuration_factory` module.
|
|
100
|
+
>>> from careamics.config.configuration_factory import create_n2v_configuration
|
|
101
|
+
>>> config = create_n2v_configuration(
|
|
102
|
+
... experiment_name="n2v_experiment",
|
|
103
|
+
... data_type="array",
|
|
104
|
+
... axes="YX",
|
|
105
|
+
... patch_size=[64, 64],
|
|
106
|
+
... batch_size=32,
|
|
107
|
+
... num_epochs=100
|
|
108
|
+
... )
|
|
109
|
+
|
|
110
|
+
The configuration can be exported to a dictionary using the model_dump method:
|
|
111
|
+
>>> config_dict = config.model_dump()
|
|
112
|
+
|
|
113
|
+
Configurations can also be exported or imported from yaml files:
|
|
114
|
+
>>> from careamics.config import save_configuration, load_configuration
|
|
115
|
+
>>> path_to_config = save_configuration(config, my_path / "config.yml")
|
|
116
|
+
>>> other_config = load_configuration(path_to_config)
|
|
117
|
+
|
|
118
|
+
Examples
|
|
119
|
+
--------
|
|
120
|
+
Minimum example:
|
|
121
|
+
>>> from careamics.config import Configuration
|
|
122
|
+
>>> config_dict = {
|
|
123
|
+
... "experiment_name": "N2V_experiment",
|
|
124
|
+
... "algorithm_config": {
|
|
125
|
+
... "algorithm": "n2v",
|
|
126
|
+
... "loss": "n2v",
|
|
127
|
+
... "model": {
|
|
128
|
+
... "architecture": "UNet",
|
|
129
|
+
... },
|
|
130
|
+
... },
|
|
131
|
+
... "training_config": {
|
|
132
|
+
... "num_epochs": 200,
|
|
133
|
+
... },
|
|
134
|
+
... "data_config": {
|
|
135
|
+
... "data_type": "tiff",
|
|
136
|
+
... "patch_size": [64, 64],
|
|
137
|
+
... "axes": "SYX",
|
|
138
|
+
... },
|
|
139
|
+
... }
|
|
140
|
+
>>> config = Configuration(**config_dict)
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
model_config = ConfigDict(
|
|
144
|
+
validate_assignment=True,
|
|
145
|
+
set_arbitrary_types_allowed=True,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# version
|
|
149
|
+
version: Literal["0.1.0"] = Field(
|
|
150
|
+
default="0.1.0", description="Version of the CAREamics configuration."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# required parameters
|
|
154
|
+
experiment_name: str = Field(
|
|
155
|
+
..., description="Name of the experiment, used to name logs and checkpoints."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Sub-configurations
|
|
159
|
+
algorithm_config: AlgorithmConfig
|
|
160
|
+
|
|
161
|
+
data_config: DataConfig
|
|
162
|
+
training_config: TrainingConfig
|
|
163
|
+
|
|
164
|
+
@field_validator("experiment_name")
|
|
165
|
+
@classmethod
|
|
166
|
+
def no_symbol(cls, name: str) -> str:
|
|
167
|
+
"""
|
|
168
|
+
Validate experiment name.
|
|
169
|
+
|
|
170
|
+
A valid experiment name is a non-empty string with only contains letters,
|
|
171
|
+
numbers, underscores, dashes and spaces.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
name : str
|
|
176
|
+
Name to validate.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
str
|
|
181
|
+
Validated name.
|
|
182
|
+
|
|
183
|
+
Raises
|
|
184
|
+
------
|
|
185
|
+
ValueError
|
|
186
|
+
If the name is empty or contains invalid characters.
|
|
187
|
+
"""
|
|
188
|
+
if len(name) == 0 or name.isspace():
|
|
189
|
+
raise ValueError("Experiment name is empty.")
|
|
190
|
+
|
|
191
|
+
# Validate using a regex that it contains only letters, numbers, underscores,
|
|
192
|
+
# dashes and spaces
|
|
193
|
+
if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Experiment name contains invalid characters (got {name}). "
|
|
196
|
+
f"Only letters, numbers, underscores, dashes and spaces are allowed."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return name
|
|
200
|
+
|
|
201
|
+
@model_validator(mode="after")
|
|
202
|
+
def validate_3D(self: Self) -> Self:
|
|
203
|
+
"""
|
|
204
|
+
Change algorithm dimensions to match data.axes.
|
|
205
|
+
|
|
206
|
+
Only for non-custom algorithms.
|
|
207
|
+
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
Self
|
|
211
|
+
Validated configuration.
|
|
212
|
+
"""
|
|
213
|
+
if self.algorithm_config.algorithm != SupportedAlgorithm.CUSTOM:
|
|
214
|
+
if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
|
|
215
|
+
# change algorithm to 3D
|
|
216
|
+
self.algorithm_config.model.set_3D(True)
|
|
217
|
+
elif (
|
|
218
|
+
"Z" not in self.data_config.axes and self.algorithm_config.model.is_3D()
|
|
219
|
+
):
|
|
220
|
+
# change algorithm to 2D
|
|
221
|
+
self.algorithm_config.model.set_3D(False)
|
|
222
|
+
|
|
223
|
+
return self
|
|
224
|
+
|
|
225
|
+
@model_validator(mode="after")
|
|
226
|
+
def validate_algorithm_and_data(self: Self) -> Self:
|
|
227
|
+
"""
|
|
228
|
+
Validate algorithm and data compatibility.
|
|
229
|
+
|
|
230
|
+
In particular, the validation does the following:
|
|
231
|
+
|
|
232
|
+
- If N2V is used, it enforces the presence of N2V_Maniuplate in the transforms
|
|
233
|
+
- If N2V2 is used, it enforces the correct manipulation strategy
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
Self
|
|
238
|
+
Validated configuration.
|
|
239
|
+
"""
|
|
240
|
+
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
241
|
+
# if we have a list of transform (as opposed to Compose)
|
|
242
|
+
if self.data_config.has_transform_list():
|
|
243
|
+
# missing N2V_MANIPULATE
|
|
244
|
+
if not self.data_config.has_n2v_manipulate():
|
|
245
|
+
self.data_config.transforms.append(
|
|
246
|
+
N2VManipulateModel(
|
|
247
|
+
name=SupportedTransform.N2V_MANIPULATE.value,
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
median = SupportedPixelManipulation.MEDIAN.value
|
|
252
|
+
uniform = SupportedPixelManipulation.UNIFORM.value
|
|
253
|
+
strategy = median if self.algorithm_config.model.n2v2 else uniform
|
|
254
|
+
self.data_config.set_N2V2_strategy(strategy)
|
|
255
|
+
else:
|
|
256
|
+
# if we have a list of transform, remove N2V manipulate if present
|
|
257
|
+
if self.data_config.has_transform_list():
|
|
258
|
+
if self.data_config.has_n2v_manipulate():
|
|
259
|
+
self.data_config.remove_n2v_manipulate()
|
|
260
|
+
|
|
261
|
+
return self
|
|
262
|
+
|
|
263
|
+
def __str__(self) -> str:
|
|
264
|
+
"""
|
|
265
|
+
Pretty string reprensenting the configuration.
|
|
266
|
+
|
|
267
|
+
Returns
|
|
268
|
+
-------
|
|
269
|
+
str
|
|
270
|
+
Pretty string.
|
|
271
|
+
"""
|
|
272
|
+
return pformat(self.model_dump())
|
|
273
|
+
|
|
274
|
+
def set_3D(self, is_3D: bool, axes: str, patch_size: List[int]) -> None:
|
|
275
|
+
"""
|
|
276
|
+
Set 3D flag and axes.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
is_3D : bool
|
|
281
|
+
Whether the algorithm is 3D or not.
|
|
282
|
+
axes : str
|
|
283
|
+
Axes of the data.
|
|
284
|
+
patch_size : List[int]
|
|
285
|
+
Patch size.
|
|
286
|
+
"""
|
|
287
|
+
# set the flag and axes (this will not trigger validation at the config level)
|
|
288
|
+
self.algorithm_config.model.set_3D(is_3D)
|
|
289
|
+
self.data_config.set_3D(axes, patch_size)
|
|
290
|
+
|
|
291
|
+
# cheap hack: trigger validation
|
|
292
|
+
self.algorithm_config = self.algorithm_config
|
|
293
|
+
|
|
294
|
+
def set_N2V2(self, use_n2v2: bool) -> None:
|
|
295
|
+
"""
|
|
296
|
+
Switch N2V algorithm between N2V and N2V2.
|
|
297
|
+
|
|
298
|
+
Parameters
|
|
299
|
+
----------
|
|
300
|
+
use_n2v2 : bool
|
|
301
|
+
Whether to use N2V2 or not.
|
|
302
|
+
|
|
303
|
+
Raises
|
|
304
|
+
------
|
|
305
|
+
ValueError
|
|
306
|
+
If the algorithm is not N2V.
|
|
307
|
+
"""
|
|
308
|
+
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
309
|
+
self.algorithm_config.model.n2v2 = use_n2v2
|
|
310
|
+
strategy = (
|
|
311
|
+
SupportedPixelManipulation.MEDIAN.value
|
|
312
|
+
if use_n2v2
|
|
313
|
+
else SupportedPixelManipulation.UNIFORM.value
|
|
314
|
+
)
|
|
315
|
+
self.data_config.set_N2V2_strategy(strategy)
|
|
316
|
+
else:
|
|
317
|
+
raise ValueError("N2V2 can only be set for N2V algorithm.")
|
|
318
|
+
|
|
319
|
+
def set_structN2V(
|
|
320
|
+
self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
|
|
321
|
+
) -> None:
|
|
322
|
+
"""
|
|
323
|
+
Set StructN2V parameters.
|
|
324
|
+
|
|
325
|
+
Parameters
|
|
326
|
+
----------
|
|
327
|
+
mask_axis : Literal["horizontal", "vertical", "none"]
|
|
328
|
+
Axis of the structural mask.
|
|
329
|
+
mask_span : int
|
|
330
|
+
Span of the structural mask.
|
|
331
|
+
"""
|
|
332
|
+
self.data_config.set_structN2V_mask(mask_axis, mask_span)
|
|
333
|
+
|
|
334
|
+
def get_algorithm_flavour(self) -> str:
|
|
335
|
+
"""
|
|
336
|
+
Get the algorithm name.
|
|
337
|
+
|
|
338
|
+
Returns
|
|
339
|
+
-------
|
|
340
|
+
str
|
|
341
|
+
Algorithm name.
|
|
342
|
+
"""
|
|
343
|
+
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
344
|
+
use_n2v2 = self.algorithm_config.model.n2v2
|
|
345
|
+
use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
|
|
346
|
+
|
|
347
|
+
# return the n2v flavour
|
|
348
|
+
if use_n2v2 and use_structN2V:
|
|
349
|
+
return STRUCT_N2V2
|
|
350
|
+
elif use_n2v2:
|
|
351
|
+
return N2V2
|
|
352
|
+
elif use_structN2V:
|
|
353
|
+
return STRUCT_N2V
|
|
354
|
+
else:
|
|
355
|
+
return N2V
|
|
356
|
+
elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
|
|
357
|
+
return N2N
|
|
358
|
+
elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
|
|
359
|
+
return CARE
|
|
360
|
+
else:
|
|
361
|
+
return CUSTOM
|
|
362
|
+
|
|
363
|
+
def get_algorithm_description(self) -> str:
|
|
364
|
+
"""
|
|
365
|
+
Return a description of the algorithm.
|
|
366
|
+
|
|
367
|
+
This method is used to generate the README of the BioImage Model Zoo export.
|
|
368
|
+
|
|
369
|
+
Returns
|
|
370
|
+
-------
|
|
371
|
+
str
|
|
372
|
+
Description of the algorithm.
|
|
373
|
+
"""
|
|
374
|
+
algorithm_flavour = self.get_algorithm_flavour()
|
|
375
|
+
|
|
376
|
+
if algorithm_flavour == CUSTOM:
|
|
377
|
+
return f"Custom algorithm, named {self.algorithm_config.model.name}"
|
|
378
|
+
else: # currently only N2V flavours
|
|
379
|
+
if algorithm_flavour == N2V:
|
|
380
|
+
return N2VDescription().description
|
|
381
|
+
elif algorithm_flavour == N2V2:
|
|
382
|
+
return N2V2Description().description
|
|
383
|
+
elif algorithm_flavour == STRUCT_N2V:
|
|
384
|
+
return StructN2VDescription().description
|
|
385
|
+
elif algorithm_flavour == STRUCT_N2V2:
|
|
386
|
+
return StructN2V2Description().description
|
|
387
|
+
elif algorithm_flavour == N2N:
|
|
388
|
+
return N2NDescription().description
|
|
389
|
+
elif algorithm_flavour == CARE:
|
|
390
|
+
return CAREDescription().description
|
|
391
|
+
|
|
392
|
+
return ""
|
|
393
|
+
|
|
394
|
+
def get_algorithm_citations(self) -> List[CiteEntry]:
|
|
395
|
+
"""
|
|
396
|
+
Return a list of citation entries of the current algorithm.
|
|
397
|
+
|
|
398
|
+
This is used to generate the model description for the BioImage Model Zoo.
|
|
399
|
+
|
|
400
|
+
Returns
|
|
401
|
+
-------
|
|
402
|
+
List[CiteEntry]
|
|
403
|
+
List of citation entries.
|
|
404
|
+
"""
|
|
405
|
+
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
406
|
+
use_n2v2 = self.algorithm_config.model.n2v2
|
|
407
|
+
use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
|
|
408
|
+
|
|
409
|
+
# return the (struct)N2V(2) references
|
|
410
|
+
if use_n2v2 and use_structN2V:
|
|
411
|
+
return [N2VRef, N2V2Ref, StructN2VRef]
|
|
412
|
+
elif use_n2v2:
|
|
413
|
+
return [N2VRef, N2V2Ref]
|
|
414
|
+
elif use_structN2V:
|
|
415
|
+
return [N2VRef, StructN2VRef]
|
|
416
|
+
else:
|
|
417
|
+
return [N2VRef]
|
|
418
|
+
elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
|
|
419
|
+
return [N2NRef]
|
|
420
|
+
elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
|
|
421
|
+
return [CARERef]
|
|
422
|
+
|
|
423
|
+
raise ValueError("Citation not available for custom algorithm.")
|
|
424
|
+
|
|
425
|
+
def get_algorithm_references(self) -> str:
|
|
426
|
+
"""
|
|
427
|
+
Get the algorithm references.
|
|
428
|
+
|
|
429
|
+
This is used to generate the README of the BioImage Model Zoo export.
|
|
430
|
+
|
|
431
|
+
Returns
|
|
432
|
+
-------
|
|
433
|
+
str
|
|
434
|
+
Algorithm references.
|
|
435
|
+
"""
|
|
436
|
+
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
437
|
+
use_n2v2 = self.algorithm_config.model.n2v2
|
|
438
|
+
use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
|
|
439
|
+
|
|
440
|
+
references = [
|
|
441
|
+
N2VRef.text + " doi: " + N2VRef.doi,
|
|
442
|
+
N2V2Ref.text + " doi: " + N2V2Ref.doi,
|
|
443
|
+
StructN2VRef.text + " doi: " + StructN2VRef.doi,
|
|
444
|
+
]
|
|
445
|
+
|
|
446
|
+
# return the (struct)N2V(2) references
|
|
447
|
+
if use_n2v2 and use_structN2V:
|
|
448
|
+
return "".join(references)
|
|
449
|
+
elif use_n2v2:
|
|
450
|
+
references.pop(-1)
|
|
451
|
+
return "".join(references)
|
|
452
|
+
elif use_structN2V:
|
|
453
|
+
references.pop(-2)
|
|
454
|
+
return "".join(references)
|
|
455
|
+
else:
|
|
456
|
+
return references[0]
|
|
457
|
+
|
|
458
|
+
return ""
|
|
459
|
+
|
|
460
|
+
def get_algorithm_keywords(self) -> List[str]:
|
|
461
|
+
"""
|
|
462
|
+
Get algorithm keywords.
|
|
463
|
+
|
|
464
|
+
Returns
|
|
465
|
+
-------
|
|
466
|
+
List[str]
|
|
467
|
+
List of keywords.
|
|
468
|
+
"""
|
|
469
|
+
if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
|
|
470
|
+
use_n2v2 = self.algorithm_config.model.n2v2
|
|
471
|
+
use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
|
|
472
|
+
|
|
473
|
+
keywords = [
|
|
474
|
+
"denoising",
|
|
475
|
+
"restoration",
|
|
476
|
+
"UNet",
|
|
477
|
+
"3D" if "Z" in self.data_config.axes else "2D",
|
|
478
|
+
"CAREamics",
|
|
479
|
+
"pytorch",
|
|
480
|
+
N2V,
|
|
481
|
+
]
|
|
482
|
+
|
|
483
|
+
if use_n2v2:
|
|
484
|
+
keywords.append(N2V2)
|
|
485
|
+
if use_structN2V:
|
|
486
|
+
keywords.append(STRUCT_N2V)
|
|
487
|
+
else:
|
|
488
|
+
keywords = ["CAREamics"]
|
|
489
|
+
|
|
490
|
+
return keywords
|
|
491
|
+
|
|
492
|
+
def model_dump(
|
|
493
|
+
self,
|
|
494
|
+
exclude_defaults: bool = False,
|
|
495
|
+
exclude_none: bool = True,
|
|
496
|
+
**kwargs: Dict,
|
|
497
|
+
) -> Dict:
|
|
498
|
+
"""
|
|
499
|
+
Override model_dump method in order to set default values.
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
exclude_defaults : bool, optional
|
|
504
|
+
Whether to exclude fields with default values or not, by default
|
|
505
|
+
True.
|
|
506
|
+
exclude_none : bool, optional
|
|
507
|
+
Whether to exclude fields with None values or not, by default True.
|
|
508
|
+
**kwargs : Dict
|
|
509
|
+
Keyword arguments.
|
|
510
|
+
|
|
511
|
+
Returns
|
|
512
|
+
-------
|
|
513
|
+
dict
|
|
514
|
+
Dictionary containing the model parameters.
|
|
515
|
+
"""
|
|
516
|
+
dictionary = super().model_dump(
|
|
517
|
+
exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
return dictionary
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
524
|
+
"""
|
|
525
|
+
Load configuration from a yaml file.
|
|
526
|
+
|
|
527
|
+
Parameters
|
|
528
|
+
----------
|
|
529
|
+
path : Union[str, Path]
|
|
530
|
+
Path to the configuration.
|
|
531
|
+
|
|
532
|
+
Returns
|
|
533
|
+
-------
|
|
534
|
+
Configuration
|
|
535
|
+
Configuration.
|
|
536
|
+
|
|
537
|
+
Raises
|
|
538
|
+
------
|
|
539
|
+
FileNotFoundError
|
|
540
|
+
If the configuration file does not exist.
|
|
541
|
+
"""
|
|
542
|
+
# load dictionary from yaml
|
|
543
|
+
if not Path(path).exists():
|
|
544
|
+
raise FileNotFoundError(
|
|
545
|
+
f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
|
|
549
|
+
|
|
550
|
+
return Configuration(**dictionary)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
554
|
+
"""
|
|
555
|
+
Save configuration to path.
|
|
556
|
+
|
|
557
|
+
Parameters
|
|
558
|
+
----------
|
|
559
|
+
config : Configuration
|
|
560
|
+
Configuration to save.
|
|
561
|
+
path : Union[str, Path]
|
|
562
|
+
Path to a existing folder in which to save the configuration or to an existing
|
|
563
|
+
configuration file.
|
|
564
|
+
|
|
565
|
+
Returns
|
|
566
|
+
-------
|
|
567
|
+
Path
|
|
568
|
+
Path object representing the configuration.
|
|
569
|
+
|
|
570
|
+
Raises
|
|
571
|
+
------
|
|
572
|
+
ValueError
|
|
573
|
+
If the path does not point to an existing directory or .yml file.
|
|
574
|
+
"""
|
|
575
|
+
# make sure path is a Path object
|
|
576
|
+
config_path = Path(path)
|
|
577
|
+
|
|
578
|
+
# check if path is pointing to an existing directory or .yml file
|
|
579
|
+
if config_path.exists():
|
|
580
|
+
if config_path.is_dir():
|
|
581
|
+
config_path = Path(config_path, "config.yml")
|
|
582
|
+
elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
583
|
+
raise ValueError(
|
|
584
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
585
|
+
)
|
|
586
|
+
else:
|
|
587
|
+
if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
|
|
588
|
+
raise ValueError(
|
|
589
|
+
f"Path must be a directory or .yml or .yaml file (got {config_path})."
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
# save configuration as dictionary to yaml
|
|
593
|
+
with open(config_path, "w") as f:
|
|
594
|
+
# dump configuration
|
|
595
|
+
yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
|
|
596
|
+
|
|
597
|
+
return config_path
|