careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,604 @@
1
+ """Pydantic CAREamics configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from pathlib import Path
7
+ from pprint import pformat
8
+ from typing import Literal, Union
9
+
10
+ import yaml
11
+ from bioimageio.spec.generic.v0_3 import CiteEntry
12
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
13
+ from typing_extensions import Self
14
+
15
+ from .data_model import DataConfig
16
+ from .fcn_algorithm_model import FCNAlgorithmConfig
17
+ from .references import (
18
+ CARE,
19
+ CUSTOM,
20
+ N2N,
21
+ N2V,
22
+ N2V2,
23
+ STRUCT_N2V,
24
+ STRUCT_N2V2,
25
+ CAREDescription,
26
+ CARERef,
27
+ N2NDescription,
28
+ N2NRef,
29
+ N2V2Description,
30
+ N2V2Ref,
31
+ N2VDescription,
32
+ N2VRef,
33
+ StructN2V2Description,
34
+ StructN2VDescription,
35
+ StructN2VRef,
36
+ )
37
+ from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform
38
+ from .training_model import TrainingConfig
39
+ from .transformations.n2v_manipulate_model import (
40
+ N2VManipulateModel,
41
+ )
42
+ from .vae_algorithm_model import VAEAlgorithmConfig
43
+
44
+
45
+ class Configuration(BaseModel):
46
+ """
47
+ CAREamics configuration.
48
+
49
+ The configuration defines all parameters used to build and train a CAREamics model.
50
+ These parameters are validated to ensure that they are compatible with each other.
51
+
52
+ It contains three sub-configurations:
53
+
54
+ - AlgorithmModel: configuration for the algorithm training, which includes the
55
+ architecture, loss function, optimizer, and other hyperparameters.
56
+ - DataModel: configuration for the dataloader, which includes the type of data,
57
+ transformations, mean/std and other parameters.
58
+ - TrainingModel: configuration for the training, which includes the number of
59
+ epochs or the callbacks.
60
+
61
+ Attributes
62
+ ----------
63
+ experiment_name : str
64
+ Name of the experiment, used when saving logs and checkpoints.
65
+ algorithm : AlgorithmModel
66
+ Algorithm configuration.
67
+ data : DataModel
68
+ Data configuration.
69
+ training : TrainingModel
70
+ Training configuration.
71
+
72
+ Methods
73
+ -------
74
+ set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None
75
+ Switch configuration between 2D and 3D.
76
+ set_N2V2(use_n2v2: bool) -> None
77
+ Switch N2V algorithm between N2V and N2V2.
78
+ set_structN2V(
79
+ mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int) -> None
80
+ Set StructN2V parameters.
81
+ model_dump(
82
+ exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict
83
+ ) -> Dict
84
+ Export configuration to a dictionary.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ Configuration parameter type validation errors.
90
+ ValueError
91
+ If the experiment name contains invalid characters or is empty.
92
+ ValueError
93
+ If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm
94
+ with "Z" in data axes.
95
+ ValueError
96
+ Algorithm, data or training validation errors.
97
+
98
+ Notes
99
+ -----
100
+ We provide convenience methods to create standards configurations, for instance
101
+ for N2V, in the `careamics.config.configuration_factory` module.
102
+ >>> from careamics.config.configuration_factory import create_n2v_configuration
103
+ >>> config = create_n2v_configuration(
104
+ ... experiment_name="n2v_experiment",
105
+ ... data_type="array",
106
+ ... axes="YX",
107
+ ... patch_size=[64, 64],
108
+ ... batch_size=32,
109
+ ... num_epochs=100
110
+ ... )
111
+
112
+ The configuration can be exported to a dictionary using the model_dump method:
113
+ >>> config_dict = config.model_dump()
114
+
115
+ Configurations can also be exported or imported from yaml files:
116
+ >>> from careamics.config import save_configuration, load_configuration
117
+ >>> path_to_config = save_configuration(config, my_path / "config.yml")
118
+ >>> other_config = load_configuration(path_to_config)
119
+
120
+ Examples
121
+ --------
122
+ Minimum example:
123
+ >>> from careamics.config import Configuration
124
+ >>> config_dict = {
125
+ ... "experiment_name": "N2V_experiment",
126
+ ... "algorithm_config": {
127
+ ... "algorithm_type": "fcn",
128
+ ... "algorithm": "n2v",
129
+ ... "loss": "n2v",
130
+ ... "model": {
131
+ ... "architecture": "UNet",
132
+ ... },
133
+ ... },
134
+ ... "training_config": {
135
+ ... "num_epochs": 200,
136
+ ... },
137
+ ... "data_config": {
138
+ ... "data_type": "tiff",
139
+ ... "patch_size": [64, 64],
140
+ ... "axes": "SYX",
141
+ ... },
142
+ ... }
143
+ >>> config = Configuration(**config_dict)
144
+ """
145
+
146
+ model_config = ConfigDict(
147
+ validate_assignment=True,
148
+ set_arbitrary_types_allowed=True,
149
+ )
150
+
151
+ # version
152
+ version: Literal["0.1.0"] = "0.1.0"
153
+ """CAREamics configuration version."""
154
+
155
+ # required parameters
156
+ experiment_name: str
157
+ """Name of the experiment, used to name logs and checkpoints."""
158
+
159
+ # Sub-configurations
160
+ algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field(
161
+ discriminator="algorithm_type"
162
+ )
163
+ """Algorithm configuration, holding all parameters required to configure the
164
+ model."""
165
+
166
+ data_config: DataConfig
167
+ """Data configuration, holding all parameters required to configure the training
168
+ data loader."""
169
+
170
+ training_config: TrainingConfig
171
+ """Training configuration, holding all parameters required to configure the
172
+ training process."""
173
+
174
+ @field_validator("experiment_name")
175
+ @classmethod
176
+ def no_symbol(cls, name: str) -> str:
177
+ """
178
+ Validate experiment name.
179
+
180
+ A valid experiment name is a non-empty string with only contains letters,
181
+ numbers, underscores, dashes and spaces.
182
+
183
+ Parameters
184
+ ----------
185
+ name : str
186
+ Name to validate.
187
+
188
+ Returns
189
+ -------
190
+ str
191
+ Validated name.
192
+
193
+ Raises
194
+ ------
195
+ ValueError
196
+ If the name is empty or contains invalid characters.
197
+ """
198
+ if len(name) == 0 or name.isspace():
199
+ raise ValueError("Experiment name is empty.")
200
+
201
+ # Validate using a regex that it contains only letters, numbers, underscores,
202
+ # dashes and spaces
203
+ if not re.match(r"^[a-zA-Z0-9_\- ]*$", name):
204
+ raise ValueError(
205
+ f"Experiment name contains invalid characters (got {name}). "
206
+ f"Only letters, numbers, underscores, dashes and spaces are allowed."
207
+ )
208
+
209
+ return name
210
+
211
+ @model_validator(mode="after")
212
+ def validate_3D(self: Self) -> Self:
213
+ """
214
+ Change algorithm dimensions to match data.axes.
215
+
216
+ Only for non-custom algorithms.
217
+
218
+ Returns
219
+ -------
220
+ Self
221
+ Validated configuration.
222
+ """
223
+ if self.algorithm_config.algorithm != SupportedAlgorithm.CUSTOM:
224
+ if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D():
225
+ # change algorithm to 3D
226
+ self.algorithm_config.model.set_3D(True)
227
+ elif (
228
+ "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D()
229
+ ):
230
+ # change algorithm to 2D
231
+ self.algorithm_config.model.set_3D(False)
232
+
233
+ return self
234
+
235
+ @model_validator(mode="after")
236
+ def validate_algorithm_and_data(self: Self) -> Self:
237
+ """
238
+ Validate algorithm and data compatibility.
239
+
240
+ In particular, the validation does the following:
241
+
242
+ - If N2V is used, it enforces the presence of N2V_Maniuplate in the transforms
243
+ - If N2V2 is used, it enforces the correct manipulation strategy
244
+
245
+ Returns
246
+ -------
247
+ Self
248
+ Validated configuration.
249
+ """
250
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
251
+ # missing N2V_MANIPULATE
252
+ if not self.data_config.has_n2v_manipulate():
253
+ self.data_config.transforms.append(
254
+ N2VManipulateModel(
255
+ name=SupportedTransform.N2V_MANIPULATE.value,
256
+ )
257
+ )
258
+
259
+ median = SupportedPixelManipulation.MEDIAN.value
260
+ uniform = SupportedPixelManipulation.UNIFORM.value
261
+ strategy = median if self.algorithm_config.model.n2v2 else uniform
262
+ self.data_config.set_N2V2_strategy(strategy)
263
+ else:
264
+ # remove N2V manipulate if present
265
+ if self.data_config.has_n2v_manipulate():
266
+ self.data_config.remove_n2v_manipulate()
267
+
268
+ return self
269
+
270
+ def __str__(self) -> str:
271
+ """
272
+ Pretty string reprensenting the configuration.
273
+
274
+ Returns
275
+ -------
276
+ str
277
+ Pretty string.
278
+ """
279
+ return pformat(self.model_dump())
280
+
281
+ def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None:
282
+ """
283
+ Set 3D flag and axes.
284
+
285
+ Parameters
286
+ ----------
287
+ is_3D : bool
288
+ Whether the algorithm is 3D or not.
289
+ axes : str
290
+ Axes of the data.
291
+ patch_size : list[int]
292
+ Patch size.
293
+ """
294
+ # set the flag and axes (this will not trigger validation at the config level)
295
+ self.algorithm_config.model.set_3D(is_3D)
296
+ self.data_config.set_3D(axes, patch_size)
297
+
298
+ # cheap hack: trigger validation
299
+ self.algorithm_config = self.algorithm_config
300
+
301
+ def set_N2V2(self, use_n2v2: bool) -> None:
302
+ """
303
+ Switch N2V algorithm between N2V and N2V2.
304
+
305
+ Parameters
306
+ ----------
307
+ use_n2v2 : bool
308
+ Whether to use N2V2 or not.
309
+
310
+ Raises
311
+ ------
312
+ ValueError
313
+ If the algorithm is not N2V.
314
+ """
315
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
316
+ self.algorithm_config.model.n2v2 = use_n2v2
317
+ strategy = (
318
+ SupportedPixelManipulation.MEDIAN.value
319
+ if use_n2v2
320
+ else SupportedPixelManipulation.UNIFORM.value
321
+ )
322
+ self.data_config.set_N2V2_strategy(strategy)
323
+ else:
324
+ raise ValueError("N2V2 can only be set for N2V algorithm.")
325
+
326
+ def set_structN2V(
327
+ self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
328
+ ) -> None:
329
+ """
330
+ Set StructN2V parameters.
331
+
332
+ Parameters
333
+ ----------
334
+ mask_axis : Literal["horizontal", "vertical", "none"]
335
+ Axis of the structural mask.
336
+ mask_span : int
337
+ Span of the structural mask.
338
+ """
339
+ self.data_config.set_structN2V_mask(mask_axis, mask_span)
340
+
341
+ def get_algorithm_flavour(self) -> str:
342
+ """
343
+ Get the algorithm name.
344
+
345
+ Returns
346
+ -------
347
+ str
348
+ Algorithm name.
349
+ """
350
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
351
+ use_n2v2 = self.algorithm_config.model.n2v2
352
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
353
+
354
+ # return the n2v flavour
355
+ if use_n2v2 and use_structN2V:
356
+ return STRUCT_N2V2
357
+ elif use_n2v2:
358
+ return N2V2
359
+ elif use_structN2V:
360
+ return STRUCT_N2V
361
+ else:
362
+ return N2V
363
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
364
+ return N2N
365
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
366
+ return CARE
367
+ else:
368
+ return CUSTOM
369
+
370
+ def get_algorithm_description(self) -> str:
371
+ """
372
+ Return a description of the algorithm.
373
+
374
+ This method is used to generate the README of the BioImage Model Zoo export.
375
+
376
+ Returns
377
+ -------
378
+ str
379
+ Description of the algorithm.
380
+ """
381
+ algorithm_flavour = self.get_algorithm_flavour()
382
+
383
+ if algorithm_flavour == CUSTOM:
384
+ return f"Custom algorithm, named {self.algorithm_config.model.name}"
385
+ else: # currently only N2V flavours
386
+ if algorithm_flavour == N2V:
387
+ return N2VDescription().description
388
+ elif algorithm_flavour == N2V2:
389
+ return N2V2Description().description
390
+ elif algorithm_flavour == STRUCT_N2V:
391
+ return StructN2VDescription().description
392
+ elif algorithm_flavour == STRUCT_N2V2:
393
+ return StructN2V2Description().description
394
+ elif algorithm_flavour == N2N:
395
+ return N2NDescription().description
396
+ elif algorithm_flavour == CARE:
397
+ return CAREDescription().description
398
+
399
+ return ""
400
+
401
+ def get_algorithm_citations(self) -> list[CiteEntry]:
402
+ """
403
+ Return a list of citation entries of the current algorithm.
404
+
405
+ This is used to generate the model description for the BioImage Model Zoo.
406
+
407
+ Returns
408
+ -------
409
+ List[CiteEntry]
410
+ List of citation entries.
411
+ """
412
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
413
+ use_n2v2 = self.algorithm_config.model.n2v2
414
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
415
+
416
+ # return the (struct)N2V(2) references
417
+ if use_n2v2 and use_structN2V:
418
+ return [N2VRef, N2V2Ref, StructN2VRef]
419
+ elif use_n2v2:
420
+ return [N2VRef, N2V2Ref]
421
+ elif use_structN2V:
422
+ return [N2VRef, StructN2VRef]
423
+ else:
424
+ return [N2VRef]
425
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N:
426
+ return [N2NRef]
427
+ elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE:
428
+ return [CARERef]
429
+
430
+ raise ValueError("Citation not available for custom algorithm.")
431
+
432
+ def get_algorithm_references(self) -> str:
433
+ """
434
+ Get the algorithm references.
435
+
436
+ This is used to generate the README of the BioImage Model Zoo export.
437
+
438
+ Returns
439
+ -------
440
+ str
441
+ Algorithm references.
442
+ """
443
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
444
+ use_n2v2 = self.algorithm_config.model.n2v2
445
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
446
+
447
+ references = [
448
+ N2VRef.text + " doi: " + N2VRef.doi,
449
+ N2V2Ref.text + " doi: " + N2V2Ref.doi,
450
+ StructN2VRef.text + " doi: " + StructN2VRef.doi,
451
+ ]
452
+
453
+ # return the (struct)N2V(2) references
454
+ if use_n2v2 and use_structN2V:
455
+ return "".join(references)
456
+ elif use_n2v2:
457
+ references.pop(-1)
458
+ return "".join(references)
459
+ elif use_structN2V:
460
+ references.pop(-2)
461
+ return "".join(references)
462
+ else:
463
+ return references[0]
464
+
465
+ return ""
466
+
467
+ def get_algorithm_keywords(self) -> list[str]:
468
+ """
469
+ Get algorithm keywords.
470
+
471
+ Returns
472
+ -------
473
+ list[str]
474
+ List of keywords.
475
+ """
476
+ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V:
477
+ use_n2v2 = self.algorithm_config.model.n2v2
478
+ use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none"
479
+
480
+ keywords = [
481
+ "denoising",
482
+ "restoration",
483
+ "UNet",
484
+ "3D" if "Z" in self.data_config.axes else "2D",
485
+ "CAREamics",
486
+ "pytorch",
487
+ N2V,
488
+ ]
489
+
490
+ if use_n2v2:
491
+ keywords.append(N2V2)
492
+ if use_structN2V:
493
+ keywords.append(STRUCT_N2V)
494
+ else:
495
+ keywords = ["CAREamics"]
496
+
497
+ return keywords
498
+
499
+ def model_dump(
500
+ self,
501
+ exclude_defaults: bool = False,
502
+ exclude_none: bool = True,
503
+ **kwargs: dict,
504
+ ) -> dict:
505
+ """
506
+ Override model_dump method in order to set default values.
507
+
508
+ Parameters
509
+ ----------
510
+ exclude_defaults : bool, optional
511
+ Whether to exclude fields with default values or not, by default
512
+ True.
513
+ exclude_none : bool, optional
514
+ Whether to exclude fields with None values or not, by default True.
515
+ **kwargs : dict
516
+ Keyword arguments.
517
+
518
+ Returns
519
+ -------
520
+ dict
521
+ Dictionary containing the model parameters.
522
+ """
523
+ dictionary = super().model_dump(
524
+ exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs
525
+ )
526
+
527
+ return dictionary
528
+
529
+
530
+ def load_configuration(path: Union[str, Path]) -> Configuration:
531
+ """
532
+ Load configuration from a yaml file.
533
+
534
+ Parameters
535
+ ----------
536
+ path : str or Path
537
+ Path to the configuration.
538
+
539
+ Returns
540
+ -------
541
+ Configuration
542
+ Configuration.
543
+
544
+ Raises
545
+ ------
546
+ FileNotFoundError
547
+ If the configuration file does not exist.
548
+ """
549
+ # load dictionary from yaml
550
+ if not Path(path).exists():
551
+ raise FileNotFoundError(
552
+ f"Configuration file {path} does not exist in " f" {Path.cwd()!s}"
553
+ )
554
+
555
+ dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader)
556
+
557
+ return Configuration(**dictionary)
558
+
559
+
560
+ def save_configuration(config: Configuration, path: Union[str, Path]) -> Path:
561
+ """
562
+ Save configuration to path.
563
+
564
+ Parameters
565
+ ----------
566
+ config : Configuration
567
+ Configuration to save.
568
+ path : str or Path
569
+ Path to a existing folder in which to save the configuration or to an existing
570
+ configuration file.
571
+
572
+ Returns
573
+ -------
574
+ Path
575
+ Path object representing the configuration.
576
+
577
+ Raises
578
+ ------
579
+ ValueError
580
+ If the path does not point to an existing directory or .yml file.
581
+ """
582
+ # make sure path is a Path object
583
+ config_path = Path(path)
584
+
585
+ # check if path is pointing to an existing directory or .yml file
586
+ if config_path.exists():
587
+ if config_path.is_dir():
588
+ config_path = Path(config_path, "config.yml")
589
+ elif config_path.suffix != ".yml" and config_path.suffix != ".yaml":
590
+ raise ValueError(
591
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
592
+ )
593
+ else:
594
+ if config_path.suffix != ".yml" and config_path.suffix != ".yaml":
595
+ raise ValueError(
596
+ f"Path must be a directory or .yml or .yaml file (got {config_path})."
597
+ )
598
+
599
+ # save configuration as dictionary to yaml
600
+ with open(config_path, "w") as f:
601
+ # dump configuration
602
+ yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)
603
+
604
+ return config_path