careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -2,26 +2,93 @@
2
2
 
3
3
  from typing import Any, Literal, Optional, Union
4
4
 
5
- from .architectures import UNetModel
6
- from .configuration_model import Configuration
7
- from .data_model import DataConfig
8
- from .fcn_algorithm_model import FCNAlgorithmConfig
9
- from .support import (
5
+ from pydantic import TypeAdapter
6
+
7
+ from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
8
+ from careamics.config.architectures import UNetModel
9
+ from careamics.config.care_configuration import CAREConfiguration
10
+ from careamics.config.configuration import Configuration
11
+ from careamics.config.data import DataConfig, N2VDataConfig
12
+ from careamics.config.n2n_configuration import N2NConfiguration
13
+ from careamics.config.n2v_configuration import N2VConfiguration
14
+ from careamics.config.support import (
10
15
  SupportedArchitecture,
11
16
  SupportedPixelManipulation,
12
17
  SupportedTransform,
13
18
  )
14
- from .training_model import TrainingConfig
15
- from .transformations import (
19
+ from careamics.config.training_model import TrainingConfig
20
+ from careamics.config.transformations import (
21
+ N2V_TRANSFORMS_UNION,
22
+ SPATIAL_TRANSFORMS_UNION,
16
23
  N2VManipulateModel,
17
24
  XYFlipModel,
18
25
  XYRandomRotate90Model,
19
26
  )
20
27
 
21
28
 
22
- def _list_augmentations(
23
- augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]],
24
- ) -> list[Union[XYFlipModel, XYRandomRotate90Model]]:
29
+ def configuration_factory(
30
+ configuration: dict[str, Any]
31
+ ) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
32
+ """
33
+ Create a configuration for training CAREamics.
34
+
35
+ Parameters
36
+ ----------
37
+ configuration : dict
38
+ Configuration dictionary.
39
+
40
+ Returns
41
+ -------
42
+ N2VConfiguration or N2NConfiguration or CAREConfiguration
43
+ Configuration for training CAREamics.
44
+ """
45
+ adapter: TypeAdapter = TypeAdapter(
46
+ Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]
47
+ )
48
+ return adapter.validate_python(configuration)
49
+
50
+
51
+ def algorithm_factory(
52
+ algorithm: dict[str, Any]
53
+ ) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
54
+ """
55
+ Create an algorithm model for training CAREamics.
56
+
57
+ Parameters
58
+ ----------
59
+ algorithm : dict
60
+ Algorithm dictionary.
61
+
62
+ Returns
63
+ -------
64
+ N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
65
+ Algorithm model for training CAREamics.
66
+ """
67
+ adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm])
68
+ return adapter.validate_python(algorithm)
69
+
70
+
71
+ def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
72
+ """
73
+ Create a data model for training CAREamics.
74
+
75
+ Parameters
76
+ ----------
77
+ data : dict
78
+ Data dictionary.
79
+
80
+ Returns
81
+ -------
82
+ DataConfig or N2VDataConfig
83
+ Data model for training CAREamics.
84
+ """
85
+ adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
86
+ return adapter.validate_python(data)
87
+
88
+
89
+ def _list_spatial_augmentations(
90
+ augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
91
+ ) -> list[SPATIAL_TRANSFORMS_UNION]:
25
92
  """
26
93
  List the augmentations to apply.
27
94
 
@@ -44,7 +111,7 @@ def _list_augmentations(
44
111
  If there are duplicate transforms.
45
112
  """
46
113
  if augmentations is None:
47
- transform_list: list[Union[XYFlipModel, XYRandomRotate90Model]] = [
114
+ transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
48
115
  XYFlipModel(),
49
116
  XYRandomRotate90Model(),
50
117
  ]
@@ -123,7 +190,7 @@ def _create_configuration(
123
190
  patch_size: list[int],
124
191
  batch_size: int,
125
192
  num_epochs: int,
126
- augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]],
193
+ augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
127
194
  independent_channels: bool,
128
195
  loss: Literal["n2v", "mae", "mse"],
129
196
  n_channels_in: int,
@@ -188,21 +255,21 @@ def _create_configuration(
188
255
  )
189
256
 
190
257
  # algorithm model
191
- algorithm_config = FCNAlgorithmConfig(
192
- algorithm=algorithm,
193
- loss=loss,
194
- model=unet_model,
195
- )
258
+ algorithm_config = {
259
+ "algorithm": algorithm,
260
+ "loss": loss,
261
+ "model": unet_model,
262
+ }
196
263
 
