careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (103) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +92 -55
  4. careamics/config/__init__.py +0 -1
  5. careamics/config/algorithm_model.py +5 -3
  6. careamics/config/architectures/architecture_model.py +7 -0
  7. careamics/config/architectures/custom_model.py +8 -1
  8. careamics/config/architectures/register_model.py +3 -1
  9. careamics/config/architectures/unet_model.py +3 -0
  10. careamics/config/architectures/vae_model.py +2 -0
  11. careamics/config/callback_model.py +4 -15
  12. careamics/config/configuration_example.py +4 -4
  13. careamics/config/configuration_factory.py +113 -55
  14. careamics/config/configuration_model.py +14 -16
  15. careamics/config/data_model.py +63 -165
  16. careamics/config/inference_model.py +9 -75
  17. careamics/config/optimizer_models.py +4 -4
  18. careamics/config/references/algorithm_descriptions.py +1 -0
  19. careamics/config/references/references.py +1 -0
  20. careamics/config/support/__init__.py +0 -2
  21. careamics/config/support/supported_activations.py +2 -0
  22. careamics/config/support/supported_algorithms.py +3 -1
  23. careamics/config/support/supported_architectures.py +2 -0
  24. careamics/config/support/supported_data.py +2 -0
  25. careamics/config/support/supported_loggers.py +2 -0
  26. careamics/config/support/supported_losses.py +2 -0
  27. careamics/config/support/supported_optimizers.py +2 -0
  28. careamics/config/support/supported_pixel_manipulations.py +3 -3
  29. careamics/config/support/supported_struct_axis.py +2 -0
  30. careamics/config/support/supported_transforms.py +4 -15
  31. careamics/config/tile_information.py +2 -0
  32. careamics/config/training_model.py +1 -0
  33. careamics/config/transformations/__init__.py +3 -2
  34. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  35. careamics/config/transformations/normalize_model.py +1 -0
  36. careamics/config/transformations/transform_model.py +1 -0
  37. careamics/config/transformations/xy_flip_model.py +43 -0
  38. careamics/config/transformations/xy_random_rotate90_model.py +13 -7
  39. careamics/config/validators/validator_utils.py +1 -0
  40. careamics/conftest.py +13 -0
  41. careamics/dataset/dataset_utils/__init__.py +0 -1
  42. careamics/dataset/dataset_utils/dataset_utils.py +5 -4
  43. careamics/dataset/dataset_utils/file_utils.py +4 -3
  44. careamics/dataset/dataset_utils/read_tiff.py +6 -2
  45. careamics/dataset/dataset_utils/read_utils.py +2 -0
  46. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  47. careamics/dataset/in_memory_dataset.py +84 -76
  48. careamics/dataset/iterable_dataset.py +166 -134
  49. careamics/dataset/patching/__init__.py +0 -7
  50. careamics/dataset/patching/patching.py +56 -14
  51. careamics/dataset/patching/random_patching.py +8 -2
  52. careamics/dataset/patching/sequential_patching.py +20 -14
  53. careamics/dataset/patching/tiled_patching.py +13 -7
  54. careamics/dataset/patching/validate_patch_dimension.py +2 -0
  55. careamics/dataset/zarr_dataset.py +2 -0
  56. careamics/lightning_datamodule.py +63 -41
  57. careamics/lightning_module.py +9 -3
  58. careamics/lightning_prediction_datamodule.py +15 -20
  59. careamics/lightning_prediction_loop.py +8 -6
  60. careamics/losses/__init__.py +1 -3
  61. careamics/losses/loss_factory.py +2 -1
  62. careamics/losses/losses.py +11 -7
  63. careamics/model_io/__init__.py +0 -1
  64. careamics/model_io/bioimage/_readme_factory.py +2 -1
  65. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  66. careamics/model_io/bioimage/model_description.py +1 -0
  67. careamics/model_io/bmz_io.py +4 -3
  68. careamics/models/activation.py +2 -0
  69. careamics/models/layers.py +122 -25
  70. careamics/models/model_factory.py +2 -1
  71. careamics/models/unet.py +114 -19
  72. careamics/prediction/stitch_prediction.py +2 -5
  73. careamics/transforms/__init__.py +4 -25
  74. careamics/transforms/compose.py +124 -0
  75. careamics/transforms/n2v_manipulate.py +65 -34
  76. careamics/transforms/normalize.py +91 -28
  77. careamics/transforms/pixel_manipulation.py +7 -7
  78. careamics/transforms/struct_mask_parameters.py +3 -1
  79. careamics/transforms/transform.py +24 -0
  80. careamics/transforms/tta.py +2 -2
  81. careamics/transforms/xy_flip.py +123 -0
  82. careamics/transforms/xy_random_rotate90.py +66 -60
  83. careamics/utils/__init__.py +0 -1
  84. careamics/utils/base_enum.py +28 -0
  85. careamics/utils/context.py +1 -0
  86. careamics/utils/logging.py +1 -0
  87. careamics/utils/metrics.py +1 -0
  88. careamics/utils/path_utils.py +2 -0
  89. careamics/utils/ram.py +2 -0
  90. careamics/utils/receptive_field.py +93 -87
  91. careamics/utils/torch_utils.py +1 -0
  92. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
  93. careamics-0.1.0rc6.dist-info/RECORD +107 -0
  94. careamics/config/noise_models.py +0 -162
  95. careamics/config/support/supported_extraction_strategies.py +0 -24
  96. careamics/config/transformations/nd_flip_model.py +0 -32
  97. careamics/dataset/patching/patch_transform.py +0 -44
  98. careamics/losses/noise_model_factory.py +0 -40
  99. careamics/losses/noise_models.py +0 -524
  100. careamics/transforms/nd_flip.py +0 -93
  101. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  102. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
  103. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,10 @@
