careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (111) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +38 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +30 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +29 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +91 -186
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +5 -4
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/__init__.py +12 -1
  35. careamics/config/validators/model_validators.py +84 -0
  36. careamics/config/validators/validator_utils.py +3 -3
  37. careamics/dataset/__init__.py +2 -2
  38. careamics/dataset/dataset_utils/__init__.py +3 -3
  39. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  40. careamics/dataset/dataset_utils/file_utils.py +9 -9
  41. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +3 -3
  63. careamics/lightning/lightning_module.py +11 -7
  64. careamics/lightning/train_data_module.py +36 -45
  65. careamics/losses/__init__.py +3 -3
  66. careamics/lvae_training/calibration.py +64 -57
  67. careamics/lvae_training/dataset/lc_dataset.py +2 -1
  68. careamics/lvae_training/dataset/multich_dataset.py +2 -2
  69. careamics/lvae_training/dataset/types.py +1 -1
  70. careamics/lvae_training/eval_utils.py +123 -128
  71. careamics/model_io/__init__.py +1 -1
  72. careamics/model_io/bioimage/__init__.py +1 -1
  73. careamics/model_io/bioimage/_readme_factory.py +1 -1
  74. careamics/model_io/bioimage/model_description.py +17 -17
  75. careamics/model_io/bmz_io.py +6 -17
  76. careamics/model_io/model_io_utils.py +9 -9
  77. careamics/models/layers.py +16 -16
  78. careamics/models/lvae/likelihoods.py +2 -0
  79. careamics/models/lvae/lvae.py +13 -4
  80. careamics/models/lvae/noise_models.py +280 -217
  81. careamics/models/lvae/stochastic.py +1 -0
  82. careamics/models/model_factory.py +2 -15
  83. careamics/models/unet.py +8 -8
  84. careamics/prediction_utils/__init__.py +1 -1
  85. careamics/prediction_utils/prediction_outputs.py +15 -15
  86. careamics/prediction_utils/stitch_prediction.py +6 -6
  87. careamics/transforms/__init__.py +5 -5
  88. careamics/transforms/compose.py +13 -13
  89. careamics/transforms/n2v_manipulate.py +3 -3
  90. careamics/transforms/pixel_manipulation.py +9 -9
  91. careamics/transforms/xy_random_rotate90.py +4 -4
  92. careamics/utils/__init__.py +5 -5
  93. careamics/utils/context.py +2 -1
  94. careamics/utils/logging.py +11 -10
  95. careamics/utils/metrics.py +25 -0
  96. careamics/utils/plotting.py +78 -0
  97. careamics/utils/torch_utils.py +7 -7
  98. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
  99. careamics-0.0.7.dist-info/RECORD +178 -0
  100. careamics/config/architectures/custom_model.py +0 -162
  101. careamics/config/architectures/register_model.py +0 -103
  102. careamics/config/configuration_model.py +0 -603
  103. careamics/config/fcn_algorithm_model.py +0 -152
  104. careamics/config/references/__init__.py +0 -45
  105. careamics/config/references/algorithm_descriptions.py +0 -132
  106. careamics/config/references/references.py +0 -39
  107. careamics/config/transformations/transform_union.py +0 -20
  108. careamics-0.0.5.dist-info/RECORD +0 -171
  109. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
  110. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
  111. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,27 +1,120 @@
1
1
  """Convenience functions to create configurations for training and inference."""
2
2
 
