careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (92) hide show
  1. careamics/careamist.py +55 -61
  2. careamics/cli/conf.py +24 -9
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/__init__.py +8 -0
  6. careamics/config/algorithms/__init__.py +4 -0
  7. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  8. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  9. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  10. careamics/config/algorithms/vae_algorithm_model.py +53 -18
  11. careamics/config/architectures/lvae_model.py +12 -8
  12. careamics/config/callback_model.py +15 -11
  13. careamics/config/configuration.py +9 -8
  14. careamics/config/configuration_factories.py +892 -78
  15. careamics/config/data/data_model.py +7 -14
  16. careamics/config/data/ng_data_model.py +8 -15
  17. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  18. careamics/config/inference_model.py +6 -11
  19. careamics/config/likelihood_model.py +4 -4
  20. careamics/config/loss_model.py +6 -2
  21. careamics/config/nm_model.py +30 -7
  22. careamics/config/optimizer_models.py +1 -2
  23. careamics/config/support/supported_algorithms.py +5 -3
  24. careamics/config/support/supported_losses.py +5 -2
  25. careamics/config/training_model.py +8 -38
  26. careamics/config/transformations/normalize_model.py +3 -4
  27. careamics/config/transformations/xy_flip_model.py +2 -2
  28. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  29. careamics/config/validators/validator_utils.py +1 -2
  30. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  31. careamics/dataset/in_memory_dataset.py +2 -2
  32. careamics/dataset/iterable_dataset.py +1 -2
  33. careamics/dataset/patching/random_patching.py +6 -6
  34. careamics/dataset/patching/sequential_patching.py +4 -4
  35. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  36. careamics/dataset_ng/dataset.py +3 -3
  37. careamics/dataset_ng/factory.py +19 -19
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  39. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  40. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  41. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  42. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  43. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  44. careamics/file_io/read/__init__.py +0 -1
  45. careamics/lightning/__init__.py +16 -2
  46. careamics/lightning/callbacks/__init__.py +2 -0
  47. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  48. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  49. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  50. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  51. careamics/lightning/dataset_ng/data_module.py +43 -43
  52. careamics/lightning/lightning_module.py +166 -68
  53. careamics/lightning/microsplit_data_module.py +631 -0
  54. careamics/lightning/predict_data_module.py +16 -9
  55. careamics/lightning/train_data_module.py +29 -18
  56. careamics/losses/__init__.py +7 -1
  57. careamics/losses/loss_factory.py +9 -1
  58. careamics/losses/lvae/losses.py +94 -9
  59. careamics/lvae_training/dataset/__init__.py +8 -8
  60. careamics/lvae_training/dataset/config.py +56 -44
  61. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  62. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  63. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  64. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  65. careamics/model_io/bioimage/model_description.py +12 -11
  66. careamics/model_io/bmz_io.py +12 -8
  67. careamics/models/layers.py +5 -5
  68. careamics/models/lvae/likelihoods.py +30 -14
  69. careamics/models/lvae/lvae.py +2 -2
  70. careamics/models/lvae/noise_models.py +20 -14
  71. careamics/prediction_utils/__init__.py +8 -2
  72. careamics/prediction_utils/lvae_prediction.py +5 -5
  73. careamics/prediction_utils/prediction_outputs.py +48 -3
  74. careamics/prediction_utils/stitch_prediction.py +71 -0
  75. careamics/transforms/compose.py +9 -9
  76. careamics/transforms/n2v_manipulate.py +3 -3
  77. careamics/transforms/n2v_manipulate_torch.py +4 -4
  78. careamics/transforms/normalize.py +4 -6
  79. careamics/transforms/pixel_manipulation.py +6 -8
  80. careamics/transforms/pixel_manipulation_torch.py +5 -7
  81. careamics/transforms/xy_flip.py +3 -5
  82. careamics/transforms/xy_random_rotate90.py +4 -6
  83. careamics/utils/logging.py +8 -8
  84. careamics/utils/metrics.py +2 -2
  85. careamics/utils/plotting.py +1 -3
  86. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
  87. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
  88. careamics/dataset/zarr_dataset.py +0 -151
  89. careamics/file_io/read/zarr.py +0 -60
  90. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  91. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  92. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ import os
