careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (155) hide show
  1. careamics/__init__.py +6 -1
  2. careamics/careamist.py +729 -0
  3. careamics/config/__init__.py +39 -0
  4. careamics/config/architectures/__init__.py +17 -0
  5. careamics/config/architectures/architecture_model.py +37 -0
  6. careamics/config/architectures/custom_model.py +162 -0
  7. careamics/config/architectures/lvae_model.py +174 -0
  8. careamics/config/architectures/register_model.py +103 -0
  9. careamics/config/architectures/unet_model.py +118 -0
  10. careamics/config/callback_model.py +123 -0
  11. careamics/config/configuration_factory.py +583 -0
  12. careamics/config/configuration_model.py +604 -0
  13. careamics/config/data_model.py +527 -0
  14. careamics/config/fcn_algorithm_model.py +147 -0
  15. careamics/config/inference_model.py +239 -0
  16. careamics/config/likelihood_model.py +43 -0
  17. careamics/config/nm_model.py +101 -0
  18. careamics/config/optimizer_models.py +187 -0
  19. careamics/config/references/__init__.py +45 -0
  20. careamics/config/references/algorithm_descriptions.py +132 -0
  21. careamics/config/references/references.py +39 -0
  22. careamics/config/support/__init__.py +31 -0
  23. careamics/config/support/supported_activations.py +27 -0
  24. careamics/config/support/supported_algorithms.py +33 -0
  25. careamics/config/support/supported_architectures.py +17 -0
  26. careamics/config/support/supported_data.py +109 -0
  27. careamics/config/support/supported_loggers.py +10 -0
  28. careamics/config/support/supported_losses.py +29 -0
  29. careamics/config/support/supported_optimizers.py +57 -0
  30. careamics/config/support/supported_pixel_manipulations.py +15 -0
  31. careamics/config/support/supported_struct_axis.py +21 -0
  32. careamics/config/support/supported_transforms.py +11 -0
  33. careamics/config/tile_information.py +65 -0
  34. careamics/config/training_model.py +72 -0
  35. careamics/config/transformations/__init__.py +15 -0
  36. careamics/config/transformations/n2v_manipulate_model.py +64 -0
  37. careamics/config/transformations/normalize_model.py +60 -0
  38. careamics/config/transformations/transform_model.py +45 -0
  39. careamics/config/transformations/xy_flip_model.py +43 -0
  40. careamics/config/transformations/xy_random_rotate90_model.py +35 -0
  41. careamics/config/vae_algorithm_model.py +171 -0
  42. careamics/config/validators/__init__.py +5 -0
  43. careamics/config/validators/validator_utils.py +101 -0
  44. careamics/conftest.py +39 -0
  45. careamics/dataset/__init__.py +17 -0
  46. careamics/dataset/dataset_utils/__init__.py +19 -0
  47. careamics/dataset/dataset_utils/dataset_utils.py +101 -0
  48. careamics/dataset/dataset_utils/file_utils.py +141 -0
  49. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  50. careamics/dataset/dataset_utils/running_stats.py +186 -0
  51. careamics/dataset/in_memory_dataset.py +310 -0
  52. careamics/dataset/in_memory_pred_dataset.py +88 -0
  53. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  54. careamics/dataset/iterable_dataset.py +295 -0
  55. careamics/dataset/iterable_pred_dataset.py +122 -0
  56. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  57. careamics/dataset/patching/__init__.py +1 -0
  58. careamics/dataset/patching/patching.py +299 -0
  59. careamics/dataset/patching/random_patching.py +201 -0
  60. careamics/dataset/patching/sequential_patching.py +212 -0
  61. careamics/dataset/patching/validate_patch_dimension.py +64 -0
  62. careamics/dataset/tiling/__init__.py +10 -0
  63. careamics/dataset/tiling/collate_tiles.py +33 -0
  64. careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  65. careamics/dataset/tiling/tiled_patching.py +164 -0
  66. careamics/dataset/zarr_dataset.py +151 -0
  67. careamics/file_io/__init__.py +15 -0
  68. careamics/file_io/read/__init__.py +12 -0
  69. careamics/file_io/read/get_func.py +56 -0
  70. careamics/file_io/read/tiff.py +58 -0
  71. careamics/file_io/read/zarr.py +60 -0
  72. careamics/file_io/write/__init__.py +15 -0
  73. careamics/file_io/write/get_func.py +63 -0
  74. careamics/file_io/write/tiff.py +40 -0
  75. careamics/lightning/__init__.py +18 -0
  76. careamics/lightning/callbacks/__init__.py +11 -0
  77. careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
  78. careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
  79. careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
  80. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
  81. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
  82. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
  83. careamics/lightning/callbacks/progress_bar_callback.py +90 -0
  84. careamics/lightning/lightning_module.py +632 -0
  85. careamics/lightning/predict_data_module.py +333 -0
  86. careamics/lightning/train_data_module.py +680 -0
  87. careamics/losses/__init__.py +15 -0
  88. careamics/losses/fcn/__init__.py +1 -0
  89. careamics/losses/fcn/losses.py +98 -0
  90. careamics/losses/loss_factory.py +155 -0
  91. careamics/losses/lvae/__init__.py +1 -0
  92. careamics/losses/lvae/loss_utils.py +83 -0
  93. careamics/losses/lvae/losses.py +445 -0
  94. careamics/lvae_training/__init__.py +0 -0
  95. careamics/lvae_training/dataset/__init__.py +0 -0
  96. careamics/lvae_training/dataset/data_utils.py +701 -0
  97. careamics/lvae_training/dataset/lc_dataset.py +259 -0
  98. careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  99. careamics/lvae_training/dataset/vae_data_config.py +179 -0
  100. careamics/lvae_training/dataset/vae_dataset.py +1054 -0
  101. careamics/lvae_training/eval_utils.py +905 -0
  102. careamics/lvae_training/get_config.py +84 -0
  103. careamics/lvae_training/lightning_module.py +701 -0
  104. careamics/lvae_training/metrics.py +214 -0
  105. careamics/lvae_training/train_lvae.py +342 -0
  106. careamics/lvae_training/train_utils.py +121 -0
  107. careamics/model_io/__init__.py +7 -0
  108. careamics/model_io/bioimage/__init__.py +11 -0
  109. careamics/model_io/bioimage/_readme_factory.py +121 -0
  110. careamics/model_io/bioimage/bioimage_utils.py +52 -0
  111. careamics/model_io/bioimage/model_description.py +327 -0
  112. careamics/model_io/bmz_io.py +246 -0
  113. careamics/model_io/model_io_utils.py +95 -0
  114. careamics/models/__init__.py +5 -0
  115. careamics/models/activation.py +39 -0
  116. careamics/models/layers.py +493 -0
  117. careamics/models/lvae/__init__.py +3 -0
  118. careamics/models/lvae/layers.py +1998 -0
  119. careamics/models/lvae/likelihoods.py +364 -0
  120. careamics/models/lvae/lvae.py +901 -0
  121. careamics/models/lvae/noise_models.py +541 -0
  122. careamics/models/lvae/utils.py +395 -0
  123. careamics/models/model_factory.py +67 -0
  124. careamics/models/unet.py +443 -0
  125. careamics/prediction_utils/__init__.py +10 -0
  126. careamics/prediction_utils/lvae_prediction.py +158 -0
  127. careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  128. careamics/prediction_utils/prediction_outputs.py +135 -0
  129. careamics/prediction_utils/stitch_prediction.py +112 -0
  130. careamics/transforms/__init__.py +20 -0
  131. careamics/transforms/compose.py +107 -0
  132. careamics/transforms/n2v_manipulate.py +146 -0
  133. careamics/transforms/normalize.py +243 -0
  134. careamics/transforms/pixel_manipulation.py +407 -0
  135. careamics/transforms/struct_mask_parameters.py +20 -0
  136. careamics/transforms/transform.py +24 -0
  137. careamics/transforms/tta.py +88 -0
  138. careamics/transforms/xy_flip.py +123 -0
  139. careamics/transforms/xy_random_rotate90.py +101 -0
  140. careamics/utils/__init__.py +19 -0
  141. careamics/utils/autocorrelation.py +40 -0
  142. careamics/utils/base_enum.py +60 -0
  143. careamics/utils/context.py +66 -0
  144. careamics/utils/logging.py +322 -0
  145. careamics/utils/metrics.py +188 -0
  146. careamics/utils/path_utils.py +26 -0
  147. careamics/utils/ram.py +15 -0
  148. careamics/utils/receptive_field.py +108 -0
  149. careamics/utils/torch_utils.py +127 -0
  150. careamics-0.0.3.dist-info/METADATA +78 -0
  151. careamics-0.0.3.dist-info/RECORD +154 -0
  152. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
  153. {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
  154. careamics-0.0.1.dist-info/METADATA +0 -46
  155. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,299 @@
1
+ """Patching functions."""
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Callable, Union
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from ...utils.logging import get_logger
11
+ from ..dataset_utils import reshape_array
12
+ from ..dataset_utils.running_stats import compute_normalization_stats
13
+ from .sequential_patching import extract_patches_sequential
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class Stats:
20
+ """Dataclass to store statistics."""
21
+
22
+ means: Union[NDArray, tuple, list, None]
23
+ """Mean of the data across channels."""
24
+
25
+ stds: Union[NDArray, tuple, list, None]
26
+ """Standard deviation of the data across channels."""
27
+
28
+ def get_statistics(self) -> tuple[list[float], list[float]]:
29
+ """Return the means and standard deviations.
30
+
31
+ Returns
32
+ -------
33
+ tuple of two lists of floats
34
+ Means and standard deviations.
35
+ """
36
+ if self.means is None or self.stds is None:
37
+ return [], []
38
+
39
+ return list(self.means), list(self.stds)
40
+
41
+
42
+ @dataclass
43
+ class PatchedOutput:
44
+ """Dataclass to store patches and statistics."""
45
+
46
+ patches: Union[NDArray]
47
+ """Image patches."""
48
+
49
+ targets: Union[NDArray, None]
50
+ """Target patches."""
51
+
52
+ image_stats: Stats
53
+ """Statistics of the image patches."""
54
+
55
+ target_stats: Stats
56
+ """Statistics of the target patches."""
57
+
58
+
59
+ # called by in memory dataset
60
+ def prepare_patches_supervised(
61
+ train_files: list[Path],
62
+ target_files: list[Path],
63
+ axes: str,
64
+ patch_size: Union[list[int], tuple[int, ...]],
65
+ read_source_func: Callable,
66
+ ) -> PatchedOutput:
67
+ """
68
+ Iterate over data source and create an array of patches and corresponding targets.
69
+
70
+ The lists of Paths should be pre-sorted.
71
+
72
+ Parameters
73
+ ----------
74
+ train_files : list of pathlib.Path
75
+ List of paths to training data.
76
+ target_files : list of pathlib.Path
77
+ List of paths to target data.
78
+ axes : str
79
+ Axes of the data.
80
+ patch_size : list or tuple of int
81
+ Size of the patches.
82
+ read_source_func : Callable
83
+ Function to read the data.
84
+
85
+ Returns
86
+ -------
87
+ np.ndarray
88
+ Array of patches.
89
+ """
90
+ means, stds, num_samples = 0, 0, 0
91
+ all_patches, all_targets = [], []
92
+ for train_filename, target_filename in zip(train_files, target_files):
93
+ try:
94
+ sample: np.ndarray = read_source_func(train_filename, axes)
95
+ target: np.ndarray = read_source_func(target_filename, axes)
96
+ means += sample.mean()
97
+ stds += sample.std()
98
+ num_samples += 1
99
+
100
+ # reshape array
101
+ sample = reshape_array(sample, axes)
102
+ target = reshape_array(target, axes)
103
+
104
+ # generate patches, return a generator
105
+ patches, targets = extract_patches_sequential(
106
+ sample, patch_size=patch_size, target=target
107
+ )
108
+
109
+ # convert generator to list and add to all_patches
110
+ all_patches.append(patches)
111
+
112
+ # ensure targets are not None (type checking)
113
+ if targets is not None:
114
+ all_targets.append(targets)
115
+ else:
116
+ raise ValueError(f"No target found for {target_filename}.")
117
+
118
+ except Exception as e:
119
+ # emit warning and continue
120
+ logger.error(f"Failed to read {train_filename} or {target_filename}: {e}")
121
+
122
+ # raise error if no valid samples found
123
+ if num_samples == 0 or len(all_patches) == 0:
124
+ raise ValueError(
125
+ f"No valid samples found in the input data: {train_files} and "
126
+ f"{target_files}."
127
+ )
128
+
129
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
130
+ target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
131
+
132
+ patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
133
+ target_array: np.ndarray = np.concatenate(all_targets, axis=0)
134
+ logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
135
+
136
+ return PatchedOutput(
137
+ patch_array,
138
+ target_array,
139
+ Stats(image_means, image_stds),
140
+ Stats(target_means, target_stds),
141
+ )
142
+
143
+
144
+ # called by in_memory_dataset
145
+ def prepare_patches_unsupervised(
146
+ train_files: list[Path],
147
+ axes: str,
148
+ patch_size: Union[list[int], tuple[int]],
149
+ read_source_func: Callable,
150
+ ) -> PatchedOutput:
151
+ """Iterate over data source and create an array of patches.
152
+
153
+ This method returns the mean and standard deviation of the image.
154
+
155
+ Parameters
156
+ ----------
157
+ train_files : list of pathlib.Path
158
+ List of paths to training data.
159
+ axes : str
160
+ Axes of the data.
161
+ patch_size : list or tuple of int
162
+ Size of the patches.
163
+ read_source_func : Callable
164
+ Function to read the data.
165
+
166
+ Returns
167
+ -------
168
+ PatchedOutput
169
+ Dataclass holding patches and their statistics.
170
+ """
171
+ means, stds, num_samples = 0, 0, 0
172
+ all_patches = []
173
+ for filename in train_files:
174
+ try:
175
+ sample: np.ndarray = read_source_func(filename, axes)
176
+ means += sample.mean()
177
+ stds += sample.std()
178
+ num_samples += 1
179
+
180
+ # reshape array
181
+ sample = reshape_array(sample, axes)
182
+
183
+ # generate patches, return a generator
184
+ patches, _ = extract_patches_sequential(sample, patch_size=patch_size)
185
+
186
+ # convert generator to list and add to all_patches
187
+ all_patches.append(patches)
188
+ except Exception as e:
189
+ # emit warning and continue
190
+ logger.error(f"Failed to read {filename}: {e}")
191
+
192
+ # raise error if no valid samples found
193
+ if num_samples == 0:
194
+ raise ValueError(f"No valid samples found in the input data: {train_files}.")
195
+
196
+ image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
197
+
198
+ patch_array: np.ndarray = np.concatenate(all_patches)
199
+ logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
200
+
201
+ return PatchedOutput(
202
+ patch_array, None, Stats(image_means, image_stds), Stats((), ())
203
+ )
204
+
205
+
206
+ # called on arrays by in memory dataset
207
+ def prepare_patches_supervised_array(
208
+ data: NDArray,
209
+ axes: str,
210
+ data_target: NDArray,
211
+ patch_size: Union[list[int], tuple[int]],
212
+ ) -> PatchedOutput:
213
+ """Iterate over data source and create an array of patches.
214
+
215
+ This method expects an array of shape SC(Z)YX, where S and C can be singleton
216
+ dimensions.
217
+
218
+ Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
219
+
220
+ Parameters
221
+ ----------
222
+ data : numpy.ndarray
223
+ Input data array.
224
+ axes : str
225
+ Axes of the data.
226
+ data_target : numpy.ndarray
227
+ Target data array.
228
+ patch_size : list or tuple of int
229
+ Size of the patches.
230
+
231
+ Returns
232
+ -------
233
+ PatchedOutput
234
+ Dataclass holding the source and target patches, with their statistics.
235
+ """
236
+ # reshape array
237
+ reshaped_sample = reshape_array(data, axes)
238
+ reshaped_target = reshape_array(data_target, axes)
239
+
240
+ # compute statistics
241
+ image_means, image_stds = compute_normalization_stats(reshaped_sample)
242
+ target_means, target_stds = compute_normalization_stats(reshaped_target)
243
+
244
+ # generate patches, return a generator
245
+ patches, patch_targets = extract_patches_sequential(
246
+ reshaped_sample, patch_size=patch_size, target=reshaped_target
247
+ )
248
+
249
+ if patch_targets is None:
250
+ raise ValueError("No target extracted.")
251
+
252
+ logger.info(f"Extracted {patches.shape[0]} patches from input array.")
253
+
254
+ return PatchedOutput(
255
+ patches,
256
+ patch_targets,
257
+ Stats(image_means, image_stds),
258
+ Stats(target_means, target_stds),
259
+ )
260
+
261
+
262
+ # called by in memory dataset
263
+ def prepare_patches_unsupervised_array(
264
+ data: NDArray,
265
+ axes: str,
266
+ patch_size: Union[list[int], tuple[int]],
267
+ ) -> PatchedOutput:
268
+ """
269
+ Iterate over data source and create an array of patches.
270
+
271
+ This method expects an array of shape SC(Z)YX, where S and C can be singleton
272
+ dimensions.
273
+
274
+ Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
275
+
276
+ Parameters
277
+ ----------
278
+ data : numpy.ndarray
279
+ Input data array.
280
+ axes : str
281
+ Axes of the data.
282
+ patch_size : list or tuple of int
283
+ Size of the patches.
284
+
285
+ Returns
286
+ -------
287
+ PatchedOutput
288
+ Dataclass holding the patches and their statistics.
289
+ """
290
+ # reshape array
291
+ reshaped_sample = reshape_array(data, axes)
292
+
293
+ # calculate mean and std
294
+ means, stds = compute_normalization_stats(reshaped_sample)
295
+
296
+ # generate patches, return a generator
297
+ patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
298
+
299
+ return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
@@ -0,0 +1,201 @@
1
+ """Random patching utilities."""
2
+
3
+ from typing import Generator, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import zarr
7
+
8
+ from .validate_patch_dimension import validate_patch_dimensions
9
+
10
+
11
+ # TOOD split in testable functions
12
+ def extract_patches_random(
13
+ arr: np.ndarray,
14
+ patch_size: Union[List[int], Tuple[int, ...]],
15
+ target: Optional[np.ndarray] = None,
16
+ seed: Optional[int] = None,
17
+ ) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
18
+ """
19
+ Generate patches from an array in a random manner.
20
+
21
+ The method calculates how many patches the image can be divided into and then
22
+ extracts an equal number of random patches.
23
+
24
+ It returns a generator that yields the following:
25
+
26
+ - patch: np.ndarray, dimension C(Z)YX.
27
+ - target_patch: np.ndarray, dimension C(Z)YX, if the target is present, None
28
+ otherwise.
29
+
30
+ Parameters
31
+ ----------
32
+ arr : np.ndarray
33
+ Input image array.
34
+ patch_size : Tuple[int]
35
+ Patch sizes in each dimension.
36
+ target : Optional[np.ndarray], optional
37
+ Target array, by default None.
38
+ seed : Optional[int], optional
39
+ Random seed, by default None.
40
+
41
+ Yields
42
+ ------
43
+ Generator[np.ndarray, None, None]
44
+ Generator of patches.
45
+ """
46
+ rng = np.random.default_rng(seed=seed)
47
+
48
+ is_3d_patch = len(patch_size) == 3
49
+
50
+ # patches sanity check
51
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
52
+
53
+ # Update patch size to encompass S and C dimensions
54
+ patch_size = [1, arr.shape[1], *patch_size]
55
+
56
+ # iterate over the number of samples (S or T)
57
+ for sample_idx in range(arr.shape[0]):
58
+ # get sample array
59
+ sample: np.ndarray = arr[sample_idx, ...]
60
+
61
+ # same for target
62
+ if target is not None:
63
+ target_sample: np.ndarray = target[sample_idx, ...]
64
+
65
+ # calculate the number of patches
66
+ n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
67
+
68
+ # iterate over the number of patches
69
+ for _ in range(n_patches):
70
+ # get crop coordinates
71
+ crop_coords = [
72
+ rng.integers(0, sample.shape[i] - patch_size[1:][i], endpoint=True)
73
+ for i in range(len(patch_size[1:]))
74
+ ]
75
+
76
+ # extract patch
77
+ patch = (
78
+ sample[
79
+ (
80
+ ..., # type: ignore
81
+ *[ # type: ignore
82
+ slice(c, c + patch_size[1:][i])
83
+ for i, c in enumerate(crop_coords)
84
+ ],
85
+ )
86
+ ]
87
+ .copy()
88
+ .astype(np.float32)
89
+ )
90
+
91
+ # same for target
92
+ if target is not None:
93
+ target_patch = (
94
+ target_sample[
95
+ (
96
+ ..., # type: ignore
97
+ *[ # type: ignore
98
+ slice(c, c + patch_size[1:][i])
99
+ for i, c in enumerate(crop_coords)
100
+ ],
101
+ )
102
+ ]
103
+ .copy()
104
+ .astype(np.float32)
105
+ )
106
+ # return patch and target patch
107
+ yield patch, target_patch
108
+ else:
109
+ # return patch
110
+ yield patch, None
111
+
112
+
113
+ def extract_patches_random_from_chunks(
114
+ arr: zarr.Array,
115
+ patch_size: Union[List[int], Tuple[int, ...]],
116
+ chunk_size: Union[List[int], Tuple[int, ...]],
117
+ chunk_limit: Optional[int] = None,
118
+ seed: Optional[int] = None,
119
+ ) -> Generator[np.ndarray, None, None]:
120
+ """
121
+ Generate patches from an array in a random manner.
122
+
123
+ The method calculates how many patches the image can be divided into and then
124
+ extracts an equal number of random patches.
125
+
126
+ Parameters
127
+ ----------
128
+ arr : np.ndarray
129
+ Input image array.
130
+ patch_size : Union[List[int], Tuple[int, ...]]
131
+ Patch sizes in each dimension.
132
+ chunk_size : Union[List[int], Tuple[int, ...]]
133
+ Chunk sizes to load from the.
134
+ chunk_limit : Optional[int], optional
135
+ Number of chunks to load, by default None.
136
+ seed : Optional[int], optional
137
+ Random seed, by default None.
138
+
139
+ Yields
140
+ ------
141
+ Generator[np.ndarray, None, None]
142
+ Generator of patches.
143
+ """
144
+ is_3d_patch = len(patch_size) == 3
145
+
146
+ # Patches sanity check
147
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
148
+
149
+ rng = np.random.default_rng(seed=seed)
150
+ num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
151
+
152
+ # Iterate over num chunks in the array
153
+ for _ in range(num_chunks):
154
+ chunk_crop_coords = [
155
+ rng.integers(0, max(0, arr.shape[i] - chunk_size[i]), endpoint=True)
156
+ for i in range(len(chunk_size))
157
+ ]
158
+ chunk = arr[
159
+ (
160
+ ...,
161
+ *[slice(c, c + chunk_size[i]) for i, c in enumerate(chunk_crop_coords)],
162
+ )
163
+ ].squeeze()
164
+
165
+ # Add a singleton dimension if the chunk does not have a sample dimension
166
+ if len(chunk.shape) == len(patch_size):
167
+ chunk = np.expand_dims(chunk, axis=0)
168
+
169
+ # Iterate over num samples (S)
170
+ for sample_idx in range(chunk.shape[0]):
171
+ spatial_chunk = chunk[sample_idx]
172
+ assert len(spatial_chunk.shape) == len(
173
+ patch_size
174
+ ), "Requested chunk shape is not equal to patch size"
175
+
176
+ n_patches = np.ceil(
177
+ np.prod(spatial_chunk.shape) / np.prod(patch_size)
178
+ ).astype(int)
179
+
180
+ # Iterate over the number of patches
181
+ for _ in range(n_patches):
182
+ patch_crop_coords = [
183
+ rng.integers(
184
+ 0, spatial_chunk.shape[i] - patch_size[i], endpoint=True
185
+ )
186
+ for i in range(len(patch_size))
187
+ ]
188
+ patch = (
189
+ spatial_chunk[
190
+ (
191
+ ...,
192
+ *[
193
+ slice(c, c + patch_size[i])
194
+ for i, c in enumerate(patch_crop_coords)
195
+ ],
196
+ )
197
+ ]
198
+ .copy()
199
+ .astype(np.float32)
200
+ )
201
+ yield patch
@@ -0,0 +1,212 @@
1
+ """Sequential patching functions."""
2
+
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from skimage.util import view_as_windows
7
+
8
+ from .validate_patch_dimension import validate_patch_dimensions
9
+
10
+
11
+ def _compute_number_of_patches(
12
+ arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
13
+ ) -> Tuple[int, ...]:
14
+ """
15
+ Compute the number of patches that fit in each dimension.
16
+
17
+ Parameters
18
+ ----------
19
+ arr_shape : Tuple[int, ...]
20
+ Shape of the input array.
21
+ patch_sizes : Union[List[int], Tuple[int, ...]
22
+ Shape of the patches.
23
+
24
+ Returns
25
+ -------
26
+ Tuple[int, ...]
27
+ Number of patches in each dimension.
28
+ """
29
+ if len(arr_shape) != len(patch_sizes):
30
+ raise ValueError(
31
+ f"Array shape {arr_shape} and patch size {patch_sizes} should have the "
32
+ f"same dimension, including singleton dimension for S and equal dimension "
33
+ f"for C."
34
+ )
35
+
36
+ try:
37
+ n_patches = [
38
+ np.ceil(arr_shape[i] / patch_sizes[i]).astype(int)
39
+ for i in range(len(patch_sizes))
40
+ ]
41
+ except IndexError as e:
42
+ raise ValueError(
43
+ f"Patch size {patch_sizes} is not compatible with array shape {arr_shape}"
44
+ ) from e
45
+
46
+ return tuple(n_patches)
47
+
48
+
49
+ def _compute_overlap(
50
+ arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
51
+ ) -> Tuple[int, ...]:
52
+ """
53
+ Compute the overlap between patches in each dimension.
54
+
55
+ If the array dimensions are divisible by the patch sizes, then the overlap is
56
+ 0. Otherwise, it is the result of the division rounded to the upper value.
57
+
58
+ Parameters
59
+ ----------
60
+ arr_shape : Tuple[int, ...]
61
+ Input array shape.
62
+ patch_sizes : Union[List[int], Tuple[int, ...]]
63
+ Size of the patches.
64
+
65
+ Returns
66
+ -------
67
+ Tuple[int, ...]
68
+ Overlap between patches in each dimension.
69
+ """
70
+ n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
71
+
72
+ overlap = [
73
+ np.ceil(
74
+ np.clip(n_patches[i] * patch_sizes[i] - arr_shape[i], 0, None)
75
+ / max(1, (n_patches[i] - 1))
76
+ ).astype(int)
77
+ for i in range(len(patch_sizes))
78
+ ]
79
+ return tuple(overlap)
80
+
81
+
82
+ def _compute_patch_steps(
83
+ patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
84
+ ) -> Tuple[int, ...]:
85
+ """
86
+ Compute steps between patches.
87
+
88
+ Parameters
89
+ ----------
90
+ patch_sizes : Tuple[int]
91
+ Size of the patches.
92
+ overlaps : Tuple[int]
93
+ Overlap between patches.
94
+
95
+ Returns
96
+ -------
97
+ Tuple[int]
98
+ Steps between patches.
99
+ """
100
+ steps = [
101
+ min(patch_sizes[i] - overlaps[i], patch_sizes[i])
102
+ for i in range(len(patch_sizes))
103
+ ]
104
+ return tuple(steps)
105
+
106
+
107
+ # TODO why stack the target here and not on a different dimension before this function?
108
+ def _compute_patch_views(
109
+ arr: np.ndarray,
110
+ window_shape: List[int],
111
+ step: Tuple[int, ...],
112
+ output_shape: List[int],
113
+ target: Optional[np.ndarray] = None,
114
+ ) -> np.ndarray:
115
+ """
116
+ Compute views of an array corresponding to patches.
117
+
118
+ Parameters
119
+ ----------
120
+ arr : np.ndarray
121
+ Array from which the views are extracted.
122
+ window_shape : Tuple[int]
123
+ Shape of the views.
124
+ step : Tuple[int]
125
+ Steps between views.
126
+ output_shape : Tuple[int]
127
+ Shape of the output array.
128
+ target : Optional[np.ndarray], optional
129
+ Target array, by default None.
130
+
131
+ Returns
132
+ -------
133
+ np.ndarray
134
+ Array with views dimension.
135
+ """
136
+ rng = np.random.default_rng()
137
+
138
+ if target is not None:
139
+ arr = np.stack([arr, target], axis=0)
140
+ window_shape = [arr.shape[0], *window_shape]
141
+ step = (arr.shape[0], *step)
142
+ output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
143
+
144
+ patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
145
+ *output_shape
146
+ )
147
+ rng.shuffle(patches, axis=0)
148
+ return patches
149
+
150
+
151
+ def extract_patches_sequential(
152
+ arr: np.ndarray,
153
+ patch_size: Union[List[int], Tuple[int, ...]],
154
+ target: Optional[np.ndarray] = None,
155
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
156
+ """
157
+ Generate patches from an array in a sequential manner.
158
+
159
+ Array dimensions should be SC(Z)YX, where S and C can be singleton dimensions. The
160
+ patches are generated sequentially and cover the whole array.
161
+
162
+ Parameters
163
+ ----------
164
+ arr : np.ndarray
165
+ Input image array.
166
+ patch_size : Tuple[int]
167
+ Patch sizes in each dimension.
168
+ target : Optional[np.ndarray], optional
169
+ Target array, by default None.
170
+
171
+ Returns
172
+ -------
173
+ Tuple[np.ndarray, Optional[np.ndarray]]
174
+ Patches.
175
+ """
176
+ is_3d_patch = len(patch_size) == 3
177
+
178
+ # Patches sanity check
179
+ validate_patch_dimensions(arr, patch_size, is_3d_patch)
180
+
181
+ # Update patch size to encompass S and C dimensions
182
+ patch_size = [1, arr.shape[1], *patch_size]
183
+
184
+ # Compute overlap
185
+ overlaps = _compute_overlap(arr_shape=arr.shape, patch_sizes=patch_size)
186
+
187
+ # Create view window and overlaps
188
+ window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
189
+
190
+ output_shape = [
191
+ -1,
192
+ ] + patch_size[1:]
193
+
194
+ # Generate a view of the input array containing pre-calculated number of patches
195
+ # in each dimension with overlap.
196
+ # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches, C, Y, X)
197
+ patches = _compute_patch_views(
198
+ arr,
199
+ window_shape=patch_size,
200
+ step=window_steps,
201
+ output_shape=output_shape,
202
+ target=target,
203
+ )
204
+
205
+ if target is not None:
206
+ # target was concatenated to patches in _compute_reshaped_view
207
+ return (
208
+ patches[:, 0, ...],
209
+ patches[:, 1, ...],
210
+ ) # TODO in _compute_reshaped_view?
211
+ else:
212
+ return patches, None