careamics 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (57) hide show
  1. careamics/careamist.py +49 -49
  2. careamics/cli/conf.py +6 -6
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/algorithms/vae_algorithm_model.py +4 -4
  6. careamics/config/callback_model.py +8 -8
  7. careamics/config/configuration_factories.py +49 -49
  8. careamics/config/data/data_model.py +7 -13
  9. careamics/config/data/ng_data_model.py +8 -14
  10. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  11. careamics/config/inference_model.py +6 -10
  12. careamics/config/likelihood_model.py +2 -2
  13. careamics/config/nm_model.py +5 -7
  14. careamics/config/training_model.py +4 -4
  15. careamics/config/transformations/normalize_model.py +3 -3
  16. careamics/config/transformations/xy_flip_model.py +2 -2
  17. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  18. careamics/config/validators/validator_utils.py +1 -2
  19. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  20. careamics/dataset/in_memory_dataset.py +2 -2
  21. careamics/dataset/iterable_dataset.py +1 -2
  22. careamics/dataset/patching/random_patching.py +6 -6
  23. careamics/dataset/patching/sequential_patching.py +4 -4
  24. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  25. careamics/dataset_ng/dataset.py +3 -3
  26. careamics/dataset_ng/factory.py +19 -19
  27. careamics/dataset_ng/patching_strategies/random_patching.py +3 -4
  28. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  29. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  30. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  31. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  32. careamics/lightning/dataset_ng/data_module.py +43 -43
  33. careamics/lightning/lightning_module.py +12 -14
  34. careamics/lightning/predict_data_module.py +8 -8
  35. careamics/lightning/train_data_module.py +11 -11
  36. careamics/losses/lvae/losses.py +9 -9
  37. careamics/model_io/bioimage/model_description.py +12 -11
  38. careamics/model_io/bmz_io.py +4 -4
  39. careamics/models/layers.py +5 -5
  40. careamics/models/unet.py +16 -10
  41. careamics/prediction_utils/lvae_prediction.py +5 -5
  42. careamics/transforms/compose.py +9 -9
  43. careamics/transforms/n2v_manipulate.py +3 -3
  44. careamics/transforms/n2v_manipulate_torch.py +4 -4
  45. careamics/transforms/normalize.py +4 -6
  46. careamics/transforms/pixel_manipulation.py +6 -8
  47. careamics/transforms/pixel_manipulation_torch.py +5 -7
  48. careamics/transforms/xy_flip.py +3 -5
  49. careamics/transforms/xy_random_rotate90.py +3 -5
  50. careamics/utils/logging.py +8 -8
  51. careamics/utils/metrics.py +2 -2
  52. careamics/utils/plotting.py +1 -3
  53. {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/METADATA +2 -3
  54. {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/RECORD +57 -57
  55. {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/WHEEL +0 -0
  56. {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/entry_points.txt +0 -0
  57. {careamics-0.0.13.dist-info → careamics-0.0.15.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Sequence
6
6
  from pprint import pformat
7
- from typing import Annotated, Any, Literal, Optional, Union
7
+ from typing import Annotated, Any, Literal, Union
8
8
  from warnings import warn
9
9
 
10
10
  import numpy as np
@@ -106,22 +106,16 @@ class NGDataConfig(BaseModel):
106
106
  batch_size: int = Field(default=1, ge=1, validate_default=True)
107
107
  """Batch size for training."""
108
108
 
109
- image_means: Optional[list[Float]] = Field(
110
- default=None, min_length=0, max_length=32
111
- )
109
+ image_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
112
110
  """Means of the data across channels, used for normalization."""
113
111
 
114
- image_stds: Optional[list[Float]] = Field(default=None, min_length=0, max_length=32)
112
+ image_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
115
113
  """Standard deviations of the data across channels, used for normalization."""
116
114
 
117
- target_means: Optional[list[Float]] = Field(
118
- default=None, min_length=0, max_length=32
119
- )
115
+ target_means: list[Float] | None = Field(default=None, min_length=0, max_length=32)
120
116
  """Means of the target data across channels, used for normalization."""
121
117
 
122
- target_stds: Optional[list[Float]] = Field(
123
- default=None, min_length=0, max_length=32
124
- )
118
+ target_stds: list[Float] | None = Field(default=None, min_length=0, max_length=32)
125
119
  """Standard deviations of the target data across channels, used for
126
120
  normalization."""
127
121
 
@@ -148,7 +142,7 @@ class NGDataConfig(BaseModel):
148
142
  test_dataloader_params: dict[str, Any] = Field(default={})
149
143
  """Dictionary of PyTorch test dataloader parameters."""
150
144
 
151
- seed: Optional[int] = Field(default=None, gt=0)
145
+ seed: int | None = Field(default=None, gt=0)
152
146
  """Random seed for reproducibility."""
153
147
 
154
148
  @field_validator("axes")
@@ -330,8 +324,8 @@ class NGDataConfig(BaseModel):
330
324
  self,
331
325
  image_means: Union[NDArray, tuple, list, None],
332
326
  image_stds: Union[NDArray, tuple, list, None],
333
- target_means: Optional[Union[NDArray, tuple, list, None]] = None,
334
- target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
327
+ target_means: Union[NDArray, tuple, list, None] | None = None,
328
+ target_stds: Union[NDArray, tuple, list, None] | None = None,
335
329
  ) -> None:
336
330
  """
337
331
  Set mean and standard deviation of the data across channels.
@@ -1,7 +1,6 @@
1
1
  """Sequential patching Pydantic model."""
2
2
 
3
3
  from collections.abc import Sequence
4
- from typing import Optional
5
4
 
6
5
  from pydantic import Field, ValidationInfo, field_validator
7
6
 
@@ -24,7 +23,7 @@ class _OverlappingPatchedModel(_PatchedModel):
24
23
  dimension, and the number of dimensions be either 2 or 3.
25
24
  """
26
25
 
27
- overlaps: Optional[Sequence[int]] = Field(
26
+ overlaps: Sequence[int] | None = Field(
28
27
  default=None,
29
28
  min_length=2,
30
29
  max_length=3,
@@ -37,8 +36,8 @@ class _OverlappingPatchedModel(_PatchedModel):
37
36
  @field_validator("overlaps")
38
37
  @classmethod
39
38
  def overlap_smaller_than_patch_size(
40
- cls, overlaps: Optional[Sequence[int]], values: ValidationInfo
41
- ) -> Optional[Sequence[int]]:
39
+ cls, overlaps: Sequence[int] | None, values: ValidationInfo
40
+ ) -> Sequence[int] | None:
42
41
  """
43
42
  Validate overlap.
44
43
 
@@ -78,7 +77,7 @@ class _OverlappingPatchedModel(_PatchedModel):
78
77
 
79
78
  @field_validator("overlaps")
80
79
  @classmethod
81
- def overlap_even(cls, overlaps: Optional[Sequence[int]]) -> Optional[Sequence[int]]:
80
+ def overlap_even(cls, overlaps: Sequence[int] | None) -> Sequence[int] | None:
82
81
  """
83
82
  Validate overlaps.
84
83
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Any, Literal, Optional, Union
5
+ from typing import Any, Literal, Union
6
6
 
7
7
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
8
8
  from typing_extensions import Self
@@ -18,12 +18,10 @@ class InferenceConfig(BaseModel):
18
18
  data_type: Literal["array", "tiff", "czi", "custom"] # As defined in SupportedData
19
19
  """Type of input data: numpy.ndarray (array) or path (tiff, czi, or custom)."""
20
20
 
21
- tile_size: Optional[Union[list[int]]] = Field(
22
- default=None, min_length=2, max_length=3
23
- )
21
+ tile_size: Union[list[int]] | None = Field(default=None, min_length=2, max_length=3)
24
22
  """Tile size of prediction, only effective if `tile_overlap` is specified."""
25
23
 
26
- tile_overlap: Optional[Union[list[int]]] = Field(
24
+ tile_overlap: Union[list[int]] | None = Field(
27
25
  default=None, min_length=2, max_length=3
28
26
  )
29
27
  """Overlap between tiles, only effective if `tile_size` is specified."""
@@ -48,8 +46,8 @@ class InferenceConfig(BaseModel):
48
46
  @field_validator("tile_overlap")
49
47
  @classmethod
50
48
  def all_elements_non_zero_even(
51
- cls, tile_overlap: Optional[list[int]]
52
- ) -> Optional[list[int]]:
49
+ cls, tile_overlap: list[int] | None
50
+ ) -> list[int] | None:
53
51
  """
54
52
  Validate tile overlap.
55
53
 
@@ -86,9 +84,7 @@ class InferenceConfig(BaseModel):
86
84
 
87
85
  @field_validator("tile_size")
88
86
  @classmethod
89
- def tile_min_8_power_of_2(
90
- cls, tile_list: Optional[list[int]]
91
- ) -> Optional[list[int]]:
87
+ def tile_min_8_power_of_2(cls, tile_list: list[int] | None) -> list[int] | None:
92
88
  """
93
89
  Validate that each entry is greater or equal than 8 and a power of 2.
94
90
 
@@ -1,6 +1,6 @@
1
1
  """Likelihood model."""
2
2
 
3
- from typing import Annotated, Literal, Optional, Union
3
+ from typing import Annotated, Literal, Union
4
4
 
5
5
  import numpy as np
6
6
  import torch
@@ -31,7 +31,7 @@ class GaussianLikelihoodConfig(BaseModel):
31
31
 
32
32
  model_config = ConfigDict(validate_assignment=True)
33
33
 
34
- predict_logvar: Optional[Literal["pixelwise"]] = None
34
+ predict_logvar: Literal["pixelwise"] | None = None
35
35
  """If `pixelwise`, log-variance is computed for each pixel, else log-variance
36
36
  is not computed."""
37
37
 
@@ -1,7 +1,7 @@
1
1
  """Noise models config."""
2
2
 
3
3
  from pathlib import Path
4
- from typing import Annotated, Literal, Optional, Union
4
+ from typing import Annotated, Literal, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -42,21 +42,19 @@ class GaussianMixtureNMConfig(BaseModel):
42
42
  # model type
43
43
  model_type: Literal["GaussianMixtureNoiseModel"]
44
44
 
45
- path: Optional[Union[Path, str]] = None
45
+ path: Union[Path, str] | None = None
46
46
  """Path to the directory where the trained noise model (*.npz) is saved in the
47
47
  `train` method."""
48
48
 
49
49
  # TODO remove and use as parameters to the NM functions?
50
- signal: Optional[Union[str, Path, np.ndarray]] = Field(default=None, exclude=True)
50
+ signal: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
51
51
  """Path to the file containing signal or respective numpy array."""
52
52
 
53
53
  # TODO remove and use as parameters to the NM functions?
54
- observation: Optional[Union[str, Path, np.ndarray]] = Field(
55
- default=None, exclude=True
56
- )
54
+ observation: Union[str, Path, np.ndarray] | None = Field(default=None, exclude=True)
57
55
  """Path to the file containing observation or respective numpy array."""
58
56
 
59
- weight: Optional[Array] = None
57
+ weight: Array | None = None
60
58
  """A [3*n_gaussian, n_coeff] sized array containing the values of the weights
61
59
  describing the GMM noise model, with each row corresponding to one
62
60
  parameter of each gaussian, namely [mean, standard deviation and weight].
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from pprint import pformat
6
- from typing import Literal, Optional, Union
6
+ from typing import Literal, Union
7
7
 
8
8
  from pydantic import BaseModel, ConfigDict, Field, field_validator
9
9
 
@@ -41,11 +41,11 @@ class TrainingConfig(BaseModel):
41
41
  """Validation step frequency."""
42
42
  accumulate_grad_batches: int = Field(default=1, ge=1)
43
43
  """Number of batches to accumulate gradients over before stepping the optimizer."""
44
- gradient_clip_val: Optional[Union[int, float]] = None
44
+ gradient_clip_val: Union[int, float] | None = None
45
45
  """The value to which to clip the gradient"""
46
46
  gradient_clip_algorithm: Literal["value", "norm"] = "norm"
47
47
  """The algorithm to use for gradient clipping (see lightning `Trainer`)."""
48
- logger: Optional[Literal["wandb", "tensorboard"]] = None
48
+ logger: Literal["wandb", "tensorboard"] | None = None
49
49
  """Logger to use during training. If None, no logger will be used. Available
50
50
  loggers are defined in SupportedLogger."""
51
51
 
@@ -53,7 +53,7 @@ class TrainingConfig(BaseModel):
53
53
  """Checkpoint callback configuration, following PyTorch Lightning Checkpoint
54
54
  callback."""
55
55
 
56
- early_stopping_callback: Optional[EarlyStoppingModel] = Field(
56
+ early_stopping_callback: EarlyStoppingModel | None = Field(
57
57
  default=None, validate_default=True
58
58
  )
59
59
  """Early stopping callback configuration, following PyTorch Lightning Checkpoint
@@ -1,6 +1,6 @@
1
1
  """Pydantic model for the Normalize transform."""
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Literal
4
4
 
5
5
  from pydantic import ConfigDict, Field, model_validator
6
6
  from typing_extensions import Self
@@ -31,8 +31,8 @@ class NormalizeModel(TransformModel):
31
31
  name: Literal["Normalize"] = "Normalize"
32
32
  image_means: list = Field(..., min_length=0, max_length=32)
33
33
  image_stds: list = Field(..., min_length=0, max_length=32)
34
- target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
35
- target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)
34
+ target_means: list | None = Field(default=None, min_length=0, max_length=32)
35
+ target_stds: list | None = Field(default=None, min_length=0, max_length=32)
36
36
 
37
37
  @model_validator(mode="after")
38
38
  def validate_means_stds(self: Self) -> Self:
@@ -1,6 +1,6 @@
1
1
  """Pydantic model for the XYFlip transform."""
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Literal
4
4
 
5
5
  from pydantic import ConfigDict, Field
6
6
 
@@ -40,4 +40,4 @@ class XYFlipModel(TransformModel):
40
40
  ge=0,
41
41
  le=1,
42
42
  )
43
- seed: Optional[int] = None
43
+ seed: int | None = None
@@ -1,6 +1,6 @@
1
1
  """Pydantic model for the XYRandomRotate90 transform."""
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Literal
4
4
 
5
5
  from pydantic import ConfigDict, Field
6
6
 
@@ -32,4 +32,4 @@ class XYRandomRotate90Model(TransformModel):
32
32
  ge=0,
33
33
  le=1,
34
34
  )
35
- seed: Optional[int] = None
35
+ seed: int | None = None
@@ -5,7 +5,6 @@ These functions are used to validate dimensions and axes of inputs.
5
5
  """
6
6
 
7
7
  from collections.abc import Sequence
8
- from typing import Optional
9
8
 
10
9
  _AXES = "STCZYX"
11
10
 
@@ -80,7 +79,7 @@ def value_ge_than_8_power_of_2(
80
79
 
81
80
 
82
81
  def patch_size_ge_than_8_power_of_2(
83
- patch_list: Optional[Sequence[int]],
82
+ patch_list: Sequence[int] | None,
84
83
  ) -> None:
85
84
  """
86
85
  Validate that each entry is greater or equal than 8 and a power of 2.
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable, Generator
6
6
  from pathlib import Path
7
- from typing import Optional, Union
7
+ from typing import Union
8
8
 
9
9
  from numpy.typing import NDArray
10
10
  from torch.utils.data import get_worker_info
@@ -21,9 +21,9 @@ logger = get_logger(__name__)
21
21
  def iterate_over_files(
22
22
  data_config: Union[DataConfig, InferenceConfig],
23
23
  data_files: list[Path],
24
- target_files: Optional[list[Path]] = None,
24
+ target_files: list[Path] | None = None,
25
25
  read_source_func: Callable = read_tiff,
26
- ) -> Generator[tuple[NDArray, Optional[NDArray]], None, None]:
26
+ ) -> Generator[tuple[NDArray, NDArray | None], None, None]:
27
27
  """Iterate over data source and yield whole reshaped images.
28
28
 
29
29
  Parameters
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import copy
6
6
  from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import Any, Optional, Union
8
+ from typing import Any, Union
9
9
 
10
10
  import numpy as np
11
11
  from torch.utils.data import Dataset
@@ -49,7 +49,7 @@ class InMemoryDataset(Dataset):
49
49
  self,
50
50
  data_config: DataConfig,
51
51
  inputs: Union[np.ndarray, list[Path]],
52
- input_target: Optional[Union[np.ndarray, list[Path]]] = None,
52
+ input_target: Union[np.ndarray, list[Path]] | None = None,
53
53
  read_source_func: Callable = read_tiff,
54
54
  **kwargs: Any,
55
55
  ) -> None:
@@ -5,7 +5,6 @@ from __future__ import annotations
5
5
  import copy
6
6
  from collections.abc import Callable, Generator
7
7
  from pathlib import Path
8
- from typing import Optional
9
8
 
10
9
  import numpy as np
11
10
  from torch.utils.data import IterableDataset
@@ -51,7 +50,7 @@ class PathIterableDataset(IterableDataset):
51
50
  self,
52
51
  data_config: DataConfig,
53
52
  src_files: list[Path],
54
- target_files: Optional[list[Path]] = None,
53
+ target_files: list[Path] | None = None,
55
54
  read_source_func: Callable = read_tiff,
56
55
  ) -> None:
57
56
  """Constructors.
@@ -1,7 +1,7 @@
1
1
  """Random patching utilities."""
2
2
 
3
3
  from collections.abc import Generator
4
- from typing import Optional, Union
4
+ from typing import Union
5
5
 
6
6
  import numpy as np
7
7
  import zarr
@@ -13,9 +13,9 @@ from .validate_patch_dimension import validate_patch_dimensions
13
13
  def extract_patches_random(
14
14
  arr: np.ndarray,
15
15
  patch_size: Union[list[int], tuple[int, ...]],
16
- target: Optional[np.ndarray] = None,
17
- seed: Optional[int] = None,
18
- ) -> Generator[tuple[np.ndarray, Optional[np.ndarray]], None, None]:
16
+ target: np.ndarray | None = None,
17
+ seed: int | None = None,
18
+ ) -> Generator[tuple[np.ndarray, np.ndarray | None], None, None]:
19
19
  """
20
20
  Generate patches from an array in a random manner.
21
21
 
@@ -115,8 +115,8 @@ def extract_patches_random_from_chunks(
115
115
  arr: zarr.Array,
116
116
  patch_size: Union[list[int], tuple[int, ...]],
117
117
  chunk_size: Union[list[int], tuple[int, ...]],
118
- chunk_limit: Optional[int] = None,
119
- seed: Optional[int] = None,
118
+ chunk_limit: int | None = None,
119
+ seed: int | None = None,
120
120
  ) -> Generator[np.ndarray, None, None]:
121
121
  """