6
6
  import sys
7
7
  from collections.abc import Sequence
8
8
  from pprint import pformat
9
- from typing import Annotated, Any, Literal, Optional, Union
9
+ from typing import Annotated, Any, Literal, Self, Union
10
10
  from warnings import warn
11
11
 
12
12
  import numpy as np
@@ -19,7 +19,6 @@ from pydantic import (
19
19
  field_validator,
20
20
  model_validator,
21
21
  )
22
- from typing_extensions import Self
23
22
 
24
23
  from ..transformations import XYFlipModel, XYRandomRotate90Model
25
24
  from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2
@@ -109,22 +108,16 @@ class DataConfig(BaseModel):
109
108
  """Batch size for training."""
110
109
 
111
110
  # Optional fields
112
- image_means: Optional[list[Float]] = Field(
113
- default=None, min_length=0, max_length=32
114
- )
111
+ image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
115
112
  """Means of the data across channels, used for normalization."""
116
113
 
117
- image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
114
+ image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
118
115
  """Standard deviations of the data across channels, used for normalization."""
119
116
 
120
- target_means: Optional[list[Float]] = Field(
121
- default=None, min_length=0, max_length=32
122
- )
117
+ target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
123
118
  """Means of the target data across channels, used for normalization."""
124
119
 
125
- target_stds: Optional[list[Float]] = Field(
126
- default=None, min_length=0, max_length=32
127
- )
120
+ target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
128
121
  """Standard deviations of the target data across channels, used for
129
122
  normalization."""
130
123
 
@@ -388,8 +381,8 @@ class DataConfig(BaseModel):
388
381
  self,
389
382
  image_means: Union[NDArray, tuple, list, None],
390
383
  image_stds: Union[NDArray, tuple, list, None],
391
- target_means: Optional[Union[NDArray, tuple, list, None]] = None,
392
- target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
384
+ target_means: Union[NDArray, tuple, list, None] | None = None,
385
+ target_stds: Union[NDArray, tuple, list, None] | None = None,
393
386
  ) -> None:
394
387
  """
395
388
  Set mean and standard deviation of the data across channels.
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Sequence
6
6
  from pprint import pformat
7
- from typing import Annotated, Any, Literal, Optional, Union
7
+ from typing import Annotated, Any, Literal, Self, Union
8
8
  from warnings import warn
9
9
 
10
10
  import numpy as np
@@ -17,7 +17,6 @@ from pydantic import (
17
17
  field_validator,
18
18
  model_validator,
19
19
  )
20
- from typing_extensions import Self
21
20
 
22
21
  from ..transformations import XYFlipModel, XYRandomRotate90Model
23
22
  from ..validators import check_axes_validity
@@ -106,22 +105,16 @@ class NGDataConfig(BaseModel):
106
105
  batch_size: int = Field(default=1, ge=1, validate_default=True)
107
106
  """Batch size for training."""
108
107
 
109
- image_means: Optional[list[Float]] = Field(
110
- default=None, min_length=0, max_length=32
111
- )
108
+ image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
112
109
  """Means of the data across channels, used for normalization."""
113
110
 
114
- image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
111
+ image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
115
112
  """Standard deviations of the data across channels, used for normalization."""
116
113
 
117
- target_means: Optional[list[Float]] = Field(
118
- default=None, min_length=0, max_length=32
119
- )
114
+ target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
120
115
  """Means of the target data across channels, used for normalization."""
121
116
 
122
- target_stds: Optional[list[Float]] = Field(
123
- default=None, min_length=0, max_length=32
124
- )
117
+ target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
125
118
  """Standard deviations of the target data across channels, used for