1
1
  """Data configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from pprint import pformat
5
6
  from typing import Any, List, Literal, Optional, Union
6
7
 
7
- from albumentations import Compose
8
8
  from pydantic import (
9
9
  BaseModel,
10
10
  ConfigDict,
@@ -17,14 +17,14 @@ from typing_extensions import Annotated, Self
17
17
 
18
18
  from .support import SupportedTransform
19
19
  from .transformations.n2v_manipulate_model import N2VManipulateModel
20
- from .transformations.nd_flip_model import NDFlipModel
21
20
  from .transformations.normalize_model import NormalizeModel
21
+ from .transformations.xy_flip_model import XYFlipModel
22
22
  from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
23
23
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
24
24
 
25
25
  TRANSFORMS_UNION = Annotated[
26
26
  Union[
27
- NDFlipModel,
27
+ XYFlipModel,
28
28
  XYRandomRotate90Model,
29
29
  NormalizeModel,
30
30
  N2VManipulateModel,
@@ -41,6 +41,8 @@ class DataConfig(BaseModel):
41
41
  and then the mean (if they were both `None` before) will raise a validation error.
42
42
  Prefer instead `set_mean_and_std` to set both at once.
43
43
 
44
+ All supported transforms are defined in the SupportedTransform enum.
45
+
44
46
  Examples
45
47
  --------
46
48
  Minimum example:
@@ -56,7 +58,7 @@ class DataConfig(BaseModel):
56
58
  >>> data.set_mean_and_std(mean=214.3, std=84.5)
57
59
 
58
60
  One can pass also a list of transformations, by keyword, using the
59
- SupportedTransform or the name of an Albumentation transform:
61
+ SupportedTransform value:
60
62
  >>> from careamics.config.support import SupportedTransform
61
63
  >>> data = DataConfig(
62
64
  ... data_type="tiff",
@@ -70,9 +72,7 @@ class DataConfig(BaseModel):
70
72
  ... "std": 47.2,
71
73
  ... },
72
74
  ... {
73
- ... "name": "NDFlip",
74
- ... "is_3D": True,
75
- ... "flip_z": True,
75
+ ... "name": "XYFlip",
76
76
  ... }
77
77
  ... ]
78
78
  ... )
@@ -81,7 +81,6 @@ class DataConfig(BaseModel):
81
81
  # Pydantic class configuration
82
82
  model_config = ConfigDict(
83
83
  validate_assignment=True,
84
- arbitrary_types_allowed=True, # Allow Compose declaration
85
84
  )
86
85
 
87
86
  # Dataset configuration
@@ -94,13 +93,13 @@ class DataConfig(BaseModel):
94
93
  mean: Optional[float] = None
95
94
  std: Optional[float] = None
96
95
 
97
- transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
96
+ transforms: List[TRANSFORMS_UNION] = Field(
98
97
  default=[
99
98
  {
100
99
  "name": SupportedTransform.NORMALIZE.value,
101
100
  },
102
101
  {
103
- "name": SupportedTransform.NDFLIP.value,
102
+ "name": SupportedTransform.XY_FLIP.value,
104
103
  },
105
104
  {
106
105
  "name": SupportedTransform.XY_RANDOM_ROTATE90.value,
@@ -181,19 +180,19 @@ class DataConfig(BaseModel):
181
180
  @field_validator("transforms")
182
181
  @classmethod
183
182
  def validate_prediction_transforms(
184
- cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
185
- ) -> Union[List[TRANSFORMS_UNION], Compose]:
183
+ cls, transforms: List[TRANSFORMS_UNION]
184
+ ) -> List[TRANSFORMS_UNION]:
186
185
  """