3
- from typing import Any, Literal, Optional, Union
4
-
5
- from .architectures import UNetModel
6
- from .configuration_model import Configuration
7
- from .data_model import DataConfig
8
- from .fcn_algorithm_model import FCNAlgorithmConfig
9
- from .support import (
3
+ from typing import Annotated, Any, Literal, Optional, Union
4
+
5
+ from pydantic import Discriminator, Tag, TypeAdapter
6
+
7
+ from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
8
+ from careamics.config.architectures import UNetModel
9
+ from careamics.config.care_configuration import CAREConfiguration
10
+ from careamics.config.configuration import Configuration
11
+ from careamics.config.data import DataConfig, N2VDataConfig
12
+ from careamics.config.n2n_configuration import N2NConfiguration
13
+ from careamics.config.n2v_configuration import N2VConfiguration
14
+ from careamics.config.support import (
15
+ SupportedAlgorithm,
10
16
  SupportedArchitecture,
11
17
  SupportedPixelManipulation,
12
18
  SupportedTransform,
13
19
  )
14
- from .training_model import TrainingConfig
15
- from .transformations import (
20
+ from careamics.config.training_model import TrainingConfig
21
+ from careamics.config.transformations import (
22
+ N2V_TRANSFORMS_UNION,
23
+ SPATIAL_TRANSFORMS_UNION,
16
24
  N2VManipulateModel,
17
25
  XYFlipModel,
18
26
  XYRandomRotate90Model,
19
27
  )
20
28
 
21
29
 
22
- def _list_augmentations(
23
- augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]],
24
- ) -> list[Union[XYFlipModel, XYRandomRotate90Model]]:
30
+ def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str:
31
+ """Discriminate algorithm-specific configurations based on the algorithm.
32
+
33
+ Parameters
34
+ ----------
35
+ value : Any
36
+ Value to discriminate.
37
+
38
+ Returns
39
+ -------
40
+ str
41
+ Discriminator value.
42
+ """
43
+ if isinstance(value, dict):
44
+ return value["algorithm_config"]["algorithm"]
45
+ return value.algorithm_config.algorithm
46
+
47
+
48
+ def configuration_factory(
49
+ configuration: dict[str, Any]
50
+ ) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
51
+ """
52
+ Create a configuration for training CAREamics.
53
+
54
+ Parameters
55
+ ----------
56
+ configuration : dict
57
+ Configuration dictionary.
58
+
59
+ Returns
60
+ -------
61
+ N2VConfiguration or N2NConfiguration or CAREConfiguration
62
+ Configuration for training CAREamics.
63
+ """
64
+ adapter: TypeAdapter = TypeAdapter(
65
+ Annotated[
66
+ Union[
67
+ Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)],
68
+ Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)],
69
+ Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)],
70
+ ],
71
+ Discriminator(_algorithm_config_discriminator),
72
+ ]
73
+ )
74
+ return adapter.validate_python(configuration)
75
+
76
+
77
+ def algorithm_factory(
78
+ algorithm: dict[str, Any]
79
+ ) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
80
+ """
81
+ Create an algorithm model for training CAREamics.
82
+
83
+ Parameters
84
+ ----------
85
+ algorithm : dict
86
+ Algorithm dictionary.
87
+
88
+ Returns
89
+ -------
90
+ N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
91
+ Algorithm model for training CAREamics.
92
+ """
93
+ adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm])
94
+ return adapter.validate_python(algorithm)
95
+
96
+
97
+ def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
98
+ """
99
+ Create a data model for training CAREamics.
100
+
101
+ Parameters
102
+ ----------
103
+ data : dict
104
+ Data dictionary.
105
+
106
+ Returns
107
+ -------
108
+ DataConfig or N2VDataConfig
109
+ Data model for training CAREamics.
110
+ """
111
+ adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
112
+ return adapter.validate_python(data)
113
+
114
+
115
+ def _list_spatial_augmentations(
116
+ augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
117
+ ) -> list[SPATIAL_TRANSFORMS_UNION]:
25
118
  """
26
119
  List the augmentations to apply.
27
120
 
@@ -44,7 +137,7 @@ def _list_augmentations(
44
137
  If there are duplicate transforms.
45
138
  """
46
139
  if augmentations is None:
47
- transform_list: list[Union[XYFlipModel, XYRandomRotate90Model]] = [
140
+ transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
48
141
  XYFlipModel(),
49
142
  XYRandomRotate90Model(),
50
143
  ]
@@ -123,7 +216,7 @@ def _create_configuration(
123
216
  patch_size: list[int],
124
217
  batch_size: int,
125
218
  num_epochs: int,
126
- augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]],
219
+ augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
127
220
  independent_channels: bool,
128
221
  loss: Literal["n2v", "mae", "mse"],
129
222
  n_channels_in: int,
@@ -131,7 +224,8 @@ def _create_configuration(
131
224
  logger: Literal["wandb", "tensorboard", "none"],
132
225
  use_n2v2: bool = False,
133
226
  model_params: Optional[dict] = None,
134
- dataloader_params: Optional[dict] = None,
227
+ train_dataloader_params: Optional[dict[str, Any]] = None,
228
+ val_dataloader_params: Optional[dict[str, Any]] = None,
135
229
  ) -> Configuration:
136
230
  """
137
231
  Create a configuration for training N2V, CARE or Noise2Noise.
@@ -169,8 +263,10 @@ def _create_configuration(
169
263
  Whether to use N2V2, by default False.
170
264
  model_params : dict
171
265
  UNetModel parameters.
172
- dataloader_params : dict
173
- Parameters for the dataloader, see PyTorch notes, by default None.
266
+ train_dataloader_params : dict
267
+ Parameters for the training dataloader, see PyTorch notes, by default None.
268
+ val_dataloader_params : dict
269
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
174
270
 
175
271
  Returns
176
272
  -------
@@ -188,21 +284,25 @@ def _create_configuration(
188
284
  )
189
285
 
190
286
  # algorithm model
191
- algorithm_config = FCNAlgorithmConfig(
192
- algorithm=algorithm,
193
- loss=loss,
194
- model=unet_model,
195
- )
287
+ algorithm_config = {
288
+ "algorithm": algorithm,
289
+ "loss": loss,
290
+ "model": unet_model,
291
+ }
196
292
 
197
293
  # data model
198
- data = DataConfig(
199
- data_type=data_type,
200
- axes=axes,
201
- patch_size=patch_size,
202
- batch_size=batch_size,
203
- transforms=augmentations,
204
- dataloader_params=dataloader_params,
205
- )
294
+ data = {
295
+ "data_type": data_type,
296
+ "axes": axes,
297
+ "patch_size": patch_size,
298
+ "batch_size": batch_size,
299
+ "transforms": augmentations,
300
+ }
301
+ # Don't override defaults set in DataConfig class
302
+ if train_dataloader_params is not None:
303
+ data["train_dataloader_params"] = train_dataloader_params
304
+ if val_dataloader_params is not None:
305
+ data["val_dataloader_params"] = val_dataloader_params
206
306
 
207
307
  # training model
208
308
  training = TrainingConfig(
@@ -212,14 +312,14 @@ def _create_configuration(
212
312
  )
213
313
 
214
314
  # create configuration
215
- configuration = Configuration(
216
- experiment_name=experiment_name,
217
- algorithm_config=algorithm_config,
218
- data_config=data,
219
- training_config=training,
220
- )
315
+ configuration = {
316
+ "experiment_name": experiment_name,
317
+ "algorithm_config": algorithm_config,
318
+ "data_config": data,
319
+ "training_config": training,
320
+ }
221
321
 
222
- return configuration
322
+ return configuration_factory(configuration)
223
323
 
224
324
 
225
325
  # TODO reconsider naming once we officially support LVAE approaches
@@ -238,7 +338,8 @@ def _create_supervised_configuration(
238
338
  n_channels_out: Optional[int] = None,
239
339
  logger: Literal["wandb", "tensorboard", "none"] = "none",
240
340
  model_params: Optional[dict] = None,
241
- dataloader_params: Optional[dict] = None,
341
+ train_dataloader_params: Optional[dict[str, Any]] = None,
342
+ val_dataloader_params: Optional[dict[str, Any]] = None,
242
343
  ) -> Configuration:
243
344
  """
244
345
  Create a configuration for training CARE or Noise2Noise.