126
119
  normalization."""
127
120
 
@@ -148,7 +141,7 @@ class NGDataConfig(BaseModel):
148
141
  test_dataloader_params: dict[str, Any] = Field(default={})
149
142
  """Dictionary of PyTorch test dataloader parameters."""
150
143
 
151
- seed: Optional[int] = Field(default=None, gt=0)
144
+ seed: int | None = Field(default=None, gt=0)
152
145
  """Random seed for reproducibility."""
153
146
 
154
147
  @field_validator("axes")
@@ -330,8 +323,8 @@ class NGDataConfig(BaseModel):
330
323
  self,
331
324
  image_means: Union[NDArray, tuple, list, None],
332
325
  image_stds: Union[NDArray, tuple, list, None],
333
- target_means: Optional[Union[NDArray, tuple, list, None]] = None,
334
- target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
326
+ target_means: Union[NDArray, tuple, list, None] | None = None,
327
+ target_stds: Union[NDArray, tuple, list, None] | None = None,
335
328
  ) -> None:
336
329
  """
337
330
  Set mean and standard deviation of the data across channels.
@@ -1,7 +1,6 @@
1
1
  """Sequential patching Pydantic model."""
2
2
 
3
3
  from collections.abc import Sequence
4
- from typing import Optional
5
4
 
6
5
  from pydantic import Field, ValidationInfo, field_validator
7
6
 
@@ -24,7 +23,7 @@ class _OverlappingPatchedModel(_PatchedModel):
24
23
  dimension, and the number of dimensions be either 2 or 3.
25
24
  """
26
25
 
27
- overlaps: Optional[Sequence[int]] = Field(
26
+ overlaps: Sequence[int] | None = Field(
28
27
  default=None,
29
28
  min_length=2,
30
29
  max_length=3,
@@ -37,8 +36,8 @@ class _OverlappingPatchedModel(_PatchedModel):
37
36
  @field_validator("overlaps")
38
37
  @classmethod
39
38
  def overlap_smaller_than_patch_size(
40
- cls, overlaps: Optional[Sequence[int]], values: ValidationInfo
41
- ) -> Optional[Sequence[int]]:
39
+ cls, overlaps: Sequence[int] | None, values: ValidationInfo
40
+ ) -> Sequence[int] | None:
42
41
  """
43
42
  Validate overlap.
44
43
 
@@ -78,7 +77,7 @@ class _OverlappingPatchedModel(_PatchedModel):
78
77
 
79
78
  @field_validator("overlaps")
80
79
  @classmethod
81
- def overlap_even(cls, overlaps: Optional[Sequence[int]]) -> Optional[Sequence[int]]:
80
+ def overlap_even(cls, overlaps: Sequence[int] | None) -> Sequence[int] | None:
82
81
  """
83
82
  Validate overlaps.
84
83
 
@@ -2,10 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Any, Literal, Optional, Union
5
+ from typing import Any, Literal, Self, Union
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
- from typing_extensions import Self
9
8
 
10
9
  from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2
11
10
 
@@ -18,12 +17,10 @@ class InferenceConfig(BaseModel):
18
17
  data_type: Literal["array", "tiff", "czi", "custom"] # As defined in SupportedData
19
18
  """Type of input data: numpy.ndarray (array) or path (tiff, czi, or custom)."""
20
19
 
21
- tile_size: Optional[Union[list[int]]] = Field(
22
- default=None, min_length=2, max_length=3
23
- )
20
+ tile_size: Union[list[int]] | None = Field(default=None, min_length=2, max_length=3)
24
21
  """Tile size of prediction, only effective if `tile_overlap` is specified."""
25
22
 
26
- tile_overlap: Optional[Union[list[int]]] = Field(
23
+ tile_overlap: Union[list[int]] | None = Field(
27
24
  default=None, min_length=2, max_length=3
28
25
  )
29
26
  """Overlap between tiles, only effective if `tile_size` is specified."""
@@ -48,8 +45,8 @@ class InferenceConfig(BaseModel):
48
45
  @field_validator("tile_overlap")
49
46
  @classmethod
50
47
  def all_elements_non_zero_even(
51
- cls, tile_overlap: Optional[list[int]]
52
- ) -> Optional[list[int]]:
48
+ cls, tile_overlap: list[int] | None
49
+ ) -> list[int] | None:
53
50
  """
