careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -3,8 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Any, List, Literal, Optional, Union
6
+ from typing import Any, Literal, Optional, Union
7
7
 
8
+ from numpy.typing import NDArray
8
9
  from pydantic import (
9
10
  BaseModel,
10
11
  ConfigDict,
@@ -17,7 +18,6 @@ from typing_extensions import Annotated, Self
17
18
 
18
19
  from .support import SupportedTransform
19
20
  from .transformations.n2v_manipulate_model import N2VManipulateModel
20
- from .transformations.normalize_model import NormalizeModel
21
21
  from .transformations.xy_flip_model import XYFlipModel
22
22
  from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
23
23
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -26,7 +26,6 @@ TRANSFORMS_UNION = Annotated[
26
26
  Union[
27
27
  XYFlipModel,
28
28
  XYRandomRotate90Model,
29
- NormalizeModel,
30
29
  N2VManipulateModel,
31
30
  ],
32
31
  Discriminator("name"), # used to tell the different transform models apart
@@ -39,7 +38,9 @@ class DataConfig(BaseModel):
39
38
 
40
39
  If std is specified, mean must be specified as well. Note that setting the std first
41
40
  and then the mean (if they were both `None` before) will raise a validation error.
42
- Prefer instead `set_mean_and_std` to set both at once.
41
+ Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
42
+ to be lists of floats, one for each channel. For supervised tasks, the mean and std
43
+ of the target could be different from the input data.
43
44
 
44
45
  All supported transforms are defined in the SupportedTransform enum.
45
46
 
@@ -54,8 +55,8 @@ class DataConfig(BaseModel):
54
55
  ... axes="YX"
55
56
  ... )
56
57
 
57
- To change the mean and std of the data:
58
- >>> data.set_mean_and_std(mean=214.3, std=84.5)
58
+ To change the image_means and image_stds of the data:
59
+ >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5])
59
60
 
60
61
  One can pass also a list of transformations, by keyword, using the
61
62
  SupportedTransform value:
@@ -67,11 +68,6 @@ class DataConfig(BaseModel):
67
68
  ... axes="YX",
68
69
  ... transforms=[
69
70
  ... {
70
- ... "name": SupportedTransform.NORMALIZE.value,
71
- ... "mean": 167.6,
72
- ... "std": 47.2,
73
- ... },
74
- ... {
75
71
  ... "name": "XYFlip",
76
72
  ... }
77
73
  ... ]
@@ -84,20 +80,41 @@ class DataConfig(BaseModel):
84
80
  )
85
81
 
86
82
  # Dataset configuration
87
- data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
88
- patch_size: Union[List[int]] = Field(..., min_length=2, max_length=3)
89
- batch_size: int = Field(default=1, ge=1, validate_default=True)
83
+ data_type: Literal["array", "tiff", "custom"]
84
+ """Type of input data, numpy.ndarray (array) or paths (tiff and custom), as defined
85
+ in SupportedData."""
86
+
90
87
  axes: str
88
+ """Axes of the data, as defined in SupportedAxes."""
89
+
90
+ patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
91
+ """Patch size, as used during training."""
92
+
93
+ batch_size: int = Field(default=1, ge=1, validate_default=True)
94
+ """Batch size for training."""
91
95
 
92
96
  # Optional fields
93
- mean: Optional[float] = None
94
- std: Optional[float] = None
97
+ image_means: Optional[list[float]] = Field(
98
+ default=None, min_length=0, max_length=32
99
+ )
100
+ """Means of the data across channels, used for normalization."""
101
+
102
+ image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
103
+ """Standard deviations of the data across channels, used for normalization."""
95
104
 