187
186
  Validate N2VManipulate transform position in the transform list.
188
187
 
189
188
  Parameters
190
189
  ----------
191
- transforms : Union[List[Transformations_Union], Compose]
190
+ transforms : List[Transformations_Union]
192
191
  Transforms.
193
192
 
194
193
  Returns
195
194
  -------
196
- Union[List[Transformations_Union], Compose]
195
+ List[TRANSFORMS_UNION]
197
196
  Validated transforms.
198
197
 
199
198
  Raises
@@ -201,23 +200,22 @@ class DataConfig(BaseModel):
201
200
  ValueError
202
201
  If multiple instances of N2VManipulate are found.
203
202
  """
204
- if not isinstance(transforms, Compose):
205
- transform_list = [t.name for t in transforms]
206
-
207
- if SupportedTransform.N2V_MANIPULATE in transform_list:
208
- # multiple N2V_MANIPULATE
209
- if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
210
- raise ValueError(
211
- f"Multiple instances of "
212
- f"{SupportedTransform.N2V_MANIPULATE} transforms "
213
- f"are not allowed."
214
- )
215
-
216
- # N2V_MANIPULATE not the last transform
217
- elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
218
- index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
219
- transform = transforms.pop(index)
220
- transforms.append(transform)
203
+ transform_list = [t.name for t in transforms]
204
+
205
+ if SupportedTransform.N2V_MANIPULATE in transform_list:
206
+ # multiple N2V_MANIPULATE
207
+ if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
208
+ raise ValueError(
209
+ f"Multiple instances of "
210
+ f"{SupportedTransform.N2V_MANIPULATE} transforms "
211
+ f"are not allowed."
212
+ )
213
+
214
+ # N2V_MANIPULATE not the last transform
215
+ elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
216
+ index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
217
+ transform = transforms.pop(index)
218
+ transforms.append(transform)
221
219
 
222
220
  return transforms
223
221
 
@@ -254,13 +252,12 @@ class DataConfig(BaseModel):
254
252
  Self
255
253
  Data model with mean and std added to the Normalize transform.
256
254
  """
257
- if self.mean is not None or self.std is not None:
255
+ if self.mean is not None and self.std is not None:
258
256
  # search in the transforms for Normalize and update parameters
259
- if self.has_transform_list():
260
- for transform in self.transforms:
261
- if transform.name == SupportedTransform.NORMALIZE.value:
262
- transform.mean = self.mean
263
- transform.std = self.std
257
+ for transform in self.transforms:
258
+ if transform.name == SupportedTransform.NORMALIZE.value:
259
+ transform.mean = self.mean
260
+ transform.std = self.std
264
261
 
265
262
  return self
266
263
 
@@ -286,13 +283,6 @@ class DataConfig(BaseModel):
286
283
  f"({self.axes})."
287
284
  )
288
285
 
289
- if self.has_transform_list():
290
- for transform in self.transforms:
291
- if transform.name == SupportedTransform.NDFLIP:
292
- transform.is_3D = True
293
- elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
294
- transform.is_3D = True
295
-
296
286
  else:
297
287
  if len(self.patch_size) != 2:
298
288
  raise ValueError(
@@ -300,13 +290,6 @@ class DataConfig(BaseModel):
300
290
  f"({self.axes})."
301
291
  )
302
292
 
303
- if self.has_transform_list():
304
- for transform in self.transforms:
305
- if transform.name == SupportedTransform.NDFLIP:
306
- transform.is_3D = False
307
- elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90:
308
- transform.is_3D = False
309
-
310
293
  return self
311
294
 
312
295
  def __str__(self) -> str:
@@ -332,84 +315,31 @@ class DataConfig(BaseModel):
332
315
  self.__dict__.update(kwargs)
333
316
  self.__class__.model_validate(self.__dict__)
334
317
 
335
- def has_transform_list(self) -> bool:
336
- """
337
- Check if the transforms are a list, as opposed to a Compose object.
338
-
339
- Returns
340
- -------
341
- bool
342
- True if the transforms are a list, False otherwise.
343
- """
344
- return isinstance(self.transforms, list)
345
-
346
318
  def has_n2v_manipulate(self) -> bool:
347
319
  """
348
320
  Check if the transforms contain N2VManipulate.
349
321
 
350
- Use `has_transform_list` to check if the transforms are a list.
351
-
352
322
  Returns
353
323
  -------
354
324
  bool
355
325
  True if the transforms contain N2VManipulate, False otherwise.
356
-
357
- Raises
358
- ------
359
- ValueError
360
- If the transforms are a Compose object.
361
326
  """
362
- if self.has_transform_list():
363
- return any(
364
- transform.name == SupportedTransform.N2V_MANIPULATE.value
365
- for transform in self.transforms
366
- )
367
- else:
368
- raise ValueError(
369
- "Checking for N2VManipulate with Compose transforms is not allowed. "
370
- "Check directly in the Compose."
371
- )
327
+ return any(
328
+ transform.name == SupportedTransform.N2V_MANIPULATE.value
329
+ for transform in self.transforms
330
+ )
372
331
 
373
332
  def add_n2v_manipulate(self) -> None:
374
- """
375
- Add N2VManipulate to the transforms.
376
-
377
- Use `has_transform_list` to check if the transforms are a list.
378
-
379
- Raises
380
- ------
381
- ValueError
382
- If the transforms are a Compose object.
383
- """
384
- if self.has_transform_list():
385
- if not self.has_n2v_manipulate():
386
- self.transforms.append(
387
- N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
388
- )
389
- else:
390
- raise ValueError(
391
- "Adding N2VManipulate with Compose transforms is not allowed. Add "
392
- "N2VManipulate directly to the transform in the Compose."
333
+ """Add N2VManipulate to the transforms."""
334
+ if not self.has_n2v_manipulate():
335
+ self.transforms.append(
336
+ N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value)
393
337
  )
394
338
 
395
339
  def remove_n2v_manipulate(self) -> None:
396
- """
397
- Remove N2VManipulate from the transforms.
398
-
399
- Use `has_transform_list` to check if the transforms are a list.
400
-
401
- Raises
402
- ------
403
- ValueError
404
- If the transforms are a Compose object.
405
- """
406
- if self.has_transform_list() and self.has_n2v_manipulate():
340
+ """Remove N2VManipulate from the transforms."""
341
+ if self.has_n2v_manipulate():
407
342
  self.transforms.pop(-1)
408
- else:
409
- raise ValueError(
410
- "Removing N2VManipulate with Compose transforms is not allowed. Remove "
411
- "N2VManipulate directly from the transform in the Compose."
412
- )
413
343
 
414
344
  def set_mean_and_std(self, mean: float, std: float) -> None:
415
345
  """
@@ -427,18 +357,6 @@ class DataConfig(BaseModel):
427
357
  """
428
358
  self._update(mean=mean, std=std)
429
359
 