54
51
  Validate tile overlap.
55
52
 
@@ -86,9 +83,7 @@ class InferenceConfig(BaseModel):
86
83
 
87
84
  @field_validator("tile_size")
88
85
  @classmethod
89
- def tile_min_8_power_of_2(
90
- cls, tile_list: Optional[list[int]]
91
- ) -> Optional[list[int]]:
86
+ def tile_min_8_power_of_2(cls, tile_list: list[int] | None) -> list[int] | None:
92
87
  """
93
88
  Validate that each entry is greater or equal than 8 and a power of 2.
94
89
 
@@ -1,6 +1,6 @@
1
1
  """Likelihood model."""
2
2
 
3
- from typing import Annotated, Literal, Optional, Union
3
+ from typing import Annotated, Literal, Union
4
4
 
5
5
  import numpy as np
6
6
  import torch
@@ -31,7 +31,7 @@ class GaussianLikelihoodConfig(BaseModel):
31
31
 
32
32
  model_config = ConfigDict(validate_assignment=True)
33
33
 
34
- predict_logvar: Optional[Literal["pixelwise"]] = None
34
+ predict_logvar: Literal["pixelwise"] | None = None
35
35
  """If `pixelwise`, log-variance is computed for each pixel, else log-variance
36
36
  is not computed."""
37
37
 
@@ -50,11 +50,11 @@ class NMLikelihoodConfig(BaseModel):
50
50
  model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
51
51
 
52
52
  # TODO remove and use as parameters to the likelihood functions?
53
- data_mean: Tensor = torch.zeros(1)
53
+ data_mean: Tensor | None = None
54
54
  """The mean of the data, used to unnormalize data for noise model evaluation.
55
55
  Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
56
56
 
57
57
  # TODO remove and use as parameters to the likelihood functions?
58
- data_std: Tensor = torch.ones(1)
58
+ data_std: Tensor | None = None
59
59
  """The standard deviation of the data, used to unnormalize data for noise
60
60
  model evaluation. Shape is (target_ch,) (or (1, target_ch, [1], 1, 1))."""
@@ -35,7 +35,9 @@ class LVAELossConfig(BaseModel):
35
35
  validate_assignment=True, validate_default=True, arbitrary_types_allowed=True
36
36
  )
37
37
 
38
- loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"]
38
+ loss_type: Literal[
39
+ "hdn", "microsplit", "musplit", "denoisplit", "denoisplit_musplit"
40
+ ]
39
41
  """Type of loss to use for LVAE."""
40
42
 
41
43
  reconstruction_weight: float = 1.0
@@ -50,7 +52,9 @@ class LVAELossConfig(BaseModel):
50
52
  """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
51
53
  kl_params: KLLossConfig = KLLossConfig()
52
54
  """KL loss configuration."""
53
-
55
+ # TODO revisit weights for the losses
54
56
  # TODO: remove?
55
57
  non_stochastic: bool = False
56
58
  """Whether to sample latents and compute KL."""
59
+
60
+ # TODO what are the correct parameters for HDN ?
@@ -1,7 +1,7 @@
1
1
  """Noise models config."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Annotated, Literal, Optional, Union
4
+ from typing import Annotated, Literal, Self, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -11,6 +11,7 @@ from pydantic import (
11
11
  Field,
12
12
  PlainSerializer,
13
13
  PlainValidator,
14
+ model_validator,
14
15
  )
