careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,527 @@
1
+ """Data configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pprint import pformat
6
+ from typing import Any, Literal, Optional, Union
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+ from pydantic import (
11
+ BaseModel,
12
+ ConfigDict,
13
+ Discriminator,
14
+ Field,
15
+ PlainSerializer,
16
+ field_validator,
17
+ model_validator,
18
+ )
19
+ from typing_extensions import Annotated, Self
20
+
21
+ from .support import SupportedTransform
22
+ from .transformations.n2v_manipulate_model import N2VManipulateModel
23
+ from .transformations.xy_flip_model import XYFlipModel
24
+ from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
25
+ from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
26
+
27
+
28
+ def np_float_to_scientific_str(x: float) -> str:
29
+ """Return a string scientific representation of a float.
30
+
31
+ In particular, this method is used to serialize floats to strings, allowing
32
+ numpy.float32 to be passed in the Pydantic model and written to a yaml file as str.
33
+
34
+ Parameters
35
+ ----------
36
+ x : float
37
+ Input value.
38
+
39
+ Returns
40
+ -------
41
+ str
42
+ Scientific string representation of the input value.
43
+ """
44
+ return np.format_float_scientific(x, precision=7)
45
+
46
+
47
+ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type=str)]
48
+ """Annotated float type, used to serialize floats to strings."""
49
+
50
+
51
+ TRANSFORMS_UNION = Annotated[
52
+ Union[
53
+ XYFlipModel,
54
+ XYRandomRotate90Model,
55
+ N2VManipulateModel,
56
+ ],
57
+ Discriminator("name"), # used to tell the different transform models apart
58
+ ]
59
+ """Available transforms in CAREamics."""
60
+
61
+
62
+ class DataConfig(BaseModel):
63
+ """
64
+ Data configuration.
65
+
66
+ If std is specified, mean must be specified as well. Note that setting the std first
67
+ and then the mean (if they were both `None` before) will raise a validation error.
68
+ Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
69
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
70
+ of the target could be different from the input data.
71
+
72
+ All supported transforms are defined in the SupportedTransform enum.
73
+
74
+ Examples
75
+ --------
76
+ Minimum example:
77
+
78
+ >>> data = DataConfig(
79
+ ... data_type="array", # defined in SupportedData
80
+ ... patch_size=[128, 128],
81
+ ... batch_size=4,
82
+ ... axes="YX"
83
+ ... )
84
+
85
+ To change the image_means and image_stds of the data:
86
+ >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
87
+
88
+ One can pass also a list of transformations, by keyword, using the
89
+ SupportedTransform value:
90
+ >>> from careamics.config.support import SupportedTransform
91
+ >>> data = DataConfig(
92
+ ... data_type="tiff",
93
+ ... patch_size=[128, 128],
94
+ ... batch_size=4,
95
+ ... axes="YX",
96
+ ... transforms=[
97
+ ... {
98
+ ... "name": "XYFlip",
99
+ ... }
100
+ ... ]
101
+ ... )
102
+ """
103
+
104
+ # Pydantic class configuration
105
+ model_config = ConfigDict(
106
+ validate_assignment=True,
107
+ )
108
+
109
+ # Dataset configuration
110
+ data_type: Literal["array", "tiff", "custom"]
111
+ """Type of input data, numpy.ndarray (array) or paths (tiff and custom), as defined
112
+ in SupportedData."""
113
+
114
+ axes: str
115
+ """Axes of the data, as defined in SupportedAxes."""
116
+
117
+ patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
118
+ """Patch size, as used during training."""
119
+
120
+ batch_size: int = Field(default=1, ge=1, validate_default=True)
121
+ """Batch size for training."""
122
+
123
+ # Optional fields
124
+ image_means: Optional[list[Float]] = Field(
125
+ default=None, min_length=0, max_length=32
126
+ )
127
+ """Means of the data across channels, used for normalization."""
128
+
129
+ image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
130
+ """Standard deviations of the data across channels, used for normalization."""
131
+
132
+ target_means: Optional[list[Float]] = Field(
133
+ default=None, min_length=0, max_length=32
134
+ )
135
+ """Means of the target data across channels, used for normalization."""
136
+
137
+ target_stds: Optional[list[Float]] = Field(
138
+ default=None, min_length=0, max_length=32
139
+ )
140
+ """Standard deviations of the target data across channels, used for
141
+ normalization."""
142
+
143
+ transforms: list[TRANSFORMS_UNION] = Field(
144
+ default=[
145
+ {
146
+ "name": SupportedTransform.XY_FLIP.value,
147
+ },
148
+ {
149
+ "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
150
+ },
151
+ {
152
+ "name": SupportedTransform.N2V_MANIPULATE.value,
153
+ },
154
+ ],
155
+ validate_default=True,
156
+ )
157
+ """List of transformations to apply to the data, available transforms are defined
158
+ in SupportedTransform. The default values are set for Noise2Void."""
159
+
160
+ dataloader_params: Optional[dict] = None
161
+ """Dictionary of PyTorch dataloader parameters."""
162
+
163
+ @field_validator("patch_size")
164
+ @classmethod
165
+ def all_elements_power_of_2_minimum_8(
166
+ cls, patch_list: Union[list[int]]
167
+ ) -> Union[list[int]]:
168
+ """
169
+ Validate patch size.
170
+
171
+ Patch size must be powers of 2 and minimum 8.
172
+
173
+ Parameters
174
+ ----------
175
+ patch_list : list of int
176
+ Patch size.
177
+
178
+ Returns
179
+ -------
180
+ list of int
181
+ Validated patch size.
182
+
183
+ Raises
184
+ ------
185
+ ValueError
186
+ If the patch size is smaller than 8.
187
+ ValueError
188
+ If the patch size is not a power of 2.
189
+ """
190
+ patch_size_ge_than_8_power_of_2(patch_list)
191
+
192
+ return patch_list
193
+
194
+ @field_validator("axes")
195
+ @classmethod
196
+ def axes_valid(cls, axes: str) -> str:
197
+ """
198
+ Validate axes.
199
+
200
+ Axes must:
201
+ - be a combination of 'STCZYX'
202
+ - not contain duplicates
203
+ - contain at least 2 contiguous axes: X and Y
204
+ - contain at most 4 axes
205
+ - not contain both S and T axes
206
+
207
+ Parameters
208
+ ----------
209
+ axes : str
210
+ Axes to validate.
211
+
212
+ Returns
213
+ -------
214
+ str
215
+ Validated axes.
216
+
217
+ Raises
218
+ ------
219
+ ValueError
220
+ If axes are not valid.
221
+ """
222
+ # Validate axes
223
+ check_axes_validity(axes)
224
+
225
+ return axes
226
+
227
+ @field_validator("transforms")
228
+ @classmethod
229
+ def validate_prediction_transforms(
230
+ cls, transforms: list[TRANSFORMS_UNION]
231
+ ) -> list[TRANSFORMS_UNION]:
232
+ """
233
+ Validate N2VManipulate transform position in the transform list.
234
+
235
+ Parameters
236
+ ----------
237
+ transforms : list[Transformations_Union]
238
+ Transforms.
239
+
240
+ Returns
241
+ -------
242
+ list of transforms
243
+ Validated transforms.
244
+
245
+ Raises
246
+ ------
247
+ ValueError
248
+ If multiple instances of N2VManipulate are found.
249
+ """
250
+ transform_list = [t.name for t in transforms]
251
+
252
+ if SupportedTransform.N2V_MANIPULATE in transform_list:
253
+ # multiple N2V_MANIPULATE
254
+ if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
255
+ raise ValueError(
256
+ f"Multiple instances of "
257
+ f"{SupportedTransform.N2V_MANIPULATE} transforms "
258
+ f"are not allowed."
259
+ )
260
+
261
+ # N2V_MANIPULATE not the last transform
262
+ elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
263
+ index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
264
+ transform = transforms.pop(index)
265
+ transforms.append(transform)
266
+
267
+ return transforms
268
+
269
+ @model_validator(mode="after")
270
+ def std_only_with_mean(self: Self) -> Self:
271
+ """
272
+ Check that mean and std are either both None, or both specified.
273
+
274
+ Returns
275
+ -------
276
+ Self
277
+ Validated data model.
278
+
279
+ Raises
280
+ ------
281
+ ValueError
282
+ If std is not None and mean is None.
283
+ """
284
+ # check that mean and std are either both None, or both specified
285
+ if (self.image_means and not self.image_stds) or (
286
+ self.image_stds and not self.image_means
287
+ ):
288
+ raise ValueError(
289
+ "Mean and std must be either both None, or both specified."
290
+ )
291
+
292
+ elif (self.image_means is not None and self.image_stds is not None) and (
293
+ len(self.image_means) != len(self.image_stds)
294
+ ):
295
+ raise ValueError("Mean and std must be specified for each input channel.")
296
+
297
+ if (self.target_means and not self.target_stds) or (
298
+ self.target_stds and not self.target_means
299
+ ):
300
+ raise ValueError(
301
+ "Mean and std must be either both None, or both specified "
302
+ )
303
+
304
+ elif self.target_means is not None and self.target_stds is not None:
305
+ if len(self.target_means) != len(self.target_stds):
306
+ raise ValueError(
307
+ "Mean and std must be either both None, or both specified for each "
308
+ "target channel."
309
+ )
310
+
311
+ return self
312
+
313
+ @model_validator(mode="after")
314
+ def validate_dimensions(self: Self) -> Self:
315
+ """
316
+ Validate 2D/3D dimensions between axes, patch size and transforms.
317
+
318
+ Returns
319
+ -------
320
+ Self
321
+ Validated data model.
322
+
323
+ Raises
324
+ ------
325
+ ValueError
326
+ If the transforms are not valid.
327
+ """
328
+ if "Z" in self.axes:
329
+ if len(self.patch_size) != 3:
330
+ raise ValueError(
331
+ f"Patch size must have 3 dimensions if the data is 3D "
332
+ f"({self.axes})."
333
+ )
334
+
335
+ else:
336
+ if len(self.patch_size) != 2:
337
+ raise ValueError(
338
+ f"Patch size must have 3 dimensions if the data is 3D "
339
+ f"({self.axes})."
340
+ )
341
+
342
+ return self
343
+
344
+ def __str__(self) -> str:
345
+ """
346
+ Pretty string reprensenting the configuration.
347
+
348
+ Returns
349
+ -------
350
+ str
351
+ Pretty string.
352
+ """
353
+ return pformat(self.model_dump())
354
+
355
+ def _update(self, **kwargs: Any) -> None:
356
+ """
357
+ Update multiple arguments at once.
358
+
359
+ Parameters
360
+ ----------
361
+ **kwargs : Any
362
+ Keyword arguments to update.
363
+ """
364
+ self.__dict__.update(kwargs)
365
+ self.__class__.model_validate(self.__dict__)
366
+
367
+ def has_n2v_manipulate(self) -> bool:
368
+ """
369
+ Check if the transforms contain N2VManipulate.
370
+
371
+ Returns
372
+ -------
373
+ bool
374
+ True if the transforms contain N2VManipulate, False otherwise.
375
+ """
376
+ return any(
377
+ transform.name == SupportedTransform.N2V_MANIPULATE.value
378
+ for transform in self.transforms
379
+ )
380
+
381
+ def add_n2v_manipulate(self) -> None:
382
+ """Add N2VManipulate to the transforms."""
383
+ if not self.has_n2v_manipulate():
384
+ self.transforms.append(
385
+ N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
386
+ )
387
+
388
+ def remove_n2v_manipulate(self) -> None:
389
+ """Remove N2VManipulate from the transforms."""
390
+ if self.has_n2v_manipulate():
391
+ self.transforms.pop(-1)
392
+
393
+ def set_means_and_stds(
394
+ self,
395
+ image_means: Union[NDArray, tuple, list, None],
396
+ image_stds: Union[NDArray, tuple, list, None],
397
+ target_means: Optional[Union[NDArray, tuple, list, None]] = None,
398
+ target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
399
+ ) -> None:
400
+ """
401
+ Set mean and standard deviation of the data across channels.
402
+
403
+ This method should be used instead setting the fields directly, as it would
404
+ otherwise trigger a validation error.
405
+
406
+ Parameters
407
+ ----------
408
+ image_means : numpy.ndarray, tuple or list
409
+ Mean values for normalization.
410
+ image_stds : numpy.ndarray, tuple or list
411
+ Standard deviation values for normalization.
412
+ target_means : numpy.ndarray, tuple or list, optional
413
+ Target mean values for normalization, by default ().
414
+ target_stds : numpy.ndarray, tuple or list, optional
415
+ Target standard deviation values for normalization, by default ().
416
+ """
417
+ # make sure we pass a list
418
+ if image_means is not None:
419
+ image_means = list(image_means)
420
+ if image_stds is not None:
421
+ image_stds = list(image_stds)
422
+ if target_means is not None:
423
+ target_means = list(target_means)
424
+ if target_stds is not None:
425
+ target_stds = list(target_stds)
426
+
427
+ self._update(
428
+ image_means=image_means,
429
+ image_stds=image_stds,
430
+ target_means=target_means,
431
+ target_stds=target_stds,
432
+ )
433
+
434
+ def set_3D(self, axes: str, patch_size: list[int]) -> None:
435
+ """
436
+ Set 3D parameters.
437
+
438
+ Parameters
439
+ ----------
440
+ axes : str
441
+ Axes.
442
+ patch_size : list of int
443
+ Patch size.
444
+ """
445
+ self._update(axes=axes, patch_size=patch_size)
446
+
447
+ def set_N2V2(self, use_n2v2: bool) -> None:
448
+ """
449
+ Set N2V2.
450
+
451
+ Parameters
452
+ ----------
453
+ use_n2v2 : bool
454
+ Whether to use N2V2.
455
+
456
+ Raises
457
+ ------
458
+ ValueError
459
+ If the N2V pixel manipulate transform is not found in the transforms.
460
+ """
461
+ if use_n2v2:
462
+ self.set_N2V2_strategy("median")
463
+ else:
464
+ self.set_N2V2_strategy("uniform")
465
+
466
+ def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None:
467
+ """
468
+ Set N2V2 strategy.
469
+
470
+ Parameters
471
+ ----------
472
+ strategy : Literal["uniform", "median"]
473
+ Strategy to use for N2V2.
474
+
475
+ Raises
476
+ ------
477
+ ValueError
478
+ If the N2V pixel manipulate transform is not found in the transforms.
479
+ """
480
+ found_n2v = False
481
+
482
+ for transform in self.transforms:
483
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
484
+ transform.strategy = strategy
485
+ found_n2v = True
486
+
487
+ if not found_n2v:
488
+ transforms = [t.name for t in self.transforms]
489
+ raise ValueError(
490
+ f"N2V_Manipulate transform not found in the transforms list "
491
+ f"({transforms})."
492
+ )
493
+
494
+ def set_structN2V_mask(
495
+ self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
496
+ ) -> None:
497
+ """
498
+ Set structN2V mask parameters.
499
+
500
+ Setting `mask_axis` to `none` will disable structN2V.
501
+
502
+ Parameters
503
+ ----------
504
+ mask_axis : Literal["horizontal", "vertical", "none"]
505
+ Axis along which to apply the mask. `none` will disable structN2V.
506
+ mask_span : int
507
+ Total span of the mask in pixels.
508
+
509
+ Raises
510
+ ------
511
+ ValueError
512
+ If the N2V pixel manipulate transform is not found in the transforms.
513
+ """
514
+ found_n2v = False
515
+
516
+ for transform in self.transforms:
517
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
518
+ transform.struct_mask_axis = mask_axis
519
+ transform.struct_mask_span = mask_span
520
+ found_n2v = True
521
+
522
+ if not found_n2v:
523
+ transforms = [t.name for t in self.transforms]
524
+ raise ValueError(
525
+ f"N2V pixel manipulate transform not found in the transforms "
526
+ f"({transforms})."
527
+ )
@@ -0,0 +1,147 @@
1
+ """Module containing `FCNAlgorithmConfig` class."""
2
+
3
+ from pprint import pformat
4
+ from typing import Literal, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
7
+ from typing_extensions import Self
8
+
9
+ from careamics.config.architectures import CustomModel, UNetModel
10
+ from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel
11
+
12
+
13
+ class FCNAlgorithmConfig(BaseModel):
14
+ """Algorithm configuration.
15
+
16
+ This Pydantic model validates the parameters governing the components of the
17
+ training algorithm: which algorithm, loss function, model architecture, optimizer,
18
+ and learning rate scheduler to use.
19
+
20
+ Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
21
+ only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
22
+ allows you to register your own architecture and select it using its name as
23
+ `name` in the custom pydantic model.
24
+
25
+ Attributes
26
+ ----------
27
+ algorithm : Literal["n2v", "custom"]
28
+ Algorithm to use.
29
+ loss : Literal["n2v", "mae", "mse"]
30
+ Loss function to use.
31
+ model : Union[UNetModel, LVAEModel, CustomModel]
32
+ Model architecture to use.
33
+ optimizer : OptimizerModel, optional
34
+ Optimizer to use.
35
+ lr_scheduler : LrSchedulerModel, optional
36
+ Learning rate scheduler to use.
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ Algorithm parameter type validation errors.
42
+ ValueError
43
+ If the algorithm, loss and model are not compatible.
44
+
45
+ Examples
46
+ --------
47
+ Minimum example:
48
+ >>> from careamics.config import FCNAlgorithmConfig
49
+ >>> config_dict = {
50
+ ... "algorithm": "n2v",
51
+ ... "algorithm_type": "fcn",
52
+ ... "loss": "n2v",
53
+ ... "model": {
54
+ ... "architecture": "UNet",
55
+ ... }
56
+ ... }
57
+ >>> config = FCNAlgorithmConfig(**config_dict)
58
+ """
59
+
60
+ # Pydantic class configuration
61
+ model_config = ConfigDict(
62
+ protected_namespaces=(), # allows to use model_* as a field name
63
+ validate_assignment=True,
64
+ extra="allow",
65
+ )
66
+
67
+ # Mandatory fields
68
+ # defined in SupportedAlgorithm
69
+ algorithm_type: Literal["fcn"]
70
+ """Algorithm type must be `fcn` (fully convolutional network) to differentiate this
71
+ configuration from LVAE."""
72
+
73
+ algorithm: Literal["n2v", "care", "n2n", "custom"]
74
+ """Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom
75
+ model architecture."""
76
+
77
+ loss: Literal["n2v", "mae", "mse"]
78
+ """Loss function to use, as defined in SupportedLoss."""
79
+
80
+ model: Union[UNetModel, CustomModel] = Field(discriminator="architecture")
81
+ """Model architecture to use, along with its parameters. Compatible architectures
82
+ are defined in SupportedArchitecture, and their Pydantic models in
83
+ `careamics.config.architectures`."""
84
+ # TODO supported architectures are now all the architectures but does not warn users
85
+ # of the compatibility with the algorithm
86
+
87
+ # Optional fields
88
+ optimizer: OptimizerModel = OptimizerModel()
89
+ """Optimizer to use, defined in SupportedOptimizer."""
90
+
91
+ lr_scheduler: LrSchedulerModel = LrSchedulerModel()
92
+ """Learning rate scheduler to use, defined in SupportedLrScheduler."""
93
+
94
+ @model_validator(mode="after")
95
+ def algorithm_cross_validation(self: Self) -> Self:
96
+ """Validate the algorithm model based on `algorithm`.
97
+
98
+ N2V:
99
+ - loss must be n2v
100
+ - model must be a `UNetModel`
101
+
102
+ Returns
103
+ -------
104
+ Self
105
+ The validated model.
106
+ """
107
+ # N2V
108
+ if self.algorithm == "n2v":
109
+ # n2v is only compatible with the n2v loss
110
+ if self.loss != "n2v":
111
+ raise ValueError(
112
+ f"Algorithm {self.algorithm} only supports loss `n2v`."
113
+ )
114
+
115
+ # n2v is only compatible with the UNet model
116
+ if not isinstance(self.model, UNetModel):
117
+ raise ValueError(
118
+ f"Model for algorithm {self.algorithm} must be a `UNetModel`."
119
+ )
120
+
121
+ # n2v requires the number of input and output channels to be the same
122
+ if self.model.in_channels != self.model.num_classes:
123
+ raise ValueError(
124
+ "N2V requires the same number of input and output channels. Make "
125
+ "sure that `in_channels` and `num_classes` are the same."
126
+ )
127
+
128
+ if self.algorithm == "care" or self.algorithm == "n2n":
129
+ if self.loss == "n2v":
130
+ raise ValueError("Supervised algorithms do not support loss `n2v`.")
131
+
132
+ if (self.algorithm == "custom") != (self.model.architecture == "custom"):
133
+ raise ValueError(
134
+ "Algorithm and model architecture must be both `custom` or not."
135
+ )
136
+
137
+ return self
138
+
139
+ def __str__(self) -> str:
140
+ """Pretty string representing the configuration.
141
+
142
+ Returns
143
+ -------
144
+ str
145
+ Pretty string.
146
+ """
147
+ return pformat(self.model_dump())