430
- # search in the transforms for Normalize and update parameters
431
- if self.has_transform_list():
432
- for transform in self.transforms:
433
- if transform.name == SupportedTransform.NORMALIZE.value:
434
- transform.mean = mean
435
- transform.std = std
436
- else:
437
- raise ValueError(
438
- "Setting mean and std with Compose transforms is not allowed. Add "
439
- "mean and std parameters directly to the transform in the Compose."
440
- )
441
-
442
360
  def set_3D(self, axes: str, patch_size: List[int]) -> None:
443
361
  """
444
362
  Set 3D parameters.
@@ -465,8 +383,6 @@ class DataConfig(BaseModel):
465
383
  ------
466
384
  ValueError
467
385
  If the N2V pixel manipulate transform is not found in the transforms.
468
- ValueError
469
- If the transforms are a Compose object.
470
386
  """
471
387
  if use_n2v2:
472
388
  self.set_N2V2_strategy("median")
@@ -486,28 +402,19 @@ class DataConfig(BaseModel):
486
402
  ------
487
403
  ValueError
488
404
  If the N2V pixel manipulate transform is not found in the transforms.
489
- ValueError
490
- If the transforms are a Compose object.
491
405
  """
492
- if isinstance(self.transforms, list):
493
- found_n2v = False
494
-
495
- for transform in self.transforms:
496
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
497
- transform.strategy = strategy
498
- found_n2v = True
406
+ found_n2v = False
499
407
 
500
- if not found_n2v:
501
- transforms = [t.name for t in self.transforms]
502
- raise ValueError(
503
- f"N2V_Manipulate transform not found in the transforms list "
504
- f"({transforms})."
505
- )
408
+ for transform in self.transforms:
409
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
410
+ transform.strategy = strategy
411
+ found_n2v = True
506
412
 
507
- else:
413
+ if not found_n2v:
414
+ transforms = [t.name for t in self.transforms]
508
415
  raise ValueError(
509
- "Setting N2V2 strategy with Compose transforms is not allowed. Add "
510
- "N2V2 strategy parameters directly to the transform in the Compose."
416
+ f"N2V_Manipulate transform not found in the transforms list "
417
+ f"({transforms})."
511
418
  )
512
419
 
513
420
  def set_structN2V_mask(
@@ -529,27 +436,18 @@ class DataConfig(BaseModel):
529
436
  ------
530
437
  ValueError
531
438
  If the N2V pixel manipulate transform is not found in the transforms.
532
- ValueError
533
- If the transforms are a Compose object.
534
439
  """
535
- if isinstance(self.transforms, list):
536
- found_n2v = False
537
-
538
- for transform in self.transforms:
539
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
540
- transform.struct_mask_axis = mask_axis
541
- transform.struct_mask_span = mask_span
542
- found_n2v = True
440
+ found_n2v = False
543
441
 
544
- if not found_n2v:
545
- transforms = [t.name for t in self.transforms]
546
- raise ValueError(
547
- f"N2V pixel manipulate transform not found in the transforms "
548
- f"({transforms})."
549
- )
442
+ for transform in self.transforms:
443
+ if transform.name == SupportedTransform.N2V_MANIPULATE.value:
444
+ transform.struct_mask_axis = mask_axis
445
+ transform.struct_mask_span = mask_span
446
+ found_n2v = True
550
447
 
551
- else:
448
+ if not found_n2v:
449
+ transforms = [t.name for t in self.transforms]
552
450
  raise ValueError(
553
- "Setting structN2VMask with Compose transforms is not allowed. Add "
554
- "structN2VMask parameters directly to the transform in the Compose."
451
+ f"N2V pixel manipulate transform not found in the transforms "
452
+ f"({transforms})."
555
453
  )