15
16
 
16
17
  from careamics.utils.serializers import _array_to_json, _to_numpy
@@ -42,21 +43,19 @@ class GaussianMixtureNMConfig(BaseModel):
42
43
  # model type
43
44
  model_type: Literal["GaussianMixtureNoiseModel"]
44
45
 
45
- path: Optional[Union[Path, str]] = None
46
+ path: Union[Path, str] | None = None
46
47
  """Path to the directory where the trained noise model (*.npz) is saved in the
47
48
  `train` method."""
48
49
 
49
50
  # TODO remove and use as parameters to the NM functions?
50
- signal: Optional[Union[str, Path, np.ndarray]] = Field(default=None, exclude=True)
51
+ signal: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
51
52
  """Path to the file containing signal or respective numpy array."""
52
53
 
53
54
  # TODO remove and use as parameters to the NM functions?
54
- observation: Optional[Union[str, Path, np.ndarray]] = Field(
55
- default=None, exclude=True
56
- )
55
+ observation: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
57
56
  """Path to the file containing observation or respective numpy array."""
58
57
 
59
- weight: Optional[Array] = None
58
+ weight: Array | None = None
60
59
  """A [3*n_gaussian, n_coeff] sized array containing the values of the weights
61
60
  describing the GMM noise model, with each row corresponding to one
62
61
  parameter of each gaussian, namely [mean, standard deviation and weight].
@@ -88,6 +87,30 @@ class GaussianMixtureNMConfig(BaseModel):
88
87
  tol: float = Field(default=1e-10)
89
88
  """Tolerance used in the computation of the noise model likelihood."""
90
89
 
90
+ @model_validator(mode="after")
91
+ def validate_path(self: Self) -> Self:
92
+ """Validate that the path points to a valid .npz file if provided.
93
+
94
+ Returns
95
+ -------
96
+ Self
97
+ Returns itself.
98
+
99
+ Raises
100
+ ------
101
+ ValueError
102
+ If the path is provided but does not point to a valid .npz file.
103
+ """
104
+ if self.path is not None:
105
+ path = Path(self.path)
106
+ if not path.exists():
107
+ raise ValueError(f"Path {path} does not exist.")
108
+ if path.suffix != ".npz":
109
+ raise ValueError(f"Path {path} must point to a .npz file.")
110
+ if not path.is_file():
111
+ raise ValueError(f"Path {path} must point to a file.")
112
+ return self
113
+
91
114
  # @model_validator(mode="after")
92
115
  # def validate_path_to_pretrained_vs_training_data(self: Self) -> Self:
93
116
  # """Validate paths provided in the config.
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Literal
5
+ from typing import Literal, Self
6
6
 