122
122
  Generate patches from an array in a random manner.
@@ -1,6 +1,6 @@
1
1
  """Sequential patching functions."""
2
2
 
3
- from typing import Optional, Union
3
+ from typing import Union
4
4
 
5
5
  import numpy as np
6
6
  from skimage.util import view_as_windows
@@ -110,7 +110,7 @@ def _compute_patch_views(
110
110
  window_shape: list[int],
111
111
  step: tuple[int, ...],
112
112
  output_shape: list[int],
113
- target: Optional[np.ndarray] = None,
113
+ target: np.ndarray | None = None,
114
114
  ) -> np.ndarray:
115
115
  """
116
116
  Compute views of an array corresponding to patches.
@@ -151,8 +151,8 @@ def _compute_patch_views(
151
151
  def extract_patches_sequential(
152
152
  arr: np.ndarray,
153
153
  patch_size: Union[list[int], tuple[int, ...]],
154
- target: Optional[np.ndarray] = None,
155
- ) -> tuple[np.ndarray, Optional[np.ndarray]]:
154
+ target: np.ndarray | None = None,
155
+ ) -> tuple[np.ndarray, np.ndarray | None]:
156
156
  """
157
157
  Generate patches from an array in a sequential manner.
158
158
 
@@ -3,7 +3,7 @@
3
3
  import builtins
