careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
careamics/config/config.py
DELETED
|
@@ -1,297 +0,0 @@
|
|
|
1
|
-
"""Pydantic CAREamics configuration."""
|
|
2
|
-
from __future__ import annotations
|
|
3
|
-
|
|
4
|
-
import re
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Dict, List, Union
|
|
7
|
-
|
|
8
|
-
import yaml
|
|
9
|
-
from pydantic import (
|
|
10
|
-
BaseModel,
|
|
11
|
-
ConfigDict,
|
|
12
|
-
field_validator,
|
|
13
|
-
model_validator,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
# ignore typing-only-first-party-import in this file (flake8)
|
|
17
|
-
from .algorithm import Algorithm # noqa: TCH001
|
|
18
|
-
from .config_filter import paths_to_str
|
|
19
|
-
from .data import Data # noqa: TCH001
|
|
20
|
-
from .training import Training # noqa: TCH001
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class Configuration(BaseModel):
|
|
24
|
-
"""
|
|
25
|
-
CAREamics configuration.
|
|
26
|
-
|
|
27
|
-
To change the configuration from 2D to 3D, we recommend using the following method:
|
|
28
|
-
>>> set_3D(is_3D, axes)
|
|
29
|
-
|
|
30
|
-
Attributes
|
|
31
|
-
----------
|
|
32
|
-
experiment_name : str
|
|
33
|
-
Name of the experiment.
|
|
34
|
-
working_directory : Union[str, Path]
|
|
35
|
-
Path to the working directory.
|
|
36
|
-
algorithm : Algorithm
|
|
37
|
-
Algorithm configuration.
|
|
38
|
-
training : Training
|
|
39
|
-
Training configuration.
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
model_config = ConfigDict(validate_assignment=True)
|
|
43
|
-
|
|
44
|
-
# required parameters
|
|
45
|
-
experiment_name: str
|
|
46
|
-
working_directory: Path
|
|
47
|
-
|
|
48
|
-
# Sub-configurations
|
|
49
|
-
algorithm: Algorithm
|
|
50
|
-
data: Data
|
|
51
|
-
training: Training
|
|
52
|
-
|
|
53
|
-
def set_3D(self, is_3D: bool, axes: str) -> None:
|
|
54
|
-
"""
|
|
55
|
-
Set 3D flag and axes.
|
|
56
|
-
|
|
57
|
-
Parameters
|
|
58
|
-
----------
|
|
59
|
-
is_3D : bool
|
|
60
|
-
Whether the algorithm is 3D or not.
|
|
61
|
-
axes : str
|
|
62
|
-
Axes of the data.
|
|
63
|
-
"""
|
|
64
|
-
# set the flag and axes (this will not trigger validation at the config level)
|
|
65
|
-
self.algorithm.is_3D = is_3D
|
|
66
|
-
self.data.axes = axes
|
|
67
|
-
|
|
68
|
-
# cheap hack: trigger validation
|
|
69
|
-
self.algorithm = self.algorithm
|
|
70
|
-
|
|
71
|
-
@field_validator("experiment_name")
|
|
72
|
-
def no_symbol(cls, name: str) -> str:
|
|
73
|
-
"""
|
|
74
|
-
Validate experiment name.
|
|
75
|
-
|
|
76
|
-
A valid experiment name is a non-empty string with only contains letters,
|
|
77
|
-
numbers, underscores, dashes and spaces.
|
|
78
|
-
|
|
79
|
-
Parameters
|
|
80
|
-
----------
|
|
81
|
-
name : str
|
|
82
|
-
Name to validate.
|
|
83
|
-
|
|
84
|
-
Returns
|
|
85
|
-
-------
|
|
86
|
-
str
|
|
87
|
-
Validated name.
|
|
88
|
-
|
|
89
|
-
Raises
|
|
90
|
-
------
|
|
91
|
-
ValueError
|
|
92
|
-
If the name is empty or contains invalid characters.
|
|
93
|
-
"""
|
|
94
|
-
if len(name) == 0 or name.isspace():
|
|
95
|
-
raise ValueError("Experiment name is empty.")
|
|
96
|
-
|
|
97
|
-
# Validate using a regex that it contains only letters, numbers, underscores,
|
|
98
|
-
# dashes and spaces
|
|
99
|
-
if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
|
|
100
|
-
raise ValueError(
|
|
101
|
-
f"Experiment name contains invalid characters (got {name}). "
|
|
102
|
-
f"Only letters, numbers, underscores, dashes and spaces are allowed."
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
return name
|
|
106
|
-
|
|
107
|
-
@field_validator("working_directory")
|
|
108
|
-
def parent_directory_exists(cls, workdir: Union[str, Path]) -> Path:
|
|
109
|
-
"""
|
|
110
|
-
Validate working directory.
|
|
111
|
-
|
|
112
|
-
A valid working directory is a directory whose parent directory exists. If the
|
|
113
|
-
working directory does not exist itself, it is then created.
|
|
114
|
-
|
|
115
|
-
Parameters
|
|
116
|
-
----------
|
|
117
|
-
workdir : Union[str, Path]
|
|
118
|
-
Working directory to validate.
|
|
119
|
-
|
|
120
|
-
Returns
|
|
121
|
-
-------
|
|
122
|
-
Path
|
|
123
|
-
Validated working directory.
|
|
124
|
-
|
|
125
|
-
Raises
|
|
126
|
-
------
|
|
127
|
-
ValueError
|
|
128
|
-
If the working directory is not a directory, or if the parent directory does
|
|
129
|
-
not exist.
|
|
130
|
-
"""
|
|
131
|
-
path = Path(workdir)
|
|
132
|
-
|
|
133
|
-
# check if it is a directory
|
|
134
|
-
if path.exists() and not path.is_dir():
|
|
135
|
-
raise ValueError(f"Working directory is not a directory (got {workdir}).")
|
|
136
|
-
|
|
137
|
-
# check if parent directory exists
|
|
138
|
-
if not path.parent.exists():
|
|
139
|
-
raise ValueError(
|
|
140
|
-
f"Parent directory of working directory does not exist (got {workdir})."
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
# create directory if it does not exist already
|
|
144
|
-
path.mkdir(exist_ok=True)
|
|
145
|
-
|
|
146
|
-
return path
|
|
147
|
-
|
|
148
|
-
@model_validator(mode="after")
|
|
149
|
-
def validate_3D(cls, config: Configuration) -> Configuration:
|
|
150
|
-
"""
|
|
151
|
-
Check 3D flag validity.
|
|
152
|
-
|
|
153
|
-
Check that the algorithm is_3D flag is compatible with the axes in the
|
|
154
|
-
data configuration.
|
|
155
|
-
|
|
156
|
-
Parameters
|
|
157
|
-
----------
|
|
158
|
-
config : Configuration
|
|
159
|
-
Configuration to validate.
|
|
160
|
-
|
|
161
|
-
Returns
|
|
162
|
-
-------
|
|
163
|
-
Configuration
|
|
164
|
-
Validated configuration.
|
|
165
|
-
|
|
166
|
-
Raises
|
|
167
|
-
------
|
|
168
|
-
ValueError
|
|
169
|
-
If the algorithm is 3D but the data axes are not, or if the algorithm is
|
|
170
|
-
not 3D but the data axes are.
|
|
171
|
-
"""
|
|
172
|
-
# check that is_3D and axes are compatible
|
|
173
|
-
if config.algorithm.is_3D and "Z" not in config.data.axes:
|
|
174
|
-
raise ValueError(
|
|
175
|
-
f"Algorithm is 3D but data axes are not (got axes {config.data.axes})."
|
|
176
|
-
)
|
|
177
|
-
elif not config.algorithm.is_3D and "Z" in config.data.axes:
|
|
178
|
-
raise ValueError(
|
|
179
|
-
f"Algorithm is not 3D but data axes are (got axes {config.data.axes})."
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
return config
|
|
183
|
-
|
|
184
|
-
def model_dump(
|
|
185
|
-
self, exclude_optionals: bool = True, *args: List, **kwargs: Dict
|
|
186
|
-
) -> Dict:
|
|
187
|
-
"""
|
|
188
|
-
Override model_dump method.
|
|
189
|
-
|
|
190
|
-
The purpose is to ensure export smooth import to yaml. It includes:
|
|
191
|
-
- remove entries with None value.
|
|
192
|
-
- remove optional values if they have the default value.
|
|
193
|
-
|
|
194
|
-
Parameters
|
|
195
|
-
----------
|
|
196
|
-
exclude_optionals : bool, optional
|
|
197
|
-
Whether to exclude optional fields with default values or not, by default
|
|
198
|
-
True.
|
|
199
|
-
*args : List
|
|
200
|
-
Positional arguments, unused.
|
|
201
|
-
**kwargs : Dict
|
|
202
|
-
Keyword arguments, unused.
|
|
203
|
-
|
|
204
|
-
Returns
|
|
205
|
-
-------
|
|
206
|
-
dict
|
|
207
|
-
Dictionary containing the model parameters.
|
|
208
|
-
"""
|
|
209
|
-
dictionary = super().model_dump(exclude_none=True)
|
|
210
|
-
|
|
211
|
-
# remove paths
|
|
212
|
-
dictionary = paths_to_str(dictionary)
|
|
213
|
-
|
|
214
|
-
dictionary["algorithm"] = self.algorithm.model_dump(
|
|
215
|
-
exclude_optionals=exclude_optionals
|
|
216
|
-
)
|
|
217
|
-
dictionary["data"] = self.data.model_dump()
|
|
218
|
-
|
|
219
|
-
dictionary["training"] = self.training.model_dump(
|
|
220
|
-
exclude_optionals=exclude_optionals
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
return dictionary
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
def load_configuration(path: Union[str, Path]) -> Configuration:
|
|
227
|
-
"""
|
|
228
|
-
Load configuration from a yaml file.
|
|
229
|
-
|
|
230
|
-
Parameters
|
|
231
|
-
----------
|
|
232
|
-
path : Union[str, Path]
|
|
233
|
-
Path to the configuration.
|
|
234
|
-
|
|
235
|
-
Returns
|
|
236
|
-
-------
|
|
237
|
-
Configuration
|
|
238
|
-
Configuration.
|
|
239
|
-
|
|
240
|
-
Raises
|
|
241
|
-
------
|
|
242
|
-
FileNotFoundError
|
|
243
|
-
If the configuration file does not exist.
|
|
244
|
-
"""
|
|
245
|
-
# load dictionary from yaml
|
|
246
|
-
if not Path(path).exists():
|
|
247
|
-
raise FileNotFoundError(
|
|
248
|
-
f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
|
|
252
|
-
|
|
253
|
-
return Configuration(**dictionary)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
|
|
257
|
-
"""
|
|
258
|
-
Save configuration to path.
|
|
259
|
-
|
|
260
|
-
Parameters
|
|
261
|
-
----------
|
|
262
|
-
config : Configuration
|
|
263
|
-
Configuration to save.
|
|
264
|
-
path : Union[str, Path]
|
|
265
|
-
Path to a existing folder in which to save the configuration or to an existing
|
|
266
|
-
configuration file.
|
|
267
|
-
|
|
268
|
-
Returns
|
|
269
|
-
-------
|
|
270
|
-
Path
|
|
271
|
-
Path object representing the configuration.
|
|
272
|
-
|
|
273
|
-
Raises
|
|
274
|
-
------
|
|
275
|
-
ValueError
|
|
276
|
-
If the path does not point to an existing directory or .yml file.
|
|
277
|
-
"""
|
|
278
|
-
# make sure path is a Path object
|
|
279
|
-
config_path = Path(path)
|
|
280
|
-
|
|
281
|
-
# check if path is pointing to an existing directory or .yml file
|
|
282
|
-
if config_path.exists():
|
|
283
|
-
if config_path.is_dir():
|
|
284
|
-
config_path = Path(config_path, "config.yml")
|
|
285
|
-
elif config_path.suffix != ".yml":
|
|
286
|
-
raise ValueError(
|
|
287
|
-
f"Path must be a directory or .yml file (got {config_path})."
|
|
288
|
-
)
|
|
289
|
-
else:
|
|
290
|
-
if config_path.suffix != ".yml":
|
|
291
|
-
raise ValueError(f"Path must be a .yml file (got {config_path}).")
|
|
292
|
-
|
|
293
|
-
# save configuration as dictionary to yaml
|
|
294
|
-
with open(config_path, "w") as f:
|
|
295
|
-
yaml.dump(config.model_dump(), f, default_flow_style=False)
|
|
296
|
-
|
|
297
|
-
return config_path
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
"""Convenience functions to filter dictionaries resulting from a Pydantic export."""
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from typing import Dict
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def paths_to_str(dictionary: dict) -> dict:
|
|
7
|
-
"""
|
|
8
|
-
Replace Path objects in a dictionary by str.
|
|
9
|
-
|
|
10
|
-
Parameters
|
|
11
|
-
----------
|
|
12
|
-
dictionary : dict
|
|
13
|
-
Dictionary to modify.
|
|
14
|
-
|
|
15
|
-
Returns
|
|
16
|
-
-------
|
|
17
|
-
dict
|
|
18
|
-
Modified dictionary.
|
|
19
|
-
"""
|
|
20
|
-
for k in dictionary.keys():
|
|
21
|
-
if isinstance(dictionary[k], Path):
|
|
22
|
-
dictionary[k] = str(dictionary[k])
|
|
23
|
-
|
|
24
|
-
return dictionary
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def remove_default_optionals(dictionary: Dict, default: Dict) -> None:
|
|
28
|
-
"""
|
|
29
|
-
Remove default arguments from a dictionary.
|
|
30
|
-
|
|
31
|
-
The method removes arguments if they are equal to the provided default ones.
|
|
32
|
-
|
|
33
|
-
Parameters
|
|
34
|
-
----------
|
|
35
|
-
dictionary : dict
|
|
36
|
-
Dictionary to modify.
|
|
37
|
-
default : dict
|
|
38
|
-
Dictionary containing the default values.
|
|
39
|
-
"""
|
|
40
|
-
dict_copy = dictionary.copy()
|
|
41
|
-
for k in dict_copy.keys():
|
|
42
|
-
if k in default.keys():
|
|
43
|
-
if dict_copy[k] == default[k]:
|
|
44
|
-
del dictionary[k]
|
careamics/config/data.py
DELETED
|
@@ -1,194 +0,0 @@
|
|
|
1
|
-
"""Data configuration."""
|
|
2
|
-
from __future__ import annotations
|
|
3
|
-
|
|
4
|
-
from enum import Enum
|
|
5
|
-
from typing import Dict, List, Optional
|
|
6
|
-
|
|
7
|
-
from pydantic import (
|
|
8
|
-
BaseModel,
|
|
9
|
-
ConfigDict,
|
|
10
|
-
Field,
|
|
11
|
-
field_validator,
|
|
12
|
-
model_validator,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
from careamics.utils import check_axes_validity
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class SupportedExtension(str, Enum):
|
|
19
|
-
"""
|
|
20
|
-
Supported extensions for input data.
|
|
21
|
-
|
|
22
|
-
Currently supported:
|
|
23
|
-
- tif/tiff: .tiff files.
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
TIFF = "tiff" # TODO these should be a single one
|
|
27
|
-
TIF = "tif"
|
|
28
|
-
|
|
29
|
-
@classmethod
|
|
30
|
-
def _missing_(cls, value: object) -> str:
|
|
31
|
-
"""
|
|
32
|
-
Override default behaviour for missing values.
|
|
33
|
-
|
|
34
|
-
This method is called when `value` is not found in the enum values. It converts
|
|
35
|
-
`value` to lowercase, removes "." if it is the first character and tries to
|
|
36
|
-
match it with enum values.
|
|
37
|
-
|
|
38
|
-
Parameters
|
|
39
|
-
----------
|
|
40
|
-
value : object
|
|
41
|
-
Value to be matched with enum values.
|
|
42
|
-
|
|
43
|
-
Returns
|
|
44
|
-
-------
|
|
45
|
-
str
|
|
46
|
-
Matched enum value.
|
|
47
|
-
"""
|
|
48
|
-
if isinstance(value, str):
|
|
49
|
-
lower_value = value.lower()
|
|
50
|
-
|
|
51
|
-
if lower_value.startswith("."):
|
|
52
|
-
lower_value = lower_value[1:]
|
|
53
|
-
|
|
54
|
-
# attempt to match lowercase value with enum values
|
|
55
|
-
for member in cls:
|
|
56
|
-
if member.value == lower_value:
|
|
57
|
-
return member
|
|
58
|
-
|
|
59
|
-
# still missing
|
|
60
|
-
return super()._missing_(value)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class Data(BaseModel):
|
|
64
|
-
"""
|
|
65
|
-
Data configuration.
|
|
66
|
-
|
|
67
|
-
If std is specified, mean must be specified as well. Note that setting the std first
|
|
68
|
-
and then the mean (if they were both `None` before) will raise a validation error.
|
|
69
|
-
Prefer instead the following:
|
|
70
|
-
>>> set_mean_and_std(mean, std)
|
|
71
|
-
|
|
72
|
-
Attributes
|
|
73
|
-
----------
|
|
74
|
-
in_memory : bool
|
|
75
|
-
Whether to load the data in memory or not.
|
|
76
|
-
data_format : SupportedExtension
|
|
77
|
-
Extension of the data, without period.
|
|
78
|
-
axes : str
|
|
79
|
-
Axes of the data.
|
|
80
|
-
mean: Optional[float]
|
|
81
|
-
Expected data mean.
|
|
82
|
-
std: Optional[float]
|
|
83
|
-
Expected data standard deviation.
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
# Pydantic class configuration
|
|
87
|
-
model_config = ConfigDict(
|
|
88
|
-
use_enum_values=True,
|
|
89
|
-
validate_assignment=True,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
# Mandatory fields
|
|
93
|
-
in_memory: bool
|
|
94
|
-
data_format: SupportedExtension
|
|
95
|
-
axes: str
|
|
96
|
-
|
|
97
|
-
# Optional fields
|
|
98
|
-
mean: Optional[float] = Field(default=None, ge=0)
|
|
99
|
-
std: Optional[float] = Field(default=None, gt=0)
|
|
100
|
-
|
|
101
|
-
def set_mean_and_std(self, mean: float, std: float) -> None:
|
|
102
|
-
"""
|
|
103
|
-
Set mean and standard deviation of the data.
|
|
104
|
-
|
|
105
|
-
This method is preferred to setting the fields directly, as it ensures that the
|
|
106
|
-
mean is set first, then the std; thus avoiding a validation error to be thrown.
|
|
107
|
-
|
|
108
|
-
Parameters
|
|
109
|
-
----------
|
|
110
|
-
mean : float
|
|
111
|
-
Mean of the data.
|
|
112
|
-
std : float
|
|
113
|
-
Standard deviation of the data.
|
|
114
|
-
"""
|
|
115
|
-
self.mean = mean
|
|
116
|
-
self.std = std
|
|
117
|
-
|
|
118
|
-
@field_validator("axes")
|
|
119
|
-
def valid_axes(cls, axes: str) -> str:
|
|
120
|
-
"""
|
|
121
|
-
Validate axes.
|
|
122
|
-
|
|
123
|
-
Axes must be a subset of STZYX, must contain YX, be in the right order
|
|
124
|
-
and not contain both S and T.
|
|
125
|
-
|
|
126
|
-
Parameters
|
|
127
|
-
----------
|
|
128
|
-
axes : str
|
|
129
|
-
Axes of the training data.
|
|
130
|
-
|
|
131
|
-
Returns
|
|
132
|
-
-------
|
|
133
|
-
str
|
|
134
|
-
Validated axes of the training data.
|
|
135
|
-
|
|
136
|
-
Raises
|
|
137
|
-
------
|
|
138
|
-
ValueError
|
|
139
|
-
If axes are not valid.
|
|
140
|
-
"""
|
|
141
|
-
# validate axes
|
|
142
|
-
check_axes_validity(axes)
|
|
143
|
-
|
|
144
|
-
return axes
|
|
145
|
-
|
|
146
|
-
@model_validator(mode="after")
|
|
147
|
-
def std_only_with_mean(cls, data_model: Data) -> Data:
|
|
148
|
-
"""
|
|
149
|
-
Check that mean and std are either both None, or both specified.
|
|
150
|
-
|
|
151
|
-
If we enforce both None or both specified, we cannot set the values one by one
|
|
152
|
-
due to the ConfDict enforcing the validation on assignment. Therefore, we check
|
|
153
|
-
only when the std is not None and the mean is None.
|
|
154
|
-
|
|
155
|
-
Parameters
|
|
156
|
-
----------
|
|
157
|
-
data_model : Data
|
|
158
|
-
Data model.
|
|
159
|
-
|
|
160
|
-
Returns
|
|
161
|
-
-------
|
|
162
|
-
Data
|
|
163
|
-
Validated data model.
|
|
164
|
-
|
|
165
|
-
Raises
|
|
166
|
-
------
|
|
167
|
-
ValueError
|
|
168
|
-
If std is not None and mean is None.
|
|
169
|
-
"""
|
|
170
|
-
if data_model.std is not None and data_model.mean is None:
|
|
171
|
-
raise ValueError("Cannot have std non None if mean is None.")
|
|
172
|
-
|
|
173
|
-
return data_model
|
|
174
|
-
|
|
175
|
-
def model_dump(self, *args: List, **kwargs: Dict) -> dict:
|
|
176
|
-
"""
|
|
177
|
-
Override model_dump method.
|
|
178
|
-
|
|
179
|
-
The purpose is to ensure export smooth import to yaml. It includes:
|
|
180
|
-
- remove entries with None value.
|
|
181
|
-
|
|
182
|
-
Parameters
|
|
183
|
-
----------
|
|
184
|
-
*args : List
|
|
185
|
-
Positional arguments, unused.
|
|
186
|
-
**kwargs : Dict
|
|
187
|
-
Keyword arguments, unused.
|
|
188
|
-
|
|
189
|
-
Returns
|
|
190
|
-
-------
|
|
191
|
-
dict
|
|
192
|
-
Dictionary containing the model parameters.
|
|
193
|
-
"""
|
|
194
|
-
return super().model_dump(exclude_none=True)
|
careamics/config/torch_optim.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
"""Convenience functions to instantiate torch.optim optimizers and schedulers."""
|
|
2
|
-
import inspect
|
|
3
|
-
from enum import Enum
|
|
4
|
-
from typing import Dict
|
|
5
|
-
|
|
6
|
-
from torch import optim
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class TorchOptimizer(str, Enum):
|
|
10
|
-
"""
|
|
11
|
-
Supported optimizers.
|
|
12
|
-
|
|
13
|
-
Currently only supports Adam and SGD.
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
# ASGD = "ASGD"
|
|
17
|
-
# Adadelta = "Adadelta"
|
|
18
|
-
# Adagrad = "Adagrad"
|
|
19
|
-
Adam = "Adam"
|
|
20
|
-
# AdamW = "AdamW"
|
|
21
|
-
# Adamax = "Adamax"
|
|
22
|
-
# LBFGS = "LBFGS"
|
|
23
|
-
# NAdam = "NAdam"
|
|
24
|
-
# RAdam = "RAdam"
|
|
25
|
-
# RMSprop = "RMSprop"
|
|
26
|
-
# Rprop = "Rprop"
|
|
27
|
-
SGD = "SGD"
|
|
28
|
-
# SparseAdam = "SparseAdam"
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
# TODO: Test which schedulers are compatible and if not, how to make them compatible
|
|
32
|
-
# (if we want to support them)
|
|
33
|
-
class TorchLRScheduler(str, Enum):
|
|
34
|
-
"""
|
|
35
|
-
Supported learning rate schedulers.
|
|
36
|
-
|
|
37
|
-
Currently only supports ReduceLROnPlateau and StepLR.
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
# ChainedScheduler = "ChainedScheduler"
|
|
41
|
-
# ConstantLR = "ConstantLR"
|
|
42
|
-
# CosineAnnealingLR = "CosineAnnealingLR"
|
|
43
|
-
# CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts"
|
|
44
|
-
# CyclicLR = "CyclicLR"
|
|
45
|
-
# ExponentialLR = "ExponentialLR"
|
|
46
|
-
# LambdaLR = "LambdaLR"
|
|
47
|
-
# LinearLR = "LinearLR"
|
|
48
|
-
# MultiStepLR = "MultiStepLR"
|
|
49
|
-
# MultiplicativeLR = "MultiplicativeLR"
|
|
50
|
-
# OneCycleLR = "OneCycleLR"
|
|
51
|
-
# PolynomialLR = "PolynomialLR"
|
|
52
|
-
ReduceLROnPlateau = "ReduceLROnPlateau"
|
|
53
|
-
# SequentialLR = "SequentialLR"
|
|
54
|
-
StepLR = "StepLR"
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def get_parameters(
|
|
58
|
-
func: type,
|
|
59
|
-
user_params: dict,
|
|
60
|
-
) -> dict:
|
|
61
|
-
"""
|
|
62
|
-
Filter parameters according to the function signature.
|
|
63
|
-
|
|
64
|
-
Parameters
|
|
65
|
-
----------
|
|
66
|
-
func : type
|
|
67
|
-
Class object.
|
|
68
|
-
user_params : Dict
|
|
69
|
-
User provided parameters.
|
|
70
|
-
|
|
71
|
-
Returns
|
|
72
|
-
-------
|
|
73
|
-
Dict
|
|
74
|
-
Parameters matching `func`'s signature.
|
|
75
|
-
"""
|
|
76
|
-
# Get the list of all default parameters
|
|
77
|
-
default_params = list(inspect.signature(func).parameters.keys())
|
|
78
|
-
|
|
79
|
-
# Filter matching parameters
|
|
80
|
-
params_to_be_used = set(user_params.keys()) & set(default_params)
|
|
81
|
-
|
|
82
|
-
return {key: user_params[key] for key in params_to_be_used}
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
def get_optimizers() -> Dict[str, str]:
|
|
86
|
-
"""
|
|
87
|
-
Return the list of all optimizers available in torch.optim.
|
|
88
|
-
|
|
89
|
-
Returns
|
|
90
|
-
-------
|
|
91
|
-
Dict
|
|
92
|
-
Optimizers available in torch.optim.
|
|
93
|
-
"""
|
|
94
|
-
optims = {}
|
|
95
|
-
for name, obj in inspect.getmembers(optim):
|
|
96
|
-
if inspect.isclass(obj) and issubclass(obj, optim.Optimizer):
|
|
97
|
-
if name != "Optimizer":
|
|
98
|
-
optims[name] = name
|
|
99
|
-
return optims
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def get_schedulers() -> Dict[str, str]:
|
|
103
|
-
"""
|
|
104
|
-
Return the list of all schedulers available in torch.optim.lr_scheduler.
|
|
105
|
-
|
|
106
|
-
Returns
|
|
107
|
-
-------
|
|
108
|
-
Dict
|
|
109
|
-
Schedulers available in torch.optim.lr_scheduler.
|
|
110
|
-
"""
|
|
111
|
-
schedulers = {}
|
|
112
|
-
for name, obj in inspect.getmembers(optim.lr_scheduler):
|
|
113
|
-
if inspect.isclass(obj) and issubclass(obj, optim.lr_scheduler.LRScheduler):
|
|
114
|
-
if "LRScheduler" not in name:
|
|
115
|
-
schedulers[name] = name
|
|
116
|
-
elif name == "ReduceLROnPlateau": # somewhat not a subclass of LRScheduler
|
|
117
|
-
schedulers[name] = name
|
|
118
|
-
return schedulers
|