7
7
  from pydantic import (
8
8
  BaseModel,
@@ -13,7 +13,6 @@ from pydantic import (
13
13
  model_validator,
14
14
  )
15
15
  from torch import optim
16
- from typing_extensions import Self
17
16
 
18
17
  from careamics.utils.torch_utils import filter_parameters
19
18
 
@@ -26,9 +26,11 @@ class SupportedAlgorithm(str, BaseEnum):
26
26
  MUSPLIT = "musplit"
27
27
  """An image splitting approach based on ladder VAE architectures."""
28
28
 
29
+ MICROSPLIT = "microsplit"
30
+ """A micro-level image splitting approach based on ladder VAE architectures."""
31
+
29
32
  DENOISPLIT = "denoisplit"
30
33
  """An image splitting and denoising approach based on ladder VAE architectures."""
31
34
 
32
- # PN2V = "pn2v"
33
- # HDN = "hdn"
34
- # SEG = "segmentation"
35
+ HDN = "hdn"
36
+ """Hierarchical Denoising Network, an unsupervised denoising algorithm"""
@@ -21,9 +21,12 @@ class SupportedLoss(str, BaseEnum):
21
21
  MAE = "mae"
22
22
  N2V = "n2v"
23
23
  # PN2V = "pn2v"
24
- # HDN = "hdn"
24
+ HDN = "hdn"
25
25
  MUSPLIT = "musplit"
26
+ MICROSPLIT = "microsplit"
26
27
  DENOISPLIT = "denoisplit"
27
- DENOISPLIT_MUSPLIT = "denoisplit_musplit"
28
+ DENOISPLIT_MUSPLIT = (
29
+ "denoisplit_musplit" # TODO refac losses, leave only microsplit
30
+ )
28
31
  # CE = "ce"
29
32
  # DICE = "dice"
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Optional, Union
6
+ from typing import Literal
7
7
 
8
- from pydantic import BaseModel, ConfigDict, Field, field_validator
8
+ from pydantic import BaseModel, ConfigDict, Field
9
9
 
10
10
  from .callback_model import CheckpointModel, EarlyStoppingModel
11
11
 
@@ -29,31 +29,20 @@ class TrainingConfig(BaseModel):
29
29
  model_config = ConfigDict(
30
30
  validate_assignment=True,
31
31
  )
32
+ lightning_trainer_config: dict | None = None
33
+ """Configuration for the PyTorch Lightning Trainer, following PyTorch Lightning
34
+ Trainer class"""
32
35
 
33
- num_epochs: int = Field(default=20, ge=1)
34
- """Number of epochs, greater than 0."""
35
-
36
- precision: Literal["64", "32", "16-mixed", "bf16-mixed"] = Field(default="32")
37
- """Numerical precision"""
38
- max_steps: int = Field(default=-1, ge=-1)
39
- """Maximum number of steps to train for. -1 means no limit."""
40
- check_val_every_n_epoch: int = Field(default=1, ge=1)
41
- """Validation step frequency."""
42
- accumulate_grad_batches: int = Field(default=1, ge=1)
43
- """Number of batches to accumulate gradients over before stepping the optimizer."""
44
- gradient_clip_val: Optional[Union[int, float]] = None
45
- """The value to which to clip the gradient"""
46
- gradient_clip_algorithm: Literal["value", "norm"] = "norm"
47
- """The algorithm to use for gradient clipping (see lightning `Trainer`)."""
48
- logger: Optional[Literal["wandb", "tensorboard"]] = None
36
+ logger: Literal["wandb", "tensorboard"] | None = None
49
37
  """Logger to use during training. If None, no logger will be used. Available
50
38
  loggers are defined in SupportedLogger."""
51
39
 
40
+ # Only basic callbacks
52
41
  checkpoint_callback: CheckpointModel = CheckpointModel()
53
42
  """Checkpoint callback configuration, following PyTorch Lightning Checkpoint
54
43
  callback."""
55
44
 
56
- early_stopping_callback: Optional[EarlyStoppingModel] = Field(
45
+ early_stopping_callback: EarlyStoppingModel | None = Field(
57
46
  default=None, validate_default=True
58
47
  )
59
48
  """Early stopping callback configuration, following PyTorch Lightning Checkpoint
@@ -78,22 +67,3 @@ class TrainingConfig(BaseModel):
78
67
  Whether the logger is defined or not.
79
68
  """
80
69
  return self.logger is not None
81
-
82
- @field_validator("max_steps")
83
- @classmethod
84
- def validate_max_steps(cls, max_steps: int) -> int:
85
- """Validate the max_steps parameter.
86
-
87
- Parameters
88
- ----------
89
- max_steps : int
90
- Maximum number of steps to train for. -1 means no limit.
91
-
92
- Returns
93
- -------
94
- int
95
- Validated max_steps.
96
- """
97
- if max_steps == 0:
98
- raise ValueError("max_steps must be greater than 0. Use -1 for no limit.")
99
- return max_steps
@@ -1,9 +1,8 @@
1
1
  """Pydantic model for the Normalize transform."""
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Literal, Self
4
4
 
5
5
  from pydantic import ConfigDict, Field, model_validator
6
- from typing_extensions import Self
7
6
 
8
7
  from .transform_model import TransformModel
9
8
 
@@ -31,8 +30,8 @@ class NormalizeModel(TransformModel):
31
30
  name: Literal["Normalize"] = "Normalize"
32
31
  image_means: list = Field(..., min_length=0, max_length=32)
33
32
  image_stds: list = Field(..., min_length=0, max_length=32)
34
- target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
35
- target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
33
+ target_means: list | None = Field(default=None, min_length=0, max_length=32)
34
+ target_stds: list | None = Field(default=None, min_length=0, max_length=32)
36
35
 
37
36
  @model_validator(mode="after")
38
37
  def validate_means_stds(self: Self) -> Self:
@@ -1,6 +1,6 @@
1
1
  """Pydantic model for the XYFlip transform."""
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Literal
4
4
 
5
5
  from pydantic import ConfigDict, Field
6
6
 
@@ -40,4 +40,4 @@ class XYFlipModel(TransformModel):
40
40
  ge=0,
41
41
  le=1,
42
42
  )
43
- seed: Optional[int] = None
43
+ seed: int | None = None
@@ -1,6 +1,6 @@
1
1
  """Pydantic model for the XYRandomRotate90 transform."""
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Literal
4
4
 
5
5
  from pydantic import ConfigDict, Field
6
6
 
@@ -32,4 +32,4 @@ class XYRandomRotate90Model(TransformModel):
32
32
  ge=0,
33
33
  le=1,
34
34
  )