@@ -1,18 +1,14 @@
1
1
  """Pydantic model representing CAREamics prediction configuration."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from typing import Any, List, Literal, Optional, Union
5
6
 
6
- from albumentations import Compose
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
9
9
 
10
- from .support import SupportedTransform
11
- from .transformations.normalize_model import NormalizeModel
12
10
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
13
11
 
14
- TRANSFORMS_UNION = Union[NormalizeModel]
15
-
16
12
 
17
13
  class InferenceConfig(BaseModel):
18
14
  """Configuration class for the prediction model."""
@@ -33,15 +29,6 @@ class InferenceConfig(BaseModel):
33
29
  mean: float
34
30
  std: float = Field(..., ge=0.0)
35
31
 
36
- transforms: Union[List[TRANSFORMS_UNION], Compose] = Field(
37
- default=[
38
- {
39
- "name": SupportedTransform.NORMALIZE.value,
40
- },
41
- ],
42
- validate_default=True,
43
- )
44
-
45
32
  # only default TTAs are supported for now
46
33
  tta_transforms: bool = Field(default=True)
47
34
 
@@ -51,22 +38,22 @@ class InferenceConfig(BaseModel):
51
38
  @field_validator("tile_overlap")
52
39
  @classmethod
53
40
  def all_elements_non_zero_even(
54
- cls, patch_list: Optional[Union[List[int]]]
41
+ cls, tile_overlap: Optional[Union[List[int]]]
55
42
  ) -> Optional[Union[List[int]]]:
56
43
  """
57
- Validate patch size.
44
+ Validate tile overlap.
58
45
 
59
- Patch size must be non-zero, positive and even.
46
+ Overlaps must be non-zero, positive and even.
60
47
 
61
48
  Parameters
62
49
  ----------
63
- patch_list : Optional[Union[List[int]]]
50
+ tile_overlap : Optional[Union[List[int]]]
64
51
  Patch size.
65
52
 
66
53
  Returns
67
54
  -------
68
55
  Optional[Union[List[int]]]
69
- Validated patch size.
56
+ Validated tile overlap.
70
57
 
71
58
  Raises
72
59
  ------
@@ -75,8 +62,8 @@ class InferenceConfig(BaseModel):
75
62
  ValueError
76
63
  If the patch size is not even.
77
64
  """
78
- if patch_list is not None:
79
- for dim in patch_list:
65
+ if tile_overlap is not None:
66
+ for dim in tile_overlap:
80
67
  if dim < 1:
81
68
  raise ValueError(
82
69
  f"Patch size must be non-zero positive (got {dim})."
@@ -85,7 +72,7 @@ class InferenceConfig(BaseModel):
85
72
  if dim % 2 != 0:
86
73
  raise ValueError(f"Patch size must be even (got {dim}).")
87
74
 
88
- return patch_list
75
+ return tile_overlap
89
76
 
90
77
  @field_validator("tile_size")
91
78
  @classmethod
@@ -149,39 +136,6 @@ class InferenceConfig(BaseModel):
149
136
 
150
137
  return axes
151
138
 
152
- @field_validator("transforms")
153
- @classmethod
154
- def validate_transforms(
155
- cls, transforms: Union[List[TRANSFORMS_UNION], Compose]
156
- ) -> Union[List[TRANSFORMS_UNION], Compose]:
157
- """
158
- Validate that transforms do not have N2V pixel manipulate transforms.
159
-
160
- Parameters
161
- ----------
162
- transforms : Union[List[TransformModel], Compose]
163
- Transforms.
164
-
165
- Returns
166
- -------
167
- Union[List[Transformations_Union], Compose]
168
- Validated transforms.
169
-
170
- Raises
171
- ------
172
- ValueError
173
- If transforms contain N2V pixel manipulate transforms.
174
- """
175
- if not isinstance(transforms, Compose) and transforms is not None:
176
- for transform in transforms:
177
- if transform.name == SupportedTransform.N2V_MANIPULATE.value:
178
- raise ValueError(
179
- "N2V_Manipulate transform is not allowed in "
180
- "prediction transforms."
181
- )
182
-
183
- return transforms
184
-
185
139
  @model_validator(mode="after")
186
140
  def validate_dimensions(self: Self) -> Self:
187
141
  """
@@ -235,26 +189,6 @@ class InferenceConfig(BaseModel):
235
189
 
236
190
  return self
237
191
 
238
- @model_validator(mode="after")
239
- def add_std_and_mean_to_normalize(self: Self) -> Self:
240
- """
241
- Add mean and std to the Normalize transform if it is present.
242
-
243
- Returns
244
- -------
245
- Self
246
- Inference model with mean and std added to the Normalize transform.
247
- """
248
- if self.mean is not None or self.std is not None:
249
- # search in the transforms for Normalize and update parameters
250
- if not isinstance(self.transforms, Compose):
251
- for transform in self.transforms:
252
- if transform.name == SupportedTransform.NORMALIZE.value:
253
- transform.mean = self.mean
254
- transform.std = self.std
255
-
256
- return self
257
-
258
192
  def _update(self, **kwargs: Any) -> None:
259
193
  """
260
194
  Update multiple arguments at once.
@@ -1,3 +1,5 @@
1
+ """Optimizers and schedulers Pydantic models."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Dict, Literal
@@ -19,8 +21,7 @@ from .support import SupportedOptimizer
19
21
 
20
22
 
21
23
  class OptimizerModel(BaseModel):
22
- """
23
- Torch optimizer.
24
+ """Torch optimizer Pydantic model.
24
25
 
25
26
  Only parameters supported by the corresponding torch optimizer will be taken
26
27
  into account. For more details, check:
@@ -115,8 +116,7 @@ class OptimizerModel(BaseModel):
115
116
 
116
117
 
117
118
  class LrSchedulerModel(BaseModel):
118
- """
119
- Torch learning rate scheduler.
119
+ """Torch learning rate scheduler Pydantic model.
120
120
 
121
121
  Only parameters supported by the corresponding torch lr scheduler will be taken
122
122
  into account. For more details, check:
@@ -1,4 +1,5 @@
1
1
  """Descriptions of the algorithms used in CAREmics."""
2
+
2
3
  from pydantic import BaseModel
3
4
 
4
5
  CUSTOM = "Custom"
@@ -1,4 +1,5 @@
1
1
  """References for the CAREamics algorithms."""
2
+
2
3
  from bioimageio.spec.generic.v0_3 import CiteEntry
3
4
 
4
5
  N2VRef = CiteEntry(
@@ -14,7 +14,6 @@ __all__ = [
14
14
  "SupportedPixelManipulation",
15
15
  "SupportedTransform",
16
16
  "SupportedData",
17
- "SupportedExtractionStrategy",
18
17
  "SupportedStructAxis",
19
18
  "SupportedLogger",
20
19
  ]
@@ -24,7 +23,6 @@ from .supported_activations import SupportedActivation
24
23
  from .supported_algorithms import SupportedAlgorithm
25
24
  from .supported_architectures import SupportedArchitecture
26
25
  from .supported_data import SupportedData
27
- from .supported_extraction_strategies import SupportedExtractionStrategy
28
26
  from .supported_loggers import SupportedLogger
29
27
  from .supported_losses import SupportedLoss
30
28
  from .supported_optimizers import SupportedOptimizer, SupportedScheduler
@@ -1,3 +1,5 @@
1
+ """Activations supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Algorithms supported by CAREamics."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from careamics.utils import BaseEnum
@@ -10,9 +12,9 @@ class SupportedAlgorithm(str, BaseEnum):
10
12
  """
11
13
 
12
14
  N2V = "n2v"
13
- CUSTOM = "custom"
14
15
  CARE = "care"
15
16
  N2N = "n2n"
17
+ CUSTOM = "custom"
16
18
  # PN2V = "pn2v"
17
19
  # HDN = "hdn"
18
20
  # SEG = "segmentation"
@@ -1,3 +1,5 @@
1
+ """Architectures supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Data supported by CAREamics."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from typing import Union
@@ -1,3 +1,5 @@
1
+ """Logger supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Losses supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ """Optimizers and schedulers supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
@@ -1,15 +1,15 @@
1
+ """Pixel manipulation methods supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5
 
4
6
  class SupportedPixelManipulation(str, BaseEnum):
5
- """_summary_.
7
+ """Supported Noise2Void pixel manipulations.
6
8
 
7
9
  - Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor
8
10
  pixel value.
9
11
  - Median: Replace masked pixel value by the mean of the neighborhood.
10
12
  """
11
13
 
12
- # TODO docs
13
-
14
14
  UNIFORM = "uniform"
15
15
  MEDIAN = "median"
@@ -1,3 +1,5 @@
1
+ """StructN2V axes supported by CAREamics."""
2
+
1
3
  from careamics.utils import BaseEnum
2
4
 
3
5