careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,139 @@
1
+ """Iterable tiled prediction dataset used to load data file by file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Generator
7
+
8
+ from numpy.typing import NDArray
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from careamics.transforms import Compose
12
+
13
+ from ..config import InferenceConfig
14
+ from ..config.tile_information import TileInformation
15
+ from ..config.transformations import NormalizeModel
16
+ from .dataset_utils import iterate_over_files, read_tiff
17
+ from .tiling import extract_tiles
18
+
19
+
20
+ class IterableTiledPredDataset(IterableDataset):
21
+ """Tiled prediction dataset.
22
+
23
+ Parameters
24
+ ----------
25
+ prediction_config : InferenceConfig
26
+ Inference configuration.
27
+ src_files : list of pathlib.Path
28
+ List of data files.
29
+ read_source_func : Callable, optional
30
+ Read source function for custom types, by default read_tiff.
31
+ **kwargs : Any
32
+ Additional keyword arguments, unused.
33
+
34
+ Attributes
35
+ ----------
36
+ data_path : str or pathlib.Path
37
+ Path to the data, must be a directory.
38
+ axes : str
39
+ Description of axes in format STCZYX.
40
+ mean : float, optional
41
+ Expected mean of the dataset, by default None.
42
+ std : float, optional
43
+ Expected standard deviation of the dataset, by default None.
44
+ patch_transform : Callable, optional
45
+ Patch transform callable, by default None.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ prediction_config: InferenceConfig,
51
+ src_files: list[Path],
52
+ read_source_func: Callable = read_tiff,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ """Constructor.
56
+
57
+ Parameters
58
+ ----------
59
+ prediction_config : InferenceConfig
60
+ Inference configuration.
61
+ src_files : List[Path]
62
+ List of data files.
63
+ read_source_func : Callable, optional
64
+ Read source function for custom types, by default read_tiff.
65
+ **kwargs : Any
66
+ Additional keyword arguments, unused.
67
+
68
+ Raises
69
+ ------
70
+ ValueError
71
+ If mean and std are not provided in the inference configuration.
72
+ """
73
+ if (
74
+ prediction_config.tile_size is None
75
+ or prediction_config.tile_overlap is None
76
+ ):
77
+ raise ValueError(
78
+ "Tile size and overlap must be provided for tiled prediction."
79
+ )
80
+
81
+ self.prediction_config = prediction_config
82
+ self.data_files = src_files
83
+ self.axes = prediction_config.axes
84
+ self.tile_size = prediction_config.tile_size
85
+ self.tile_overlap = prediction_config.tile_overlap
86
+ self.read_source_func = read_source_func
87
+
88
+ # check mean and std and create normalize transform
89
+ if (
90
+ self.prediction_config.image_means is None
91
+ or self.prediction_config.image_stds is None
92
+ ):
93
+ raise ValueError("Mean and std must be provided for prediction.")
94
+ else:
95
+ self.image_means = self.prediction_config.image_means
96
+ self.image_stds = self.prediction_config.image_stds
97
+
98
+ # instantiate normalize transform
99
+ self.patch_transform = Compose(
100
+ transform_list=[
101
+ NormalizeModel(
102
+ image_means=self.image_means,
103
+ image_stds=self.image_stds,
104
+ )
105
+ ],
106
+ )
107
+
108
+ def __iter__(
109
+ self,
110
+ ) -> Generator[tuple[NDArray, TileInformation], None, None]:
111
+ """
112
+ Iterate over data source and yield single patch.
113
+
114
+ Yields
115
+ ------
116
+ Generator of NDArray and TileInformation tuple
117
+ Generator of single tiles.
118
+ """
119
+ assert (
120
+ self.image_means is not None and self.image_stds is not None
121
+ ), "Mean and std must be provided"
122
+
123
+ for sample, _ in iterate_over_files(
124
+ self.prediction_config,
125
+ self.data_files,
126
+ read_source_func=self.read_source_func,
127
+ ):
128
+ # generate patches, return a generator of single tiles
129
+ patch_gen = extract_tiles(
130
+ arr=sample,
131
+ tile_size=self.tile_size,
132
+ overlaps=self.tile_overlap,
133
+ )
134
+
135
+ # apply transform to patches
136
+ for patch_array, tile_info in patch_gen:
137
+ transformed_patch, _ = self.patch_transform(patch=patch_array)
138
+
139
+ yield transformed_patch, tile_info
@@ -1,9 +1,6 @@
1
- """
2
- Tiling submodule.
3
-
4
- These functions are used to tile images into patches or tiles.
5
- """
1
+ """Patching functions."""
6
2
 
