careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (111) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +38 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +30 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +29 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +91 -186
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +5 -4
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/__init__.py +12 -1
  35. careamics/config/validators/model_validators.py +84 -0
  36. careamics/config/validators/validator_utils.py +3 -3
  37. careamics/dataset/__init__.py +2 -2
  38. careamics/dataset/dataset_utils/__init__.py +3 -3
  39. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  40. careamics/dataset/dataset_utils/file_utils.py +9 -9
  41. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +3 -3
  63. careamics/lightning/lightning_module.py +11 -7
  64. careamics/lightning/train_data_module.py +36 -45
  65. careamics/losses/__init__.py +3 -3
  66. careamics/lvae_training/calibration.py +64 -57
  67. careamics/lvae_training/dataset/lc_dataset.py +2 -1
  68. careamics/lvae_training/dataset/multich_dataset.py +2 -2
  69. careamics/lvae_training/dataset/types.py +1 -1
  70. careamics/lvae_training/eval_utils.py +123 -128
  71. careamics/model_io/__init__.py +1 -1
  72. careamics/model_io/bioimage/__init__.py +1 -1
  73. careamics/model_io/bioimage/_readme_factory.py +1 -1
  74. careamics/model_io/bioimage/model_description.py +17 -17
  75. careamics/model_io/bmz_io.py +6 -17
  76. careamics/model_io/model_io_utils.py +9 -9
  77. careamics/models/layers.py +16 -16
  78. careamics/models/lvae/likelihoods.py +2 -0
  79. careamics/models/lvae/lvae.py +13 -4
  80. careamics/models/lvae/noise_models.py +280 -217
  81. careamics/models/lvae/stochastic.py +1 -0
  82. careamics/models/model_factory.py +2 -15
  83. careamics/models/unet.py +8 -8
  84. careamics/prediction_utils/__init__.py +1 -1
  85. careamics/prediction_utils/prediction_outputs.py +15 -15
  86. careamics/prediction_utils/stitch_prediction.py +6 -6
  87. careamics/transforms/__init__.py +5 -5
  88. careamics/transforms/compose.py +13 -13
  89. careamics/transforms/n2v_manipulate.py +3 -3
  90. careamics/transforms/pixel_manipulation.py +9 -9
  91. careamics/transforms/xy_random_rotate90.py +4 -4
  92. careamics/utils/__init__.py +5 -5
  93. careamics/utils/context.py +2 -1
  94. careamics/utils/logging.py +11 -10
  95. careamics/utils/metrics.py +25 -0
  96. careamics/utils/plotting.py +78 -0
  97. careamics/utils/torch_utils.py +7 -7
  98. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
  99. careamics-0.0.7.dist-info/RECORD +178 -0
  100. careamics/config/architectures/custom_model.py +0 -162
  101. careamics/config/architectures/register_model.py +0 -103
  102. careamics/config/configuration_model.py +0 -603
  103. careamics/config/fcn_algorithm_model.py +0 -152
  104. careamics/config/references/__init__.py +0 -45
  105. careamics/config/references/algorithm_descriptions.py +0 -132
  106. careamics/config/references/references.py +0 -39
  107. careamics/config/transformations/transform_union.py +0 -20
  108. careamics-0.0.5.dist-info/RECORD +0 -171
  109. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
  110. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
  111. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,603 +0,0 @@