197
264
  # data model
198
- data = DataConfig(
199
- data_type=data_type,
200
- axes=axes,
201
- patch_size=patch_size,
202
- batch_size=batch_size,
203
- transforms=augmentations,
204
- dataloader_params=dataloader_params,
205
- )
265
+ data = {
266
+ "data_type": data_type,
267
+ "axes": axes,
268
+ "patch_size": patch_size,
269
+ "batch_size": batch_size,
270
+ "transforms": augmentations,
271
+ "dataloader_params": dataloader_params,
272
+ }
206
273
 
207
274
  # training model
208
275
  training = TrainingConfig(
@@ -212,14 +279,14 @@ def _create_configuration(
212
279
  )
213
280
 
214
281
  # create configuration
215
- configuration = Configuration(
216
- experiment_name=experiment_name,
217
- algorithm_config=algorithm_config,
218
- data_config=data,
219
- training_config=training,
220
- )
282
+ configuration = {
283
+ "experiment_name": experiment_name,
284
+ "algorithm_config": algorithm_config,
285
+ "data_config": data,
286
+ "training_config": training,
287
+ }
221
288
 
222
- return configuration
289
+ return configuration_factory(configuration)
223
290
 
224
291
 
225
292
  # TODO reconsider naming once we officially support LVAE approaches
@@ -234,8 +301,8 @@ def _create_supervised_configuration(
234
301
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
235
302
  independent_channels: bool = True,
236
303
  loss: Literal["mae", "mse"] = "mae",
237
- n_channels_in: int = 1,
238
- n_channels_out: int = 1,
304
+ n_channels_in: Optional[int] = None,
305
+ n_channels_out: Optional[int] = None,
239
306
  logger: Literal["wandb", "tensorboard", "none"] = "none",
240
307
  model_params: Optional[dict] = None,
241
308
  dataloader_params: Optional[dict] = None,
@@ -267,10 +334,10 @@ def _create_supervised_configuration(
267
334
  Whether to train all channels independently, by default False.
268
335
  loss : Literal["mae", "mse"], optional
269
336
  Loss function to use, by default "mae".
270
- n_channels_in : int, optional
271
- Number of channels in, by default 1.
272
- n_channels_out : int, optional
273
- Number of channels out, by default 1.
337
+ n_channels_in : int or None, default=None
338
+ Number of channels in.
339
+ n_channels_out : int or None, default=None
340
+ Number of channels out.
274
341
  logger : Literal["wandb", "tensorboard", "none"], optional
275
342
  Logger to use, by default "none".
276
343
  model_params : dict, optional
@@ -282,21 +349,31 @@ def _create_supervised_configuration(
282
349
  -------
283
350
  Configuration
284
351
  Configuration for training CARE or Noise2Noise.
352
+
353
+ Raises
354
+ ------
355
+ ValueError
356
+ If the number of channels is not specified when using channels.
357
+ ValueError
358
+ If the number of channels is specified but "C" is not in the axes.
285
359
  """
286
360
  # if there are channels, we need to specify their number
287
- if "C" in axes and n_channels_in == 1:
288
- raise ValueError(
289
- f"Number of channels in must be specified when using channels "
290
- f"(got {n_channels_in} channel)."
291
- )
292
- elif "C" not in axes and n_channels_in > 1:
361
+ if "C" in axes and n_channels_in is None:
362
+ raise ValueError("Number of channels in must be specified when using channels ")
363
+ elif "C" not in axes and (n_channels_in is not None and n_channels_in > 1):
293
364
  raise ValueError(
294
365
  f"C is not present in the axes, but number of channels is specified "
295
366
  f"(got {n_channels_in} channels)."
296
367
  )
297
368
 
369
+ if n_channels_in is None:
370
+ n_channels_in = 1
371
+
372
+ if n_channels_out is None:
373
+ n_channels_out = n_channels_in
374
+
298
375
  # augmentations
299
- transform_list = _list_augmentations(augmentations)
376
+ spatial_transform_list = _list_spatial_augmentations(augmentations)
300
377
 
301
378
  return _create_configuration(
302
379
  algorithm=algorithm,
@@ -306,7 +383,7 @@ def _create_supervised_configuration(
306
383
  patch_size=patch_size,
307
384
  batch_size=batch_size,
308
385
  num_epochs=num_epochs,
309
- augmentations=transform_list,
386
+ augmentations=spatial_transform_list,
310
387
  independent_channels=independent_channels,
311
388
  loss=loss,
312
389
  n_channels_in=n_channels_in,
@@ -327,8 +404,8 @@ def create_care_configuration(
327
404
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
328
405
  independent_channels: bool = True,
329
406
  loss: Literal["mae", "mse"] = "mae",
330
- n_channels_in: int = 1,
331
- n_channels_out: int = -1,
407
+ n_channels_in: Optional[int] = None,
408
+ n_channels_out: Optional[int] = None,
332
409
  logger: Literal["wandb", "tensorboard", "none"] = "none",
333
410
  model_params: Optional[dict] = None,
334
411
  dataloader_params: Optional[dict] = None,
@@ -374,16 +451,16 @@ def create_care_configuration(
374
451
  and XYRandomRotate90 (in XY) to the images.
375
452
  independent_channels : bool, optional
376
453
  Whether to train all channels independently, by default False.
377
- loss : Literal["mae", "mse"], optional
378
- Loss function to use, by default "mae".
379
- n_channels_in : int, optional
380
- Number of channels in, by default 1.
381
- n_channels_out : int, optional
382
- Number of channels out, by default -1.
383
- logger : Literal["wandb", "tensorboard", "none"], optional
384
- Logger to use, by default "none".
385
- model_params : dict, optional
386
- UNetModel parameters, by default None.
454
+ loss : Literal["mae", "mse"], default="mae"
455
+ Loss function to use.
456
+ n_channels_in : int or None, default=None
457
+ Number of channels in.
458
+ n_channels_out : int or None, default=None
459
+ Number of channels out.
460
+ logger : Literal["wandb", "tensorboard", "none"], default="none"
461
+ Logger to use.
462
+ model_params : dict, default=None
463
+ UNetModel parameters.
387
464
  dataloader_params : dict, optional
388
465
  Parameters for the dataloader, see PyTorch notes, by default None.
389
466
 
@@ -459,9 +536,6 @@ def create_care_configuration(
459
536
  ... n_channels_out=1 # if applicable
460
537
  ... )
461
538
  """
462
- if n_channels_out == -1:
463
- n_channels_out = n_channels_in
464
-
465
539
  return _create_supervised_configuration(
466
540
  algorithm="care",
467
541
  experiment_name=experiment_name,
@@ -491,8 +565,8 @@ def create_n2n_configuration(
491
565
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
492
566
  independent_channels: bool = True,
493
567
  loss: Literal["mae", "mse"] = "mae",
494
- n_channels_in: int = 1,
495
- n_channels_out: int = -1,
568
+ n_channels_in: Optional[int] = None,
569
+ n_channels_out: Optional[int] = None,
496
570
  logger: Literal["wandb", "tensorboard", "none"] = "none",
497
571
  model_params: Optional[dict] = None,
498
572
  dataloader_params: Optional[dict] = None,
@@ -540,10 +614,10 @@ def create_n2n_configuration(
540
614
  Whether to train all channels independently, by default False.
541
615
  loss : Literal["mae", "mse"], optional
542
616
  Loss function to use, by default "mae".
543
- n_channels_in : int, optional
544
- Number of channels in, by default 1.
545
- n_channels_out : int, optional
546
- Number of channels out, by default -1.
617
+ n_channels_in : int or None, default=None
618
+ Number of channels in.
619
+ n_channels_out : int or None, default=None
620
+ Number of channels out.
547
621
  logger : Literal["wandb", "tensorboard", "none"], optional
548
622
  Logger to use, by default "none".
549
623
  model_params : dict, optional
@@ -623,9 +697,6 @@ def create_n2n_configuration(
623
697
  ... n_channels_out=1 # if applicable
624
698
  ... )
625
699
  """
626
- if n_channels_out == -1:
627
- n_channels_out = n_channels_in
628
-
629
700
  return _create_supervised_configuration(
630
701
  algorithm="n2n",
631
702
  experiment_name=experiment_name,
@@ -655,7 +726,7 @@ def create_n2v_configuration(
655
726
  augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]] = None,
656
727
  independent_channels: bool = True,
657
728
  use_n2v2: bool = False,
658
- n_channels: int = 1,
729
+ n_channels: Optional[int] = None,
659
730
  roi_size: int = 11,
660
731
  masked_pixel_percentage: float = 0.2,
661
732
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
@@ -727,8 +798,8 @@ def create_n2v_configuration(
727
798
  Whether to train all channels together, by default True.
728
799
  use_n2v2 : bool, optional
729
800
  Whether to use N2V2, by default False.
730
- n_channels : int, optional
731
- Number of channels (in and out), by default 1.
801
+ n_channels : int or None, default=None
802
+ Number of channels (in and out).
732
803
  roi_size : int, optional
733
804
  N2V pixel manipulation area, by default 11.
734
805
  masked_pixel_percentage : float, optional
@@ -837,19 +908,19 @@ def create_n2v_configuration(
837
908
  ... )
838
909
  """
839
910
  # if there are channels, we need to specify their number
840
- if "C" in axes and n_channels == 1:
841
- raise ValueError(
842
- f"Number of channels must be specified when using channels "
843
- f"(got {n_channels} channel)."
844
- )
845
- elif "C" not in axes and n_channels > 1:
911
+ if "C" in axes and n_channels is None:
912
+ raise ValueError("Number of channels must be specified when using channels.")
913
+ elif "C" not in axes and (n_channels is not None and n_channels > 1):
846
914
  raise ValueError(
847
915
  f"C is not present in the axes, but number of channels is specified "
848
916
  f"(got {n_channels} channel)."
849
917
  )
850
918
 
919
+ if n_channels is None:
920
+ n_channels = 1
921
+
851
922
  # augmentations
852
- transform_list = _list_augmentations(augmentations)
923
+ spatial_transforms = _list_spatial_augmentations(augmentations)
853
924
 
854
925
  # create the N2VManipulate transform using the supplied parameters
855
926
  n2v_transform = N2VManipulateModel(
@@ -864,7 +935,7 @@ def create_n2v_configuration(
864
935
  struct_mask_axis=struct_n2v_axis,
865
936
  struct_mask_span=struct_n2v_span,
866
937
  )
867
- transform_list.append(n2v_transform)
938
+ transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
868
939
 
869
940
  return _create_configuration(
870
941
  algorithm="n2v",
@@ -0,0 +1,85 @@
1
+ """I/O functions for Configuration objects."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import yaml
7
+
8
+ from careamics.config import Configuration, configuration_factory
9
+
10
+
11
+ def load_configuration(path: Union[str, Path]) -> Configuration:
12
+ """
13
+ Load configuration from a yaml file.
14
+
15
+ Parameters
16
+ ----------
17
+ path : str or Path
18
+ Path to the configuration.
19
+
20
+ Returns
21
+ -------
22
+ Configuration
23
+ Configuration.
24
+
25
+ Raises
26
+ ------
27
+ FileNotFoundError
28
+ If the configuration file does not exist.
29
+ """
30
+ # load dictionary from yaml
31
+ if not Path(path).exists():
32
+ raise FileNotFoundError(
33
+ f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
34
+ )
35
+
36
+ dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
37
+
38
+ return configuration_factory(dictionary)
39
+
40
+
41
+ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
42
+ """
43
+ Save configuration to path.
44
+
45
+ Parameters
46
+ ----------
47
+ config : Configuration
48
+ Configuration to save.
49
+ path : str or Path
50
+ Path to a existing folder in which to save the configuration, or to a valid
51
+ configuration file path (uses a .yml or .yaml extension).
52
+
53
+ Returns
54
+ -------
55
+ Path
56
+ Path object representing the configuration.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If the path does not point to an existing directory or .yml file.
62
+ """
63
+ # make sure path is a Path object
64
+ config_path = Path(path)
65
+
66
+ # check if path is pointing to an existing directory or .yml file
67
+ if config_path.exists():
68
+ if config_path.is_dir():
69
+ config_path = Path(config_path, "config.yml")
70
+ elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
71
+ raise ValueError(
72
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
73
+ )
74
+ else:
75
+ if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
76
+ raise ValueError(
77
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
78
+ )
79
+
80
+ # save configuration as dictionary to yaml
81
+ with open(config_path, "w") as f:
82
+ # dump configuration
83
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
84
+
85
+ return config_path
@@ -0,0 +1,10 @@
1
+ """Data Pydantic configuration models."""
2
+
3
+ __all__ = [
4
+ "DataConfig",
5
+ "GeneralDataConfig",
6
+ "N2VDataConfig",
7
+ ]
8
+
9
+ from .data_model import DataConfig, GeneralDataConfig
10
+ from .n2v_data_model import N2VDataConfig