careamics 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (141) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +726 -0
  3. careamics/config/__init__.py +35 -0
  4. careamics/config/algorithm_model.py +162 -0
  5. careamics/config/architectures/__init__.py +17 -0
  6. careamics/config/architectures/architecture_model.py +37 -0
  7. careamics/config/architectures/custom_model.py +159 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/architectures/vae_model.py +42 -0
  11. careamics/config/callback_model.py +123 -0
  12. careamics/config/configuration_factory.py +575 -0
  13. careamics/config/configuration_model.py +600 -0
  14. careamics/config/data_model.py +502 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/optimizer_models.py +187 -0
  17. careamics/config/references/__init__.py +45 -0
  18. careamics/config/references/algorithm_descriptions.py +132 -0
  19. careamics/config/references/references.py +39 -0
  20. careamics/config/support/__init__.py +31 -0
  21. careamics/config/support/supported_activations.py +26 -0
  22. careamics/config/support/supported_algorithms.py +20 -0
  23. careamics/config/support/supported_architectures.py +20 -0
  24. careamics/config/support/supported_data.py +109 -0
  25. careamics/config/support/supported_loggers.py +10 -0
  26. careamics/config/support/supported_losses.py +27 -0
  27. careamics/config/support/supported_optimizers.py +57 -0
  28. careamics/config/support/supported_pixel_manipulations.py +15 -0
  29. careamics/config/support/supported_struct_axis.py +21 -0
  30. careamics/config/support/supported_transforms.py +11 -0
  31. careamics/config/tile_information.py +65 -0
  32. careamics/config/training_model.py +72 -0
  33. careamics/config/transformations/__init__.py +15 -0
  34. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  35. careamics/config/transformations/normalize_model.py +60 -0
  36. careamics/config/transformations/transform_model.py +45 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  39. careamics/config/validators/__init__.py +5 -0
  40. careamics/config/validators/validator_utils.py +101 -0
  41. careamics/conftest.py +39 -0
  42. careamics/dataset/__init__.py +17 -0
  43. careamics/dataset/dataset_utils/__init__.py +19 -0
  44. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  45. careamics/dataset/dataset_utils/file_utils.py +141 -0
  46. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  47. careamics/dataset/dataset_utils/running_stats.py +186 -0
  48. careamics/dataset/in_memory_dataset.py +310 -0
  49. careamics/dataset/in_memory_pred_dataset.py +88 -0
  50. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  51. careamics/dataset/iterable_dataset.py +295 -0
  52. careamics/dataset/iterable_pred_dataset.py +122 -0
  53. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  54. careamics/dataset/patching/__init__.py +1 -0
  55. careamics/dataset/patching/patching.py +299 -0
  56. careamics/dataset/patching/random_patching.py +201 -0
  57. careamics/dataset/patching/sequential_patching.py +212 -0
  58. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  59. careamics/dataset/tiling/__init__.py +10 -0
  60. careamics/dataset/tiling/collate_tiles.py +33 -0
  61. careamics/dataset/tiling/tiled_patching.py +164 -0
  62. careamics/dataset/zarr_dataset.py +151 -0
  63. careamics/file_io/__init__.py +15 -0
  64. careamics/file_io/read/__init__.py +12 -0
  65. careamics/file_io/read/get_func.py +56 -0
  66. careamics/file_io/read/tiff.py +58 -0
  67. careamics/file_io/read/zarr.py +60 -0
  68. careamics/file_io/write/__init__.py +15 -0
  69. careamics/file_io/write/get_func.py +63 -0
  70. careamics/file_io/write/tiff.py +40 -0
  71. careamics/lightning/__init__.py +17 -0
  72. careamics/lightning/callbacks/__init__.py +11 -0
  73. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  74. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  75. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  76. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  77. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  79. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  80. careamics/lightning/lightning_module.py +276 -0
  81. careamics/lightning/predict_data_module.py +333 -0
  82. careamics/lightning/train_data_module.py +680 -0
  83. careamics/losses/__init__.py +5 -0
  84. careamics/losses/loss_factory.py +49 -0
  85. careamics/losses/losses.py +98 -0
  86. careamics/lvae_training/__init__.py +0 -0
  87. careamics/lvae_training/data_modules.py +1220 -0
  88. careamics/lvae_training/data_utils.py +618 -0
  89. careamics/lvae_training/eval_utils.py +905 -0
  90. careamics/lvae_training/get_config.py +84 -0
  91. careamics/lvae_training/lightning_module.py +701 -0
  92. careamics/lvae_training/metrics.py +214 -0
  93. careamics/lvae_training/train_lvae.py +339 -0
  94. careamics/lvae_training/train_utils.py +121 -0
  95. careamics/model_io/__init__.py +7 -0
  96. careamics/model_io/bioimage/__init__.py +11 -0
  97. careamics/model_io/bioimage/_readme_factory.py +121 -0
  98. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  99. careamics/model_io/bioimage/model_description.py +327 -0
  100. careamics/model_io/bmz_io.py +233 -0
  101. careamics/model_io/model_io_utils.py +83 -0
  102. careamics/models/__init__.py +7 -0
  103. careamics/models/activation.py +37 -0
  104. careamics/models/layers.py +493 -0
  105. careamics/models/lvae/__init__.py +0 -0
  106. careamics/models/lvae/layers.py +1998 -0
  107. careamics/models/lvae/likelihoods.py +312 -0
  108. careamics/models/lvae/lvae.py +985 -0
  109. careamics/models/lvae/noise_models.py +409 -0
  110. careamics/models/lvae/utils.py +395 -0
  111. careamics/models/model_factory.py +52 -0
  112. careamics/models/unet.py +443 -0
  113. careamics/prediction_utils/__init__.py +10 -0
  114. careamics/prediction_utils/prediction_outputs.py +135 -0
  115. careamics/prediction_utils/stitch_prediction.py +98 -0
  116. careamics/transforms/__init__.py +20 -0
  117. careamics/transforms/compose.py +107 -0
  118. careamics/transforms/n2v_manipulate.py +146 -0
  119. careamics/transforms/normalize.py +243 -0
  120. careamics/transforms/pixel_manipulation.py +407 -0
  121. careamics/transforms/struct_mask_parameters.py +20 -0
  122. careamics/transforms/transform.py +24 -0
  123. careamics/transforms/tta.py +88 -0
  124. careamics/transforms/xy_flip.py +123 -0
  125. careamics/transforms/xy_random_rotate90.py +101 -0
  126. careamics/utils/__init__.py +19 -0
  127. careamics/utils/autocorrelation.py +40 -0
  128. careamics/utils/base_enum.py +60 -0
  129. careamics/utils/context.py +66 -0
  130. careamics/utils/logging.py +322 -0
  131. careamics/utils/metrics.py +115 -0
  132. careamics/utils/path_utils.py +26 -0
  133. careamics/utils/ram.py +15 -0
  134. careamics/utils/receptive_field.py +108 -0
  135. careamics/utils/torch_utils.py +127 -0
  136. careamics-0.0.2.dist-info/METADATA +78 -0
  137. careamics-0.0.2.dist-info/RECORD +140 -0
  138. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/WHEEL +1 -1
  139. {careamics-0.0.1.dist-info → careamics-0.0.2.dist-info}/licenses/LICENSE +1 -1
  140. careamics-0.0.1.dist-info/METADATA +0 -46
  141. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,600 @@