35
- seed: Optional[int] = None
35
+ seed: int | None = None
@@ -5,7 +5,6 @@ These functions are used to validate dimensions and axes of inputs.
5
5
  """
6
6
 
7
7
  from collections.abc import Sequence
8
- from typing import Optional
9
8
 
10
9
  _AXES = "STCZYX"
11
10
 
@@ -80,7 +79,7 @@ def value_ge_than_8_power_of_2(
80
79
 
81
80
 
82
81
  def patch_size_ge_than_8_power_of_2(
83
- patch_list: Optional[Sequence[int]],
82
+ patch_list: Sequence[int] | None,
84
83
  ) -> None:
85
84
  """
86
85
  Validate that each entry is greater or equal than 8 and a power of 2.
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable, Generator
6
6
  from pathlib import Path
7
- from typing import Optional, Union
7
+ from typing import Union
8
8
 
9
9
  from numpy.typing import NDArray
10
10
  from torch.utils.data import get_worker_info
@@ -21,9 +21,9 @@ logger = get_logger(__name__)
21
21
  def iterate_over_files(
22
22
  data_config: Union[DataConfig, InferenceConfig],
23
23
  data_files: list[Path],
24
- target_files: Optional[list[Path]] = None,
24
+ target_files: list[Path] | None = None,
25
25
  read_source_func: Callable = read_tiff,
26
- ) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
26
+ ) -> Generator[tuple[NDArray, NDArray | None], None, None]:
27
27
  """Iterate over data source and yield whole reshaped images.
28
28
 
29
29
  Parameters
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import copy
6
6
  from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import Any, Optional, Union
8
+ from typing import Any, Union
9
9
 
10
10
  import numpy as np
11
11
  from torch.utils.data import Dataset
@@ -49,7 +49,7 @@ class InMemoryDataset(Dataset):
49
49
  self,
50
50
  data_config: DataConfig,
51
51
  inputs: Union[np.ndarray, list[Path]],
52
- input_target: Optional[Union[np.ndarray, list[Path]]] = None,
52
+ input_target: Union[np.ndarray, list[Path]] | None = None,
53
53
  read_source_func: Callable = read_tiff,
54
54
  **kwargs: Any,
55
55
  ) -> None:
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
  import copy
6
6
  from collections.abc import Callable, Generator
7
7
  from pathlib import Path
8
- from typing import Optional
9
8
 
10
9
  import numpy as np
11
10
  from torch.utils.data import IterableDataset
@@ -51,7 +50,7 @@ class PathIterableDataset(IterableDataset):
51
50
  self,
52
51
  data_config: DataConfig,
53
52
  src_files: list[Path],
54
- target_files: Optional[list[Path]] = None,
53
+ target_files: list[Path] | None = None,
55
54
  read_source_func: Callable = read_tiff,
56
55
  ) -> None:
57
56
  """Constructors.
