careamics 0.0.4.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/__init__.py +17 -2
  2. careamics/careamist.py +239 -28
  3. careamics/cli/conf.py +19 -31
  4. careamics/cli/main.py +112 -12
  5. careamics/cli/utils.py +29 -0
  6. careamics/config/__init__.py +48 -24
  7. careamics/config/algorithms/__init__.py +15 -0
  8. careamics/config/algorithms/care_algorithm_model.py +50 -0
  9. careamics/config/algorithms/n2n_algorithm_model.py +42 -0
  10. careamics/config/algorithms/n2v_algorithm_model.py +35 -0
  11. careamics/config/algorithms/unet_algorithm_model.py +88 -0
  12. careamics/config/{vae_algorithm_model.py → algorithms/vae_algorithm_model.py} +26 -23
  13. careamics/config/architectures/__init__.py +1 -11
  14. careamics/config/architectures/architecture_model.py +3 -3
  15. careamics/config/architectures/lvae_model.py +109 -21
  16. careamics/config/architectures/unet_model.py +1 -0
  17. careamics/config/care_configuration.py +100 -0
  18. careamics/config/configuration.py +354 -0
  19. careamics/config/{configuration_factory.py → configuration_factories.py} +152 -81
  20. careamics/config/configuration_io.py +85 -0
  21. careamics/config/data/__init__.py +10 -0
  22. careamics/config/{data_model.py → data/data_model.py} +58 -198
  23. careamics/config/data/n2v_data_model.py +193 -0
  24. careamics/config/likelihood_model.py +8 -8
  25. careamics/config/loss_model.py +56 -0
  26. careamics/config/n2n_configuration.py +101 -0
  27. careamics/config/n2v_configuration.py +266 -0
  28. careamics/config/nm_model.py +24 -25
  29. careamics/config/support/__init__.py +7 -7
  30. careamics/config/support/supported_algorithms.py +0 -3
  31. careamics/config/support/supported_architectures.py +0 -4
  32. careamics/config/transformations/__init__.py +10 -4
  33. careamics/config/transformations/transform_model.py +3 -3
  34. careamics/config/transformations/transform_unions.py +42 -0
  35. careamics/config/validators/validator_utils.py +3 -3
  36. careamics/dataset/__init__.py +2 -2
  37. careamics/dataset/dataset_utils/__init__.py +3 -3
  38. careamics/dataset/dataset_utils/dataset_utils.py +4 -6
  39. careamics/dataset/dataset_utils/file_utils.py +9 -9
  40. careamics/dataset/dataset_utils/iterate_over_files.py +4 -3
  41. careamics/dataset/dataset_utils/running_stats.py +22 -23
  42. careamics/dataset/in_memory_dataset.py +11 -12
  43. careamics/dataset/iterable_dataset.py +4 -4
  44. careamics/dataset/iterable_pred_dataset.py +2 -1
  45. careamics/dataset/iterable_tiled_pred_dataset.py +2 -1
  46. careamics/dataset/patching/random_patching.py +11 -10
  47. careamics/dataset/patching/sequential_patching.py +26 -26
  48. careamics/dataset/patching/validate_patch_dimension.py +3 -3
  49. careamics/dataset/tiling/__init__.py +2 -2
  50. careamics/dataset/tiling/collate_tiles.py +3 -3
  51. careamics/dataset/tiling/lvae_tiled_patching.py +2 -1
  52. careamics/dataset/tiling/tiled_patching.py +11 -10
  53. careamics/file_io/__init__.py +5 -5
  54. careamics/file_io/read/__init__.py +1 -1
  55. careamics/file_io/read/get_func.py +2 -2
  56. careamics/file_io/write/__init__.py +2 -2
  57. careamics/lightning/__init__.py +5 -5
  58. careamics/lightning/callbacks/__init__.py +1 -1
  59. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +3 -3
  60. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +2 -1
  61. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +2 -1
  62. careamics/lightning/callbacks/progress_bar_callback.py +2 -2
  63. careamics/lightning/lightning_module.py +69 -34
  64. careamics/lightning/train_data_module.py +41 -27
  65. careamics/losses/__init__.py +3 -3
  66. careamics/losses/loss_factory.py +1 -85
  67. careamics/losses/lvae/losses.py +223 -164
  68. careamics/lvae_training/calibration.py +184 -0
  69. careamics/lvae_training/dataset/config.py +2 -2
  70. careamics/lvae_training/dataset/multich_dataset.py +11 -19
  71. careamics/lvae_training/dataset/multifile_dataset.py +3 -2
  72. careamics/lvae_training/dataset/types.py +15 -26
  73. careamics/lvae_training/dataset/utils/index_manager.py +4 -4
  74. careamics/lvae_training/eval_utils.py +125 -213
  75. careamics/model_io/__init__.py +1 -1
  76. careamics/model_io/bioimage/__init__.py +1 -1
  77. careamics/model_io/bioimage/_readme_factory.py +26 -34
  78. careamics/model_io/bioimage/cover_factory.py +171 -0
  79. careamics/model_io/bioimage/model_description.py +56 -34
  80. careamics/model_io/bmz_io.py +42 -42
  81. careamics/model_io/model_io_utils.py +9 -9
  82. careamics/models/layers.py +22 -20
  83. careamics/models/lvae/layers.py +348 -975
  84. careamics/models/lvae/likelihoods.py +10 -8
  85. careamics/models/lvae/lvae.py +214 -275
  86. careamics/models/lvae/noise_models.py +179 -112
  87. careamics/models/lvae/stochastic.py +393 -0
  88. careamics/models/lvae/utils.py +82 -73
  89. careamics/models/model_factory.py +2 -15
  90. careamics/models/unet.py +8 -8
  91. careamics/prediction_utils/__init__.py +1 -1
  92. careamics/prediction_utils/prediction_outputs.py +15 -15
  93. careamics/prediction_utils/stitch_prediction.py +6 -6
  94. careamics/transforms/__init__.py +5 -5
  95. careamics/transforms/compose.py +13 -13
  96. careamics/transforms/n2v_manipulate.py +3 -3
  97. careamics/transforms/pixel_manipulation.py +9 -9
  98. careamics/transforms/xy_random_rotate90.py +4 -4
  99. careamics/utils/__init__.py +5 -5
  100. careamics/utils/context.py +2 -1
  101. careamics/utils/lightning_utils.py +57 -0
  102. careamics/utils/logging.py +11 -10
  103. careamics/utils/serializers.py +2 -0
  104. careamics/utils/torch_utils.py +8 -8
  105. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/METADATA +16 -13
  106. careamics-0.0.6.dist-info/RECORD +176 -0
  107. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/WHEEL +1 -1
  108. careamics/config/architectures/custom_model.py +0 -162
  109. careamics/config/architectures/register_model.py +0 -103
  110. careamics/config/configuration_model.py +0 -603
  111. careamics/config/fcn_algorithm_model.py +0 -152
  112. careamics/config/references/__init__.py +0 -45
  113. careamics/config/references/algorithm_descriptions.py +0 -132
  114. careamics/config/references/references.py +0 -39
  115. careamics/config/transformations/transform_union.py +0 -20
  116. careamics-0.0.4.2.dist-info/RECORD +0 -165
  117. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/entry_points.txt +0 -0
  118. {careamics-0.0.4.2.dist-info → careamics-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,7 @@
2
2
 
3
3
  from fnmatch import fnmatch
4
4
  from pathlib import Path
5
- from typing import List, Union
5
+ from typing import Union
6
6
 
7
7
  import numpy as np
8
8
 
@@ -12,12 +12,12 @@ from careamics.utils.logging import get_logger
12
12
  logger = get_logger(__name__)
13
13
 
14
14
 
15
- def get_files_size(files: List[Path]) -> float:
15
+ def get_files_size(files: list[Path]) -> float:
16
16
  """Get files size in MB.
17
17
 
18
18
  Parameters
19
19
  ----------
20
- files : List[Path]
20
+ files : list of pathlib.Path
21
21
  List of files.
22
22
 
23
23
  Returns
@@ -32,7 +32,7 @@ def list_files(
32
32
  data_path: Union[str, Path],
33
33
  data_type: Union[str, SupportedData],
34
34
  extension_filter: str = "",
35
- ) -> List[Path]:
35
+ ) -> list[Path]:
36
36
  """List recursively files in `data_path` and return a sorted list.
37
37
 
38
38
  If `data_path` is a file, its name is validated against the `data_type` using
@@ -55,8 +55,8 @@ def list_files(
55
55
 
56
56
  Returns
57
57
  -------
58
- List[Path]
59
- List of pathlib.Path objects.
58
+ list[Path]
59
+ list of pathlib.Path objects.
60
60
 
61
61
  Raises
62
62
  ------
@@ -105,7 +105,7 @@ def list_files(
105
105
  return files
106
106
 
107
107
 
108
- def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -> None:
108
+ def validate_source_target_files(src_files: list[Path], tar_files: list[Path]) -> None:
109
109
  """
110
110
  Validate source and target path lists.
111
111
 
@@ -113,9 +113,9 @@ def validate_source_target_files(src_files: List[Path], tar_files: List[Path]) -
113
113
 
114
114
  Parameters
115
115
  ----------
116
- src_files : List[Path]
116
+ src_files : list of pathlib.Path
117
117
  List of source files.
118
- tar_files : List[Path]
118
+ tar_files : list of pathlib.Path
119
119
  List of target files.
120
120
 
121
121
  Raises
@@ -2,13 +2,14 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Generator
5
6
  from pathlib import Path
6
- from typing import Callable, Generator, Optional, Union
7
+ from typing import Callable, Optional, Union
7
8
 
8
9
  from numpy.typing import NDArray
9
10
  from torch.utils.data import get_worker_info
10
11
 
11
- from careamics.config import DataConfig, InferenceConfig
12
+ from careamics.config import GeneralDataConfig, InferenceConfig
12
13
  from careamics.file_io.read import read_tiff
13
14
  from careamics.utils.logging import get_logger
14
15
 
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
18
19
 
19
20
 
20
21
  def iterate_over_files(
21
- data_config: Union[DataConfig, InferenceConfig],
22
+ data_config: Union[GeneralDataConfig, InferenceConfig],
22
23
  data_files: list[Path],
23
24
  target_files: Optional[list[Path]] = None,
24
25
  read_source_func: Callable = read_tiff,
@@ -34,36 +34,35 @@ def update_iterative_stats(
34
34
  Parameters
35
35
  ----------
36
36
  count : NDArray
37
- Number of elements in the array.
37
+ Number of elements in the array. Shape: (C,).
38
38
  mean : NDArray
39
- Mean of the array.
39
+ Mean of the array. Shape: (C,).
40
40
  m2 : NDArray
41
- Variance of the array.
41
+ Variance of the array. Shape: (C,).
42
42
  new_values : NDArray
43
- New values to add to the mean and variance.
43
+ New values to add to the mean and variance. Shape: (C, 1, 1, Z, Y, X).
44
44
 
45
45
  Returns
46
46
  -------
47
47
  tuple[NDArray, NDArray, NDArray]
48
48
  Updated count, mean, and variance.
49
49
  """
50
- count += np.array([np.prod(channel.shape) for channel in new_values])
51
- # newvalues - oldMean
52
- delta = [
53
- np.subtract(v.flatten(), [m] * len(v.flatten()))
54
- for v, m in zip(new_values, mean)
55
- ]
50
+ num_channels = len(new_values)
56
51
 
57
- mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
58
- # newvalues - newMeant
59
- delta2 = [
60
- np.subtract(v.flatten(), [m] * len(v.flatten()))
61
- for v, m in zip(new_values, mean)
62
- ]
52
+ # --- update channel-wise counts ---
53
+ count += np.ones_like(count) * np.prod(new_values.shape[1:])
63
54
 
64
- m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])
55
+ # --- update channel-wise mean ---
56
+ # compute (new_values - old_mean) -> shape: (C, Z*Y*X)
57
+ delta = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
58
+ mean += np.sum(delta / count.reshape(num_channels, 1), axis=1)
65
59
 
66
- return (count, mean, m2)
60
+ # --- update channel-wise SoS ---
61
+ # compute (new_values - new_mean) -> shape: (C, Z*Y*X)
62
+ delta2 = new_values.reshape(num_channels, -1) - mean.reshape(num_channels, 1)
63
+ m2 += np.sum(delta * delta2, axis=1)
64
+
65
+ return count, mean, m2
67
66
 
68
67
 
69
68
  def finalize_iterative_stats(
@@ -74,18 +73,18 @@ def finalize_iterative_stats(
74
73
  Parameters
75
74
  ----------
76
75
  count : NDArray
77
- Number of elements in the array.
76
+ Number of elements in the array. Shape: (C,).
78
77
  mean : NDArray
79
- Mean of the array.
78
+ Mean of the array. Shape: (C,).
80
79
  m2 : NDArray
81
- Variance of the array.
80
+ Variance of the array. Shape: (C,).
82
81
 
83
82
  Returns
84
83
  -------
85
84
  tuple[NDArray, NDArray]
86
- Final mean and standard deviation.
85
+ Final channel-wise mean and standard deviation.
87
86
  """
88
- std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
87
+ std = np.sqrt(m2 / count)
89
88
  if any(c < 2 for c in count):
90
89
  return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
91
90
  else:
@@ -9,13 +9,9 @@ from typing import Any, Callable, Optional, Union
9
9
  import numpy as np
10
10
  from torch.utils.data import Dataset
11
11
 
12
- from careamics.file_io.read import read_tiff
13
- from careamics.transforms import Compose
14
-
15
- from ..config import DataConfig
16
- from ..config.transformations import NormalizeModel
17
- from ..utils.logging import get_logger
18
- from .patching.patching import (
12
+ from careamics.config import GeneralDataConfig, N2VDataConfig
13
+ from careamics.config.transformations import NormalizeModel
14
+ from careamics.dataset.patching.patching import (
19
15
  PatchedOutput,
20
16
  Stats,
21
17
  prepare_patches_supervised,
@@ -23,6 +19,9 @@ from .patching.patching import (
23
19
  prepare_patches_unsupervised,
24
20
  prepare_patches_unsupervised_array,
25
21
  )
22
+ from careamics.file_io.read import read_tiff
23
+ from careamics.transforms import Compose
24
+ from careamics.utils.logging import get_logger
26
25
 
27
26
  logger = get_logger(__name__)
28
27
 
@@ -47,7 +46,7 @@ class InMemoryDataset(Dataset):
47
46
 
48
47
  def __init__(
49
48
  self,
50
- data_config: DataConfig,
49
+ data_config: GeneralDataConfig,
51
50
  inputs: Union[np.ndarray, list[Path]],
52
51
  input_target: Optional[Union[np.ndarray, list[Path]]] = None,
53
52
  read_source_func: Callable = read_tiff,
@@ -58,7 +57,7 @@ class InMemoryDataset(Dataset):
58
57
 
59
58
  Parameters
60
59
  ----------
61
- data_config : DataConfig
60
+ data_config : GeneralDataConfig
62
61
  Data configuration.
63
62
  inputs : numpy.ndarray or list[pathlib.Path]
64
63
  Input data.
@@ -124,7 +123,7 @@ class InMemoryDataset(Dataset):
124
123
  target_stds=self.target_stats.stds,
125
124
  )
126
125
  ]
127
- + self.data_config.transforms,
126
+ + list(self.data_config.transforms),
128
127
  )
129
128
 
130
129
  def _prepare_patches(self, supervised: bool) -> PatchedOutput:
@@ -219,12 +218,12 @@ class InMemoryDataset(Dataset):
219
218
 
220
219
  return self.patch_transform(patch=patch, target=target)
221
220
 
222
- elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
221
+ elif isinstance(self.data_config, N2VDataConfig):
223
222
  return self.patch_transform(patch=patch)
224
223
  else:
225
224
  raise ValueError(
226
225
  "Something went wrong! No target provided (not supervised training) "
227
- "and no N2V manipulation (no N2V training)."
226
+ "while the algorithm is not Noise2Void."
228
227
  )
229
228
 
230
229
  def get_data_statistics(self) -> tuple[list[float], list[float]]:
@@ -10,7 +10,7 @@ from typing import Callable, Optional
10
10
  import numpy as np
11
11
  from torch.utils.data import IterableDataset
12
12
 
13
- from careamics.config import DataConfig
13
+ from careamics.config import GeneralDataConfig
14
14
  from careamics.config.transformations import NormalizeModel
15
15
  from careamics.file_io.read import read_tiff
16
16
  from careamics.transforms import Compose
@@ -49,7 +49,7 @@ class PathIterableDataset(IterableDataset):
49
49
 
50
50
  def __init__(
51
51
  self,
52
- data_config: DataConfig,
52
+ data_config: GeneralDataConfig,
53
53
  src_files: list[Path],
54
54
  target_files: Optional[list[Path]] = None,
55
55
  read_source_func: Callable = read_tiff,
@@ -58,7 +58,7 @@ class PathIterableDataset(IterableDataset):
58
58
 
59
59
  Parameters
60
60
  ----------
61
- data_config : DataConfig
61
+ data_config : GeneralDataConfig
62
62
  Data configuration.
63
63
  src_files : list[Path]
64
64
  List of data files.
@@ -115,7 +115,7 @@ class PathIterableDataset(IterableDataset):
115
115
  target_stds=self.target_stats.stds,
116
116
  )
117
117
  ]
118
- + data_config.transforms
118
+ + list(data_config.transforms)
119
119
  )
120
120
 
121
121
  def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Generator
5
6
  from pathlib import Path
6
- from typing import Any, Callable, Generator
7
+ from typing import Any, Callable
7
8
 
8
9
  from numpy.typing import NDArray
9
10
  from torch.utils.data import IterableDataset
@@ -2,8 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Generator
5
6
  from pathlib import Path
6
- from typing import Any, Callable, Generator
7
+ from typing import Any, Callable
7
8
 
8
9
  from numpy.typing import NDArray
9
10
  from torch.utils.data import IterableDataset
@@ -1,6 +1,7 @@
1
1
  """Random patching utilities."""
2
2
 
3
- from typing import Generator, List, Optional, Tuple, Union
3
+ from collections.abc import Generator
4
+ from typing import Optional, Union
4
5
 
5
6
  import numpy as np
6
7
  import zarr
@@ -11,10 +12,10 @@ from .validate_patch_dimension import validate_patch_dimensions
11
12
  # TOOD split in testable functions
12
13
  def extract_patches_random(
13
14
  arr: np.ndarray,
14
- patch_size: Union[List[int], Tuple[int, ...]],
15
+ patch_size: Union[list[int], tuple[int, ...]],
15
16
  target: Optional[np.ndarray] = None,
16
17
  seed: Optional[int] = None,
17
- ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
18
+ ) -> Generator[tuple[np.ndarray, Optional[np.ndarray]], None, None]:
18
19
  """
19
20
  Generate patches from an array in a random manner.
20
21
 
@@ -31,12 +32,12 @@ def extract_patches_random(
31
32
  ----------
32
33
  arr : np.ndarray
33
34
  Input image array.
34
- patch_size : Tuple[int]
35
+ patch_size : tuple of int
35
36
  Patch sizes in each dimension.
36
37
  target : Optional[np.ndarray], optional
37
38
  Target array, by default None.
38
- seed : Optional[int], optional
39
- Random seed, by default None.
39
+ seed : int or None, default=None
40
+ Random seed.
40
41
 
41
42
  Yields
42
43
  ------
@@ -112,8 +113,8 @@ def extract_patches_random(
112
113
 
113
114
  def extract_patches_random_from_chunks(
114
115
  arr: zarr.Array,
115
- patch_size: Union[List[int], Tuple[int, ...]],
116
- chunk_size: Union[List[int], Tuple[int, ...]],
116
+ patch_size: Union[list[int], tuple[int, ...]],
117
+ chunk_size: Union[list[int], tuple[int, ...]],
117
118
  chunk_limit: Optional[int] = None,
118
119
  seed: Optional[int] = None,
119
120
  ) -> Generator[np.ndarray, None, None]:
@@ -127,9 +128,9 @@ def extract_patches_random_from_chunks(
127
128
  ----------
128
129
  arr : np.ndarray
129
130
  Input image array.
130
- patch_size : Union[List[int], Tuple[int, ...]]
131
+ patch_size : Union[list[int], tuple[int, ...]]
131
132
  Patch sizes in each dimension.
132
- chunk_size : Union[List[int], Tuple[int, ...]]
133
+ chunk_size : Union[list[int], tuple[int, ...]]
133
134
  Chunk sizes to load from the.
134
135
  chunk_limit : Optional[int], optional
135
136
  Number of chunks to load, by default None.
@@ -1,6 +1,6 @@
1
1
  """Sequential patching functions."""
2
2
 
3
- from typing import List, Optional, Tuple, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  from skimage.util import view_as_windows
@@ -9,21 +9,21 @@ from .validate_patch_dimension import validate_patch_dimensions
9
9
 
10
10
 
11
11
  def _compute_number_of_patches(
12
- arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
13
- ) -> Tuple[int, ...]:
12
+ arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
13
+ ) -> tuple[int, ...]:
14
14
  """
15
15
  Compute the number of patches that fit in each dimension.
16
16
 
17
17
  Parameters
18
18
  ----------
19
- arr_shape : Tuple[int, ...]
19
+ arr_shape : tuple[int, ...]
20
20
  Shape of the input array.
21
- patch_sizes : Union[List[int], Tuple[int, ...]
21
+ patch_sizes : Union[list[int], tuple[int, ...]
22
22
  Shape of the patches.
23
23
 
24
24
  Returns
25
25
  -------
26
- Tuple[int, ...]
26
+ tuple[int, ...]
27
27
  Number of patches in each dimension.
28
28
  """
29
29
  if len(arr_shape) != len(patch_sizes):
@@ -47,8 +47,8 @@ def _compute_number_of_patches(
47
47
 
48
48
 
49
49
  def _compute_overlap(
50
- arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
51
- ) -> Tuple[int, ...]:
50
+ arr_shape: tuple[int, ...], patch_sizes: Union[list[int], tuple[int, ...]]
51
+ ) -> tuple[int, ...]:
52
52
  """
53
53
  Compute the overlap between patches in each dimension.
54
54
 
@@ -57,14 +57,14 @@ def _compute_overlap(
57
57
 
58
58
  Parameters
59
59
  ----------
60
- arr_shape : Tuple[int, ...]
60
+ arr_shape : tuple[int, ...]
61
61
  Input array shape.
62
- patch_sizes : Union[List[int], Tuple[int, ...]]
62
+ patch_sizes : Union[list[int], tuple[int, ...]]
63
63
  Size of the patches.
64
64
 
65
65
  Returns
66
66
  -------
67
- Tuple[int, ...]
67
+ tuple[int, ...]
68
68
  Overlap between patches in each dimension.
69
69
  """
70
70
  n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
@@ -80,21 +80,21 @@ def _compute_overlap(
80
80
 
81
81
 
82
82
  def _compute_patch_steps(
83
- patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
84
- ) -> Tuple[int, ...]:
83
+ patch_sizes: Union[list[int], tuple[int, ...]], overlaps: tuple[int, ...]
84
+ ) -> tuple[int, ...]:
85
85
  """
86
86
  Compute steps between patches.
87
87
 
88
88
  Parameters
89
89
  ----------
90
- patch_sizes : Tuple[int]
90
+ patch_sizes : tuple[int]
91
91
  Size of the patches.
92
- overlaps : Tuple[int]
92
+ overlaps : tuple[int]
93
93
  Overlap between patches.
94
94
 
95
95
  Returns
96
96
  -------
97
- Tuple[int]
97
+ tuple[int]
98
98
  Steps between patches.
99
99
  """
100
100
  steps = [
@@ -107,9 +107,9 @@ def _compute_patch_steps(
107
107
  # TODO why stack the target here and not on a different dimension before this function?
108
108
  def _compute_patch_views(
109
109
  arr: np.ndarray,
110
- window_shape: List[int],
111
- step: Tuple[int, ...],
112
- output_shape: List[int],
110
+ window_shape: list[int],
111
+ step: tuple[int, ...],
112
+ output_shape: list[int],
113
113
  target: Optional[np.ndarray] = None,
114
114
  ) -> np.ndarray:
115
115
  """
@@ -119,11 +119,11 @@ def _compute_patch_views(
119
119
  ----------
120
120
  arr : np.ndarray
121
121
  Array from which the views are extracted.
122
- window_shape : Tuple[int]
122
+ window_shape : tuple[int]
123
123
  Shape of the views.
124
- step : Tuple[int]
124
+ step : tuple[int]
125
125
  Steps between views.
126
- output_shape : Tuple[int]
126
+ output_shape : tuple[int]
127
127
  Shape of the output array.
128
128
  target : Optional[np.ndarray], optional
129
129
  Target array, by default None.
@@ -150,9 +150,9 @@ def _compute_patch_views(
150
150
 
151
151
  def extract_patches_sequential(
152
152
  arr: np.ndarray,
153
- patch_size: Union[List[int], Tuple[int, ...]],
153
+ patch_size: Union[list[int], tuple[int, ...]],
154
154
  target: Optional[np.ndarray] = None,
155
- ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
155
+ ) -> tuple[np.ndarray, Optional[np.ndarray]]:
156
156
  """
157
157
  Generate patches from an array in a sequential manner.
158
158
 
@@ -163,14 +163,14 @@ def extract_patches_sequential(
163
163
  ----------
164
164
  arr : np.ndarray
165
165
  Input image array.
166
- patch_size : Tuple[int]
166
+ patch_size : tuple[int]
167
167
  Patch sizes in each dimension.
168
168
  target : Optional[np.ndarray], optional
169
169
  Target array, by default None.
170
170
 
171
171
  Returns
172
172
  -------
173
- Tuple[np.ndarray, Optional[np.ndarray]]
173
+ tuple[np.ndarray, Optional[np.ndarray]]
174
174
  Patches.
175
175
  """
176
176
  is_3d_patch = len(patch_size) == 3
@@ -1,13 +1,13 @@
1
1
  """Patch validation functions."""
2
2
 
3
- from typing import List, Tuple, Union
3
+ from typing import Union
4
4
 
5
5
  import numpy as np
6
6
 
7
7
 
8
8
  def validate_patch_dimensions(
9
9
  arr: np.ndarray,
10
- patch_size: Union[List[int], Tuple[int, ...]],
10
+ patch_size: Union[list[int], tuple[int, ...]],
11
11
  is_3d_patch: bool,
12
12
  ) -> None:
13
13
  """
@@ -26,7 +26,7 @@ def validate_patch_dimensions(
26
26
  ----------
27
27
  arr : np.ndarray
28
28
  Input array.
29
- patch_size : Union[List[int], Tuple[int, ...]]
29
+ patch_size : Union[list[int], tuple[int, ...]]
30
30
  Size of the patches along each dimension of the array, except the first.
31
31
  is_3d_patch : bool
32
32
  Whether the patch is 3D or not.
@@ -1,9 +1,9 @@
1
1
  """Tiling functions."""
2
2
 
3
3
  __all__ = [
4
- "stitch_prediction",
5
- "extract_tiles",
6
4
  "collate_tiles",
5
+ "extract_tiles",
6
+ "stitch_prediction",
7
7
  ]
8
8
 
9
9
  from .collate_tiles import collate_tiles
@@ -1,6 +1,6 @@
1
1
  """Collate function for tiling."""
2
2
 
3
- from typing import Any, List, Tuple
3
+ from typing import Any
4
4
 
5
5
  import numpy as np
6
6
  from torch.utils.data.dataloader import default_collate
@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
8
8
  from careamics.config.tile_information import TileInformation
9
9
 
10
10
 
11
- def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
11
+ def collate_tiles(batch: list[tuple[np.ndarray, TileInformation]]) -> Any:
12
12
  """
13
13
  Collate tiles received from CAREamics prediction dataloader.
14
14
 
@@ -19,7 +19,7 @@ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
19
19
 
20
20
  Parameters
21
21
  ----------
22
- batch : List[Tuple[np.ndarray, TileInformation], ...]
22
+ batch : list[tuple[np.ndarray, TileInformation], ...]
23
23
  Batch of tiles.
24
24
 
25
25
  Returns
@@ -2,7 +2,8 @@
2
2
 
3
3
  import builtins
4
4
  import itertools
5
- from typing import Any, Generator, Optional, Union
5
+ from collections.abc import Generator
6
+ from typing import Any, Optional, Union
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -1,7 +1,8 @@
1
1
  """Tiled patching utilities."""
2
2
 
3
3
  import itertools
4
- from typing import Generator, List, Tuple, Union
4
+ from collections.abc import Generator
5
+ from typing import Union
5
6
 
6
7
  import numpy as np
7
8
 
@@ -10,7 +11,7 @@ from careamics.config.tile_information import TileInformation
10
11
 
11
12
  def _compute_crop_and_stitch_coords_1d(
12
13
  axis_size: int, tile_size: int, overlap: int
13
- ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
14
+ ) -> tuple[list[tuple[int, int]], list[tuple[int, int]], list[tuple[int, int]]]:
14
15
  """
15
16
  Compute the coordinates of each tile along an axis, given the overlap.
16
17
 
@@ -25,8 +26,8 @@ def _compute_crop_and_stitch_coords_1d(
25
26
 
26
27
  Returns
27
28
  -------
28
- Tuple[Tuple[int, ...], ...]
29
- Tuple of all coordinates for given axis.
29
+ tuple[tuple[int, ...], ...]
30
+ tuple of all coordinates for given axis.
30
31
  """
31
32
  # Compute the step between tiles
32
33
  step = tile_size - overlap
@@ -81,9 +82,9 @@ def _compute_crop_and_stitch_coords_1d(
81
82
 
82
83
  def extract_tiles(
83
84
  arr: np.ndarray,
84
- tile_size: Union[List[int], Tuple[int, ...]],
85
- overlaps: Union[List[int], Tuple[int, ...]],
86
- ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
85
+ tile_size: Union[list[int], tuple[int, ...]],
86
+ overlaps: Union[list[int], tuple[int, ...]],
87
+ ) -> Generator[tuple[np.ndarray, TileInformation], None, None]:
87
88
  """Generate tiles from the input array with specified overlap.
88
89
 
89
90
  The tiles cover the whole array. The method returns a generator that yields
@@ -98,14 +99,14 @@ def extract_tiles(
98
99
  ----------
99
100
  arr : np.ndarray
100
101
  Array of shape (S, C, (Z), Y, X).
101
- tile_size : Union[List[int], Tuple[int]]
102
+ tile_size : Union[list[int], tuple[int]]
102
103
  Tile sizes in each dimension, of length 2 or 3.
103
- overlaps : Union[List[int], Tuple[int]]
104
+ overlaps : Union[list[int], tuple[int]]
104
105
  Overlap values in each dimension, of length 2 or 3.
105
106
 
106
107
  Yields
107
108
  ------
108
- Generator[Tuple[np.ndarray, TileInformation], None, None]
109
+ Generator[tuple[np.ndarray, TileInformation], None, None]
109
110
  Tile generator, yields the tile and additional information.
110
111
  """
111
112
  # Iterate over num samples (S)
@@ -1,13 +1,13 @@
1
1
  """Functions relating reading and writing image files."""
2
2
 
3
3
  __all__ = [
4
- "read",
5
- "write",
6
- "get_read_func",
7
- "get_write_func",
8
4
  "ReadFunc",
9
- "WriteFunc",
10
5
  "SupportedWriteType",
6
+ "WriteFunc",
7
+ "get_read_func",
8
+ "get_write_func",
9
+ "read",
10
+ "write",
11
11
  ]
12
12
 
13
13
  from . import read, write
@@ -1,10 +1,10 @@
1
1
  """Functions relating to reading image files of different formats."""
2
2
 
3
3
  __all__ = [
4
+ "ReadFunc",
4
5
  "get_read_func",
5
6
  "read_tiff",
6
7
  "read_zarr",
7
- "ReadFunc",
8
8
  ]
9
9
 
10
10
  from .get_func import ReadFunc, get_read_func