96
- transforms: List[TRANSFORMS_UNION] = Field(
105
+ target_means: Optional[list[float]] = Field(
106
+ default=None, min_length=0, max_length=32
107
+ )
108
+ """Means of the target data across channels, used for normalization."""
109
+
110
+ target_stds: Optional[list[float]] = Field(
111
+ default=None, min_length=0, max_length=32
112
+ )
113
+ """Standard deviations of the target data across channels, used for
114
+ normalization."""
115
+
116
+ transforms: list[TRANSFORMS_UNION] = Field(
97
117
  default=[
98
- {
99
- "name": SupportedTransform.NORMALIZE.value,
100
- },
101
118
  {
102
119
  "name": SupportedTransform.XY_FLIP.value,
103
120
  },
@@ -110,14 +127,17 @@ class DataConfig(BaseModel):
110
127
  ],
111
128
  validate_default=True,
112
129
  )
130
+ """List of transformations to apply to the data, available transforms are defined
131
+ in SupportedTransform. The default values are set for Noise2Void."""
113
132
 
114
133
  dataloader_params: Optional[dict] = None
134
+ """Dictionary of PyTorch dataloader parameters."""
115
135
 
116
136
  @field_validator("patch_size")
117
137
  @classmethod
118
138
  def all_elements_power_of_2_minimum_8(
119
- cls, patch_list: Union[List[int]]
120
- ) -> Union[List[int]]:
139
+ cls, patch_list: Union[list[int]]
140
+ ) -> Union[list[int]]:
121
141
  """
122
142
  Validate patch size.
123
143
 
@@ -125,12 +145,12 @@ class DataConfig(BaseModel):
125
145
 
126
146
  Parameters
127
147
  ----------
128
- patch_list : Union[List[int]]
148
+ patch_list : list of int
129
149
  Patch size.
130
150
 
131
151
  Returns
132
152
  -------
133
- Union[List[int]]
153
+ list of int
134
154
  Validated patch size.
135
155
 
136
156
  Raises
@@ -180,19 +200,19 @@ class DataConfig(BaseModel):
180
200
  @field_validator("transforms")
181
201
  @classmethod
182
202
  def validate_prediction_transforms(
183
- cls, transforms: List[TRANSFORMS_UNION]
184
- ) -> List[TRANSFORMS_UNION]:
203
+ cls, transforms: list[TRANSFORMS_UNION]
204
+ ) -> list[TRANSFORMS_UNION]:
185
205
  """
186
206
  Validate N2VManipulate transform position in the transform list.
187
207
 
188
208
  Parameters
189
209
  ----------
190
- transforms : List[Transformations_Union]
210
+ transforms : list[Transformations_Union]
191
211
  Transforms.
192
212
 
193
213
  Returns
194
214
  -------
195
- List[TRANSFORMS_UNION]
215
+ list of transforms
196
216
  Validated transforms.
197
217
 
198
218
  Raises
@@ -235,29 +255,33 @@ class DataConfig(BaseModel):
235
255
  If std is not None and mean is None.
236
256
  """
237
257
  # check that mean and std are either both None, or both specified
238
- if (self.mean is None) != (self.std is None):
258
+ if (self.image_means and not self.image_stds) or (
259
+ self.image_stds and not self.image_means
260
+ ):
239
261
  raise ValueError(
240
262
  "Mean and std must be either both None, or both specified."
241
263
  )
242
264
 
243
- return self
265
+ elif (self.image_means is not None and self.image_stds is not None) and (
266
+ len(self.image_means) != len(self.image_stds)
267
+ ):
268
+ raise ValueError(
269
+ "Mean and std must be specified for each " "input channel."
270
+ )
244
271
 
245
- @model_validator(mode="after")
246
- def add_std_and_mean_to_normalize(self: Self) -> Self:
247
- """
248
- Add mean and std to the Normalize transform if it is present.
272
+ if (self.target_means and not self.target_stds) or (
273
+ self.target_stds and not self.target_means
274
+ ):
275
+ raise ValueError(
276
+ "Mean and std must be either both None, or both specified "
277
+ )
249
278
 
250
- Returns
251
- -------
252
- Self
253
- Data model with mean and std added to the Normalize transform.
254
- """
255
- if self.mean is not None and self.std is not None:
256
- # search in the transforms for Normalize and update parameters
257
- for transform in self.transforms:
258
- if transform.name == SupportedTransform.NORMALIZE.value:
259
- transform.mean = self.mean
260
- transform.std = self.std
279
+ elif self.target_means is not None and self.target_stds is not None:
280
+ if len(self.target_means) != len(self.target_stds):
281
+ raise ValueError(
282
+ "Mean and std must be either both None, or both specified for each "
283
+ "target channel."
284
+ )
261
285
 
262
286
  return self
263
287
 
@@ -341,23 +365,48 @@ class DataConfig(BaseModel):
341
365
  if self.has_n2v_manipulate():
342
366
  self.transforms.pop(-1)
343
367
 
344
- def set_mean_and_std(self, mean: float, std: float) -> None:
368
+ def set_means_and_stds(
369
+ self,
370
+ image_means: Union[NDArray, tuple, list, None],
371
+ image_stds: Union[NDArray, tuple, list, None],
372
+ target_means: Optional[Union[NDArray, tuple, list, None]] = None,
373
+ target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
374
+ ) -> None:
345
375
  """
