careamics 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (111) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +4 -3
  3. careamics/cli/conf.py +1 -2
  4. careamics/cli/main.py +1 -2
  5. careamics/cli/utils.py +3 -3
  6. careamics/config/__init__.py +47 -25
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +38 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +30 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +29 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +14 -12
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +6 -1
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +185 -57
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +91 -186
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +1 -2
  25. careamics/config/n2n_configuration.py +101 -0
  26. careamics/config/n2v_configuration.py +266 -0
  27. careamics/config/nm_model.py +1 -2
  28. careamics/config/support/__init__.py +7 -7
  29. careamics/config/support/supported_algorithms.py +5 -4
  30. careamics/config/support/supported_architectures.py +0 -4
  31. careamics/config/transformations/__init__.py +10 -4
  32. careamics/config/transformations/transform_model.py +3 -3
  33. careamics/config/transformations/transform_unions.py +42 -0
  34. careamics/config/validators/__init__.py +12 -1
  35. careamics/config/validators/model_validators.py +84 -0
  36. careamics/config/validators/validator_utils.py +3 -3
  37. careamics/dataset/__init__.py +2 -2
  38. careamics/dataset/dataset_utils/__init__.py +3 -3
  39. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  40. careamics/dataset/dataset_utils/file_utils.py +9 -9
  41. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +3 -3
  63. careamics/lightning/lightning_module.py +11 -7
  64. careamics/lightning/train_data_module.py +36 -45
  65. careamics/losses/__init__.py +3 -3
  66. careamics/lvae_training/calibration.py +64 -57
  67. careamics/lvae_training/dataset/lc_dataset.py +2 -1
  68. careamics/lvae_training/dataset/multich_dataset.py +2 -2
  69. careamics/lvae_training/dataset/types.py +1 -1
  70. careamics/lvae_training/eval_utils.py +123 -128
  71. careamics/model_io/__init__.py +1 -1
  72. careamics/model_io/bioimage/__init__.py +1 -1
  73. careamics/model_io/bioimage/_readme_factory.py +1 -1
  74. careamics/model_io/bioimage/model_description.py +17 -17
  75. careamics/model_io/bmz_io.py +6 -17
  76. careamics/model_io/model_io_utils.py +9 -9
  77. careamics/models/layers.py +16 -16
  78. careamics/models/lvae/likelihoods.py +2 -0
  79. careamics/models/lvae/lvae.py +13 -4
  80. careamics/models/lvae/noise_models.py +280 -217
  81. careamics/models/lvae/stochastic.py +1 -0
  82. careamics/models/model_factory.py +2 -15
  83. careamics/models/unet.py +8 -8
  84. careamics/prediction_utils/__init__.py +1 -1
  85. careamics/prediction_utils/prediction_outputs.py +15 -15
  86. careamics/prediction_utils/stitch_prediction.py +6 -6
  87. careamics/transforms/__init__.py +5 -5
  88. careamics/transforms/compose.py +13 -13
  89. careamics/transforms/n2v_manipulate.py +3 -3
  90. careamics/transforms/pixel_manipulation.py +9 -9
  91. careamics/transforms/xy_random_rotate90.py +4 -4
  92. careamics/utils/__init__.py +5 -5
  93. careamics/utils/context.py +2 -1
  94. careamics/utils/logging.py +11 -10
  95. careamics/utils/metrics.py +25 -0
  96. careamics/utils/plotting.py +78 -0
  97. careamics/utils/torch_utils.py +7 -7
  98. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/METADATA +13 -11
  99. careamics-0.0.7.dist-info/RECORD +178 -0
  100. careamics/config/architectures/custom_model.py +0 -162
  101. careamics/config/architectures/register_model.py +0 -103
  102. careamics/config/configuration_model.py +0 -603
  103. careamics/config/fcn_algorithm_model.py +0 -152
  104. careamics/config/references/__init__.py +0 -45
  105. careamics/config/references/algorithm_descriptions.py +0 -132
  106. careamics/config/references/references.py +0 -39
  107. careamics/config/transformations/transform_union.py +0 -20
  108. careamics-0.0.5.dist-info/RECORD +0 -171
  109. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/WHEEL +0 -0
  110. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/entry_points.txt +0 -0
  111. {careamics-0.0.5.dist-info → careamics-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -2,8 +2,10 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pprint import pformat
6
- from typing import Any, Literal, Optional, Union
7
+ from typing import Annotated, Any, Literal, Optional, Union
8
+ from warnings import warn
7
9
 
8
10
  import numpy as np
9
11
  from numpy.typing import NDArray
@@ -15,11 +17,10 @@ from pydantic import (
15
17
  field_validator,
16
18
  model_validator,
17
19
  )
18
- from typing_extensions import Annotated, Self
20
+ from typing_extensions import Self
19
21
 
20
- from .support import SupportedTransform
21
- from .transformations import TRANSFORMS_UNION, N2VManipulateModel
22
- from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
22
+ from ..transformations import N2V_TRANSFORMS_UNION, XYFlipModel, XYRandomRotate90Model
23
+ from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
23
24
 
24
25
 
25
26
  def np_float_to_scientific_str(x: float) -> str:
@@ -45,47 +46,8 @@ Float = Annotated[float, PlainSerializer(np_float_to_scientific_str, return_type
45
46
  """Annotated float type, used to serialize floats to strings."""
46
47
 
47
48
 
48
- class DataConfig(BaseModel):
49
- """
50
- Data configuration.
51
-
52
- If std is specified, mean must be specified as well. Note that setting the std first
53
- and then the mean (if they were both `None` before) will raise a validation error.
54
- Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
55
- to be lists of floats, one for each channel. For supervised tasks, the mean and std
56
- of the target could be different from the input data.
57
-
58
- All supported transforms are defined in the SupportedTransform enum.
59
-
60
- Examples
61
- --------
62
- Minimum example:
63
-
64
- >>> data = DataConfig(
65
- ... data_type="array", # defined in SupportedData
66
- ... patch_size=[128, 128],
67
- ... batch_size=4,
68
- ... axes="YX"
69
- ... )
70
-
71
- To change the image_means and image_stds of the data:
72
- >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
73
-
74
- One can pass also a list of transformations, by keyword, using the
75
- SupportedTransform value:
76
- >>> from careamics.config.support import SupportedTransform
77
- >>> data = DataConfig(
78
- ... data_type="tiff",
79
- ... patch_size=[128, 128],
80
- ... batch_size=4,
81
- ... axes="YX",
82
- ... transforms=[
83
- ... {
84
- ... "name": "XYFlip",
85
- ... }
86
- ... ]
87
- ... )
88
- """
49
+ class GeneralDataConfig(BaseModel):
50
+ """General data configuration."""
89
51
 
90
52
  # Pydantic class configuration
91
53
  model_config = ConfigDict(
@@ -126,25 +88,26 @@ class DataConfig(BaseModel):
126
88
  """Standard deviations of the target data across channels, used for
127
89
  normalization."""
128
90
 
129
- transforms: list[TRANSFORMS_UNION] = Field(
91
+ # defining as Sequence allows assigning subclasses of TransformModel without mypy
92
+ # complaining, this is important for instance to differentiate N2VDataConfig and
93
+ # DataConfig
94
+ transforms: Sequence[N2V_TRANSFORMS_UNION] = Field(
130
95
  default=[
131
- {
132
- "name": SupportedTransform.XY_FLIP.value,
133
- },
134
- {
135
- "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
136
- },
137
- {
138
- "name": SupportedTransform.N2V_MANIPULATE.value,
139
- },
96
+ XYFlipModel(),
97
+ XYRandomRotate90Model(),
140
98
  ],
141
99
  validate_default=True,
142
100
  )
143
101
  """List of transformations to apply to the data, available transforms are defined
144
- in SupportedTransform. The default values are set for Noise2Void."""
102
+ in SupportedTransform."""
103
+
104
+ train_dataloader_params: dict[str, Any] = Field(
105
+ default={"shuffle": True}, validate_default=True
106
+ )
107
+ """Dictionary of PyTorch training dataloader parameters."""
145
108
 
146
- dataloader_params: Optional[dict] = None
147
- """Dictionary of PyTorch dataloader parameters."""
109
+ val_dataloader_params: dict[str, Any] = Field(default={})
110
+ """Dictionary of PyTorch validation dataloader parameters."""
148
111
 
149
112
  @field_validator("patch_size")
150
113
  @classmethod
@@ -210,47 +173,44 @@ class DataConfig(BaseModel):
210
173
 
211
174
  return axes
212
175
 
213
- @field_validator("transforms")
176
+ @field_validator("train_dataloader_params")
214
177
  @classmethod
215
- def validate_prediction_transforms(
216
- cls, transforms: list[TRANSFORMS_UNION]
217
- ) -> list[TRANSFORMS_UNION]:
178
+ def shuffle_train_dataloader(
179
+ cls, train_dataloader_params: dict[str, Any]
180
+ ) -> dict[str, Any]:
218
181
  """
219
- Validate N2VManipulate transform position in the transform list.
182
+ Validate that "shuffle" is included in the training dataloader params.
183
+
184
+ A warning will be raised if `shuffle=False`.
220
185
 
221
186
  Parameters
222
187
  ----------
223
- transforms : list[Transformations_Union]
224
- Transforms.
188
+ train_dataloader_params : dict of {str: Any}
189
+ The training dataloader parameters.
225
190
 
226
191
  Returns
227
192
  -------
228
- list of transforms
229
- Validated transforms.
193
+ dict of {str: Any}
194
+ The validated training dataloader parameters.
230
195
 
231
196
  Raises
232
197
  ------
233
198
  ValueError
234
- If multiple instances of N2VManipulate are found.
199
+ If "shuffle" is not included in the training dataloader params.
235
200
  """
236
- transform_list = [t.name for t in transforms]
237
-
238
- if SupportedTransform.N2V_MANIPULATE in transform_list:
239
- # multiple N2V_MANIPULATE
240
- if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
241
- raise ValueError(
242
- f"Multiple instances of "
243
- f"{SupportedTransform.N2V_MANIPULATE} transforms "
244
- f"are not allowed."
245
- )
246
-
247
- # N2V_MANIPULATE not the last transform
248
- elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
249
- index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
250
- transform = transforms.pop(index)
251
- transforms.append(transform)
252
-
253
- return transforms
201
+ if "shuffle" not in train_dataloader_params:
202
+ raise ValueError(
203
+ "Value for 'shuffle' was not included in the `train_dataloader_params`."
204
+ )
205
+ elif ("shuffle" in train_dataloader_params) and (
206
+ not train_dataloader_params["shuffle"]
207
+ ):
208
+ warn(
209
+ "Dataloader parameters include `shuffle=False`, this will be passed to "
210
+ "the training dataloader and may result in bad results.",
211
+ stacklevel=1,
212
+ )
213
+ return train_dataloader_params
254
214
 
255
215
  @model_validator(mode="after")
256
216
  def std_only_with_mean(self: Self) -> Self:
@@ -350,32 +310,6 @@ class DataConfig(BaseModel):
350
310
  self.__dict__.update(kwargs)
351
311
  self.__class__.model_validate(self.__dict__)
352
312
 
353
- def has_n2v_manipulate(self) -> bool:
354
- """
355
- Check if the transforms contain N2VManipulate.
356
-
357
- Returns
358
- -------
359
- bool
360
- True if the transforms contain N2VManipulate, False otherwise.
361
- """
362
- return any(
363
- transform.name == SupportedTransform.N2V_MANIPULATE.value
364
- for transform in self.transforms
365
- )
366
-
367
- def add_n2v_manipulate(self) -> None:
368
- """Add N2VManipulate to the transforms."""
369
- if not self.has_n2v_manipulate():
370
- self.transforms.append(
371
- N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
372
- )
373
-
374
- def remove_n2v_manipulate(self) -> None:
375
- """Remove N2VManipulate from the transforms."""
376
- if self.has_n2v_manipulate():
377
- self.transforms.pop(-1)
378
-
379
313
  def set_means_and_stds(
380
314
  self,
381
315
  image_means: Union[NDArray, tuple, list, None],
@@ -430,84 +364,55 @@ class DataConfig(BaseModel):
430
364
  """
431
365
  self._update(axes=axes, patch_size=patch_size)
432
366
 
433
- def set_N2V2(self, use_n2v2: bool) -> None:
434
- """
435
- Set N2V2.
436
-
437
- Parameters
438
- ----------
439
- use_n2v2 : bool
440
- Whether to use N2V2.
441
367
 
442
- Raises
443
- ------
444
- ValueError
445
- If the N2V pixel manipulate transform is not found in the transforms.
446
- """
447
- if use_n2v2:
448
- self.set_N2V2_strategy("median")
449
- else:
450
- self.set_N2V2_strategy("uniform")
451
-
452
- def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None:
453
- """
454
- Set N2V2 strategy.
455
-
456
- Parameters
457
- ----------
458
- strategy : Literal["uniform", "median"]
459
- Strategy to use for N2V2.
460
-
461
- Raises
462
- ------
463
- ValueError
464
- If the N2V pixel manipulate transform is not found in the transforms.
465
- """
466
- found_n2v = False
467
-
468
- for transform in self.transforms:
469
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
470
- transform.strategy = strategy
471
- found_n2v = True
368
+ class DataConfig(GeneralDataConfig):
369
+ """
370
+ Data configuration.
472
371
 
473
- if not found_n2v:
474
- transforms = [t.name for t in self.transforms]
475
- raise ValueError(
476
- f"N2V_Manipulate transform not found in the transforms list "
477
- f"({transforms})."
478
- )
372
+ If std is specified, mean must be specified as well. Note that setting the std first
373
+ and then the mean (if they were both `None` before) will raise a validation error.
374
+ Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
375
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
376
+ of the target could be different from the input data.
479
377
 
480
- def set_structN2V_mask(
481
- self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
482
- ) -> None:
483
- """
484
- Set structN2V mask parameters.
378
+ All supported transforms are defined in the SupportedTransform enum.
485
379
 
486
- Setting `mask_axis` to `none` will disable structN2V.
380
+ Examples
381
+ --------
382
+ Minimum example:
487
383
 
488
- Parameters
489
- ----------
490
- mask_axis : Literal["horizontal", "vertical", "none"]
491
- Axis along which to apply the mask. `none` will disable structN2V.
492
- mask_span : int
493
- Total span of the mask in pixels.
384
+ >>> data = DataConfig(
385
+ ... data_type="array", # defined in SupportedData
386
+ ... patch_size=[128, 128],
387
+ ... batch_size=4,
388
+ ... axes="YX"
389
+ ... )
494
390
 
495
- Raises
496
- ------
497
- ValueError
498
- If the N2V pixel manipulate transform is not found in the transforms.
499
- """
500
- found_n2v = False
391
+ To change the image_means and image_stds of the data:
392
+ >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
501
393
 
502
- for transform in self.transforms:
503
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
504
- transform.struct_mask_axis = mask_axis
505
- transform.struct_mask_span = mask_span
506
- found_n2v = True
394
+ One can pass also a list of transformations, by keyword, using the
395
+ SupportedTransform value:
396
+ >>> from careamics.config.support import SupportedTransform
397
+ >>> data = DataConfig(
398
+ ... data_type="tiff",
399
+ ... patch_size=[128, 128],
400
+ ... batch_size=4,
401
+ ... axes="YX",
402
+ ... transforms=[
403
+ ... {
404
+ ... "name": "XYFlip",
405
+ ... }
406
+ ... ]
407
+ ... )
408
+ """
507
409
 
508
- if not found_n2v:
509
- transforms = [t.name for t in self.transforms]
510
- raise ValueError(
511
- f"N2V pixel manipulate transform not found in the transforms "
512
- f"({transforms})."
513
- )
410
+ transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field(
411
+ default=[
412
+ XYFlipModel(),
413
+ XYRandomRotate90Model(),
414
+ ],
415
+ validate_default=True,
416
+ )
417
+ """List of transformations to apply to the data, available transforms are defined
418
+ in SupportedTransform. This excludes N2V specific transformations."""
@@ -0,0 +1,193 @@
1
+ """Noise2Void specific data configuration model."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Literal
5
+
6
+ from pydantic import Field, field_validator
7
+
8
+ from careamics.config.data.data_model import GeneralDataConfig
9
+ from careamics.config.support import SupportedTransform
10
+ from careamics.config.transformations import (
11
+ N2V_TRANSFORMS_UNION,
12
+ N2VManipulateModel,
13
+ XYFlipModel,
14
+ XYRandomRotate90Model,
15
+ )
16
+
17
+
18
+ class N2VDataConfig(GeneralDataConfig):
19
+ """N2V specific data configuration model."""
20
+
21
+ transforms: Sequence[N2V_TRANSFORMS_UNION] = Field(
22
+ default=[XYFlipModel(), XYRandomRotate90Model(), N2VManipulateModel()],
23
+ validate_default=True,
24
+ )
25
+ """N2V compatible transforms. N2VManpulate should be the last transform."""
26
+
27
+ @field_validator("transforms")
28
+ @classmethod
29
+ def validate_transforms(
30
+ cls, transforms: list[N2V_TRANSFORMS_UNION]
31
+ ) -> list[N2V_TRANSFORMS_UNION]:
32
+ """
33
+ Validate N2VManipulate transform position in the transform list.
34
+
35
+ Parameters
36
+ ----------
37
+ transforms : list of transforms compatible with N2V
38
+ Transforms.
39
+
40
+ Returns
41
+ -------
42
+ list of transforms
43
+ Validated transforms.
44
+
45
+ Raises
46
+ ------
47
+ ValueError
48
+ If multiple instances of N2VManipulate are found or if it is not the last
49
+ transform.
50
+ """
51
+ transform_list = [t.name for t in transforms]
52
+
53
+ if SupportedTransform.N2V_MANIPULATE in transform_list:
54
+ # multiple N2V_MANIPULATE
55
+ if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
56
+ raise ValueError(
57
+ f"Multiple instances of "
58
+ f"{SupportedTransform.N2V_MANIPULATE} transforms "
59
+ f"are not allowed."
60
+ )
61
+
62
+ # N2V_MANIPULATE not the last transform
63
+ elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
64
+ raise ValueError(
65
+ f"{SupportedTransform.N2V_MANIPULATE} transform "
66
+ f"should be the last transform."
67
+ )
68
+
69
+ else:
70
+ raise ValueError(
71
+ f"{SupportedTransform.N2V_MANIPULATE} transform "
72
+ f"is required for N2V training."
73
+ )
74
+
75
+ return transforms
76
+
77
+ def set_n2v2(self, use_n2v2: bool) -> None:
78
+ """
79
+ Set the N2V transform to the N2V2 version.
80
+
81
+ Parameters
82
+ ----------
83
+ use_n2v2 : bool
84
+ Whether to use N2V2.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If the N2V pixel manipulate transform is not found in the transforms.
90
+ """
91
+ if use_n2v2:
92
+ self.set_masking_strategy("median")
93
+ else:
94
+ self.set_masking_strategy("uniform")
95
+
96
+ def set_masking_strategy(self, strategy: Literal["uniform", "median"]) -> None:
97
+ """
98
+ Set masking strategy.
99
+
100
+ Parameters
101
+ ----------
102
+ strategy : "uniform" or "median"
103
+ Strategy to use for N2V2.
104
+
105
+ Raises
106
+ ------
107
+ ValueError
108
+ If the N2V pixel manipulate transform is not found in the transforms.
109
+ """
110
+ found_n2v = False
111
+
112
+ for transform in self.transforms:
113
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
114
+ transform.strategy = strategy
115
+ found_n2v = True
116
+
117
+ if not found_n2v:
118
+ transforms = [t.name for t in self.transforms]
119
+ raise ValueError(
120
+ f"N2V_Manipulate transform not found in the transforms list "
121
+ f"({transforms})."
122
+ )
123
+
124
+ def get_masking_strategy(self) -> Literal["uniform", "median"]:
125
+ """
126
+ Get N2V2 strategy.
127
+
128
+ Returns
129
+ -------
130
+ "uniform" or "median"
131
+ Strategy used for N2V2.
132
+ """
133
+ for transform in self.transforms:
134
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
135
+ return transform.strategy
136
+
137
+ raise ValueError(
138
+ f"{SupportedTransform.N2V_MANIPULATE} transform "
139
+ f"is required for N2V training."
140
+ )
141
+
142
+ def set_structN2V_mask(
143
+ self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
144
+ ) -> None:
145
+ """
146
+ Set structN2V mask parameters.
147
+
148
+ Setting `mask_axis` to `none` will disable structN2V.
149
+
150
+ Parameters
151
+ ----------
152
+ mask_axis : Literal["horizontal", "vertical", "none"]
153
+ Axis along which to apply the mask. `none` will disable structN2V.
154
+ mask_span : int
155
+ Total span of the mask in pixels.
156
+
157
+ Raises
158
+ ------
159
+ ValueError
160
+ If the N2V pixel manipulate transform is not found in the transforms.
161
+ """
162
+ found_n2v = False
163
+
164
+ for transform in self.transforms:
165
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
166
+ transform.struct_mask_axis = mask_axis
167
+ transform.struct_mask_span = mask_span
168
+ found_n2v = True
169
+
170
+ if not found_n2v:
171
+ transforms = [t.name for t in self.transforms]
172
+ raise ValueError(
173
+ f"N2V pixel manipulate transform not found in the transforms "
174
+ f"({transforms})."
175
+ )
176
+
177
+ def is_using_struct_n2v(self) -> bool:
178
+ """
179
+ Check if structN2V is enabled.
180
+
181
+ Returns
182
+ -------
183
+ bool
184
+ Whether structN2V is enabled or not.
185
+ """
186
+ for transform in self.transforms:
187
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
188
+ return transform.struct_mask_axis != "none"
189
+
190
+ raise ValueError(
191
+ f"N2V pixel manipulate transform not found in the transforms "
192
+ f"({self.transforms})."
193
+ )
@@ -1,11 +1,10 @@
1
1
  """Likelihood model."""
2
2
 
3
- from typing import Literal, Optional, Union
3
+ from typing import Annotated, Literal, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
7
  from pydantic import BaseModel, ConfigDict, PlainSerializer, PlainValidator
8
- from typing_extensions import Annotated
9
8
 
10
9
  from careamics.models.lvae.noise_models import (
11
10
  GaussianMixtureNoiseModel,
@@ -0,0 +1,101 @@
1
+ """N2N configuration."""
2
+
3
+ from bioimageio.spec.generic.v0_3 import CiteEntry
4
+
5
+ from careamics.config.algorithms import N2NAlgorithm
6
+ from careamics.config.configuration import Configuration
7
+ from careamics.config.data import DataConfig
8
+
9
+ N2N = "Noise2Noise"
10
+
11
+ N2N_DESCRIPTION = (
12
+ "Noise2Noise is a deep-learning-based algorithm that uses a U-Net "
13
+ "architecture to restore images. Noise2Noise is a self-supervised "
14
+ "algorithm that requires only noisy images to train the network. "
15
+ "The algorithm learns to predict the clean image from the noisy "
16
+ "image. Noise2Noise is particularly useful when clean images are "
17
+ "not available for training."
18
+ )
19
+
20
+ N2N_REF = CiteEntry(
21
+ text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., "
22
+ 'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration '
23
+ 'without clean data". arXiv preprint arXiv:1803.04189.',
24
+ doi="10.48550/arXiv.1803.04189",
25
+ )
26
+
27
+
28
+ class N2NConfiguration(Configuration):
29
+ """Noise2Noise configuration."""
30
+
31
+ algorithm_config: N2NAlgorithm
32
+ """Algorithm configuration."""
33
+
34
+ data_config: DataConfig
35
+ """Data configuration."""
36
+
37
+ def get_algorithm_friendly_name(self) -> str:
38
+ """
39
+ Get the algorithm friendly name.
40
+
41
+ Returns
42
+ -------
43
+ str
44
+ Friendly name of the algorithm.
45
+ """
46
+ return N2N
47
+
48
+ def get_algorithm_keywords(self) -> list[str]:
49
+ """
50
+ Get algorithm keywords.
51
+
52
+ Returns
53
+ -------
54
+ list[str]
55
+ List of keywords.
56
+ """
57
+ return [
58
+ "restoration",
59
+ "UNet",
60
+ "3D" if "Z" in self.data_config.axes else "2D",
61
+ "CAREamics",
62
+ "pytorch",
63
+ N2N,
64
+ ]
65
+
66
+ def get_algorithm_references(self) -> str:
67
+ """
68
+ Get the algorithm references.
69
+
70
+ This is used to generate the README of the BioImage Model Zoo export.
71
+
72
+ Returns
73
+ -------
74
+ str
75
+ Algorithm references.
76
+ """
77
+ return N2N_REF.text + " doi: " + N2N_REF.doi
78
+
79
+ def get_algorithm_citations(self) -> list[CiteEntry]:
80
+ """
81
+ Return a list of citation entries of the current algorithm.
82
+
83
+ This is used to generate the model description for the BioImage Model Zoo.
84
+
85
+ Returns
86
+ -------
87
+ List[CiteEntry]
88
+ List of citation entries.
89
+ """
90
+ return [N2N_REF]
91
+
92
+ def get_algorithm_description(self) -> str:
93
+ """
94
+ Get the algorithm description.
95
+
96
+ Returns
97
+ -------
98
+ str
99
+ Algorithm description.
100
+ """
101
+ return N2N_DESCRIPTION