careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -3,8 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Any, List, Literal, Optional, Union
6
+ from typing import Any, Literal, Optional, Union
7
7
 
8
+ from numpy.typing import NDArray
8
9
  from pydantic import (
9
10
  BaseModel,
10
11
  ConfigDict,
@@ -17,16 +18,14 @@ from typing_extensions import Annotated, Self
17
18
 
18
19
  from .support import SupportedTransform
19
20
  from .transformations.n2v_manipulate_model import N2VManipulateModel
20
- from .transformations.nd_flip_model import NDFlipModel
21
- from .transformations.normalize_model import NormalizeModel
21
+ from .transformations.xy_flip_model import XYFlipModel
22
22
  from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
23
23
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
24
24
 
25
25
  TRANSFORMS_UNION = Annotated[
26
26
  Union[
27
- NDFlipModel,
27
+ XYFlipModel,
28
28
  XYRandomRotate90Model,
29
- NormalizeModel,
30
29
  N2VManipulateModel,
31
30
  ],
32
31
  Discriminator("name"), # used to tell the different transform models apart
@@ -39,7 +38,11 @@ class DataConfig(BaseModel):
39
38
 
40
39
  If std is specified, mean must be specified as well. Note that setting the std first
41
40
  and then the mean (if they were both `None` before) will raise a validation error.
42
- Prefer instead `set_mean_and_std` to set both at once.
41
+ Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
42
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
43
+ of the target could be different from the input data.
44
+
45
+ All supported transforms are defined in the SupportedTransform enum.
43
46
 
44
47
  Examples
45
48
  --------
@@ -53,10 +56,10 @@ class DataConfig(BaseModel):
53
56
  ... )
54
57
 
55
58
  To change the mean and std of the data:
56
- >>> data.set_mean_and_std(mean=214.3, std=84.5)
59
+ >>> data.set_mean_and_std(image_means=[214.3], image_stds=[84.5])
57
60
 
58
61
  One can pass also a list of transformations, by keyword, using the
59
- SupportedTransform or the name of an Albumentation transform:
62
+ SupportedTransform value:
60
63
  >>> from careamics.config.support import SupportedTransform
61
64
  >>> data = DataConfig(
62
65
  ... data_type="tiff",
@@ -65,12 +68,7 @@ class DataConfig(BaseModel):
65
68
  ... axes="YX",
66
69
  ... transforms=[
67
70
  ... {
68
- ... "name": SupportedTransform.NORMALIZE.value,
69
- ... "mean": 167.6,
70
- ... "std": 47.2,
71
- ... },
72
- ... {
73
- ... "name": "NDFlip",
71
+ ... "name": "XYFlip",
74
72
  ... }
75
73
  ... ]
76
74
  ... )
@@ -83,21 +81,26 @@ class DataConfig(BaseModel):
83
81
 
84
82
  # Dataset configuration
85
83
  data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
86
- patch_size: Union[List[int]] = Field(..., min_length=2, max_length=3)
84
+ patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
87
85
  batch_size: int = Field(default=1, ge=1, validate_default=True)
88
86
  axes: str
89
87
 
90
88
  # Optional fields
91
- mean: Optional[float] = None
92
- std: Optional[float] = None
89
+ image_means: Optional[list[float]] = Field(
90
+ default=None, min_length=0, max_length=32
91
+ )
92
+ image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
93
+ target_means: Optional[list[float]] = Field(
94
+ default=None, min_length=0, max_length=32
95
+ )
96
+ target_stds: Optional[list[float]] = Field(
97
+ default=None, min_length=0, max_length=32
98
+ )
93
99
 
94
- transforms: List[TRANSFORMS_UNION] = Field(
100
+ transforms: list[TRANSFORMS_UNION] = Field(
95
101
  default=[
96
102
  {
97
- "name": SupportedTransform.NORMALIZE.value,
98
- },
99
- {
100
- "name": SupportedTransform.NDFLIP.value,
103
+ "name": SupportedTransform.XY_FLIP.value,
101
104
  },
102
105
  {
103
106
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -114,8 +117,8 @@ class DataConfig(BaseModel):
114
117
  @field_validator("patch_size")
115
118
  @classmethod
116
119
  def all_elements_power_of_2_minimum_8(
117
- cls, patch_list: Union[List[int]]
118
- ) -> Union[List[int]]:
120
+ cls, patch_list: Union[list[int]]
121
+ ) -> Union[list[int]]:
119
122
  """
120
123
  Validate patch size.
121
124
 
@@ -123,12 +126,12 @@ class DataConfig(BaseModel):
123
126
 
124
127
  Parameters
125
128
  ----------
126
- patch_list : Union[List[int]]
129
+ patch_list : list of int
127
130
  Patch size.
128
131
 
129
132
  Returns
130
133
  -------
131
- Union[List[int]]
134
+ list of int
132
135
  Validated patch size.
133
136
 
134
137
  Raises
@@ -178,19 +181,19 @@ class DataConfig(BaseModel):
178
181
  @field_validator("transforms")
179
182
  @classmethod
180
183
  def validate_prediction_transforms(
181
- cls, transforms: List[TRANSFORMS_UNION]
182
- ) -> List[TRANSFORMS_UNION]:
184
+ cls, transforms: list[TRANSFORMS_UNION]
185
+ ) -> list[TRANSFORMS_UNION]:
183
186
  """
184
187
  Validate N2VManipulate transform position in the transform list.
185
188
 
186
189
  Parameters
187
190
  ----------
188
- transforms : List[Transformations_Union]
191
+ transforms : list[Transformations_Union]
189
192
  Transforms.
190
193
 
191
194
  Returns
192
195
  -------
193
- List[TRANSFORMS_UNION]
196
+ list of transforms
194
197
  Validated transforms.
195
198
 
196
199
  Raises
@@ -202,7 +205,7 @@ class DataConfig(BaseModel):
202
205
 
203
206
  if SupportedTransform.N2V_MANIPULATE in transform_list:
204
207
  # multiple N2V_MANIPULATE
205
- if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
208
+ if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
206
209
  raise ValueError(
207
210
  f"Multiple instances of "
208
211
  f"{SupportedTransform.N2V_MANIPULATE} transforms "
@@ -211,7 +214,7 @@ class DataConfig(BaseModel):
211
214
 
212
215
  # N2V_MANIPULATE not the last transform
213
216
  elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
214
- index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
217
+ index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
215
218
  transform = transforms.pop(index)
216
219
  transforms.append(transform)
217
220
 
@@ -233,29 +236,33 @@ class DataConfig(BaseModel):
233
236
  If std is not None and mean is None.
234
237
  """
235
238
  # check that mean and std are either both None, or both specified
236
- if (self.mean is None) != (self.std is None):
239
+ if (self.image_means and not self.image_stds) or (
240
+ self.image_stds and not self.image_means
241
+ ):
237
242
  raise ValueError(
238
243
  "Mean and std must be either both None, or both specified."
239
244
  )
240
245
 
241
- return self
246
+ elif (self.image_means is not None and self.image_stds is not None) and (
247
+ len(self.image_means) != len(self.image_stds)
248
+ ):
249
+ raise ValueError(
250
+ "Mean and std must be specified for each " "input channel."
251
+ )
242
252
 
243
- @model_validator(mode="after")
244
- def add_std_and_mean_to_normalize(self: Self) -> Self:
245
- """
246
- Add mean and std to the Normalize transform if it is present.
253
+ if (self.target_means and not self.target_stds) or (
254
+ self.target_stds and not self.target_means
255
+ ):
256
+ raise ValueError(
257
+ "Mean and std must be either both None, or both specified "
258
+ )
247
259
 
248
- Returns
249
- -------
250
- Self
251
- Data model with mean and std added to the Normalize transform.
252
- """
253
- if self.mean is not None or self.std is not None:
254
- # search in the transforms for Normalize and update parameters
255
- for transform in self.transforms:
256
- if transform.name == SupportedTransform.NORMALIZE.value:
257
- transform.mean = self.mean
258
- transform.std = self.std
260
+ elif self.target_means is not None and self.target_stds is not None:
261
+ if len(self.target_means) != len(self.target_stds):
262
+ raise ValueError(
263
+ "Mean and std must be either both None, or both specified for each "
264
+ "target channel."
265
+ )
259
266
 
260
267
  return self
261
268
 
@@ -339,7 +346,13 @@ class DataConfig(BaseModel):
339
346
  if self.has_n2v_manipulate():
340
347
  self.transforms.pop(-1)
341
348
 
342
- def set_mean_and_std(self, mean: float, std: float) -> None:
349
+ def set_mean_and_std(
350
+ self,
351
+ image_means: Union[NDArray, tuple, list, None],
352
+ image_stds: Union[NDArray, tuple, list, None],
353
+ target_means: Optional[Union[NDArray, tuple, list, None]] = None,
354
+ target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
355
+ ) -> None:
343
356
  """
344
357
  Set mean and standard deviation of the data.
345
358
 
@@ -348,20 +361,33 @@ class DataConfig(BaseModel):
348
361
 
349
362
  Parameters
350
363
  ----------
351
- mean : float
352
- Mean of the data.
353
- std : float
354
- Standard deviation of the data.
364
+ image_means : NDArray or tuple or list
365
+ Mean values for normalization.
366
+ image_stds : NDArray or tuple or list
367
+ Standard deviation values for normalization.
368
+ target_means : NDArray or tuple or list, optional
369
+ Target mean values for normalization, by default ().
370
+ target_stds : NDArray or tuple or list, optional
371
+ Target standard deviation values for normalization, by default ().
355
372
  """
356
- self._update(mean=mean, std=std)
357
-
358
- # search in the transforms for Normalize and update parameters
359
- for transform in self.transforms:
360
- if transform.name == SupportedTransform.NORMALIZE.value:
361
- transform.mean = mean
362
- transform.std = std
373
+ # make sure we pass a list
374
+ if image_means is not None:
375
+ image_means = list(image_means)
376
+ if image_stds is not None:
377
+ image_stds = list(image_stds)
378
+ if target_means is not None:
379
+ target_means = list(target_means)
380
+ if target_stds is not None:
381
+ target_stds = list(target_stds)
382
+
383
+ self._update(
384
+ image_means=image_means,
385
+ image_stds=image_stds,
386
+ target_means=target_means,
387
+ target_stds=target_stds,
388
+ )
363
389
 
364
- def set_3D(self, axes: str, patch_size: List[int]) -> None:
390
+ def set_3D(self, axes: str, patch_size: list[int]) -> None:
365
391
  """
366
392
  Set 3D parameters.
367
393
 
@@ -369,7 +395,7 @@ class DataConfig(BaseModel):
369
395
  ----------
370
396
  axes : str
371
397
  Axes.
372
- patch_size : List[int]
398
+ patch_size : list of int
373
399
  Patch size.
374
400
  """
375
401
  self._update(axes=axes, patch_size=patch_size)
@@ -2,17 +2,13 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Any, List, Literal, Optional, Union
5
+ from typing import Any, Literal, Optional, Union
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
9
9
 
10
- from .support import SupportedTransform
11
- from .transformations.normalize_model import NormalizeModel
12
10
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
13
11
 
14
- TRANSFORMS_UNION = Union[NormalizeModel]
15
-
16
12
 
17
13
  class InferenceConfig(BaseModel):
18
14
  """Configuration class for the prediction model."""
@@ -21,26 +17,17 @@ class InferenceConfig(BaseModel):
21
17
 
22
18
  # Mandatory fields
23
19
  data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
24
- tile_size: Optional[Union[List[int]]] = Field(
20
+ tile_size: Optional[Union[list[int]]] = Field(
25
21
  default=None, min_length=2, max_length=3
26
22
  )
27
- tile_overlap: Optional[Union[List[int]]] = Field(
23
+ tile_overlap: Optional[Union[list[int]]] = Field(
28
24
  default=None, min_length=2, max_length=3
29
25
  )
30
26
 
31
27
  axes: str
32
28
 
33
- mean: float
34
- std: float = Field(..., ge=0.0)
35
-
36
- transforms: List[TRANSFORMS_UNION] = Field(
37
- default=[
38
- {
39
- "name": SupportedTransform.NORMALIZE.value,
40
- },
41
- ],
42
- validate_default=True,
43
- )
29
+ image_means: list = Field(..., min_length=0, max_length=32)
30
+ image_stds: list = Field(..., min_length=0, max_length=32)
44
31
 
45
32
  # only default TTAs are supported for now
46
33
  tta_transforms: bool = Field(default=True)
@@ -51,8 +38,8 @@ class InferenceConfig(BaseModel):
51
38
  @field_validator("tile_overlap")
52
39
  @classmethod
53
40
  def all_elements_non_zero_even(
54
- cls, tile_overlap: Optional[Union[List[int]]]
55
- ) -> Optional[Union[List[int]]]:
41
+ cls, tile_overlap: Optional[list[int]]
42
+ ) -> Optional[list[int]]:
56
43
  """
57
44
  Validate tile overlap.
58
45
 
@@ -60,12 +47,12 @@ class InferenceConfig(BaseModel):
60
47
 
61
48
  Parameters
62
49
  ----------
63
- tile_overlap : Optional[Union[List[int]]]
50
+ tile_overlap : list[int] or None
64
51
  Patch size.
65
52
 
66
53
  Returns
67
54
  -------
68
- Optional[Union[List[int]]]
55
+ list[int] or None
69
56
  Validated tile overlap.
70
57
 
71
58
  Raises
@@ -90,19 +77,19 @@ class InferenceConfig(BaseModel):
90
77
  @field_validator("tile_size")
91
78
  @classmethod
92
79
  def tile_min_8_power_of_2(
93
- cls, tile_list: Optional[Union[List[int]]]
94
- ) -> Optional[Union[List[int]]]:
80
+ cls, tile_list: Optional[list[int]]
81
+ ) -> Optional[list[int]]:
95
82
  """
96
83
  Validate that each entry is greater or equal than 8 and a power of 2.
97
84
 
98
85
  Parameters
99
86
  ----------
100
- tile_list : List[int]
87
+ tile_list : list of int
101
88
  Patch size.
102
89
 
103
90
  Returns
104
91
  -------
105
- List[int]
92
+ list of int
106
93
  Validated patch size.
107
94
 
108
95
  Raises
@@ -149,39 +136,6 @@ class InferenceConfig(BaseModel):
149
136
 
150
137
  return axes
151
138
 
152
- @field_validator("transforms")
153
- @classmethod
154
- def validate_transforms(
155
- cls, transforms: List[TRANSFORMS_UNION]
156
- ) -> List[TRANSFORMS_UNION]:
157
- """
158
- Validate that transforms do not have N2V pixel manipulate transforms.
159
-
160
- Parameters
161
- ----------
162
- transforms : List[TRANSFORMS_UNION]
163
- Transforms.
164
-
165
- Returns
166
- -------
167
- List[TRANSFORMS_UNION]
168
- Validated transforms.
169
-
170
- Raises
171
- ------
172
- ValueError
173
- If transforms contain N2V pixel manipulate transforms.
174
- """
175
- if transforms is not None:
176
- for transform in transforms:
177
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
178
- raise ValueError(
179
- "N2V_Manipulate transform is not allowed in "
180
- "prediction transforms."
181
- )
182
-
183
- return transforms
184
-
185
139
  @model_validator(mode="after")
186
140
  def validate_dimensions(self: Self) -> Self:
187
141
  """
@@ -228,29 +182,22 @@ class InferenceConfig(BaseModel):
228
182
  If std is not None and mean is None.
229
183
  """
230
184
  # check that mean and std are either both None, or both specified
231
- if (self.mean is None) != (self.std is None):
185
+ if not self.image_means and not self.image_stds:
186
+ raise ValueError("Mean and std must be specified during inference.")
187
+
188
+ if (self.image_means and not self.image_stds) or (
189
+ self.image_stds and not self.image_means
190
+ ):
232
191
  raise ValueError(
233
192
  "Mean and std must be either both None, or both specified."
234
193
  )
235
194
 
236
- return self
237
-
238
- @model_validator(mode="after")
239
- def add_std_and_mean_to_normalize(self: Self) -> Self:
240
- """
241
- Add mean and std to the Normalize transform if it is present.
242
-
243
- Returns
244
- -------
245
- Self
246
- Inference model with mean and std added to the Normalize transform.
247
- """
248
- if self.mean is not None or self.std is not None:
249
- # search in the transforms for Normalize and update parameters
250
- for transform in self.transforms:
251
- if transform.name == SupportedTransform.NORMALIZE.value:
252
- transform.mean = self.mean
253
- transform.std = self.std
195
+ elif (self.image_means is not None and self.image_stds is not None) and (
196
+ len(self.image_means) != len(self.image_stds)
197
+ ):
198
+ raise ValueError(
199
+ "Mean and std must be specified for each " "input channel."
200
+ )
254
201
 
255
202
  return self
256
203
 
@@ -266,7 +213,7 @@ class InferenceConfig(BaseModel):
266
213
  self.__dict__.update(kwargs)
267
214
  self.__class__.model_validate(self.__dict__)
268
215
 
269
- def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None:
216
+ def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
270
217
  """
271
218
  Set 3D parameters.
272
219
 
@@ -274,9 +221,9 @@ class InferenceConfig(BaseModel):
274
221
  ----------
275
222
  axes : str
276
223
  Axes.
277
- tile_size : List[int]
224
+ tile_size : list of int
278
225
  Tile size.
279
- tile_overlap : List[int]
226
+ tile_overlap : list of int
280
227
  Tile overlap.
281
228
  """
282
229
  self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
@@ -1,6 +1,8 @@
1
+ """Optimizers and schedulers Pydantic models."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
- from typing import Dict, Literal
5
+ from typing import Literal
4
6
 
5
7
  from pydantic import (
6
8
  BaseModel,
@@ -19,8 +21,7 @@ from .support import SupportedOptimizer
19
21
 
20
22
 
21
23
  class OptimizerModel(BaseModel):
22
- """
23
- Torch optimizer.
24
+ """Torch optimizer Pydantic model.
24
25
 
25
26
  Only parameters supported by the corresponding torch optimizer will be taken
26
27
  into account. For more details, check:
@@ -31,7 +32,7 @@ class OptimizerModel(BaseModel):
31
32
 
32
33
  Attributes
33
34
  ----------
34
- name : TorchOptimizer
35
+ name : {"Adam", "SGD"}
35
36
  Name of the optimizer.
36
37
  parameters : dict
37
38
  Parameters of the optimizer (see torch documentation).
@@ -55,7 +56,7 @@ class OptimizerModel(BaseModel):
55
56
 
56
57
  @field_validator("parameters")
57
58
  @classmethod
58
- def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
59
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
59
60
  """
60
61
  Validate optimizer parameters.
61
62
 
@@ -70,7 +71,7 @@ class OptimizerModel(BaseModel):
70
71
 
71
72
  Returns
72
73
  -------
73
- Dict
74
+ dict
74
75
  Filtered optimizer parameters.
75
76
 
76
77
  Raises
@@ -115,8 +116,7 @@ class OptimizerModel(BaseModel):
115
116
 
116
117
 
117
118
  class LrSchedulerModel(BaseModel):
118
- """
119
- Torch learning rate scheduler.
119
+ """Torch learning rate scheduler Pydantic model.
120
120
 
121
121
  Only parameters supported by the corresponding torch lr scheduler will be taken
122
122
  into account. For more details, check:
@@ -127,7 +127,7 @@ class LrSchedulerModel(BaseModel):
127
127
 
128
128
  Attributes
129
129
  ----------
130
- name : TorchLRScheduler
130
+ name : {"ReduceLROnPlateau", "StepLR"}
131
131
  Name of the learning rate scheduler.
132
132
  parameters : dict
133
133
  Parameters of the learning rate scheduler (see torch documentation).
@@ -146,7 +146,7 @@ class LrSchedulerModel(BaseModel):
146
146
 
147
147
  @field_validator("parameters")
148
148
  @classmethod
149
- def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
149
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
150
150
  """Filter parameters based on the learning rate scheduler's signature.
151
151
 
152
152
  Parameters
@@ -158,7 +158,7 @@ class LrSchedulerModel(BaseModel):
158
158
 
159
159
  Returns
160
160
  -------
161
- Dict
161
+ dict
162
162
  Filtered scheduler parameters.
163
163
 
164
164
  Raises
@@ -14,7 +14,6 @@ __all__ = [
14
14
  "SupportedPixelManipulation",
15
15
  "SupportedTransform",
16
16
  "SupportedData",
17
- "SupportedExtractionStrategy",
18
17
  "SupportedStructAxis",
19
18
  "SupportedLogger",
20
19
  ]
@@ -24,7 +23,6 @@ from .supported_activations import SupportedActivation
24
23
  from .supported_algorithms import SupportedAlgorithm
25
24
  from .supported_architectures import SupportedArchitecture
26
25
  from .supported_data import SupportedData
27
- from .supported_extraction_strategies import SupportedExtractionStrategy
28
26
  from .supported_loggers import SupportedLogger
29
27
  from .supported_losses import SupportedLoss
30
28
  from .supported_optimizers import SupportedOptimizer, SupportedScheduler
@@ -1,3 +1,5 @@
1
+ """Activations supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Algorithms supported by CAREamics."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from careamics.utils import BaseEnum
@@ -10,9 +12,9 @@ class SupportedAlgorithm(str, BaseEnum):
10
12
  """
11
13
 
12
14
  N2V = "n2v"
13
- CUSTOM = "custom"
14
15
  CARE = "care"
15
16
  N2N = "n2n"
17
+ CUSTOM = "custom"
16
18
  # PN2V = "pn2v"
17
19
  # HDN = "hdn"
18
20
  # SEG = "segmentation"
@@ -1,3 +1,5 @@
1
+ """Architectures supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Data supported by CAREamics."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Union
@@ -1,3 +1,5 @@
1
+ """Logger supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Losses supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Optimizers and schedulers supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,15 +1,15 @@
1
+ """Pixel manipulation methods supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
4
6
  class SupportedPixelManipulation(str, BaseEnum):
5
- """_summary_.
7
+ """Supported Noise2Void pixel manipulations.
6
8
 
7
9
  - Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor
8
10
  pixel value.
9
11
  - Median: Replace masked pixel value by the mean of the neighborhood.
10
12
  """
11
13
 
12
- # TODO docs
13
-
14
14
  UNIFORM = "uniform"
15
15
  MEDIAN = "median"
@@ -1,3 +1,5 @@
1
+ """StructN2V axes supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5