346
- Set mean and standard deviation of the data.
376
+ Set mean and standard deviation of the data across channels.
347
377
 
348
378
  This method should be used instead setting the fields directly, as it would
349
379
  otherwise trigger a validation error.
350
380
 
351
381
  Parameters
352
382
  ----------
353
- mean : float
354
- Mean of the data.
355
- std : float
356
- Standard deviation of the data.
383
+ image_means : numpy.ndarray ,tuple or list
384
+ Mean values for normalization.
385
+ image_stds : numpy.ndarray, tuple or list
386
+ Standard deviation values for normalization.
387
+ target_means : numpy.ndarray, tuple or list, optional
388
+ Target mean values for normalization, by default ().
389
+ target_stds : numpy.ndarray, tuple or list, optional
390
+ Target standard deviation values for normalization, by default ().
357
391
  """
358
- self._update(mean=mean, std=std)
392
+ # make sure we pass a list
393
+ if image_means is not None:
394
+ image_means = list(image_means)
395
+ if image_stds is not None:
396
+ image_stds = list(image_stds)
397
+ if target_means is not None:
398
+ target_means = list(target_means)
399
+ if target_stds is not None:
400
+ target_stds = list(target_stds)
401
+
402
+ self._update(
403
+ image_means=image_means,
404
+ image_stds=image_stds,
405
+ target_means=target_means,
406
+ target_stds=target_stds,
407
+ )
359
408
 
360
- def set_3D(self, axes: str, patch_size: List[int]) -> None:
409
+ def set_3D(self, axes: str, patch_size: list[int]) -> None:
361
410
  """
362
411
  Set 3D parameters.
363
412
 
@@ -365,7 +414,7 @@ class DataConfig(BaseModel):
365
414
  ----------
366
415
  axes : str
367
416
  Axes.
368
- patch_size : List[int]
417
+ patch_size : list of int
369
418
  Patch size.
370
419
  """
371
420
  self._update(axes=axes, patch_size=patch_size)
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Any, List, Literal, Optional, Union
5
+ from typing import Any, Literal, Optional, Union
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
@@ -15,31 +15,41 @@ class InferenceConfig(BaseModel):
15
15
 
16
16
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
17
17
 
18
- # Mandatory fields
19
18
  data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
20
- tile_size: Optional[Union[List[int]]] = Field(
19
+ """Type of input data: numpy.ndarray (array) or path (tiff or custom)."""
20
+
21
+ tile_size: Optional[Union[list[int]]] = Field(
21
22
  default=None, min_length=2, max_length=3
22
23
  )
23
- tile_overlap: Optional[Union[List[int]]] = Field(
24
+ """Tile size of prediction, only effective if `tile_overlap` is specified."""
25
+
26
+ tile_overlap: Optional[Union[list[int]]] = Field(
24
27
  default=None, min_length=2, max_length=3
25
28
  )
29
+ """Overlap between tiles, only effective if `tile_size` is specified."""
26
30
 
27
31
  axes: str
32
+ """Data axes (TSCZYX) in the order of the input data."""
28
33
 
29
- mean: float
30
- std: float = Field(..., ge=0.0)
34
+ image_means: list = Field(..., min_length=0, max_length=32)
35
+ """Mean values for each input channel."""
31
36
 
32
- # only default TTAs are supported for now
37
+ image_stds: list = Field(..., min_length=0, max_length=32)
38
+ """Standard deviation values for each input channel."""
39
+
40
+ # TODO only default TTAs are supported for now
33
41
  tta_transforms: bool = Field(default=True)
42
+ """Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""
34
43
 
35
44
  # Dataloader parameters
36
45
  batch_size: int = Field(default=1, ge=1)
46
+ """Batch size for prediction."""
37
47
 
38
48
  @field_validator("tile_overlap")
39
49
  @classmethod