4
4
  import itertools
5
5
  from collections.abc import Generator
6
- from typing import Any, Optional, Union
6
+ from typing import Any, Union
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -16,7 +16,7 @@ def extract_tiles(
16
16
  arr: NDArray,
17
17
  tile_size: NDArray[np.int_],
18
18
  overlaps: NDArray[np.int_],
19
- padding_kwargs: Optional[dict[str, Any]] = None,
19
+ padding_kwargs: dict[str, Any] | None = None,
20
20
  ) -> Generator[tuple[NDArray, TileInformation], None, None]:
21
21
  """Generate tiles from the input array with specified overlap.
22
22
 
@@ -1,7 +1,7 @@
1
1
  from collections.abc import Sequence
2
2
  from enum import Enum
3
3
  from pathlib import Path
4
- from typing import Any, Generic, Literal, NamedTuple, Optional, Union
4
+ from typing import Any, Generic, Literal, NamedTuple, Union
5
5
 
6
6
  import numpy as np
7
7
  from numpy.typing import NDArray
@@ -51,7 +51,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
51
51
  data_config: NGDataConfig,
52
52
  mode: Mode,
53
53
  input_extractor: PatchExtractor[GenericImageStack],
54
- target_extractor: Optional[PatchExtractor[GenericImageStack]] = None,
54
+ target_extractor: PatchExtractor[GenericImageStack] | None = None,
55
55
  ):
56
56
  self.config = data_config
57
57
  self.mode = mode
@@ -115,7 +115,7 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
115
115
 
116
116
  return patching_strategy
117
117
 
118
- def _initialize_transforms(self) -> Optional[Compose]:
118
+ def _initialize_transforms(self) -> Compose | None:
119
119
  normalize = NormalizeModel(
120
120
  image_means=self.input_stats.means,
121
121
  image_stds=self.input_stats.stds,
@@ -1,7 +1,7 @@
1
1
  from collections.abc import Sequence
2
2
  from enum import Enum
3
3
  from pathlib import Path
4
- from typing import Any, Optional
4
+ from typing import Any
5
5
 
6
6
  from numpy.typing import NDArray
7
7
  from typing_extensions import ParamSpec
@@ -48,8 +48,8 @@ class DatasetType(Enum):
48
48
  def determine_dataset_type(
49
49
  data_type: SupportedData,
50
50
  in_memory: bool,
51
- read_func: Optional[ReadFunc] = None,
52
- image_stack_loader: Optional[ImageStackLoader] = None,
51
+ read_func: ReadFunc | None = None,
52
+ image_stack_loader: ImageStackLoader | None = None,
53
53
  ) -> DatasetType:
54
54
  """Determine what the dataset type should be based on the input arguments.