3
+ from dataclasses import dataclass
7
4
  from pathlib import Path
8
5
  from typing import Callable, List, Tuple, Union
9
6
 
@@ -11,30 +8,69 @@ import numpy as np
11
8
 
12
9
  from ...utils.logging import get_logger
13
10
  from ..dataset_utils import reshape_array
11
+ from ..dataset_utils.running_stats import compute_normalization_stats
14
12
  from .sequential_patching import extract_patches_sequential
15
13
 
16
14
  logger = get_logger(__name__)
17
15
 
18
16
 
17
+ @dataclass
18
+ class Stats:
19
+ """Dataclass to store statistics."""
20
+
21
+ means: Union[np.ndarray, tuple, list, None]
22
+ stds: Union[np.ndarray, tuple, list, None]
23
+
24
+
25
+ @dataclass
26
+ class PatchedOutput:
27
+ """Dataclass to store patches and statistics."""
28
+
29
+ patches: Union[np.ndarray]
30
+ targets: Union[np.ndarray, None]
31
+ image_stats: Stats
32
+ target_stats: Stats
33
+
34
+
35
+ @dataclass
36
+ class StatsOutput:
37
+ """Dataclass to store patches and statistics."""
38
+
39
+ image_stats: Stats
40
+ target_stats: Stats
41
+
42
+
19
43
  # called by in memory dataset
20
44
  def prepare_patches_supervised(
21
45
  train_files: List[Path],
22
46
  target_files: List[Path],
23
47
  axes: str,
24
- patch_size: Union[List[int], Tuple[int]],
48
+ patch_size: Union[List[int], Tuple[int, ...]],
25
49
  read_source_func: Callable,
26
- ) -> Tuple[np.ndarray, np.ndarray, float, float]:
50
+ ) -> PatchedOutput:
27
51
  """
28
52
  Iterate over data source and create an array of patches and corresponding targets.
29
53
 
54
+ The lists of Paths should be pre-sorted.
55
+
56
+ Parameters
57
+ ----------
58
+ train_files : List[Path]
59
+ List of paths to training data.
60
+ target_files : List[Path]
61
+ List of paths to target data.
62
+ axes : str
63
+ Axes of the data.
64
+ patch_size : Union[List[int], Tuple[int]]
65
+ Size of the patches.
66
+ read_source_func : Callable
67
+ Function to read the data.
68
+
30
69
  Returns
31
70
  -------
32
71
  np.ndarray
33
72
  Array of patches.
34
73
  """
35
- train_files.sort()
36
- target_files.sort()
37
-
38
74
  means, stds, num_samples = 0, 0, 0
39
75
  all_patches, all_targets = [], []
40
76
  for train_filename, target_filename in zip(train_files, target_files):
@@ -74,17 +110,18 @@ def prepare_patches_supervised(
74
110
  f"{target_files}."
75
111
  )
76
112
 
77
- result_mean, result_std = means / num_samples, stds / num_samples
113
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
114
+ target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
78
115
 
79
116
  patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
80
117
  target_array: np.ndarray = np.concatenate(all_targets, axis=0)
81
118
  logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
82
119
 