@@ -1,7 +1,7 @@
1
1
  """Random patching utilities."""
2
2
 
3
3
  from collections.abc import Generator
4
- from typing import Optional, Union
4
+ from typing import Union
5
5
 
6
6
  import numpy as np
7
7
  import zarr
@@ -13,9 +13,9 @@ from .validate_patch_dimension import validate_patch_dimensions
13
13
  def extract_patches_random(
14
14
  arr: np.ndarray,
15
15
  patch_size: Union[list[int], tuple[int, ...]],
16
- target: Optional[np.ndarray] = None,
17
- seed: Optional[int] = None,
18
- ) -> Generator[tuple[np.ndarray, Optional[np.ndarray]], None, None]:
16
+ target: np.ndarray | None = None,
17
+ seed: int | None = None,
18
+ ) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]:
19
19
  """
20
20
  Generate patches from an array in a random manner.
21
21
 
@@ -115,8 +115,8 @@ def extract_patches_random_from_chunks(
115
115
  arr: zarr.Array,
116
116
  patch_size: Union[list[int], tuple[int, ...]],
117
117
  chunk_size: Union[list[int], tuple[int, ...]],
118
- chunk_limit: Optional[int] = None,
119
- seed: Optional[int] = None,
118
+ chunk_limit: int | None = None,
119
+ seed: int | None = None,
120
120
  ) -> Generator[np.ndarray, None, None]:
121
121
  """
122
122
  Generate patches from an array in a random manner.
@@ -1,6 +1,6 @@
1
1
  """Sequential patching functions."""
2
2
 
3
- from typing import Optional, Union
3
+ from typing import Union
4
4
 
5
5
  import numpy as np
6
6
  from skimage.util import view_as_windows
@@ -110,7 +110,7 @@ def _compute_patch_views(
110
110
  window_shape: list[int],
111
111
  step: tuple[int, ...],
112
112
  output_shape: list[int],
113
- target: Optional[np.ndarray] = None,
113
+ target: np.ndarray | None = None,
114
114
  ) -> np.ndarray:
115
115
  """
116
116
  Compute views of an array corresponding to patches.
@@ -151,8 +151,8 @@ def _compute_patch_views(
151
151
  def extract_patches_sequential(
152
152
  arr: np.ndarray,
153
153
  patch_size: Union[list[int], tuple[int, ...]],
154
- target: Optional[np.ndarray] = None,
155
- ) -> tuple[np.ndarray, Optional[np.ndarray]]:
154
+ target: np.ndarray | None = None,
155
+ ) -> tuple[np.ndarray, np.ndarray | None]:
156
156
  """
157
157
  Generate patches from an array in a sequential manner.
158
158
 
@@ -3,7 +3,7 @@
3
3
  import builtins
4
4
  import itertools
5
5
  from collections.abc import Generator
6
- from typing import Any, Optional, Union
6
+ from typing import Any, Union
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -16,7 +16,7 @@ def extract_tiles(
16
16
  arr: NDArray,
17
17
  tile_size: NDArray[np.int_],
18
18
  overlaps: NDArray[np.int_],
19
- padding_kwargs: Optional[dict[str, Any]] = None,
19
+ padding_kwargs: dict[str, Any] | None = None,
20
20
  ) -> Generator[tuple[NDArray, TileInformation], None, None]:
21
21
  """Generate tiles from the input array with specified overlap.
22
22