1
+ """Pydantic CAREamics configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from pathlib import Path
7
+ from pprint import pformat
8
+ from typing import Literal, Union
9
+
10
+ import yaml
11
+ from bioimageio.spec.generic.v0_3 import CiteEntry
12
+ from pydantic import BaseModel, ConfigDict, field_validator, model_validator
13
+ from typing_extensions import Self
14
+
15
+ from .algorithm_model import AlgorithmConfig
16
+ from .data_model import DataConfig
17
+ from .references import (
18
+ CARE,
19
+ CUSTOM,
20
+ N2N,
21
+ N2V,
22
+ N2V2,
23
+ STRUCT_N2V,
24
+ STRUCT_N2V2,
25
+ CAREDescription,
26
+ CARERef,
27
+ N2NDescription,
28
+ N2NRef,
29
+ N2V2Description,
30
+ N2V2Ref,
31
+ N2VDescription,
32
+ N2VRef,
33
+ StructN2V2Description,
34
+ StructN2VDescription,
35
+ StructN2VRef,
36
+ )
37
+ from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
38
+ from .training_model import TrainingConfig
39
+ from .transformations.n2v_manipulate_model import (
40
+ N2VManipulateModel,
41
+ )
42
+
43
+
44
+ class Configuration(BaseModel):
45
+ """
46
+ CAREamics configuration.
47
+
48
+ The configuration defines all parameters used to build and train a CAREamics model.
49
+ These parameters are validated to ensure that they are compatible with each other.
50
+
51
+ It contains three sub-configurations:
52
+
53
+ - AlgorithmModel: configuration for the algorithm training, which includes the
54
+ architecture, loss function, optimizer, and other hyperparameters.
55
+ - DataModel: configuration for the dataloader, which includes the type of data,
56
+ transformations, mean/std and other parameters.
57
+ - TrainingModel: configuration for the training, which includes the number of
58
+ epochs or the callbacks.
59
+
60
+ Attributes
61
+ ----------
62
+ experiment_name : str
63
+ Name of the experiment, used when saving logs and checkpoints.
64
+ algorithm : AlgorithmModel
65
+ Algorithm configuration.
66
+ data : DataModel
67
+ Data configuration.
68
+ training : TrainingModel
69
+ Training configuration.
70
+
71
+ Methods
72
+ -------
73
+ set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
74
+ Switch configuration between 2D and 3D.
75
+ set_N2V2(use_n2v2: bool) -> None
76
+ Switch N2V algorithm between N2V and N2V2.
77
+ set_structN2V(
78
+ mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int) -> None
79
+ Set StructN2V parameters.
80
+ model_dump(
81
+ exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
82
+ ) -> Dict
83
+ Export configuration to a dictionary.
84
+
85
+ Raises
86
+ ------
87
+ ValueError
88
+ Configuration parameter type validation errors.
89
+ ValueError
90
+ If the experiment name contains invalid characters or is empty.
91
+ ValueError
92
+ If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
93
+ with "Z" in data axes.
94
+ ValueError
95
+ Algorithm, data or training validation errors.
96
+
97
+ Notes
98
+ -----
99
+ We provide convenience methods to create standards configurations, for instance
100
+ for N2V, in the `careamics.config.configuration_factory` module.
101
+ >>> from careamics.config.configuration_factory import create_n2v_configuration
102
+ >>> config = create_n2v_configuration(
103
+ ... experiment_name="n2v_experiment",
104
+ ... data_type="array",
105
+ ... axes="YX",
106
+ ... patch_size=[64, 64],
107
+ ... batch_size=32,
108
+ ... num_epochs=100
109
+ ... )
110
+
111
+ The configuration can be exported to a dictionary using the model_dump method:
112
+ >>> config_dict = config.model_dump()
113
+
114
+ Configurations can also be exported or imported from yaml files:
115
+ >>> from careamics.config import save_configuration, load_configuration
116
+ >>> path_to_config = save_configuration(config, my_path / "config.yml")
117
+ >>> other_config = load_configuration(path_to_config)
118
+
119
+ Examples
120
+ --------
121
+ Minimum example:
122
+ >>> from careamics.config import Configuration
123
+ >>> config_dict = {
124
+ ... "experiment_name": "N2V_experiment",
125
+ ... "algorithm_config": {
126
+ ... "algorithm": "n2v",
127
+ ... "loss": "n2v",
128
+ ... "model": {
129
+ ... "architecture": "UNet",
130
+ ... },
131
+ ... },
132
+ ... "training_config": {
133
+ ... "num_epochs": 200,
134
+ ... },
135
+ ... "data_config": {
136
+ ... "data_type": "tiff",
137
+ ... "patch_size": [64, 64],
138
+ ... "axes": "SYX",
139
+ ... },
140
+ ... }
141
+ >>> config = Configuration(**config_dict)
142
+ """
143
+
144
+ model_config = ConfigDict(
145
+ validate_assignment=True,
146
+ set_arbitrary_types_allowed=True,
147
+ )
148
+
149
+ # version
150
+ version: Literal["0.1.0"] = "0.1.0"
151
+ """CAREamics configuration version."""
152
+
153
+ # required parameters
154
+ experiment_name: str
155
+ """Name of the experiment, used to name logs and checkpoints."""
156
+
157
+ # Sub-configurations
158
+ algorithm_config: AlgorithmConfig
159
+ """Algorithm configuration, holding all parameters required to configure the
160
+ model."""
161
+
162
+ data_config: DataConfig
163
+ """Data configuration, holding all parameters required to configure the training
164
+ data loader."""
165
+
166
+ training_config: TrainingConfig
167
+ """Training configuration, holding all parameters required to configure the
168
+ training process."""
169
+
170
+ @field_validator("experiment_name")
171
+ @classmethod
172
+ def no_symbol(cls, name: str) -> str:
173
+ """
174
+ Validate experiment name.
175
+
176
+ A valid experiment name is a non-empty string with only contains letters,
177
+ numbers, underscores, dashes and spaces.
178
+
179
+ Parameters
180
+ ----------
181
+ name : str
182
+ Name to validate.
183
+
184
+ Returns
185
+ -------
186
+ str
187
+ Validated name.
188
+
189
+ Raises
190
+ ------
191
+ ValueError
192
+ If the name is empty or contains invalid characters.
193
+ """
194
+ if len(name) == 0 or name.isspace():
195
+ raise ValueError("Experiment name is empty.")
196
+
197
+ # Validate using a regex that it contains only letters, numbers, underscores,
198
+ # dashes and spaces
199
+ if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
200
+ raise ValueError(
201
+ f"Experiment name contains invalid characters (got {name}). "
202
+ f"Only letters, numbers, underscores, dashes and spaces are allowed."
203
+ )
204
+
205
+ return name
206
+
207
+ @model_validator(mode="after")
208
+ def validate_3D(self: Self) -> Self:
209
+ """
210
+ Change algorithm dimensions to match data.axes.
211
+
212
+ Only for non-custom algorithms.
213
+
214
+ Returns
215
+ -------
216
+ Self
217
+ Validated configuration.
218
+ """
219
+ if self.algorithm_config.algorithm != SupportedAlgorithm.CUSTOM:
220
+ if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
221
+ # change algorithm to 3D
222
+ self.algorithm_config.model.set_3D(True)
223
+ elif (
224
+ "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D()
225
+ ):
226
+ # change algorithm to 2D
227
+ self.algorithm_config.model.set_3D(False)
228
+
229
+ return self
230
+
231
+ @model_validator(mode="after")
232
+ def validate_algorithm_and_data(self: Self) -> Self:
233
+ """
234
+ Validate algorithm and data compatibility.
235
+
236
+ In particular, the validation does the following:
237
+
238
+ - If N2V is used, it enforces the presence of N2V_Maniuplate in the transforms
239
+ - If N2V2 is used, it enforces the correct manipulation strategy
240
+
241
+ Returns
242
+ -------
243
+ Self
244
+ Validated configuration.
245
+ """
246
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
247
+ # missing N2V_MANIPULATE
248
+ if not self.data_config.has_n2v_manipulate():
249
+ self.data_config.transforms.append(
250
+ N2VManipulateModel(
251
+ name=SupportedTransform.N2V_MANIPULATE.value,
252
+ )
253
+ )
254
+
255
+ median = SupportedPixelManipulation.MEDIAN.value
256
+ uniform = SupportedPixelManipulation.UNIFORM.value
257
+ strategy = median if self.algorithm_config.model.n2v2 else uniform
258
+ self.data_config.set_N2V2_strategy(strategy)
259
+ else:
260
+ # remove N2V manipulate if present
261
+ if self.data_config.has_n2v_manipulate():
262
+ self.data_config.remove_n2v_manipulate()
263
+
264
+ return self
265
+
266
+ def __str__(self) -> str:
267
+ """
268
+ Pretty string reprensenting the configuration.
269
+
270
+ Returns
271
+ -------
272
+ str
273
+ Pretty string.
274
+ """
275
+ return pformat(self.model_dump())
276
+
277
+ def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
278
+ """
279
+ Set 3D flag and axes.
280
+
281
+ Parameters
282
+ ----------
283
+ is_3D : bool
284
+ Whether the algorithm is 3D or not.
285
+ axes : str
286
+ Axes of the data.
287
+ patch_size : list[int]
288
+ Patch size.
289
+ """
290
+ # set the flag and axes (this will not trigger validation at the config level)
291
+ self.algorithm_config.model.set_3D(is_3D)
292
+ self.data_config.set_3D(axes, patch_size)
293
+
294
+ # cheap hack: trigger validation
295
+ self.algorithm_config = self.algorithm_config
296
+
297
+ def set_N2V2(self, use_n2v2: bool) -> None:
298
+ """
299
+ Switch N2V algorithm between N2V and N2V2.
300
+
301
+ Parameters
302
+ ----------
303
+ use_n2v2 : bool
304
+ Whether to use N2V2 or not.
305
+
306
+ Raises
307
+ ------
308
+ ValueError
309
+ If the algorithm is not N2V.
310
+ """
311
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
312
+ self.algorithm_config.model.n2v2 = use_n2v2
313
+ strategy = (
314
+ SupportedPixelManipulation.MEDIAN.value
315
+ if use_n2v2
316
+ else SupportedPixelManipulation.UNIFORM.value
317
+ )
318
+ self.data_config.set_N2V2_strategy(strategy)
319
+ else:
320
+ raise ValueError("N2V2 can only be set for N2V algorithm.")
321
+
322
+ def set_structN2V(
323
+ self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
324
+ ) -> None:
325
+ """
326
+ Set StructN2V parameters.
327
+
328
+ Parameters
329
+ ----------
330
+ mask_axis : Literal["horizontal", "vertical", "none"]
331
+ Axis of the structural mask.
332
+ mask_span : int
333
+ Span of the structural mask.
334
+ """
335
+ self.data_config.set_structN2V_mask(mask_axis, mask_span)
336
+
337
+ def get_algorithm_flavour(self) -> str:
338
+ """
339
+ Get the algorithm name.
340
+
341
+ Returns
342
+ -------
343
+ str
344
+ Algorithm name.
345
+ """
346
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
347
+ use_n2v2 = self.algorithm_config.model.n2v2
348
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
349
+
350
+ # return the n2v flavour
351
+ if use_n2v2 and use_structN2V:
352
+ return STRUCT_N2V2
353
+ elif use_n2v2:
354
+ return N2V2
355
+ elif use_structN2V:
356
+ return STRUCT_N2V
357
+ else:
358
+ return N2V
359
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
360
+ return N2N
361
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
362
+ return CARE
363
+ else:
364
+ return CUSTOM
365
+
366
+ def get_algorithm_description(self) -> str:
367
+ """
368
+ Return a description of the algorithm.
369
+
370
+ This method is used to generate the README of the BioImage Model Zoo export.
371
+
372
+ Returns
373
+ -------
374
+ str
375
+ Description of the algorithm.
376
+ """
377
+ algorithm_flavour = self.get_algorithm_flavour()
378
+
379
+ if algorithm_flavour == CUSTOM:
380
+ return f"Custom algorithm, named {self.algorithm_config.model.name}"
381
+ else: # currently only N2V flavours
382
+ if algorithm_flavour == N2V:
383
+ return N2VDescription().description
384
+ elif algorithm_flavour == N2V2:
385
+ return N2V2Description().description
386
+ elif algorithm_flavour == STRUCT_N2V:
387
+ return StructN2VDescription().description
388
+ elif algorithm_flavour == STRUCT_N2V2:
389
+ return StructN2V2Description().description
390
+ elif algorithm_flavour == N2N:
391
+ return N2NDescription().description
392
+ elif algorithm_flavour == CARE:
393
+ return CAREDescription().description
394
+
395
+ return ""
396
+
397
+ def get_algorithm_citations(self) -> list[CiteEntry]:
398
+ """
399
+ Return a list of citation entries of the current algorithm.
400
+
401
+ This is used to generate the model description for the BioImage Model Zoo.
402
+
403
+ Returns
404
+ -------
405
+ List[CiteEntry]
406
+ List of citation entries.
407
+ """
408
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
409
+ use_n2v2 = self.algorithm_config.model.n2v2
410
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
411
+
412
+ # return the (struct)N2V(2) references
413
+ if use_n2v2 and use_structN2V:
414
+ return [N2VRef, N2V2Ref, StructN2VRef]
415
+ elif use_n2v2:
416
+ return [N2VRef, N2V2Ref]
417
+ elif use_structN2V:
418
+ return [N2VRef, StructN2VRef]
419
+ else:
420
+ return [N2VRef]
421
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
422
+ return [N2NRef]
423
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
424
+ return [CARERef]
425
+
426
+ raise ValueError("Citation not available for custom algorithm.")
427
+
428
+ def get_algorithm_references(self) -> str:
429
+ """
430
+ Get the algorithm references.
431
+
432
+ This is used to generate the README of the BioImage Model Zoo export.
433
+
434
+ Returns
435
+ -------
436
+ str
437
+ Algorithm references.
438
+ """
439
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
440
+ use_n2v2 = self.algorithm_config.model.n2v2
441
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
442
+
443
+ references = [
444
+ N2VRef.text + " doi: " + N2VRef.doi,
445
+ N2V2Ref.text + " doi: " + N2V2Ref.doi,
446
+ StructN2VRef.text + " doi: " + StructN2VRef.doi,
447
+ ]
448
+
449
+ # return the (struct)N2V(2) references
450
+ if use_n2v2 and use_structN2V:
451
+ return "".join(references)
452
+ elif use_n2v2:
453
+ references.pop(-1)
454
+ return "".join(references)
455
+ elif use_structN2V:
456
+ references.pop(-2)
457
+ return "".join(references)
458
+ else:
459
+ return references[0]
460
+
461
+ return ""
462
+
463
+ def get_algorithm_keywords(self) -> list[str]:
464
+ """
465
+ Get algorithm keywords.
466
+
467
+ Returns
468
+ -------
469
+ list[str]
470
+ List of keywords.
471
+ """
472
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
473
+ use_n2v2 = self.algorithm_config.model.n2v2
474
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
475
+
476
+ keywords = [
477
+ "denoising",
478
+ "restoration",
479
+ "UNet",
480
+ "3D" if "Z" in self.data_config.axes else "2D",
481
+ "CAREamics",
482
+ "pytorch",
483
+ N2V,
484
+ ]
485
+
486
+ if use_n2v2:
487
+ keywords.append(N2V2)
488
+ if use_structN2V:
489
+ keywords.append(STRUCT_N2V)
490
+ else:
491
+ keywords = ["CAREamics"]
492
+
493
+ return keywords
494
+
495
+ def model_dump(
496
+ self,
497
+ exclude_defaults: bool = False,
498
+ exclude_none: bool = True,
499
+ **kwargs: dict,
500
+ ) -> dict:
501
+ """
502
+ Override model_dump method in order to set default values.
503
+
504
+ Parameters
505
+ ----------
506
+ exclude_defaults : bool, optional
507
+ Whether to exclude fields with default values or not, by default
508
+ True.
509
+ exclude_none : bool, optional
510
+ Whether to exclude fields with None values or not, by default True.
511
+ **kwargs : dict
512
+ Keyword arguments.
513
+
514
+ Returns
515
+ -------
516
+ dict
517
+ Dictionary containing the model parameters.
518
+ """
519
+ dictionary = super().model_dump(
520
+ exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs
521
+ )
522
+
523
+ return dictionary
524
+
525
+
526
+ def load_configuration(path: Union[str, Path]) -> Configuration:
527
+ """
528
+ Load configuration from a yaml file.
529
+
530
+ Parameters
531
+ ----------
532
+ path : str or Path
533
+ Path to the configuration.
534
+
535
+ Returns
536
+ -------
537
+ Configuration
538
+ Configuration.
539
+
540
+ Raises
541
+ ------
542
+ FileNotFoundError
543
+ If the configuration file does not exist.
544
+ """
545
+ # load dictionary from yaml
546
+ if not Path(path).exists():
547
+ raise FileNotFoundError(
548
+ f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
549
+ )
550
+
551
+ dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
552
+
553
+ return Configuration(**dictionary)
554
+
555
+
556
+ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
557
+ """
558
+ Save configuration to path.
559
+
560
+ Parameters
561
+ ----------
562
+ config : Configuration
563
+ Configuration to save.
564
+ path : str or Path
565
+ Path to a existing folder in which to save the configuration or to an existing
566
+ configuration file.
567
+
568
+ Returns
569
+ -------
570
+ Path
571
+ Path object representing the configuration.
572
+
573
+ Raises
574
+ ------
575
+ ValueError
576
+ If the path does not point to an existing directory or .yml file.
577
+ """
578
+ # make sure path is a Path object
579
+ config_path = Path(path)
580
+
581
+ # check if path is pointing to an existing directory or .yml file
582
+ if config_path.exists():
583
+ if config_path.is_dir():
584
+ config_path = Path(config_path, "config.yml")
585
+ elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
586
+ raise ValueError(
587
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
588
+ )
589
+ else:
590
+ if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
591
+ raise ValueError(
592
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
593
+ )
594
+
595
+ # save configuration as dictionary to yaml
596
+ with open(config_path, "w") as f:
597
+ # dump configuration
598
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
599
+
600
+ return config_path