55
55
 
@@ -121,10 +121,10 @@ def create_dataset(
121
121
  inputs: Any,
122
122
  targets: Any,
123
123
  in_memory: bool,
124
- read_func: Optional[ReadFunc] = None,
125
- read_kwargs: Optional[dict[str, Any]] = None,
126
- image_stack_loader: Optional[ImageStackLoader] = None,
127
- image_stack_loader_kwargs: Optional[dict[str, Any]] = None,
124
+ read_func: ReadFunc | None = None,
125
+ read_kwargs: dict[str, Any] | None = None,
126
+ image_stack_loader: ImageStackLoader | None = None,
127
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
128
128
  ) -> CareamicsDataset[ImageStack]:
129
129
  """
130
130
  Convenience function to create the CAREamicsDataset.
@@ -201,7 +201,7 @@ def create_array_dataset(
201
201
  config: NGDataConfig,
202
202
  mode: Mode,
203
203
  inputs: Sequence[NDArray[Any]],
204
- targets: Optional[Sequence[NDArray[Any]]],
204
+ targets: Sequence[NDArray[Any]] | None,
205
205
  ) -> CareamicsDataset[InMemoryImageStack]:
206
206
  """
207
207
  Create a CAREamicsDataset from array data.
@@ -223,7 +223,7 @@ def create_array_dataset(
223
223
  A CAREamicsDataset.
224
224
  """
225
225
  input_extractor = create_array_extractor(source=inputs, axes=config.axes)
226
- target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
226
+ target_extractor: PatchExtractor[InMemoryImageStack] | None
227
227
  if targets is not None:
228
228
  target_extractor = create_array_extractor(source=targets, axes=config.axes)
229
229
  else:
@@ -235,7 +235,7 @@ def create_tiff_dataset(
235
235
  config: NGDataConfig,
236
236
  mode: Mode,
237
237
  inputs: Sequence[Path],
238
- targets: Optional[Sequence[Path]],
238
+ targets: Sequence[Path] | None,
239
239
  ) -> CareamicsDataset[InMemoryImageStack]:
240
240
  """
241
241
  Create a CAREamicsDataset from tiff files that will be all loaded into memory.
@@ -260,7 +260,7 @@ def create_tiff_dataset(
260
260
  source=inputs,
261
261
  axes=config.axes,
262
262
  )
263
- target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
263
+ target_extractor: PatchExtractor[InMemoryImageStack] | None
264
264
  if targets is not None:
265
265
  target_extractor = create_tiff_extractor(source=targets, axes=config.axes)
266
266
  else:
@@ -273,7 +273,7 @@ def create_czi_dataset(
273
273
  config: NGDataConfig,
274
274
  mode: Mode,
275
275
  inputs: Sequence[Path],
276
- targets: Optional[Sequence[Path]],
276
+ targets: Sequence[Path] | None,
277
277
  ) -> CareamicsDataset[CziImageStack]:
278
278
  """
279
279
  Create a dataset from CZI files.
@@ -296,7 +296,7 @@ def create_czi_dataset(
296
296
  """
297
297
 
298
298
  input_extractor = create_czi_extractor(source=inputs, axes=config.axes)
299
- target_extractor: Optional[PatchExtractor[CziImageStack]]
299
+ target_extractor: PatchExtractor[CziImageStack] | None
300
300
  if targets is not None:
301
301
  target_extractor = create_czi_extractor(source=targets, axes=config.axes)
302
302
  else:
@@ -309,7 +309,7 @@ def create_ome_zarr_dataset(
309
309
  config: NGDataConfig,
310
310
  mode: Mode,
311
311
  inputs: Sequence[Path],
312
- targets: Optional[Sequence[Path]],
312
+ targets: Sequence[Path] | None,
313
313
  ) -> CareamicsDataset[ZarrImageStack]:
314
314
  """
315
315
  Create a dataset from OME ZARR files.
@@ -332,7 +332,7 @@ def create_ome_zarr_dataset(
332
332
  """
333
333
 
334
334
  input_extractor = create_ome_zarr_extractor(source=inputs, axes=config.axes)
335
- target_extractor: Optional[PatchExtractor[ZarrImageStack]]
335
+ target_extractor: PatchExtractor[ZarrImageStack] | None
336
336
  if targets is not None:
337
337
  target_extractor = create_ome_zarr_extractor(source=targets, axes=config.axes)
338
338
  else:
@@ -345,7 +345,7 @@ def create_custom_file_dataset(
345
345
  config: NGDataConfig,
346
346
  mode: Mode,
347
347
  inputs: Sequence[Path],
348
- targets: Optional[Sequence[Path]],
348
+ targets: Sequence[Path] | None,
349
349
  *,
350
350
  read_func: ReadFunc,
351
351
  read_kwargs: dict[str, Any],
@@ -378,7 +378,7 @@ def create_custom_file_dataset(
378
378
  input_extractor = create_custom_file_extractor(
379
379
  source=inputs, axes=config.axes, read_func=read_func, read_kwargs=read_kwargs
380
380
  )
381
- target_extractor: Optional[PatchExtractor[InMemoryImageStack]]
381
+ target_extractor: PatchExtractor[InMemoryImageStack] | None
382
382
  if targets is not None:
383
383
  target_extractor = create_custom_file_extractor(
384
384
  source=targets,
@@ -396,7 +396,7 @@ def create_custom_image_stack_dataset(
396
396
  config: NGDataConfig,
397
397
  mode: Mode,
398
398
  inputs: Any,
399
- targets: Optional[Any],
399
+ targets: Any | None,
400
400
  image_stack_loader: ImageStackLoader[P, GenericImageStack],
401
401
  *args: P.args,
402
402
  **kwargs: P.kwargs,
@@ -436,7 +436,7 @@ def create_custom_image_stack_dataset(
436
436
  *args,
437
437
  **kwargs,
438
438
  )
439
- target_extractor: Optional[PatchExtractor[GenericImageStack]]
439
+ target_extractor: PatchExtractor[GenericImageStack] | None
440
440
  if targets is not None:
441
441
  target_extractor = create_custom_image_stack_extractor(
442
442
  targets,
@@ -1,7 +1,6 @@
1
1
  """A module for random patching strategies."""
2
2
 
3
3
  from collections.abc import Sequence
4
- from typing import Optional
5
4
 
6
5
  import numpy as np
7
6
 
@@ -31,7 +30,7 @@ class RandomPatchingStrategy:
31
30
  self,
32
31
  data_shapes: Sequence[Sequence[int]],
33
32
  patch_size: Sequence[int],
34
- seed: Optional[int] = None,
33
+ seed: int | None = None,
35
34
  ):
36
35
  """
37
36
  A patching strategy for sampling random patches.
@@ -193,7 +192,7 @@ class FixedRandomPatchingStrategy:
193
192
  self,
194
193
  data_shapes: Sequence[Sequence[int]],
195
194
  patch_size: Sequence[int],
196
- seed: Optional[int] = None,
195
+ seed: int | None = None,
197
196
  ):
198
197
  """A patching strategy for sampling random patches.
199
198
 
@@ -302,7 +301,7 @@ def _generate_random_coords(
302
301
  rng.integers(
303
302
  np.zeros(len(patch_size), dtype=int),
304
303
  np.array(spatial_shape) - np.array(patch_size),
305
- endpoint=False,
304
+ endpoint=True,
306
305
  dtype=int,
307
306
  ).tolist()
308
307
  )
@@ -1,6 +1,5 @@
1
1
  import itertools
2
2
  from collections.abc import Sequence
3
- from typing import Optional
4
3
 
5
4
  import numpy as np
6
5
  from typing_extensions import ParamSpec
@@ -18,7 +17,7 @@ class SequentialPatchingStrategy:
18
17
  self,
19
18
  data_shapes: Sequence[Sequence[int]],
20
19
  patch_size: Sequence[int],
21
- overlaps: Optional[Sequence[int]] = None,
20
+ overlaps: Sequence[int] | None = None,
22
21
  ):
23
22
  self.data_shapes = data_shapes
24
23
  self.patch_size = patch_size