40
50
  def all_elements_non_zero_even(
41
- cls, tile_overlap: Optional[Union[List[int]]]
42
- ) -> Optional[Union[List[int]]]:
51
+ cls, tile_overlap: Optional[list[int]]
52
+ ) -> Optional[list[int]]:
43
53
  """
44
54
  Validate tile overlap.
45
55
 
@@ -47,12 +57,12 @@ class InferenceConfig(BaseModel):
47
57
 
48
58
  Parameters
49
59
  ----------
50
- tile_overlap : Optional[Union[List[int]]]
60
+ tile_overlap : list[int] or None
51
61
  Patch size.
52
62
 
53
63
  Returns
54
64
  -------
55
- Optional[Union[List[int]]]
65
+ list[int] or None
56
66
  Validated tile overlap.
57
67
 
58
68
  Raises
@@ -77,19 +87,19 @@ class InferenceConfig(BaseModel):
77
87
  @field_validator("tile_size")
78
88
  @classmethod
79
89
  def tile_min_8_power_of_2(
80
- cls, tile_list: Optional[Union[List[int]]]
81
- ) -> Optional[Union[List[int]]]:
90
+ cls, tile_list: Optional[list[int]]
91
+ ) -> Optional[list[int]]:
82
92
  """
83
93
  Validate that each entry is greater or equal than 8 and a power of 2.
84
94
 
85
95
  Parameters
86
96
  ----------
87
- tile_list : List[int]
97
+ tile_list : list of int
88
98
  Patch size.
89
99
 
90
100
  Returns
91
101
  -------
92
- List[int]
102
+ list of int
93
103
  Validated patch size.
94
104
 
95
105
  Raises
@@ -182,11 +192,23 @@ class InferenceConfig(BaseModel):
182
192
  If std is not None and mean is None.
183
193
  """
184
194
  # check that mean and std are either both None, or both specified
185
- if (self.mean is None) != (self.std is None):
195
+ if not self.image_means and not self.image_stds:
196
+ raise ValueError("Mean and std must be specified during inference.")
197
+
198
+ if (self.image_means and not self.image_stds) or (
199
+ self.image_stds and not self.image_means
200
+ ):
186
201
  raise ValueError(
187
202
  "Mean and std must be either both None, or both specified."
188
203
  )
189
204
 
205
+ elif (self.image_means is not None and self.image_stds is not None) and (
206
+ len(self.image_means) != len(self.image_stds)
207
+ ):
208
+ raise ValueError(
209
+ "Mean and std must be specified for each " "input channel."
210
+ )
211
+
190
212
  return self
191
213
 
192
214
  def _update(self, **kwargs: Any) -> None:
@@ -201,7 +223,7 @@ class InferenceConfig(BaseModel):
201
223
  self.__dict__.update(kwargs)
202
224
  self.__class__.model_validate(self.__dict__)
203
225
 
204
- def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None:
226
+ def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> None:
205
227
  """
206
228
  Set 3D parameters.
207
229
 
@@ -209,9 +231,9 @@ class InferenceConfig(BaseModel):
209
231
  ----------
210
232
  axes : str
211
233
  Axes.
212
- tile_size : List[int]
234
+ tile_size : list of int
213
235
  Tile size.
214
- tile_overlap : List[int]
236
+ tile_overlap : list of int
215
237
  Tile overlap.
216
238
  """
217
239
  self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Dict, Literal
5
+ from typing import Literal
6
6
 
7
7
  from pydantic import (
8
8
  BaseModel,
@@ -32,7 +32,7 @@ class OptimizerModel(BaseModel):
32
32
 
33
33
  Attributes
34
34
  ----------
35
- name : TorchOptimizer
35
+ name : {"Adam", "SGD"}
36
36
  Name of the optimizer.
37
37
  parameters : dict
38
38
  Parameters of the optimizer (see torch documentation).
@@ -45,6 +45,7 @@ class OptimizerModel(BaseModel):
45
45
 
46
46
  # Mandatory field
47
47
  name: Literal["Adam", "SGD"] = Field(default="Adam", validate_default=True)
48
+ """Name of the optimizer, supported optimizers are defined in SupportedOptimizer."""
48
49
 
49
50
  # Optional parameters, empty dict default value to allow filtering dictionary
50
51
  parameters: dict = Field(
@@ -53,10 +54,11 @@ class OptimizerModel(BaseModel):
53
54
  },
54
55
  validate_default=True,
55
56
  )
57
+ """Parameters of the optimizer, see PyTorch documentation for more details."""
56
58
 
57
59
  @field_validator("parameters")
58
60
  @classmethod
59
- def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
61
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
60
62
  """
