careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,597 @@
1
+ """Pydantic CAREamics configuration."""
2
+ from __future__ import annotations
3
+
4
+ import re
5
+ from pathlib import Path
6
+ from pprint import pformat
7
+ from typing import Dict, List, Literal, Union
8
+
9
+ import yaml
10
+ from bioimageio.spec.generic.v0_3 import CiteEntry
11
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12
+ from typing_extensions import Self
13
+
14
+ from .algorithm_model import AlgorithmConfig
15
+ from .data_model import DataConfig
16
+ from .references import (
17
+ CARE,
18
+ CUSTOM,
19
+ N2N,
20
+ N2V,
21
+ N2V2,
22
+ STRUCT_N2V,
23
+ STRUCT_N2V2,
24
+ CAREDescription,
25
+ CARERef,
26
+ N2NDescription,
27
+ N2NRef,
28
+ N2V2Description,
29
+ N2V2Ref,
30
+ N2VDescription,
31
+ N2VRef,
32
+ StructN2V2Description,
33
+ StructN2VDescription,
34
+ StructN2VRef,
35
+ )
36
+ from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
37
+ from .training_model import TrainingConfig
38
+ from .transformations.n2v_manipulate_model import (
39
+ N2VManipulateModel,
40
+ )
41
+
42
+
43
+ class Configuration(BaseModel):
44
+ """
45
+ CAREamics configuration.
46
+
47
+ The configuration defines all parameters used to build and train a CAREamics model.
48
+ These parameters are validated to ensure that they are compatible with each other.
49
+
50
+ It contains three sub-configurations:
51
+
52
+ - AlgorithmModel: configuration for the algorithm training, which includes the
53
+ architecture, loss function, optimizer, and other hyperparameters.
54
+ - DataModel: configuration for the dataloader, which includes the type of data,
55
+ transformations, mean/std and other parameters.
56
+ - TrainingModel: configuration for the training, which includes the number of
57
+ epochs or the callbacks.
58
+
59
+ Attributes
60
+ ----------
61
+ experiment_name : str
62
+ Name of the experiment, used when saving logs and checkpoints.
63
+ algorithm : AlgorithmModel
64
+ Algorithm configuration.
65
+ data : DataModel
66
+ Data configuration.
67
+ training : TrainingModel
68
+ Training configuration.
69
+
70
+ Methods
71
+ -------
72
+ set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
73
+ Switch configuration between 2D and 3D.
74
+ set_N2V2(use_n2v2: bool) -> None
75
+ Switch N2V algorithm between N2V and N2V2.
76
+ set_structN2V(
77
+ mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int) -> None
78
+ Set StructN2V parameters.
79
+ model_dump(
80
+ exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
81
+ ) -> Dict
82
+ Export configuration to a dictionary.
83
+
84
+ Raises
85
+ ------
86
+ ValueError
87
+ Configuration parameter type validation errors.
88
+ ValueError
89
+ If the experiment name contains invalid characters or is empty.
90
+ ValueError
91
+ If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
92
+ with "Z" in data axes.
93
+ ValueError
94
+ Algorithm, data or training validation errors.
95
+
96
+ Notes
97
+ -----
98
+ We provide convenience methods to create standards configurations, for instance
99
+ for N2V, in the `careamics.config.configuration_factory` module.
100
+ >>> from careamics.config.configuration_factory import create_n2v_configuration
101
+ >>> config = create_n2v_configuration(
102
+ ... experiment_name="n2v_experiment",
103
+ ... data_type="array",
104
+ ... axes="YX",
105
+ ... patch_size=[64, 64],
106
+ ... batch_size=32,
107
+ ... num_epochs=100
108
+ ... )
109
+
110
+ The configuration can be exported to a dictionary using the model_dump method:
111
+ >>> config_dict = config.model_dump()
112
+
113
+ Configurations can also be exported or imported from yaml files:
114
+ >>> from careamics.config import save_configuration, load_configuration
115
+ >>> path_to_config = save_configuration(config, my_path / "config.yml")
116
+ >>> other_config = load_configuration(path_to_config)
117
+
118
+ Examples
119
+ --------
120
+ Minimum example:
121
+ >>> from careamics.config import Configuration
122
+ >>> config_dict = {
123
+ ... "experiment_name": "N2V_experiment",
124
+ ... "algorithm_config": {
125
+ ... "algorithm": "n2v",
126
+ ... "loss": "n2v",
127
+ ... "model": {
128
+ ... "architecture": "UNet",
129
+ ... },
130
+ ... },
131
+ ... "training_config": {
132
+ ... "num_epochs": 200,
133
+ ... },
134
+ ... "data_config": {
135
+ ... "data_type": "tiff",
136
+ ... "patch_size": [64, 64],
137
+ ... "axes": "SYX",
138
+ ... },
139
+ ... }
140
+ >>> config = Configuration(**config_dict)
141
+ """
142
+
143
+ model_config = ConfigDict(
144
+ validate_assignment=True,
145
+ set_arbitrary_types_allowed=True,
146
+ )
147
+
148
+ # version
149
+ version: Literal["0.1.0"] = Field(
150
+ default="0.1.0", description="Version of the CAREamics configuration."
151
+ )
152
+
153
+ # required parameters
154
+ experiment_name: str = Field(
155
+ ..., description="Name of the experiment, used to name logs and checkpoints."
156
+ )
157
+
158
+ # Sub-configurations
159
+ algorithm_config: AlgorithmConfig
160
+
161
+ data_config: DataConfig
162
+ training_config: TrainingConfig
163
+
164
+ @field_validator("experiment_name")
165
+ @classmethod
166
+ def no_symbol(cls, name: str) -> str:
167
+ """
168
+ Validate experiment name.
169
+
170
+ A valid experiment name is a non-empty string with only contains letters,
171
+ numbers, underscores, dashes and spaces.
172
+
173
+ Parameters
174
+ ----------
175
+ name : str
176
+ Name to validate.
177
+
178
+ Returns
179
+ -------
180
+ str
181
+ Validated name.
182
+
183
+ Raises
184
+ ------
185
+ ValueError
186
+ If the name is empty or contains invalid characters.
187
+ """
188
+ if len(name) == 0 or name.isspace():
189
+ raise ValueError("Experiment name is empty.")
190
+
191
+ # Validate using a regex that it contains only letters, numbers, underscores,
192
+ # dashes and spaces
193
+ if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
194
+ raise ValueError(
195
+ f"Experiment name contains invalid characters (got {name}). "
196
+ f"Only letters, numbers, underscores, dashes and spaces are allowed."
197
+ )
198
+
199
+ return name
200
+
201
+ @model_validator(mode="after")
202
+ def validate_3D(self: Self) -> Self:
203
+ """
204
+ Change algorithm dimensions to match data.axes.
205
+
206
+ Only for non-custom algorithms.
207
+
208
+ Returns
209
+ -------
210
+ Self
211
+ Validated configuration.
212
+ """
213
+ if self.algorithm_config.algorithm != SupportedAlgorithm.CUSTOM:
214
+ if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
215
+ # change algorithm to 3D
216
+ self.algorithm_config.model.set_3D(True)
217
+ elif (
218
+ "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D()
219
+ ):
220
+ # change algorithm to 2D
221
+ self.algorithm_config.model.set_3D(False)
222
+
223
+ return self
224
+
225
+ @model_validator(mode="after")
226
+ def validate_algorithm_and_data(self: Self) -> Self:
227
+ """
228
+ Validate algorithm and data compatibility.
229
+
230
+ In particular, the validation does the following:
231
+
232
+ - If N2V is used, it enforces the presence of N2V_Maniuplate in the transforms
233
+ - If N2V2 is used, it enforces the correct manipulation strategy
234
+
235
+ Returns
236
+ -------
237
+ Self
238
+ Validated configuration.
239
+ """
240
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
241
+ # if we have a list of transform (as opposed to Compose)
242
+ if self.data_config.has_transform_list():
243
+ # missing N2V_MANIPULATE
244
+ if not self.data_config.has_n2v_manipulate():
245
+ self.data_config.transforms.append(
246
+ N2VManipulateModel(
247
+ name=SupportedTransform.N2V_MANIPULATE.value,
248
+ )
249
+ )
250
+
251
+ median = SupportedPixelManipulation.MEDIAN.value
252
+ uniform = SupportedPixelManipulation.UNIFORM.value
253
+ strategy = median if self.algorithm_config.model.n2v2 else uniform
254
+ self.data_config.set_N2V2_strategy(strategy)
255
+ else:
256
+ # if we have a list of transform, remove N2V manipulate if present
257
+ if self.data_config.has_transform_list():
258
+ if self.data_config.has_n2v_manipulate():
259
+ self.data_config.remove_n2v_manipulate()
260
+
261
+ return self
262
+
263
+ def __str__(self) -> str:
264
+ """
265
+ Pretty string reprensenting the configuration.
266
+
267
+ Returns
268
+ -------
269
+ str
270
+ Pretty string.
271
+ """
272
+ return pformat(self.model_dump())
273
+
274
+ def set_3D(self, is_3D: bool, axes: str, patch_size: List[int]) -> None:
275
+ """
276
+ Set 3D flag and axes.
277
+
278
+ Parameters
279
+ ----------
280
+ is_3D : bool
281
+ Whether the algorithm is 3D or not.
282
+ axes : str
283
+ Axes of the data.
284
+ patch_size : List[int]
285
+ Patch size.
286
+ """
287
+ # set the flag and axes (this will not trigger validation at the config level)
288
+ self.algorithm_config.model.set_3D(is_3D)
289
+ self.data_config.set_3D(axes, patch_size)
290
+
291
+ # cheap hack: trigger validation
292
+ self.algorithm_config = self.algorithm_config
293
+
294
+ def set_N2V2(self, use_n2v2: bool) -> None:
295
+ """
296
+ Switch N2V algorithm between N2V and N2V2.
297
+
298
+ Parameters
299
+ ----------
300
+ use_n2v2 : bool
301
+ Whether to use N2V2 or not.
302
+
303
+ Raises
304
+ ------
305
+ ValueError
306
+ If the algorithm is not N2V.
307
+ """
308
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
309
+ self.algorithm_config.model.n2v2 = use_n2v2
310
+ strategy = (
311
+ SupportedPixelManipulation.MEDIAN.value
312
+ if use_n2v2
313
+ else SupportedPixelManipulation.UNIFORM.value
314
+ )
315
+ self.data_config.set_N2V2_strategy(strategy)
316
+ else:
317
+ raise ValueError("N2V2 can only be set for N2V algorithm.")
318
+
319
+ def set_structN2V(
320
+ self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
321
+ ) -> None:
322
+ """
323
+ Set StructN2V parameters.
324
+
325
+ Parameters
326
+ ----------
327
+ mask_axis : Literal["horizontal", "vertical", "none"]
328
+ Axis of the structural mask.
329
+ mask_span : int
330
+ Span of the structural mask.
331
+ """
332
+ self.data_config.set_structN2V_mask(mask_axis, mask_span)
333
+
334
+ def get_algorithm_flavour(self) -> str:
335
+ """
336
+ Get the algorithm name.
337
+
338
+ Returns
339
+ -------
340
+ str
341
+ Algorithm name.
342
+ """
343
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
344
+ use_n2v2 = self.algorithm_config.model.n2v2
345
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
346
+
347
+ # return the n2v flavour
348
+ if use_n2v2 and use_structN2V:
349
+ return STRUCT_N2V2
350
+ elif use_n2v2:
351
+ return N2V2
352
+ elif use_structN2V:
353
+ return STRUCT_N2V
354
+ else:
355
+ return N2V
356
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
357
+ return N2N
358
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
359
+ return CARE
360
+ else:
361
+ return CUSTOM
362
+
363
+ def get_algorithm_description(self) -> str:
364
+ """
365
+ Return a description of the algorithm.
366
+
367
+ This method is used to generate the README of the BioImage Model Zoo export.
368
+
369
+ Returns
370
+ -------
371
+ str
372
+ Description of the algorithm.
373
+ """
374
+ algorithm_flavour = self.get_algorithm_flavour()
375
+
376
+ if algorithm_flavour == CUSTOM:
377
+ return f"Custom algorithm, named {self.algorithm_config.model.name}"
378
+ else: # currently only N2V flavours
379
+ if algorithm_flavour == N2V:
380
+ return N2VDescription().description
381
+ elif algorithm_flavour == N2V2:
382
+ return N2V2Description().description
383
+ elif algorithm_flavour == STRUCT_N2V:
384
+ return StructN2VDescription().description
385
+ elif algorithm_flavour == STRUCT_N2V2:
386
+ return StructN2V2Description().description
387
+ elif algorithm_flavour == N2N:
388
+ return N2NDescription().description
389
+ elif algorithm_flavour == CARE:
390
+ return CAREDescription().description
391
+
392
+ return ""
393
+
394
+ def get_algorithm_citations(self) -> List[CiteEntry]:
395
+ """
396
+ Return a list of citation entries of the current algorithm.
397
+
398
+ This is used to generate the model description for the BioImage Model Zoo.
399
+
400
+ Returns
401
+ -------
402
+ List[CiteEntry]
403
+ List of citation entries.
404
+ """
405
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
406
+ use_n2v2 = self.algorithm_config.model.n2v2
407
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
408
+
409
+ # return the (struct)N2V(2) references
410
+ if use_n2v2 and use_structN2V:
411
+ return [N2VRef, N2V2Ref, StructN2VRef]
412
+ elif use_n2v2:
413
+ return [N2VRef, N2V2Ref]
414
+ elif use_structN2V:
415
+ return [N2VRef, StructN2VRef]
416
+ else:
417
+ return [N2VRef]
418
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
419
+ return [N2NRef]
420
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
421
+ return [CARERef]
422
+
423
+ raise ValueError("Citation not available for custom algorithm.")
424
+
425
+ def get_algorithm_references(self) -> str:
426
+ """
427
+ Get the algorithm references.
428
+
429
+ This is used to generate the README of the BioImage Model Zoo export.
430
+
431
+ Returns
432
+ -------
433
+ str
434
+ Algorithm references.
435
+ """
436
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
437
+ use_n2v2 = self.algorithm_config.model.n2v2
438
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
439
+
440
+ references = [
441
+ N2VRef.text + " doi: " + N2VRef.doi,
442
+ N2V2Ref.text + " doi: " + N2V2Ref.doi,
443
+ StructN2VRef.text + " doi: " + StructN2VRef.doi,
444
+ ]
445
+
446
+ # return the (struct)N2V(2) references
447
+ if use_n2v2 and use_structN2V:
448
+ return "".join(references)
449
+ elif use_n2v2:
450
+ references.pop(-1)
451
+ return "".join(references)
452
+ elif use_structN2V:
453
+ references.pop(-2)
454
+ return "".join(references)
455
+ else:
456
+ return references[0]
457
+
458
+ return ""
459
+
460
+ def get_algorithm_keywords(self) -> List[str]:
461
+ """
462
+ Get algorithm keywords.
463
+
464
+ Returns
465
+ -------
466
+ List[str]
467
+ List of keywords.
468
+ """
469
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
470
+ use_n2v2 = self.algorithm_config.model.n2v2
471
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
472
+
473
+ keywords = [
474
+ "denoising",
475
+ "restoration",
476
+ "UNet",
477
+ "3D" if "Z" in self.data_config.axes else "2D",
478
+ "CAREamics",
479
+ "pytorch",
480
+ N2V,
481
+ ]
482
+
483
+ if use_n2v2:
484
+ keywords.append(N2V2)
485
+ if use_structN2V:
486
+ keywords.append(STRUCT_N2V)
487
+ else:
488
+ keywords = ["CAREamics"]
489
+
490
+ return keywords
491
+
492
+ def model_dump(
493
+ self,
494
+ exclude_defaults: bool = False,
495
+ exclude_none: bool = True,
496
+ **kwargs: Dict,
497
+ ) -> Dict:
498
+ """
499
+ Override model_dump method in order to set default values.
500
+
501
+ Parameters
502
+ ----------
503
+ exclude_defaults : bool, optional
504
+ Whether to exclude fields with default values or not, by default
505
+ True.
506
+ exclude_none : bool, optional
507
+ Whether to exclude fields with None values or not, by default True.
508
+ **kwargs : Dict
509
+ Keyword arguments.
510
+
511
+ Returns
512
+ -------
513
+ dict
514
+ Dictionary containing the model parameters.
515
+ """
516
+ dictionary = super().model_dump(
517
+ exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs
518
+ )
519
+
520
+ return dictionary
521
+
522
+
523
+ def load_configuration(path: Union[str, Path]) -> Configuration:
524
+ """
525
+ Load configuration from a yaml file.
526
+
527
+ Parameters
528
+ ----------
529
+ path : Union[str, Path]
530
+ Path to the configuration.
531
+
532
+ Returns
533
+ -------
534
+ Configuration
535
+ Configuration.
536
+
537
+ Raises
538
+ ------
539
+ FileNotFoundError
540
+ If the configuration file does not exist.
541
+ """
542
+ # load dictionary from yaml
543
+ if not Path(path).exists():
544
+ raise FileNotFoundError(
545
+ f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
546
+ )
547
+
548
+ dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
549
+
550
+ return Configuration(**dictionary)
551
+
552
+
553
+ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
554
+ """
555
+ Save configuration to path.
556
+
557
+ Parameters
558
+ ----------
559
+ config : Configuration
560
+ Configuration to save.
561
+ path : Union[str, Path]
562
+ Path to a existing folder in which to save the configuration or to an existing
563
+ configuration file.
564
+
565
+ Returns
566
+ -------
567
+ Path
568
+ Path object representing the configuration.
569
+
570
+ Raises
571
+ ------
572
+ ValueError
573
+ If the path does not point to an existing directory or .yml file.
574
+ """
575
+ # make sure path is a Path object
576
+ config_path = Path(path)
577
+
578
+ # check if path is pointing to an existing directory or .yml file
579
+ if config_path.exists():
580
+ if config_path.is_dir():
581
+ config_path = Path(config_path, "config.yml")
582
+ elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
583
+ raise ValueError(
584
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
585
+ )
586
+ else:
587
+ if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
588
+ raise ValueError(
589
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
590
+ )
591
+
592
+ # save configuration as dictionary to yaml
593
+ with open(config_path, "w") as f:
594
+ # dump configuration
595
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
596
+
597
+ return config_path