1
- """Pydantic CAREamics configuration."""
2
-
3
- from __future__ import annotations
4
-
5
- import re
6
- from pathlib import Path
7
- from pprint import pformat
8
- from typing import Literal, Union
9
-
10
- import yaml
11
- from bioimageio.spec.generic.v0_3 import CiteEntry
12
- from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
13
- from typing_extensions import Self
14
-
15
- from .data_model import DataConfig
16
- from .fcn_algorithm_model import FCNAlgorithmConfig
17
- from .references import (
18
- CARE,
19
- CUSTOM,
20
- N2N,
21
- N2V,
22
- N2V2,
23
- STRUCT_N2V,
24
- STRUCT_N2V2,
25
- CAREDescription,
26
- CARERef,
27
- N2NDescription,
28
- N2NRef,
29
- N2V2Description,
30
- N2V2Ref,
31
- N2VDescription,
32
- N2VRef,
33
- StructN2V2Description,
34
- StructN2VDescription,
35
- StructN2VRef,
36
- )
37
- from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
38
- from .training_model import TrainingConfig
39
- from .transformations.n2v_manipulate_model import (
40
- N2VManipulateModel,
41
- )
42
- from .vae_algorithm_model import VAEAlgorithmConfig
43
-
44
-
45
- class Configuration(BaseModel):
46
- """
47
- CAREamics configuration.
48
-
49
- The configuration defines all parameters used to build and train a CAREamics model.
50
- These parameters are validated to ensure that they are compatible with each other.
51
-
52
- It contains three sub-configurations:
53
-
54
- - AlgorithmModel: configuration for the algorithm training, which includes the
55
- architecture, loss function, optimizer, and other hyperparameters.
56
- - DataModel: configuration for the dataloader, which includes the type of data,
57
- transformations, mean/std and other parameters.
58
- - TrainingModel: configuration for the training, which includes the number of
59
- epochs or the callbacks.
60
-
61
- Attributes
62
- ----------
63
- experiment_name : str
64
- Name of the experiment, used when saving logs and checkpoints.
65
- algorithm : AlgorithmModel
66
- Algorithm configuration.
67
- data : DataModel
68
- Data configuration.
69
- training : TrainingModel
70
- Training configuration.
71
-
72
- Methods
73
- -------
74
- set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
75
- Switch configuration between 2D and 3D.
76
- set_N2V2(use_n2v2: bool) -> None
77
- Switch N2V algorithm between N2V and N2V2.
78
- set_structN2V(
79
- mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int) -> None
80
- Set StructN2V parameters.
81
- model_dump(
82
- exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
83
- ) -> Dict
84
- Export configuration to a dictionary.
85
-
86
- Raises
87
- ------
88
- ValueError
89
- Configuration parameter type validation errors.
90
- ValueError
91
- If the experiment name contains invalid characters or is empty.
92
- ValueError
93
- If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
94
- with "Z" in data axes.
95
- ValueError
96
- Algorithm, data or training validation errors.
97
-
98
- Notes
99
- -----
100
- We provide convenience methods to create standards configurations, for instance
101
- for N2V, in the `careamics.config.configuration_factory` module.
102
- >>> from careamics.config.configuration_factory import create_n2v_configuration
103
- >>> config = create_n2v_configuration(
104
- ... experiment_name="n2v_experiment",
105
- ... data_type="array",
106
- ... axes="YX",
107
- ... patch_size=[64, 64],
108
- ... batch_size=32,
109
- ... num_epochs=100
110
- ... )
111
-
112
- The configuration can be exported to a dictionary using the model_dump method:
113
- >>> config_dict = config.model_dump()
114
-
115
- Configurations can also be exported or imported from yaml files:
116
- >>> from careamics.config import save_configuration, load_configuration
117
- >>> path_to_config = save_configuration(config, my_path / "config.yml")
118
- >>> other_config = load_configuration(path_to_config)
119
-
120
- Examples
121
- --------
122
- Minimum example:
123
- >>> from careamics.config import Configuration
124
- >>> config_dict = {
125
- ... "experiment_name": "N2V_experiment",
126
- ... "algorithm_config": {
127
- ... "algorithm": "n2v",
128
- ... "loss": "n2v",
129
- ... "model": {
130
- ... "architecture": "UNet",
131
- ... },
132
- ... },
133
- ... "training_config": {
134
- ... "num_epochs": 200,
135
- ... },
136
- ... "data_config": {
137
- ... "data_type": "tiff",
138
- ... "patch_size": [64, 64],
139
- ... "axes": "SYX",
140
- ... },
141
- ... }
142
- >>> config = Configuration(**config_dict)
143
- """
144
-
145
- model_config = ConfigDict(
146
- validate_assignment=True,
147
- set_arbitrary_types_allowed=True,
148
- )
149
-
150
- # version
151
- version: Literal["0.1.0"] = "0.1.0"
152
- """CAREamics configuration version."""
153
-
154
- # required parameters
155
- experiment_name: str
156
- """Name of the experiment, used to name logs and checkpoints."""
157
-
158
- # Sub-configurations
159
- algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
160
- discriminator="algorithm"
161
- )
162
- """Algorithm configuration, holding all parameters required to configure the
163
- model."""
164
-
165
- data_config: DataConfig
166
- """Data configuration, holding all parameters required to configure the training
167
- data loader."""
168
-
169
- training_config: TrainingConfig
170
- """Training configuration, holding all parameters required to configure the
171
- training process."""
172
-
173
- @field_validator("experiment_name")
174
- @classmethod
175
- def no_symbol(cls, name: str) -> str:
176
- """
177
- Validate experiment name.
178
-
179
- A valid experiment name is a non-empty string with only contains letters,
180
- numbers, underscores, dashes and spaces.
181
-
182
- Parameters
183
- ----------
184
- name : str
185
- Name to validate.
186
-
187
- Returns
188
- -------
189
- str
190
- Validated name.
191
-
192
- Raises
193
- ------
194
- ValueError
195
- If the name is empty or contains invalid characters.
196
- """
197
- if len(name) == 0 or name.isspace():
198
- raise ValueError("Experiment name is empty.")
199
-
200
- # Validate using a regex that it contains only letters, numbers, underscores,
201
- # dashes and spaces
202
- if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
203
- raise ValueError(
204
- f"Experiment name contains invalid characters (got {name}). "
205
- f"Only letters, numbers, underscores, dashes and spaces are allowed."
206
- )
207
-
208
- return name
209
-
210
- @model_validator(mode="after")
211
- def validate_3D(self: Self) -> Self:
212
- """
213
- Change algorithm dimensions to match data.axes.
214
-
215
- Only for non-custom algorithms.
216
-
217
- Returns
218
- -------
219
- Self
220
- Validated configuration.
221
- """
222
- if self.algorithm_config.algorithm != SupportedAlgorithm.CUSTOM:
223
- if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
224
- # change algorithm to 3D
225
- self.algorithm_config.model.set_3D(True)
226
- elif (
227
- "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D()
228
- ):
229
- # change algorithm to 2D
230
- self.algorithm_config.model.set_3D(False)
231
-
232
- return self
233
-
234
- @model_validator(mode="after")
235
- def validate_algorithm_and_data(self: Self) -> Self:
236
- """
237
- Validate algorithm and data compatibility.
238
-
239
- In particular, the validation does the following:
240
-
241
- - If N2V is used, it enforces the presence of N2V_Maniuplate in the transforms
242
- - If N2V2 is used, it enforces the correct manipulation strategy
243
-
244
- Returns
245
- -------
246
- Self
247
- Validated configuration.
248
- """
249
- if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
250
- # missing N2V_MANIPULATE
251
- if not self.data_config.has_n2v_manipulate():
252
- self.data_config.transforms.append(
253
- N2VManipulateModel(
254
- name=SupportedTransform.N2V_MANIPULATE.value,
255
- )
256
- )
257
-
258
- median = SupportedPixelManipulation.MEDIAN.value
259
- uniform = SupportedPixelManipulation.UNIFORM.value
260
- strategy = median if self.algorithm_config.model.n2v2 else uniform
261
- self.data_config.set_N2V2_strategy(strategy)
262
- else:
263
- # remove N2V manipulate if present
264
- if self.data_config.has_n2v_manipulate():
265
- self.data_config.remove_n2v_manipulate()
266
-
267
- return self
268
-
269
- def __str__(self) -> str:
270
- """
271
- Pretty string reprensenting the configuration.
272
-
273
- Returns
274
- -------
275
- str
276
- Pretty string.
277
- """
278
- return pformat(self.model_dump())
279
-
280
- def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
281
- """
282
- Set 3D flag and axes.
283
-
284
- Parameters
285
- ----------
286
- is_3D : bool
287
- Whether the algorithm is 3D or not.
288
- axes : str
289
- Axes of the data.
290
- patch_size : list[int]
291
- Patch size.
292
- """
293
- # set the flag and axes (this will not trigger validation at the config level)
294
- self.algorithm_config.model.set_3D(is_3D)
295
- self.data_config.set_3D(axes, patch_size)
296
-
297
- # cheap hack: trigger validation
298
- self.algorithm_config = self.algorithm_config
299
-
300
- def set_N2V2(self, use_n2v2: bool) -> None:
301
- """
302
- Switch N2V algorithm between N2V and N2V2.
303
-
304
- Parameters
305
- ----------
306
- use_n2v2 : bool
307
- Whether to use N2V2 or not.
308
-
309
- Raises
310
- ------
311
- ValueError
312
- If the algorithm is not N2V.
313
- """
314
- if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
315
- self.algorithm_config.model.n2v2 = use_n2v2
316
- strategy = (
317
- SupportedPixelManipulation.MEDIAN.value
318
- if use_n2v2
319
- else SupportedPixelManipulation.UNIFORM.value
320
- )
321
- self.data_config.set_N2V2_strategy(strategy)
322
- else:
323
- raise ValueError("N2V2 can only be set for N2V algorithm.")
324
-
325
- def set_structN2V(
326
- self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
327
- ) -> None:
328
- """
329
- Set StructN2V parameters.
330
-
331
- Parameters
332
- ----------
333
- mask_axis : Literal["horizontal", "vertical", "none"]
334
- Axis of the structural mask.
335
- mask_span : int
336
- Span of the structural mask.
337
- """
338
- self.data_config.set_structN2V_mask(mask_axis, mask_span)
339
-
340
- def get_algorithm_flavour(self) -> str:
341
- """
342
- Get the algorithm name.
343
-
344
- Returns
345
- -------
346
- str
347
- Algorithm name.
348
- """
349
- if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
350
- use_n2v2 = self.algorithm_config.model.n2v2
351
- use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
352
-
353
- # return the n2v flavour
354
- if use_n2v2 and use_structN2V:
355
- return STRUCT_N2V2
356
- elif use_n2v2:
357
- return N2V2
358
- elif use_structN2V:
359
- return STRUCT_N2V
360
- else:
361
- return N2V
362
- elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
363
- return N2N
364
- elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
365
- return CARE
366
- else:
367
- return CUSTOM
368
-
369
- def get_algorithm_description(self) -> str:
370
- """
371
- Return a description of the algorithm.
372
-
373
- This method is used to generate the README of the BioImage Model Zoo export.
374
-
375
- Returns
376
- -------
377
- str
378
- Description of the algorithm.
379
- """
380
- algorithm_flavour = self.get_algorithm_flavour()
381
-
382
- if algorithm_flavour == CUSTOM:
383
- return f"Custom algorithm, named {self.algorithm_config.model.name}"
384
- else: # currently only N2V flavours
385
- if algorithm_flavour == N2V:
386
- return N2VDescription().description
387
- elif algorithm_flavour == N2V2:
388
- return N2V2Description().description
389
- elif algorithm_flavour == STRUCT_N2V:
390
- return StructN2VDescription().description
391
- elif algorithm_flavour == STRUCT_N2V2:
392
- return StructN2V2Description().description
393
- elif algorithm_flavour == N2N:
394
- return N2NDescription().description
395
- elif algorithm_flavour == CARE:
396
- return CAREDescription().description
397
-
398
- return ""
399
-
400
- def get_algorithm_citations(self) -> list[CiteEntry]:
401
- """
402
- Return a list of citation entries of the current algorithm.
403
-
404
- This is used to generate the model description for the BioImage Model Zoo.
405
-
406
- Returns
407
- -------
408
- List[CiteEntry]
409
- List of citation entries.
410
- """
411
- if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
412
- use_n2v2 = self.algorithm_config.model.n2v2
413
- use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
414
-
415
- # return the (struct)N2V(2) references
416
- if use_n2v2 and use_structN2V:
417
- return [N2VRef, N2V2Ref, StructN2VRef]
418
- elif use_n2v2:
419
- return [N2VRef, N2V2Ref]
420
- elif use_structN2V:
421
- return [N2VRef, StructN2VRef]
422
- else:
423
- return [N2VRef]
424
- elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
425
- return [N2NRef]
426
- elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
427
- return [CARERef]
428
-
429
- raise ValueError("Citation not available for custom algorithm.")
430
-
431
- def get_algorithm_references(self) -> str:
432
- """
433
- Get the algorithm references.
434
-
435
- This is used to generate the README of the BioImage Model Zoo export.
436
-
437
- Returns
438
- -------
439
- str
440
- Algorithm references.
441
- """
442
- if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
443
- use_n2v2 = self.algorithm_config.model.n2v2
444
- use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
445
-
446
- references = [
447
- N2VRef.text + " doi: " + N2VRef.doi,
448
- N2V2Ref.text + " doi: " + N2V2Ref.doi,
449
- StructN2VRef.text + " doi: " + StructN2VRef.doi,
450
- ]
451
-
452
- # return the (struct)N2V(2) references
453
- if use_n2v2 and use_structN2V:
454
- return "".join(references)
455
- elif use_n2v2:
456
- references.pop(-1)
457
- return "".join(references)
458
- elif use_structN2V:
459
- references.pop(-2)
460
- return "".join(references)
461
- else:
462
- return references[0]
463
-
464
- return ""
465
-
466
- def get_algorithm_keywords(self) -> list[str]:
467
- """
468
- Get algorithm keywords.
469
-
470
- Returns
471
- -------
472
- list[str]
473
- List of keywords.
474
- """
475
- if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
476
- use_n2v2 = self.algorithm_config.model.n2v2
477
- use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
478
-
479
- keywords = [
480
- "denoising",
481
- "restoration",
482
- "UNet",
483
- "3D" if "Z" in self.data_config.axes else "2D",
484
- "CAREamics",
485
- "pytorch",
486
- N2V,
487
- ]
488
-
489
- if use_n2v2:
490
- keywords.append(N2V2)
491
- if use_structN2V:
492
- keywords.append(STRUCT_N2V)
493
- else:
494
- keywords = ["CAREamics"]
495
-
496
- return keywords
497
-
498
- def model_dump(
499
- self,
500
- exclude_defaults: bool = False,
501
- exclude_none: bool = True,
502
- **kwargs: dict,
503
- ) -> dict:
504
- """
505
- Override model_dump method in order to set default values.
506
-
507
- Parameters
508
- ----------
509
- exclude_defaults : bool, optional
510
- Whether to exclude fields with default values or not, by default
511
- True.
512
- exclude_none : bool, optional
513
- Whether to exclude fields with None values or not, by default True.
514
- **kwargs : dict
515
- Keyword arguments.
516
-
517
- Returns
518
- -------
519
- dict
520
- Dictionary containing the model parameters.
521
- """
522
- dictionary = super().model_dump(
523
- exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs
524
- )
525
-
526
- return dictionary
527
-
528
-
529
- def load_configuration(path: Union[str, Path]) -> Configuration:
530
- """
531
- Load configuration from a yaml file.
532
-
533
- Parameters
534
- ----------
535
- path : str or Path
536
- Path to the configuration.
537
-
538
- Returns
539
- -------
540
- Configuration
541
- Configuration.
542
-
543
- Raises
544
- ------
545
- FileNotFoundError
546
- If the configuration file does not exist.
547
- """
548
- # load dictionary from yaml
549
- if not Path(path).exists():
550
- raise FileNotFoundError(
551
- f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
552
- )
553
-
554
- dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
555
-
556
- return Configuration(**dictionary)
557
-
558
-
559
- def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
560
- """
561
- Save configuration to path.
562
-
563
- Parameters
564
- ----------
565
- config : Configuration
566
- Configuration to save.
567
- path : str or Path
568
- Path to a existing folder in which to save the configuration, or to a valid
569
- configuration file path (uses a .yml or .yaml extension).
570
-
571
- Returns
572
- -------
573
- Path
574
- Path object representing the configuration.
575
-
576
- Raises
577
- ------
578
- ValueError
579
- If the path does not point to an existing directory or .yml file.
580
- """
581
- # make sure path is a Path object
582
- config_path = Path(path)
583
-
584
- # check if path is pointing to an existing directory or .yml file
585
- if config_path.exists():
586
- if config_path.is_dir():
587
- config_path = Path(config_path, "config.yml")
588
- elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
589
- raise ValueError(
590
- f"Path must be a directory or .yml or .yaml file (got {config_path})."
591
- )
592
- else:
593
- if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
594
- raise ValueError(
595
- f"Path must be a directory or .yml or .yaml file (got {config_path})."
596
- )
597
-
598
- # save configuration as dictionary to yaml
599
- with open(config_path, "w") as f:
600
- # dump configuration
601
- yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
602
-
603
- return config_path