careamics 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (98) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +103 -36
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +0 -3
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/validator_utils.py +3 -3
  35. careamics/dataset/__init__.py +2 -2
  36. careamics/dataset/dataset_utils/__init__.py +3 -3
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  38. careamics/dataset/dataset_utils/file_utils.py +9 -9
  39. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  40. careamics/dataset/in_memory_dataset.py +11 -12
  41. careamics/dataset/iterable_dataset.py +4 -4
  42. careamics/dataset/iterable_pred_dataset.py +2 -1
  43. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  44. careamics/dataset/patching/random_patching.py +11 -10
  45. careamics/dataset/patching/sequential_patching.py +26 -26
  46. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  47. careamics/dataset/tiling/__init__.py +2 -2
  48. careamics/dataset/tiling/collate_tiles.py +3 -3
  49. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  50. careamics/dataset/tiling/tiled_patching.py +11 -10
  51. careamics/file_io/__init__.py +5 -5
  52. careamics/file_io/read/__init__.py +1 -1
  53. careamics/file_io/read/get_func.py +2 -2
  54. careamics/file_io/write/__init__.py +2 -2
  55. careamics/lightning/__init__.py +5 -5
  56. careamics/lightning/callbacks/__init__.py +1 -1
  57. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  58. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  60. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  61. careamics/lightning/lightning_module.py +11 -7
  62. careamics/lightning/train_data_module.py +26 -26
  63. careamics/losses/__init__.py +3 -3
  64. careamics/model_io/__init__.py +1 -1
  65. careamics/model_io/bioimage/__init__.py +1 -1
  66. careamics/model_io/bioimage/_readme_factory.py +1 -1
  67. careamics/model_io/bioimage/model_description.py +17 -17
  68. careamics/model_io/bmz_io.py +6 -17
  69. careamics/model_io/model_io_utils.py +9 -9
  70. careamics/models/layers.py +16 -16
  71. careamics/models/lvae/lvae.py +0 -3
  72. careamics/models/model_factory.py +2 -15
  73. careamics/models/unet.py +8 -8
  74. careamics/prediction_utils/__init__.py +1 -1
  75. careamics/prediction_utils/prediction_outputs.py +15 -15
  76. careamics/prediction_utils/stitch_prediction.py +6 -6
  77. careamics/transforms/__init__.py +5 -5
  78. careamics/transforms/compose.py +13 -13
  79. careamics/transforms/n2v_manipulate.py +3 -3
  80. careamics/transforms/pixel_manipulation.py +9 -9
  81. careamics/transforms/xy_random_rotate90.py +4 -4
  82. careamics/utils/__init__.py +5 -5
  83. careamics/utils/context.py +2 -1
  84. careamics/utils/logging.py +11 -10
  85. careamics/utils/torch_utils.py +7 -7
  86. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/METADATA +11 -11
  87. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/RECORD +90 -85
  88. careamics/config/architectures/custom_model.py +0 -162
  89. careamics/config/architectures/register_model.py +0 -103
  90. careamics/config/configuration_model.py +0 -603
  91. careamics/config/fcn_algorithm_model.py +0 -152
  92. careamics/config/references/__init__.py +0 -45
  93. careamics/config/references/algorithm_descriptions.py +0 -132
  94. careamics/config/references/references.py +0 -39
  95. careamics/config/transformations/transform_union.py +0 -20
  96. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/WHEEL +0 -0
  97. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  98. {careamics-0.0.5.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,354 @@
1
+ """Pydantic CAREamics configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from pprint import pformat
7
+ from typing import Any, Literal, Union
8
+
9
+ from bioimageio.spec.generic.v0_3 import CiteEntry
10
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
11
+ from typing_extensions import Self
12
+
13
+ from careamics.config.algorithms import UNetBasedAlgorithm, VAEBasedAlgorithm
14
+ from careamics.config.data import GeneralDataConfig
15
+ from careamics.config.training_model import TrainingConfig
16
+
17
+
18
+ class Configuration(BaseModel):
19
+ """
20
+ CAREamics configuration.
21
+
22
+ The configuration defines all parameters used to build and train a CAREamics model.
23
+ These parameters are validated to ensure that they are compatible with each other.
24
+
25
+ It contains three sub-configurations:
26
+
27
+ - AlgorithmModel: configuration for the algorithm training, which includes the
28
+ architecture, loss function, optimizer, and other hyperparameters.
29
+ - DataModel: configuration for the dataloader, which includes the type of data,
30
+ transformations, mean/std and other parameters.
31
+ - TrainingModel: configuration for the training, which includes the number of
32
+ epochs or the callbacks.
33
+
34
+ Attributes
35
+ ----------
36
+ experiment_name : str
37
+ Name of the experiment, used when saving logs and checkpoints.
38
+ algorithm : AlgorithmModel
39
+ Algorithm configuration.
40
+ data : DataModel
41
+ Data configuration.
42
+ training : TrainingModel
43
+ Training configuration.
44
+
45
+ Methods
46
+ -------
47
+ set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
48
+ Switch configuration between 2D and 3D.
49
+ model_dump(
50
+ exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
51
+ ) -> Dict
52
+ Export configuration to a dictionary.
53
+
54
+ Raises
55
+ ------
56
+ ValueError
57
+ Configuration parameter type validation errors.
58
+ ValueError
59
+ If the experiment name contains invalid characters or is empty.
60
+ ValueError
61
+ If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
62
+ with "Z" in data axes.
63
+ ValueError
64
+ Algorithm, data or training validation errors.
65
+
66
+ Notes
67
+ -----
68
+ We provide convenience methods to create standards configurations, for instance:
69
+ >>> from careamics.config import create_n2v_configuration
70
+ >>> config = create_n2v_configuration(
71
+ ... experiment_name="n2v_experiment",
72
+ ... data_type="array",
73
+ ... axes="YX",
74
+ ... patch_size=[64, 64],
75
+ ... batch_size=32,
76
+ ... num_epochs=100
77
+ ... )
78
+
79
+ The configuration can be exported to a dictionary using the model_dump method:
80
+ >>> config_dict = config.model_dump()
81
+
82
+ Configurations can also be exported or imported from yaml files:
83
+ >>> from careamics.config import save_configuration, load_configuration
84
+ >>> path_to_config = save_configuration(config, my_path / "config.yml")
85
+ >>> other_config = load_configuration(path_to_config)
86
+
87
+ Examples
88
+ --------
89
+ Minimum example:
90
+ >>> from careamics import configuration_factory
91
+ >>> config_dict = {
92
+ ... "experiment_name": "N2V_experiment",
93
+ ... "algorithm_config": {
94
+ ... "algorithm": "n2v",
95
+ ... "loss": "n2v",
96
+ ... "model": {
97
+ ... "architecture": "UNet",
98
+ ... },
99
+ ... },
100
+ ... "training_config": {
101
+ ... "num_epochs": 200,
102
+ ... },
103
+ ... "data_config": {
104
+ ... "data_type": "tiff",
105
+ ... "patch_size": [64, 64],
106
+ ... "axes": "SYX",
107
+ ... },
108
+ ... }
109
+ >>> config = configuration_factory(config_dict)
110
+ """
111
+
112
+ model_config = ConfigDict(
113
+ validate_assignment=True,
114
+ arbitrary_types_allowed=True,
115
+ )
116
+
117
+ # version
118
+ version: Literal["0.1.0"] = "0.1.0"
119
+ """CAREamics configuration version."""
120
+
121
+ # required parameters
122
+ experiment_name: str
123
+ """Name of the experiment, used to name logs and checkpoints."""
124
+
125
+ # Sub-configurations
126
+ algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm] = Field(
127
+ discriminator="algorithm"
128
+ )
129
+ """Algorithm configuration, holding all parameters required to configure the
130
+ model."""
131
+
132
+ data_config: GeneralDataConfig
133
+ """Data configuration, holding all parameters required to configure the training
134
+ data loader."""
135
+
136
+ training_config: TrainingConfig
137
+ """Training configuration, holding all parameters required to configure the
138
+ training process."""
139
+
140
+ @field_validator("experiment_name")
141
+ @classmethod
142
+ def no_symbol(cls, name: str) -> str:
143
+ """
144
+ Validate experiment name.
145
+
146
+ A valid experiment name is a non-empty string with only contains letters,
147
+ numbers, underscores, dashes and spaces.
148
+
149
+ Parameters
150
+ ----------
151
+ name : str
152
+ Name to validate.
153
+
154
+ Returns
155
+ -------
156
+ str
157
+ Validated name.
158
+
159
+ Raises
160
+ ------
161
+ ValueError
162
+ If the name is empty or contains invalid characters.
163
+ """
164
+ if len(name) == 0 or name.isspace():
165
+ raise ValueError("Experiment name is empty.")
166
+
167
+ # Validate using a regex that it contains only letters, numbers, underscores,
168
+ # dashes and spaces
169
+ if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
170
+ raise ValueError(
171
+ f"Experiment name contains invalid characters (got {name}). "
172
+ f"Only letters, numbers, underscores, dashes and spaces are allowed."
173
+ )
174
+
175
+ return name
176
+
177
+ @model_validator(mode="after")
178
+ def validate_3D(self: Self) -> Self:
179
+ """
180
+ Change algorithm dimensions to match data.axes.
181
+
182
+ Returns
183
+ -------
184
+ Self
185
+ Validated configuration.
186
+ """
187
+ if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
188
+ # change algorithm to 3D
189
+ self.algorithm_config.model.set_3D(True)
190
+ elif "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D():
191
+ # change algorithm to 2D
192
+ self.algorithm_config.model.set_3D(False)
193
+
194
+ return self
195
+
196
+ def __str__(self) -> str:
197
+ """
198
+ Pretty string reprensenting the configuration.
199
+
200
+ Returns
201
+ -------
202
+ str
203
+ Pretty string.
204
+ """
205
+ return pformat(self.model_dump())
206
+
207
+ def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
208
+ """
209
+ Set 3D flag and axes.
210
+
211
+ Parameters
212
+ ----------
213
+ is_3D : bool
214
+ Whether the algorithm is 3D or not.
215
+ axes : str
216
+ Axes of the data.
217
+ patch_size : list[int]
218
+ Patch size.
219
+ """
220
+ # set the flag and axes (this will not trigger validation at the config level)
221
+ self.algorithm_config.model.set_3D(is_3D)
222
+ self.data_config.set_3D(axes, patch_size)
223
+
224
+ # cheap hack: trigger validation
225
+ self.algorithm_config = self.algorithm_config
226
+
227
+ def get_algorithm_friendly_name(self) -> str:
228
+ """
229
+ Get the algorithm name.
230
+
231
+ Returns
232
+ -------
233
+ str
234
+ Algorithm name.
235
+ """
236
+ raise ValueError("Unknown algorithm.")
237
+
238
+ def get_algorithm_description(self) -> str:
239
+ """
240
+ Return a description of the algorithm.
241
+
242
+ This method is used to generate the README of the BioImage Model Zoo export.
243
+
244
+ Returns
245
+ -------
246
+ str
247
+ Description of the algorithm.
248
+ """
249
+ raise ValueError("No algorithm description available.")
250
+
251
+ def get_algorithm_citations(self) -> list[CiteEntry]:
252
+ """
253
+ Return a list of citation entries of the current algorithm.
254
+
255
+ This is used to generate the model description for the BioImage Model Zoo.
256
+
257
+ Returns
258
+ -------
259
+ List[CiteEntry]
260
+ List of citation entries.
261
+ """
262
+ raise ValueError("No algorithm citations available.")
263
+
264
+ def get_algorithm_references(self) -> str:
265
+ """
266
+ Get the algorithm references.
267
+
268
+ This is used to generate the README of the BioImage Model Zoo export.
269
+
270
+ Returns
271
+ -------
272
+ str
273
+ Algorithm references.
274
+ """
275
+ raise ValueError("No algorithm references available.")
276
+
277
+ def get_algorithm_keywords(self) -> list[str]:
278
+ """
279
+ Get algorithm keywords.
280
+
281
+ Returns
282
+ -------
283
+ list[str]
284
+ List of keywords.
285
+ """
286
+ return ["CAREamics"]
287
+
288
+ def model_dump(
289
+ self,
290
+ *,
291
+ mode: Literal["json", "python"] | str = "python",
292
+ include: Any | None = None,
293
+ exclude: Any | None = None,
294
+ context: Any | None = None,
295
+ by_alias: bool = False,
296
+ exclude_unset: bool = False,
297
+ exclude_defaults: bool = False,
298
+ exclude_none: bool = True,
299
+ round_trip: bool = False,
300
+ warnings: bool | Literal["none", "warn", "error"] = True,
301
+ serialize_as_any: bool = False,
302
+ ) -> dict:
303
+ """
304
+ Override model_dump method in order to set default values.
305
+
306
+ As opposed to the parent model_dump method, this method sets exclude none by
307
+ default.
308
+
309
+ Parameters
310
+ ----------
311
+ mode : Literal['json', 'python'] | str, default='python'
312
+ The serialization format.
313
+ include : Any | None, default=None
314
+ Attributes to include.
315
+ exclude : Any | None, default=None
316
+ Attributes to exclude.
317
+ context : Any | None, default=None
318
+ Additional context to pass to the serialization functions.
319
+ by_alias : bool, default=False
320
+ Whether to use attribute aliases.
321
+ exclude_unset : bool, default=False
322
+ Whether to exclude fields that are not set.
323
+ exclude_defaults : bool, default=False
324
+ Whether to exclude fields that have default values.
325
+ exclude_none : bool, default=true
326
+ Whether to exclude fields that have None values.
327
+ round_trip : bool, default=False
328
+ Whether to dump and load the data to ensure that the output is a valid
329
+ representation.
330
+ warnings : bool | Literal['none', 'warn', 'error'], default=True
331
+ Whether to emit warnings.
332
+ serialize_as_any : bool, default=False
333
+ Whether to serialize all types as Any.
334
+
335
+ Returns
336
+ -------
337
+ dict
338
+ Dictionary containing the model parameters.
339
+ """
340
+ dictionary = super().model_dump(
341
+ mode=mode,
342
+ include=include,
343
+ exclude=exclude,
344
+ context=context,
345
+ by_alias=by_alias,
346
+ exclude_unset=exclude_unset,
347
+ exclude_defaults=exclude_defaults,
348
+ exclude_none=exclude_none,
349
+ round_trip=round_trip,
350
+ warnings=warnings,
351
+ serialize_as_any=serialize_as_any,
352
+ )
353
+
354
+ return dictionary
@@ -2,26 +2,93 @@
2
2
 
3
3
  from typing import Any, Literal, Optional, Union
4
4
 
5
- from .architectures import UNetModel
6
- from .configuration_model import Configuration
7
- from .data_model import DataConfig
8
- from .fcn_algorithm_model import FCNAlgorithmConfig
9
- from .support import (
5
+ from pydantic import TypeAdapter
6
+
7
+ from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm
8
+ from careamics.config.architectures import UNetModel
9
+ from careamics.config.care_configuration import CAREConfiguration
10
+ from careamics.config.configuration import Configuration
11
+ from careamics.config.data import DataConfig, N2VDataConfig
12
+ from careamics.config.n2n_configuration import N2NConfiguration
13
+ from careamics.config.n2v_configuration import N2VConfiguration
14
+ from careamics.config.support import (
10
15
  SupportedArchitecture,
11
16
  SupportedPixelManipulation,
12
17
  SupportedTransform,
13
18
  )
14
- from .training_model import TrainingConfig
15
- from .transformations import (
19
+ from careamics.config.training_model import TrainingConfig
20
+ from careamics.config.transformations import (
21
+ N2V_TRANSFORMS_UNION,
22
+ SPATIAL_TRANSFORMS_UNION,
16
23
  N2VManipulateModel,
17
24
  XYFlipModel,
18
25
  XYRandomRotate90Model,
19
26
  )
20
27
 
21
28
 
22
- def _list_augmentations(
23
- augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]],
24
- ) -> list[Union[XYFlipModel, XYRandomRotate90Model]]:
29
+ def configuration_factory(
30
+ configuration: dict[str, Any]
31
+ ) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]:
32
+ """
33
+ Create a configuration for training CAREamics.
34
+
35
+ Parameters
36
+ ----------
37
+ configuration : dict
38
+ Configuration dictionary.
39
+
40
+ Returns
41
+ -------
42
+ N2VConfiguration or N2NConfiguration or CAREConfiguration
43
+ Configuration for training CAREamics.
44
+ """
45
+ adapter: TypeAdapter = TypeAdapter(
46
+ Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]
47
+ )
48
+ return adapter.validate_python(configuration)
49
+
50
+
51
+ def algorithm_factory(
52
+ algorithm: dict[str, Any]
53
+ ) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]:
54
+ """
55
+ Create an algorithm model for training CAREamics.
56
+
57
+ Parameters
58
+ ----------
59
+ algorithm : dict
60
+ Algorithm dictionary.
61
+
62
+ Returns
63
+ -------
64
+ N2VAlgorithm or N2NAlgorithm or CAREAlgorithm
65
+ Algorithm model for training CAREamics.
66
+ """
67
+ adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm])
68
+ return adapter.validate_python(algorithm)
69
+
70
+
71
+ def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]:
72
+ """
73
+ Create a data model for training CAREamics.
74
+
75
+ Parameters
76
+ ----------
77
+ data : dict
78
+ Data dictionary.
79
+
80
+ Returns
81
+ -------
82
+ DataConfig or N2VDataConfig
83
+ Data model for training CAREamics.
84
+ """
85
+ adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig])
86
+ return adapter.validate_python(data)
87
+
88
+
89
+ def _list_spatial_augmentations(
90
+ augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]],
91
+ ) -> list[SPATIAL_TRANSFORMS_UNION]:
25
92
  """
26
93
  List the augmentations to apply.
27
94
 
@@ -44,7 +111,7 @@ def _list_augmentations(
44
111
  If there are duplicate transforms.
45
112
  """
46
113
  if augmentations is None:
47
- transform_list: list[Union[XYFlipModel, XYRandomRotate90Model]] = [
114
+ transform_list: list[SPATIAL_TRANSFORMS_UNION] = [
48
115
  XYFlipModel(),
49
116
  XYRandomRotate90Model(),
50
117
  ]
@@ -123,7 +190,7 @@ def _create_configuration(
123
190
  patch_size: list[int],
124
191
  batch_size: int,
125
192
  num_epochs: int,
126
- augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]],
193
+ augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]],
127
194
  independent_channels: bool,
128
195
  loss: Literal["n2v", "mae", "mse"],
129
196
  n_channels_in: int,
@@ -188,21 +255,21 @@ def _create_configuration(
188
255
  )
189
256
 
190
257
  # algorithm model
191
- algorithm_config = FCNAlgorithmConfig(
192
- algorithm=algorithm,
193
- loss=loss,
194
- model=unet_model,
195
- )
258
+ algorithm_config = {
259
+ "algorithm": algorithm,
260
+ "loss": loss,
261
+ "model": unet_model,
262
+ }
196
263
 
197
264
  # data model
198
- data = DataConfig(
199
- data_type=data_type,
200
- axes=axes,
201
- patch_size=patch_size,
202
- batch_size=batch_size,
203
- transforms=augmentations,
204
- dataloader_params=dataloader_params,
205
- )
265
+ data = {
266
+ "data_type": data_type,
267
+ "axes": axes,
268
+ "patch_size": patch_size,
269
+ "batch_size": batch_size,
270
+ "transforms": augmentations,
271
+ "dataloader_params": dataloader_params,
272
+ }
206
273
 
207
274
  # training model
208
275
  training = TrainingConfig(
@@ -212,14 +279,14 @@ def _create_configuration(
212
279
  )
213
280
 
214
281
  # create configuration
215
- configuration = Configuration(
216
- experiment_name=experiment_name,
217
- algorithm_config=algorithm_config,
218
- data_config=data,
219
- training_config=training,
220
- )
282
+ configuration = {
283
+ "experiment_name": experiment_name,
284
+ "algorithm_config": algorithm_config,
285
+ "data_config": data,
286
+ "training_config": training,
287
+ }
221
288
 
222
- return configuration
289
+ return configuration_factory(configuration)
223
290
 
224
291
 
225
292
  # TODO reconsider naming once we officially support LVAE approaches
@@ -306,7 +373,7 @@ def _create_supervised_configuration(
306
373
  n_channels_out = n_channels_in
307
374
 
308
375
  # augmentations
309
- transform_list = _list_augmentations(augmentations)
376
+ spatial_transform_list = _list_spatial_augmentations(augmentations)
310
377
 
311
378
  return _create_configuration(
312
379
  algorithm=algorithm,
@@ -316,7 +383,7 @@ def _create_supervised_configuration(
316
383
  patch_size=patch_size,
317
384
  batch_size=batch_size,
318
385
  num_epochs=num_epochs,
319
- augmentations=transform_list,
386
+ augmentations=spatial_transform_list,
320
387
  independent_channels=independent_channels,
321
388
  loss=loss,
322
389
  n_channels_in=n_channels_in,
@@ -853,7 +920,7 @@ def create_n2v_configuration(
853
920
  n_channels = 1
854
921
 
855
922
  # augmentations
856
- transform_list = _list_augmentations(augmentations)
923
+ spatial_transforms = _list_spatial_augmentations(augmentations)
857
924
 
858
925
  # create the N2VManipulate transform using the supplied parameters
859
926
  n2v_transform = N2VManipulateModel(
@@ -868,7 +935,7 @@ def create_n2v_configuration(
868
935
  struct_mask_axis=struct_n2v_axis,
869
936
  struct_mask_span=struct_n2v_span,
870
937
  )
871
- transform_list.append(n2v_transform)
938
+ transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform]
872
939
 
873
940
  return _create_configuration(
874
941
  algorithm="n2v",
@@ -0,0 +1,85 @@
1
+ """I/O functions for Configuration objects."""
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import yaml
7
+
8
+ from careamics.config import Configuration, configuration_factory
9
+
10
+
11
+ def load_configuration(path: Union[str, Path]) -> Configuration:
12
+ """
13
+ Load configuration from a yaml file.
14
+
15
+ Parameters
16
+ ----------
17
+ path : str or Path
18
+ Path to the configuration.
19
+
20
+ Returns
21
+ -------
22
+ Configuration
23
+ Configuration.
24
+
25
+ Raises
26
+ ------
27
+ FileNotFoundError
28
+ If the configuration file does not exist.
29
+ """
30
+ # load dictionary from yaml
31
+ if not Path(path).exists():
32
+ raise FileNotFoundError(
33
+ f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
34
+ )
35
+
36
+ dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
37
+
38
+ return configuration_factory(dictionary)
39
+
40
+
41
+ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
42
+ """
43
+ Save configuration to path.
44
+
45
+ Parameters
46
+ ----------
47
+ config : Configuration
48
+ Configuration to save.
49
+ path : str or Path
50
+ Path to a existing folder in which to save the configuration, or to a valid
51
+ configuration file path (uses a .yml or .yaml extension).
52
+
53
+ Returns
54
+ -------
55
+ Path
56
+ Path object representing the configuration.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If the path does not point to an existing directory or .yml file.
62
+ """
63
+ # make sure path is a Path object
64
+ config_path = Path(path)
65
+
66
+ # check if path is pointing to an existing directory or .yml file
67
+ if config_path.exists():
68
+ if config_path.is_dir():
69
+ config_path = Path(config_path, "config.yml")
70
+ elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
71
+ raise ValueError(
72
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
73
+ )
74
+ else:
75
+ if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
76
+ raise ValueError(
77
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
78
+ )
79
+
80
+ # save configuration as dictionary to yaml
81
+ with open(config_path, "w") as f:
82
+ # dump configuration
83
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
84
+
85
+ return config_path
@@ -0,0 +1,10 @@
1
+ """Data Pydantic configuration models."""
2
+
3
+ __all__ = [
4
+ "DataConfig",
5
+ "GeneralDataConfig",
6
+ "N2VDataConfig",
7
+ ]
8
+
9
+ from .data_model import DataConfig, GeneralDataConfig
10
+ from .n2v_data_model import N2VDataConfig