61
63
  Validate optimizer parameters.
62
64
 
@@ -71,7 +73,7 @@ class OptimizerModel(BaseModel):
71
73
 
72
74
  Returns
73
75
  -------
74
- Dict
76
+ dict
75
77
  Filtered optimizer parameters.
76
78
 
77
79
  Raises
@@ -127,7 +129,7 @@ class LrSchedulerModel(BaseModel):
127
129
 
128
130
  Attributes
129
131
  ----------
130
- name : TorchLRScheduler
132
+ name : {"ReduceLROnPlateau", "StepLR"}
131
133
  Name of the learning rate scheduler.
132
134
  parameters : dict
133
135
  Parameters of the learning rate scheduler (see torch documentation).
@@ -140,13 +142,17 @@ class LrSchedulerModel(BaseModel):
140
142
 
141
143
  # Mandatory field
142
144
  name: Literal["ReduceLROnPlateau", "StepLR"] = Field(default="ReduceLROnPlateau")
145
+ """Name of the learning rate scheduler, supported schedulers are defined in
146
+ SupportedScheduler."""
143
147
 
144
148
  # Optional parameters
145
149
  parameters: dict = Field(default={}, validate_default=True)
150
+ """Parameters of the learning rate scheduler, see PyTorch documentation for more
151
+ details."""
146
152
 
147
153
  @field_validator("parameters")
148
154
  @classmethod
149
- def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> Dict:
155
+ def filter_parameters(cls, user_params: dict, values: ValidationInfo) -> dict:
150
156
  """Filter parameters based on the learning rate scheduler's signature.
151
157
 
152
158
  Parameters
@@ -158,7 +164,7 @@ class LrSchedulerModel(BaseModel):
158
164
 
159
165
  Returns
160
166
  -------
161
- Dict
167
+ dict
162
168
  Filtered scheduler parameters.
163
169
 
164
170
  Raises
@@ -60,9 +60,9 @@ class SupportedData(str, BaseEnum):
60
60
  return super()._missing_(value)
61
61
 
62
62
  @classmethod
63
- def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
63
+ def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str:
64
64
  """
65
- Path.rglob and fnmatch compatible extension.
65
+ Get Path.rglob and fnmatch compatible extension.
66
66
 
67
67
  Parameters
68
68
  ----------
@@ -72,13 +72,38 @@ class SupportedData(str, BaseEnum):
72
72
  Returns
73
73
  -------
74
74
  str
75
- Corresponding extension.
75
+ Corresponding extension pattern.
76
76
  """
77
77
  if data_type == cls.ARRAY:
78
- raise NotImplementedError(f"Data {data_type} are not loaded from file.")
78
+ raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
79
79
  elif data_type == cls.TIFF:
80
80
  return "*.tif*"
81
81
  elif data_type == cls.CUSTOM:
82
82
  return "*.*"
83
83
  else:
84
84
  raise ValueError(f"Data type {data_type} is not supported.")
85
+
86
+ @classmethod
87
+ def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
88
+ """
89
+ Get file extension of corresponding data type.
90
+
91
+ Parameters
92
+ ----------
93
+ data_type : str or SupportedData
94
+ Data type.
95
+
96
+ Returns
97
+ -------
98
+ str
99
+ Corresponding extension.
100
+ """
101
+ if data_type == cls.ARRAY:
102
+ raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
103
+ elif data_type == cls.TIFF:
104
+ return ".tiff"
105
+ elif data_type == cls.CUSTOM:
106
+ # TODO: improve this message
107
+ raise NotImplementedError("Custom extensions have to be passed elsewhere.")
108
+ else:
109
+ raise ValueError(f"Data type {data_type} is not supported.")
@@ -8,5 +8,4 @@ class SupportedTransform(str, BaseEnum):
8
8
 
9
9
  XY_FLIP = "XYFlip"
10
10
  XY_RANDOM_ROTATE90 = "XYRandomRotate90"
11
- NORMALIZE = "Normalize"
12
11
  N2V_MANIPULATE = "N2VManipulate"