83
- return (
120
+ return PatchedOutput(
84
121
  patch_array,
85
122
  target_array,
86
- result_mean,
87
- result_std,
123
+ Stats(image_means, image_stds),
124
+ Stats(target_means, target_stds),
88
125
  )
89
126
 
90
127
 
@@ -94,14 +131,26 @@ def prepare_patches_unsupervised(
94
131
  axes: str,
95
132
  patch_size: Union[List[int], Tuple[int]],
96
133
  read_source_func: Callable,
97
- ) -> Tuple[np.ndarray, None, float, float]:
98
- """
99
- Iterate over data source and create an array of patches.
134
+ ) -> PatchedOutput:
135
+ """Iterate over data source and create an array of patches.
136
+
137
+ This method returns the mean and standard deviation of the image.
138
+
139
+ Parameters
140
+ ----------
141
+ train_files : List[Path]
142
+ List of paths to training data.
143
+ axes : str
144
+ Axes of the data.
145
+ patch_size : Union[List[int], Tuple[int]]
146
+ Size of the patches.
147
+ read_source_func : Callable
148
+ Function to read the data.
100
149
 
101
150
  Returns
102
151
  -------
103
- np.ndarray
104
- Array of patches.
152
+ Tuple[np.ndarray, None, float, float]
153
+ Source and target patches, mean and standard deviation.
105
154
  """
106
155
  means, stds, num_samples = 0, 0, 0
107
156
  all_patches = []
@@ -128,12 +177,14 @@ def prepare_patches_unsupervised(
128
177
  if num_samples == 0:
129
178
  raise ValueError(f"No valid samples found in the input data: {train_files}.")
130
179
 
131
- result_mean, result_std = means / num_samples, stds / num_samples
180
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
132
181
 
133
182
  patch_array: np.ndarray = np.concatenate(all_patches)
134
183
  logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
135
184
 
136
- return patch_array, _, result_mean, result_std # TODO return object?
185
+ return PatchedOutput(
186
+ patch_array, None, Stats(image_means, image_stds), Stats((), ())
187
+ )
137
188
 
138
189
 
139
190
  # called on arrays by in memory dataset
@@ -142,7 +193,7 @@ def prepare_patches_supervised_array(
142
193
  axes: str,
143
194
  data_target: np.ndarray,
144
195
  patch_size: Union[List[int], Tuple[int]],
145
- ) -> Tuple[np.ndarray, np.ndarray, float, float]:
196
+ ) -> PatchedOutput:
146
197
  """Iterate over data source and create an array of patches.
147
198
 
148
199
  This method expects an array of shape SC(Z)YX, where S and C can be singleton
@@ -150,19 +201,30 @@ def prepare_patches_supervised_array(
150
201
 
151
202
  Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
152
203
 
204
+ Parameters
205
+ ----------
206
+ data : np.ndarray
207
+ Input data array.
208
+ axes : str
209
+ Axes of the data.
210
+ data_target : np.ndarray
211
+ Target data array.
212
+ patch_size : Union[List[int], Tuple[int]]
213
+ Size of the patches.
214
+
153
215
  Returns
154
216
  -------
155
- np.ndarray
156
- Array of patches.
217
+ Tuple[np.ndarray, np.ndarray, float, float]
218
+ Source and target patches, mean and standard deviation.
157
219
  """
158
- # compute statistics
159
- mean = data.mean()
160
- std = data.std()
161
-
162
220
  # reshape array
163
221
  reshaped_sample = reshape_array(data, axes)
164
222
  reshaped_target = reshape_array(data_target, axes)
165
223
 
224
+ # compute statistics
225
+ image_means, image_stds = compute_normalization_stats(reshaped_sample)
226
+ target_means, target_stds = compute_normalization_stats(reshaped_target)
227
+
166
228
  # generate patches, return a generator
167
229
  patches, patch_targets = extract_patches_sequential(
168
230
  reshaped_sample, patch_size=patch_size, target=reshaped_target
@@ -173,11 +235,11 @@ def prepare_patches_supervised_array(
173
235
 
174
236
  logger.info(f"Extracted {patches.shape[0]} patches from input array.")
175
237
 
176
- return (
238
+ return PatchedOutput(
177
239
  patches,
178
240
  patch_targets,
179
- mean,
180
- std,
241
+ Stats(image_means, image_stds),
242
+ Stats(target_means, target_stds),
181
243
  )
182
244
 
183
245
 
@@ -186,7 +248,7 @@ def prepare_patches_unsupervised_array(
186
248
  data: np.ndarray,
187
249
  axes: str,
188
250
  patch_size: Union[List[int], Tuple[int]],
189
- ) -> Tuple[np.ndarray, None, float, float]:
251
+ ) -> PatchedOutput:
190
252
  """
191
253
  Iterate over data source and create an array of patches.
192
254
 
@@ -195,19 +257,27 @@ def prepare_patches_unsupervised_array(
195
257
 
196
258
  Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
197
259
 
260
+ Parameters
261
+ ----------
262
+ data : np.ndarray
263
+ Input data array.
264
+ axes : str
265
+ Axes of the data.
266
+ patch_size : Union[List[int], Tuple[int]]
267
+ Size of the patches.
268
+
198
269
  Returns
199
270
  -------
200
- np.ndarray
201
- Array of patches.
271
+ Tuple[np.ndarray, None, float, float]
272
+ Source patches, mean and standard deviation.
202
273
  """
203
- # calculate mean and std
204
- mean = data.mean()
205
- std = data.std()
206
-
207
274
  # reshape array
208
275
  reshaped_sample = reshape_array(data, axes)
209
276
 
277
+ # calculate mean and std
278
+ means, stds = compute_normalization_stats(reshaped_sample)
279
+
210
280
  # generate patches, return a generator
211
281
  patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
212
282
 
213
- return patches, _, mean, std # TODO inelegant, replace by dataclass?
283
+ return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
@@ -1,3 +1,5 @@
1
+ """Random patching utilities."""
2
+
1
3
  from typing import Generator, List, Optional, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -11,6 +13,7 @@ def extract_patches_random(
11
13
  arr: np.ndarray,
12
14
  patch_size: Union[List[int], Tuple[int, ...]],
13
15
  target: Optional[np.ndarray] = None,
16
+ seed: Optional[int] = None,
14
17
  ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
15
18
  """
16
19
  Generate patches from an array in a random manner.
@@ -30,12 +33,18 @@ def extract_patches_random(
30
33
  Input image array.
31
34
  patch_size : Tuple[int]
32
35
  Patch sizes in each dimension.
36
+ target : Optional[np.ndarray], optional
37
+ Target array, by default None.
38
+ seed : Optional[int], optional
39
+ Random seed, by default None.
33
40
 
34
41
  Yields
35
42
  ------
36
43
  Generator[np.ndarray, None, None]
37
44
  Generator of patches.
38
45
  """
46
+ rng = np.random.default_rng(seed=seed)
47
+
39
48
  is_3d_patch = len(patch_size) == 3
40
49
 
41
50
  # patches sanity check
@@ -44,9 +53,6 @@ def extract_patches_random(
44
53
  # Update patch size to encompass S and C dimensions
45
54
  patch_size = [1, arr.shape[1], *patch_size]
46
55
 
47
- # random generator
48
- rng = np.random.default_rng()
49
-
50
56
  # iterate over the number of samples (S or T)
51
57
  for sample_idx in range(arr.shape[0]):
52
58
  # get sample array
@@ -109,6 +115,7 @@ def extract_patches_random_from_chunks(
109
115
  patch_size: Union[List[int], Tuple[int, ...]],
110
116
  chunk_size: Union[List[int], Tuple[int, ...]],
111
117
  chunk_limit: Optional[int] = None,
118
+ seed: Optional[int] = None,
112
119
  ) -> Generator[np.ndarray, None, None]:
113
120
  """
114
121
  Generate patches from an array in a random manner.
@@ -120,10 +127,14 @@ def extract_patches_random_from_chunks(
120
127
  ----------
121
128
  arr : np.ndarray
122
129
  Input image array.
123
- patch_size : Tuple[int]
130
+ patch_size : Union[List[int], Tuple[int, ...]]
124
131
  Patch sizes in each dimension.
125
- chunk_size : Tuple[int]
132
+ chunk_size : Union[List[int], Tuple[int, ...]]
126
133
  Chunk sizes to load from the.
134
+ chunk_limit : Optional[int], optional
135
+ Number of chunks to load, by default None.
136
+ seed : Optional[int], optional
137
+ Random seed, by default None.
127
138
 
128
139
  Yields
129
140
  ------
@@ -135,7 +146,7 @@ def extract_patches_random_from_chunks(
135
146
  # Patches sanity check
136
147
  validate_patch_dimensions(arr, patch_size, is_3d_patch)
137
148
 
138
- rng = np.random.default_rng()
149
+ rng = np.random.default_rng(seed=seed)
139
150
  num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
140
151
 
141
152
  # Iterate over num chunks in the array
@@ -1,3 +1,5 @@
1
+ """Sequential patching functions."""
2
+
1
3
  from typing import List, Optional, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -14,14 +16,14 @@ def _compute_number_of_patches(
14
16
 
15
17
  Parameters
16
18
  ----------
17
- arr : Tuple[int, ...]
19
+ arr_shape : Tuple[int, ...]
18
20
  Shape of the input array.
19
- patch_sizes : Tuple[int]
21
+ patch_sizes : Union[List[int], Tuple[int, ...]
20
22
  Shape of the patches.
21
23
 
22
24
  Returns
23
25
  -------
24
- Tuple[int]
26
+ Tuple[int, ...]
25
27
  Number of patches in each dimension.
26
28
  """
27
29
  if len(arr_shape) != len(patch_sizes):
@@ -55,14 +57,14 @@ def _compute_overlap(
55
57
 
56
58
  Parameters
57
59
  ----------
58
- arr : Tuple[int, ...]
60
+ arr_shape : Tuple[int, ...]
59
61
  Input array shape.
60
- patch_sizes : Tuple[int]
62
+ patch_sizes : Union[List[int], Tuple[int, ...]]
61
63
  Size of the patches.
62
64
 
63
65
  Returns
64
66
  -------
65
- Tuple[int]
67
+ Tuple[int, ...]
66
68
  Overlap between patches in each dimension.
67
69
  """
68
70
  n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
@@ -123,6 +125,8 @@ def _compute_patch_views(
123
125
  Steps between views.
124
126
  output_shape : Tuple[int]
125
127
  Shape of the output array.
128
+ target : Optional[np.ndarray], optional
129
+ Target array, by default None.
126
130
 
127
131
  Returns
128
132
  -------
@@ -161,11 +165,13 @@ def extract_patches_sequential(
161
165
  Input image array.
162
166
  patch_size : Tuple[int]
163
167
  Patch sizes in each dimension.
168
+ target : Optional[np.ndarray], optional
169
+ Target array, by default None.
164
170
 
165
171
  Returns
166
172
  -------
167
- Generator[Tuple[np.ndarray, ...], None, None]
168
- Generator of patches.
173
+ Tuple[np.ndarray, Optional[np.ndarray]]
174
+ Patches.
169
175
  """
170
176
  is_3d_patch = len(patch_size) == 3
171
177
 
@@ -1,3 +1,5 @@
1
+ """Patch validation functions."""
2
+
1
3
  from typing import List, Tuple, Union
2
4
 
3
5
  import numpy as np
@@ -43,18 +45,20 @@ def validate_patch_dimensions(
43
45
  if len(patch_size) != len(arr.shape[2:]):
44
46
  raise ValueError(
45
47
  f"There must be a patch size for each spatial dimensions "
46
- f"(got {patch_size} patches for dims {arr.shape})."
48
+ f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
47
49
  )
48
50
 
49
51
  # Sanity checks on patch sizes versus array dimension
50
52
  if is_3d_patch and patch_size[0] > arr.shape[-3]:
51
53
  raise ValueError(
52
54
  f"Z patch size is inconsistent with image shape "
53
- f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
55
+ f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
56
+ f"order."
54
57
  )
55
58
 
56
59
  if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
57
60
  raise ValueError(
58
61
  f"At least one of YX patch dimensions is larger than the corresponding "
59
- f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
62
+ f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
63
+ f"Check the axes order."
60
64
  )
@@ -0,0 +1,10 @@
1
+ """Tiling functions."""
2
+
3
+ __all__ = [
4
+ "stitch_prediction",
5
+ "extract_tiles",
6
+ "collate_tiles",
7
+ ]
8
+
9
+ from .collate_tiles import collate_tiles
10
+ from .tiled_patching import extract_tiles
@@ -0,0 +1,33 @@
1
+ """Collate function for tiling."""
2
+
3
+ from typing import Any, List, Tuple
4
+
5
+ import numpy as np
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from careamics.config.tile_information import TileInformation
9
+
10
+
11
+ def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
12
+ """
13
+ Collate tiles received from CAREamics prediction dataloader.
14
+
15
+ CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
16
+ case of non-tiled data, this function will return the arrays. In case of tiled data,
17
+ it will return the arrays, the last tile flag, the overlap crop coordinates and the
18
+ stitch coordinates.
19
+
20
+ Parameters
21
+ ----------
22
+ batch : List[Tuple[np.ndarray, TileInformation], ...]
23
+ Batch of tiles.
24
+
25
+ Returns
26
+ -------
27
+ Any
28
+ Collated batch.
29
+ """
30
+ new_batch = [tile for tile, _ in batch]
31
+ tiles_batch = [tile_info for _, tile_info in batch]
32
+
33
+ return default_collate(new_batch), tiles_batch
@@ -1,3 +1,5 @@
1
+ """Tiled patching utilities."""
2
+
1
3
  import itertools
2
4
  from typing import Generator, List, Tuple, Union
3
5
 
@@ -8,7 +10,7 @@ from careamics.config.tile_information import TileInformation
8
10
 
9
11
  def _compute_crop_and_stitch_coords_1d(
10
12
  axis_size: int, tile_size: int, overlap: int
11
- ) -> Tuple[List[Tuple[int, ...]], ...]:
13
+ ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
12
14
  """
13
15
  Compute the coordinates of each tile along an axis, given the overlap.
14
16
 
@@ -82,15 +84,15 @@ def extract_tiles(
82
84
  tile_size: Union[List[int], Tuple[int, ...]],
83
85
  overlaps: Union[List[int], Tuple[int, ...]],
84
86
  ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
85
- """
86
- Generate tiles from the input array with specified overlap.
87
+ """Generate tiles from the input array with specified overlap.
87
88
 
88
89
  The tiles cover the whole array. The method returns a generator that yields
89
90
  tuples of array and tile information, the latter includes whether
90
91
  the tile is the last one, the coordinates of the overlap crop, and the coordinates
91
92
  of the stitched tile.
92
93
 
93
- The array has shape C(Z)YX, where C can be a singleton.
94
+ Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
95
+ where C can be a singleton.
94
96
 
95
97
  Parameters
96
98
  ----------
@@ -153,10 +155,10 @@ def extract_tiles(
153
155
  # create tile information
154
156
  tile_info = TileInformation(
155
157
  array_shape=sample.squeeze().shape,
156
- tiled=True,
157
158
  last_tile=last_tile,
158
159
  overlap_crop_coords=overlap_crop_coords,
159
160
  stitch_coords=stitch_coords,
161
+ sample_id=sample_idx,
160
162
  )
161
163
 
162
164
  yield tile, tile_info
@@ -1,3 +1,5 @@
1
+ """Zarr dataset."""
2
+
1
3
  # from itertools import islice
2
4
  # from typing import Callable, Dict, List, Optional, Tuple, Union
3
5