@@ -275,8 +376,10 @@ def _create_supervised_configuration(
275
376
  Logger to use, by default "none".
276
377
  model_params : dict, optional
277
378
  UNetModel parameters, by default {}.
278
- dataloader_params : dict, optional
279
- Parameters for the dataloader, see PyTorch notes, by default None.
379
+ train_dataloader_params : dict
380
+ Parameters for the training dataloader, see PyTorch notes, by default None.
381
+ val_dataloader_params : dict
382
+ Parameters for the validation dataloader, see PyTorch notes, by default None.
280
383
 
281
384
  Returns
282
385
  -------
@@ -306,7 +409,7 @@ def _create_supervised_configuration(
306
409
  n_channels_out = n_channels_in
307
410
 
308
411
  # augmentations
309
- transform_list = _list_augmentations(augmentations)
412
+ spatial_transform_list = _list_spatial_augmentations(augmentations)
310
413
 
311
414
  return _create_configuration(
312
415
  algorithm=algorithm,
@@ -316,14 +419,15 @@ def _create_supervised_configuration(
316
419
  patch_size=patch_size,
317
420
  batch_size=batch_size,
318
421
  num_epochs=num_epochs,
319
- augmentations=transform_list,
422
+ augmentations=spatial_transform_list,
320
423
  independent_channels=independent_channels,
321
424
  loss=loss,
322
425
  n_channels_in=n_channels_in,
323
426
  n_channels_out=n_channels_out,
324
427
  logger=logger,
325
428
  model_params=model_params,
326
- dataloader_params=dataloader_params,
429
+ train_dataloader_params=train_dataloader_params,
430
+ val_dataloader_params=val_dataloader_params,
327
431
  )
328
432
 
329
433
 
@@ -341,7 +445,8 @@ def create_care_configuration(
341
445
  n_channels_out: Optional[int] = None,
342
446
  logger: Literal["wandb", "tensorboard", "none"] = "none",
343
447
  model_params: Optional[dict] = None,
344
- dataloader_params: Optional[dict] = None,
448
+ train_dataloader_params: Optional[dict[str, Any]] = None,
449
+ val_dataloader_params: Optional[dict[str, Any]] = None,
345
450
  ) -> Configuration:
346
451
  """
347
452
  Create a configuration for training CARE.
@@ -394,8 +499,14 @@ def create_care_configuration(
394
499
  Logger to use.
395
500
  model_params : dict, default=None
396
501
  UNetModel parameters.
397
- dataloader_params : dict, optional
398
- Parameters for the dataloader, see PyTorch notes, by default None.
502
+ train_dataloader_params : dict, optional
503
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
504
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
505
+ the `GeneralDataConfig`.
506
+ val_dataloader_params : dict, optional
507
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
508
+ If left as `None`, the empty dict `{}` will be used, this is set in the
509
+ `GeneralDataConfig`.
399
510
 
400
511
  Returns
401
512
  -------
@@ -484,7 +595,8 @@ def create_care_configuration(
484
595
  n_channels_out=n_channels_out,
485
596
  logger=logger,
486
597
  model_params=model_params,
487
- dataloader_params=dataloader_params,
598
+ train_dataloader_params=train_dataloader_params,
599
+ val_dataloader_params=val_dataloader_params,
488
600
  )
489
601
 
490
602
 
@@ -502,7 +614,8 @@ def create_n2n_configuration(
502
614
  n_channels_out: Optional[int] = None,
503
615
  logger: Literal["wandb", "tensorboard", "none"] = "none",
504
616
  model_params: Optional[dict] = None,
505
- dataloader_params: Optional[dict] = None,
617
+ train_dataloader_params: Optional[dict[str, Any]] = None,
618
+ val_dataloader_params: Optional[dict[str, Any]] = None,
506
619
  ) -> Configuration:
507
620
  """
508
621
  Create a configuration for training Noise2Noise.
@@ -555,8 +668,14 @@ def create_n2n_configuration(
555
668
  Logger to use, by default "none".
556
669
  model_params : dict, optional
557
670
  UNetModel parameters, by default {}.
558
- dataloader_params : dict, optional
559
- Parameters for the dataloader, see PyTorch notes, by default None.
671
+ train_dataloader_params : dict, optional
672
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
673
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
674
+ the `GeneralDataConfig`.
675
+ val_dataloader_params : dict, optional
676
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
677
+ If left as `None`, the empty dict `{}` will be used, this is set in the
678
+ `GeneralDataConfig`.
560
679
 
561
680
  Returns
562
681
  -------
@@ -645,7 +764,8 @@ def create_n2n_configuration(
645
764
  n_channels_out=n_channels_out,
646
765
  logger=logger,
647
766
  model_params=model_params,
648
- dataloader_params=dataloader_params,
767
+ train_dataloader_params=train_dataloader_params,
768
+ val_dataloader_params=val_dataloader_params,
649
769
  )
650
770
 
651
771
 
@@ -666,7 +786,8 @@ def create_n2v_configuration(
666
786
  struct_n2v_span: int = 5,
667
787
  logger: Literal["wandb", "tensorboard", "none"] = "none",
668
788
  model_params: Optional[dict] = None,
669
- dataloader_params: Optional[dict] = None,
789
+ train_dataloader_params: Optional[dict[str, Any]] = None,
790
+ val_dataloader_params: Optional[dict[str, Any]] = None,
670
791
  ) -> Configuration:
671
792
  """
672
793
  Create a configuration for training Noise2Void.
@@ -745,8 +866,14 @@ def create_n2v_configuration(
745
866
  Logger to use, by default "none".
746
867
  model_params : dict, optional
747
868
  UNetModel parameters, by default None.
748
- dataloader_params : dict, optional
749
- Parameters for the dataloader, see PyTorch notes, by default None.
869
+ train_dataloader_params : dict, optional
870
+ Parameters for the training dataloader, see the PyTorch docs for `DataLoader`.
871
+ If left as `None`, the dict `{"shuffle": True}` will be used, this is set in
872
+ the `GeneralDataConfig`.
873
+ val_dataloader_params : dict, optional
874
+ Parameters for the validation dataloader, see PyTorch the docs for `DataLoader`.
875
+ If left as `None`, the empty dict `{}` will be used, this is set in the
876
+ `GeneralDataConfig`.
750
877
 
751
878
  Returns
752
879
  -------
@@ -853,7 +980,7 @@ def create_n2v_configuration(
853
980
  n_channels = 1
854
981
 
855
982
  # augmentations
856
- transform_list = _list_augmentations(augmentations)
983
+ spatial_transforms = _list_spatial_augmentations(augmentations)
857
984
 
858
985
  # create the N2VManipulate transform using the supplied parameters
859
986
  n2v_transform = N2VManipulateModel(
@@ -868,7 +995,7 @@ def create_n2v_configuration(
868
995
  struct_mask_axis=struct_n2v_axis,
869
996
  struct_mask_span=struct_n2v_span,
870
997
  )
871
- transform_list.append(n2v_transform)
998
+ transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
872
999
 
873
1000
  return _create_configuration(
874
1001
  algorithm="n2v",
@@ -886,5 +1013,6 @@ def create_n2v_configuration(
886
1013
  n_channels_out=n_channels,
887
1014
  logger=logger,
888
1015
  model_params=model_params,
889
- dataloader_params=dataloader_params,
1016
+ train_dataloader_params=train_dataloader_params,
1017
+ val_dataloader_params=val_dataloader_params,
890
1018
  )
@@ -0,0 +1,85 @@
1
+ """I/O functions for Configuration objects."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import yaml
7
+
8
+ from careamics.config import Configuration, configuration_factory
9
+
10
+
11
+ def load_configuration(path: Union[str, Path]) -> Configuration:
12
+ """
13
+ Load configuration from a yaml file.
14
+
15
+ Parameters
16
+ ----------
17
+ path : str or Path
18
+ Path to the configuration.
19
+
20
+ Returns
21
+ -------
22
+ Configuration
23
+ Configuration.
24
+
25
+ Raises
26
+ ------
27
+ FileNotFoundError
28
+ If the configuration file does not exist.
29
+ """
30
+ # load dictionary from yaml
31
+ if not Path(path).exists():
32
+ raise FileNotFoundError(
33
+ f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
34
+ )
35
+
36
+ dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
37
+
38
+ return configuration_factory(dictionary)
39
+
40
+
41
+ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
42
+ """
43
+ Save configuration to path.
44
+
45
+ Parameters
46
+ ----------
47
+ config : Configuration
48
+ Configuration to save.
49
+ path : str or Path
50
+ Path to a existing folder in which to save the configuration, or to a valid
51
+ configuration file path (uses a .yml or .yaml extension).
52
+
53
+ Returns
54
+ -------
55
+ Path
56
+ Path object representing the configuration.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If the path does not point to an existing directory or .yml file.
62
+ """
63
+ # make sure path is a Path object
64
+ config_path = Path(path)
65
+
66
+ # check if path is pointing to an existing directory or .yml file
67
+ if config_path.exists():
68
+ if config_path.is_dir():
69
+ config_path = Path(config_path, "config.yml")
70
+ elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
71
+ raise ValueError(
72
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
73
+ )
74
+ else:
75
+ if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
76
+ raise ValueError(
77
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
78
+ )
79
+
80
+ # save configuration as dictionary to yaml
81
+ with open(config_path, "w") as f:
82
+ # dump configuration
83
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
84
+
85
+ return config_path
@@ -0,0 +1,10 @@
1
+ """Data Pydantic configuration models."""
2
+
3
+ __all__ = [
4
+ "DataConfig",
5
+ "GeneralDataConfig",
6
+ "N2VDataConfig",
7
+ ]
8
+
9
+ from .data_model import DataConfig, GeneralDataConfig
10
+ from .n2v_data_model import N2VDataConfig