careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (48) hide show
  1. careamics/__init__.py +7 -1
  2. careamics/bioimage/__init__.py +15 -0
  3. careamics/bioimage/docs/Noise2Void.md +5 -0
  4. careamics/bioimage/docs/__init__.py +1 -0
  5. careamics/bioimage/io.py +182 -0
  6. careamics/bioimage/rdf.py +105 -0
  7. careamics/config/__init__.py +11 -0
  8. careamics/config/algorithm.py +231 -0
  9. careamics/config/config.py +297 -0
  10. careamics/config/config_filter.py +44 -0
  11. careamics/config/data.py +194 -0
  12. careamics/config/torch_optim.py +118 -0
  13. careamics/config/training.py +534 -0
  14. careamics/dataset/__init__.py +1 -0
  15. careamics/dataset/dataset_utils.py +111 -0
  16. careamics/dataset/extraction_strategy.py +21 -0
  17. careamics/dataset/in_memory_dataset.py +202 -0
  18. careamics/dataset/patching.py +492 -0
  19. careamics/dataset/prepare_dataset.py +175 -0
  20. careamics/dataset/tiff_dataset.py +212 -0
  21. careamics/engine.py +1014 -0
  22. careamics/losses/__init__.py +4 -0
  23. careamics/losses/loss_factory.py +38 -0
  24. careamics/losses/losses.py +34 -0
  25. careamics/manipulation/__init__.py +4 -0
  26. careamics/manipulation/pixel_manipulation.py +158 -0
  27. careamics/models/__init__.py +4 -0
  28. careamics/models/layers.py +152 -0
  29. careamics/models/model_factory.py +251 -0
  30. careamics/models/unet.py +322 -0
  31. careamics/prediction/__init__.py +9 -0
  32. careamics/prediction/prediction_utils.py +106 -0
  33. careamics/utils/__init__.py +20 -0
  34. careamics/utils/ascii_logo.txt +9 -0
  35. careamics/utils/augment.py +65 -0
  36. careamics/utils/context.py +45 -0
  37. careamics/utils/logging.py +321 -0
  38. careamics/utils/metrics.py +160 -0
  39. careamics/utils/normalization.py +55 -0
  40. careamics/utils/torch_utils.py +89 -0
  41. careamics/utils/validators.py +170 -0
  42. careamics/utils/wandb.py +121 -0
  43. careamics-0.1.0rc2.dist-info/METADATA +81 -0
  44. careamics-0.1.0rc2.dist-info/RECORD +47 -0
  45. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
  46. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
  47. careamics-0.0.1.dist-info/METADATA +0 -46
  48. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,202 @@
1
+ """In-memory dataset module."""
2
+ from pathlib import Path
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from careamics.utils import normalize
9
+ from careamics.utils.logging import get_logger
10
+
11
+ from .dataset_utils import (
12
+ list_files,
13
+ read_tiff,
14
+ )
15
+ from .extraction_strategy import ExtractionStrategy
16
+ from .patching import generate_patches
17
+
18
+ logger = get_logger(__name__)
19
+
20
+
21
+ class InMemoryDataset(torch.utils.data.Dataset):
22
+ """
23
+ Dataset storing data in memory and allowing generating patches from it.
24
+
25
+ Parameters
26
+ ----------
27
+ data_path : Union[str, Path]
28
+ Path to the data, must be a directory.
29
+ data_format : str
30
+ Extension of the data files, without period.
31
+ axes : str
32
+ Description of axes in format STCZYX.
33
+ patch_extraction_method : ExtractionStrategies
34
+ Patch extraction strategy, as defined in extraction_strategy.
35
+ patch_size : Union[List[int], Tuple[int]]
36
+ Size of the patches along each axis, must be of dimension 2 or 3.
37
+ patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
38
+ Overlap of the patches, must be of dimension 2 or 3, by default None.
39
+ mean : Optional[float], optional
40
+ Expected mean of the dataset, by default None.
41
+ std : Optional[float], optional
42
+ Expected standard deviation of the dataset, by default None.
43
+ patch_transform : Optional[Callable], optional
44
+ Patch transform to apply, by default None.
45
+ patch_transform_params : Optional[Dict], optional
46
+ Patch transform parameters, by default None.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ data_path: Union[str, Path],
52
+ data_format: str,
53
+ axes: str,
54
+ patch_extraction_method: ExtractionStrategy,
55
+ patch_size: Union[List[int], Tuple[int]],
56
+ patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
57
+ mean: Optional[float] = None,
58
+ std: Optional[float] = None,
59
+ patch_transform: Optional[Callable] = None,
60
+ patch_transform_params: Optional[Dict] = None,
61
+ ) -> None:
62
+ """
63
+ Constructor.
64
+
65
+ Parameters
66
+ ----------
67
+ data_path : Union[str, Path]
68
+ Path to the data, must be a directory.
69
+ data_format : str
70
+ Extension of the data files, without period.
71
+ axes : str
72
+ Description of axes in format STCZYX.
73
+ patch_extraction_method : ExtractionStrategies
74
+ Patch extraction strategy, as defined in extraction_strategy.
75
+ patch_size : Union[List[int], Tuple[int]]
76
+ Size of the patches along each axis, must be of dimension 2 or 3.
77
+ patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
78
+ Overlap of the patches, must be of dimension 2 or 3, by default None.
79
+ mean : Optional[float], optional
80
+ Expected mean of the dataset, by default None.
81
+ std : Optional[float], optional
82
+ Expected standard deviation of the dataset, by default None.
83
+ patch_transform : Optional[Callable], optional
84
+ Patch transform to apply, by default None.
85
+ patch_transform_params : Optional[Dict], optional
86
+ Patch transform parameters, by default None.
87
+
88
+ Raises
89
+ ------
90
+ ValueError
91
+ If data_path is not a directory.
92
+ """
93
+ self.data_path = Path(data_path)
94
+ if not self.data_path.is_dir():
95
+ raise ValueError("Path to data should be an existing folder.")
96
+
97
+ self.data_format = data_format
98
+ self.axes = axes
99
+
100
+ self.patch_transform = patch_transform
101
+
102
+ self.files = list_files(self.data_path, self.data_format)
103
+
104
+ self.patch_size = patch_size
105
+ self.patch_overlap = patch_overlap
106
+ self.patch_extraction_method = patch_extraction_method
107
+ self.patch_transform = patch_transform
108
+ self.patch_transform_params = patch_transform_params
109
+
110
+ self.mean = mean
111
+ self.std = std
112
+
113
+ # Generate patches
114
+ self.data, computed_mean, computed_std = self._prepare_patches()
115
+
116
+ if not mean or not std:
117
+ self.mean, self.std = computed_mean, computed_std
118
+ logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
119
+
120
+ assert self.mean is not None
121
+ assert self.std is not None
122
+
123
+ def _prepare_patches(self) -> Tuple[np.ndarray, float, float]:
124
+ """
125
+ Iterate over data source and create an array of patches.
126
+
127
+ Returns
128
+ -------
129
+ np.ndarray
130
+ Array of patches.
131
+ """
132
+ means, stds, num_samples = 0, 0, 0
133
+ self.all_patches = []
134
+ for filename in self.files:
135
+ sample = read_tiff(filename, self.axes)
136
+ means += sample.mean()
137
+ stds += np.std(sample)
138
+ num_samples += 1
139
+
140
+ # generate patches, return a generator
141
+ patches = generate_patches(
142
+ sample,
143
+ self.patch_extraction_method,
144
+ self.patch_size,
145
+ self.patch_overlap,
146
+ )
147
+
148
+ # convert generator to list and add to all_patches
149
+ self.all_patches.extend(list(patches))
150
+
151
+ result_mean, result_std = means / num_samples, stds / num_samples
152
+ return np.concatenate(self.all_patches), result_mean, result_std
153
+
154
+ def __len__(self) -> int:
155
+ """
156
+ Return the length of the dataset.
157
+
158
+ Returns
159
+ -------
160
+ int
161
+ Length of the dataset.
162
+ """
163
+ # convert to numpy array to convince mypy that it is not a generator
164
+ return sum(np.array(s).shape[0] for s in self.all_patches)
165
+
166
+ def __getitem__(self, index: int) -> Tuple[np.ndarray]:
167
+ """
168
+ Return the patch corresponding to the provided index.
169
+
170
+ Parameters
171
+ ----------
172
+ index : int
173
+ Index of the patch to return.
174
+
175
+ Returns
176
+ -------
177
+ Tuple[np.ndarray]
178
+ Patch.
179
+
180
+ Raises
181
+ ------
182
+ ValueError
183
+ If dataset mean and std are not set.
184
+ """
185
+ patch = self.data[index].squeeze()
186
+
187
+ if self.mean is not None and self.std is not None:
188
+ if isinstance(patch, tuple):
189
+ patch = normalize(img=patch[0], mean=self.mean, std=self.std)
190
+ patch = (patch, *patch[1:])
191
+ else:
192
+ patch = normalize(img=patch, mean=self.mean, std=self.std)
193
+
194
+ if self.patch_transform is not None:
195
+ # replace None self.patch_transform_params with empty dict
196
+ if self.patch_transform_params is None:
197
+ self.patch_transform_params = {}
198
+
199
+ patch = self.patch_transform(patch, **self.patch_transform_params)
200
+ return patch
201
+ else:
202
+ raise ValueError("Dataset mean and std must be set before using it.")
@@ -0,0 +1,492 @@
1
+ """
2
+ Tiling submodule.
3
+
4
+ These functions are used to tile images into patches or tiles.
5
+ """
6
+ import itertools
7
+ from typing import Generator, List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ from skimage.util import view_as_windows
11
+
12
+ from careamics.utils.logging import get_logger
13
+
14
+ from .extraction_strategy import ExtractionStrategy
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ def _compute_number_of_patches(
20
+ arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
21
+ ) -> Tuple[int, ...]:
22
+ """
23
+ Compute the number of patches that fit in each dimension.
24
+
25
+ Array must have one dimension more than the patches (C dimension).
26
+
27
+ Parameters
28
+ ----------
29
+ arr : np.ndarray
30
+ Input array.
31
+ patch_sizes : Tuple[int]
32
+ Size of the patches.
33
+
34
+ Returns
35
+ -------
36
+ Tuple[int]
37
+ Number of patches in each dimension.
38
+ """
39
+ n_patches = [
40
+ np.ceil(arr.shape[i + 1] / patch_sizes[i]).astype(int)
41
+ for i in range(len(patch_sizes))
42
+ ]
43
+ return tuple(n_patches)
44
+
45
+
46
+ def _compute_overlap(
47
+ arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
48
+ ) -> Tuple[int, ...]:
49
+ """
50
+ Compute the overlap between patches in each dimension.
51
+
52
+ Array must be of dimensions C(Z)YX, and patches must be of dimensions YX or ZYX.
53
+ If the array dimensions are divisible by the patch sizes, then the overlap is 0.
54
+ Otherwise, it is the result of the division rounded to the upper value.
55
+
56
+ Parameters
57
+ ----------
58
+ arr : np.ndarray
59
+ Input array 3 or 4 dimensions.
60
+ patch_sizes : Tuple[int]
61
+ Size of the patches.
62
+
63
+ Returns
64
+ -------
65
+ Tuple[int]
66
+ Overlap between patches in each dimension.
67
+ """
68
+ n_patches = _compute_number_of_patches(arr, patch_sizes)
69
+
70
+ overlap = [
71
+ np.ceil(
72
+ np.clip(n_patches[i] * patch_sizes[i] - arr.shape[i + 1], 0, None)
73
+ / max(1, (n_patches[i] - 1))
74
+ ).astype(int)
75
+ for i in range(len(patch_sizes))
76
+ ]
77
+ return tuple(overlap)
78
+
79
+
80
+ def _compute_crop_and_stitch_coords_1d(
81
+ axis_size: int, tile_size: int, overlap: int
82
+ ) -> Tuple[List[Tuple[int, int]], ...]:
83
+ """
84
+ Compute the coordinates of each tile along an axis, given the overlap.
85
+
86
+ Parameters
87
+ ----------
88
+ axis_size : int
89
+ Length of the axis.
90
+ tile_size : int
91
+ Size of the tile for the given axis.
92
+ overlap : int
93
+ Size of the overlap for the given axis.
94
+
95
+ Returns
96
+ -------
97
+ Tuple[Tuple[int]]
98
+ Tuple of all coordinates for given axis.
99
+ """
100
+ # Compute the step between tiles
101
+ step = tile_size - overlap
102
+ crop_coords = []
103
+ stitch_coords = []
104
+ overlap_crop_coords = []
105
+ # Iterate over the axis with a certain step
106
+ for i in range(0, axis_size - overlap, step):
107
+ # Check if the tile fits within the axis
108
+ if i + tile_size <= axis_size:
109
+ # Add the coordinates to crop one tile
110
+ crop_coords.append((i, i + tile_size))
111
+ # Add the pixel coordinates of the cropped tile in the original image space
112
+ stitch_coords.append(
113
+ (
114
+ i + overlap // 2 if i > 0 else 0,
115
+ i + tile_size - overlap // 2
116
+ if crop_coords[-1][1] < axis_size
117
+ else axis_size,
118
+ )
119
+ )
120
+ # Add the coordinates to crop the overlap from the prediction.
121
+ overlap_crop_coords.append(
122
+ (
123
+ overlap // 2 if i > 0 else 0,
124
+ tile_size - overlap // 2
125
+ if crop_coords[-1][1] < axis_size
126
+ else tile_size,
127
+ )
128
+ )
129
+ # If the tile does not fit within the axis, perform the abovementioned
130
+ # operations starting from the end of the axis
131
+ else:
132
+ # if (axis_size - tile_size, axis_size) not in crop_coords:
133
+ crop_coords.append((axis_size - tile_size, axis_size))
134
+ last_tile_end_coord = stitch_coords[-1][1]
135
+ stitch_coords.append((last_tile_end_coord, axis_size))
136
+ overlap_crop_coords.append(
137
+ (tile_size - (axis_size - last_tile_end_coord), tile_size)
138
+ )
139
+ break
140
+ return crop_coords, stitch_coords, overlap_crop_coords
141
+
142
+
143
+ def _compute_patch_steps(
144
+ patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
145
+ ) -> Tuple[int, ...]:
146
+ """
147
+ Compute steps between patches.
148
+
149
+ Parameters
150
+ ----------
151
+ patch_sizes : Tuple[int]
152
+ Size of the patches.
153
+ overlaps : Tuple[int]
154
+ Overlap between patches.
155
+
156
+ Returns
157
+ -------
158
+ Tuple[int]
159
+ Steps between patches.
160
+ """
161
+ steps = [
162
+ min(patch_sizes[i] - overlaps[i], patch_sizes[i])
163
+ for i in range(len(patch_sizes))
164
+ ]
165
+ return tuple(steps)
166
+
167
+
168
+ def _compute_reshaped_view(
169
+ arr: np.ndarray,
170
+ window_shape: Tuple[int, ...],
171
+ step: Tuple[int, ...],
172
+ output_shape: Tuple[int, ...],
173
+ ) -> np.ndarray:
174
+ """
175
+ Compute reshaped views of an array, where views correspond to patches.
176
+
177
+ Parameters
178
+ ----------
179
+ arr : np.ndarray
180
+ Array from which the views are extracted.
181
+ window_shape : Tuple[int]
182
+ Shape of the views.
183
+ step : Tuple[int]
184
+ Steps between views.
185
+ output_shape : Tuple[int]
186
+ Shape of the output array.
187
+
188
+ Returns
189
+ -------
190
+ np.ndarray
191
+ Array with views dimension.
192
+ """
193
+ rng = np.random.default_rng()
194
+ patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
195
+ *output_shape
196
+ )
197
+ rng.shuffle(patches, axis=0)
198
+ return patches
199
+
200
+
201
+ def _patches_sanity_check(
202
+ arr: np.ndarray,
203
+ patch_size: Union[List[int], Tuple[int, ...]],
204
+ is_3d_patch: bool,
205
+ ) -> None:
206
+ """
207
+ Check patch size and array compatibility.
208
+
209
+ This method validates the patch sizes with respect to the array dimensions:
210
+ - The patch sizes must have one dimension fewer than the array (C dimension).
211
+ - Chack that patch sizes are smaller than array dimensions.
212
+
213
+ Parameters
214
+ ----------
215
+ arr : np.ndarray
216
+ Input array.
217
+ patch_size : Union[List[int], Tuple[int, ...]]
218
+ Size of the patches along each dimension of the array, except the first.
219
+ is_3d_patch : bool
220
+ Whether the patch is 3D or not.
221
+
222
+ Raises
223
+ ------
224
+ ValueError
225
+ If the patch size is not consistent with the array shape (one more array
226
+ dimension).
227
+ ValueError
228
+ If the patch size in Z is larger than the array dimension.
229
+ ValueError
230
+ If either of the patch sizes in X or Y is larger than the corresponding array
231
+ dimension.
232
+ """
233
+ if len(patch_size) != len(arr.shape[1:]):
234
+ raise ValueError(
235
+ f"There must be a patch size for each spatial dimensions "
236
+ f"(got {patch_size} patches for dims {arr.shape})."
237
+ )
238
+
239
+ # Sanity checks on patch sizes versus array dimension
240
+ if is_3d_patch and patch_size[0] > arr.shape[-3]:
241
+ raise ValueError(
242
+ f"Z patch size is inconsistent with image shape "
243
+ f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
244
+ )
245
+
246
+ if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
247
+ raise ValueError(
248
+ f"At least one of YX patch dimensions is inconsistent with image shape "
249
+ f"(got {patch_size} patches for dims {arr.shape[-2:]})."
250
+ )
251
+
252
+
253
+ # formerly :
254
+ # in dataloader.py#L52, 00d536c
255
+ def _extract_patches_sequential(
256
+ arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
257
+ ) -> Generator[np.ndarray, None, None]:
258
+ """
259
+ Generate patches from an array in a sequential manner.
260
+
261
+ Array dimensions should be C(Z)YX, where C can be a singleton dimension. The patches
262
+ are generated sequentially and cover the whole array.
263
+
264
+ Parameters
265
+ ----------
266
+ arr : np.ndarray
267
+ Input image array.
268
+ patch_size : Tuple[int]
269
+ Patch sizes in each dimension.
270
+
271
+ Returns
272
+ -------
273
+ Generator[np.ndarray, None, None]
274
+ Generator of patches.
275
+ """
276
+ # Patches sanity check
277
+ is_3d_patch = len(patch_size) == 3
278
+
279
+ _patches_sanity_check(arr, patch_size, is_3d_patch)
280
+
281
+ # Compute overlap
282
+ overlaps = _compute_overlap(arr=arr, patch_sizes=patch_size)
283
+
284
+ # Create view window and overlaps
285
+ window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
286
+
287
+ # Correct for first dimension for computing windowed views
288
+ window_shape = (1, *patch_size)
289
+ window_steps = (1, *window_steps)
290
+
291
+ if is_3d_patch and patch_size[0] == 1:
292
+ output_shape = (-1,) + window_shape[1:]
293
+ else:
294
+ output_shape = (-1, *window_shape)
295
+
296
+ # Generate a view of the input array containing pre-calculated number of patches
297
+ # in each dimension with overlap.
298
+ # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches,C, Y, X)
299
+ patches = _compute_reshaped_view(
300
+ arr, window_shape=window_shape, step=window_steps, output_shape=output_shape
301
+ )
302
+ logger.info(f"Extracted {patches.shape[0]} patches from input array.")
303
+
304
+ # return a generator of patches
305
+ return (patches[i, ...] for i in range(patches.shape[0]))
306
+
307
+
308
+ def _extract_patches_random(
309
+ arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
310
+ ) -> Generator[np.ndarray, None, None]:
311
+ """
312
+ Generate patches from an array in a random manner.
313
+
314
+ The method calculates how many patches the image can be divided into and then
315
+ extracts an equal number of random patches.
316
+
317
+ Parameters
318
+ ----------
319
+ arr : np.ndarray
320
+ Input image array.
321
+ patch_size : Tuple[int]
322
+ Patch sizes in each dimension.
323
+
324
+ Yields
325
+ ------
326
+ Generator[np.ndarray, None, None]
327
+ Generator of patches.
328
+ """
329
+ is_3d_patch = len(patch_size) == 3
330
+
331
+ # Patches sanity check
332
+ _patches_sanity_check(arr, patch_size, is_3d_patch)
333
+
334
+ rng = np.random.default_rng()
335
+ # shuffle the array along the first axis TODO do we need shuffling?
336
+ rng.shuffle(arr, axis=0)
337
+
338
+ for sample_idx in range(arr.shape[0]):
339
+ sample = arr[sample_idx]
340
+ # calculate how many number of patches can image area be divided into
341
+ n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
342
+ for _ in range(n_patches):
343
+ crop_coords = [
344
+ rng.integers(0, arr.shape[i + 1] - patch_size[i])
345
+ for i in range(len(patch_size))
346
+ ]
347
+ patch = (
348
+ sample[
349
+ (
350
+ ...,
351
+ *[
352
+ slice(c, c + patch_size[i])
353
+ for i, c in enumerate(crop_coords)
354
+ ],
355
+ )
356
+ ]
357
+ .copy()
358
+ .astype(np.float32)
359
+ )
360
+ yield patch
361
+
362
+
363
+ def _extract_tiles(
364
+ arr: np.ndarray,
365
+ tile_size: Union[List[int], Tuple[int]],
366
+ overlaps: Union[List[int], Tuple[int]],
367
+ ) -> Generator:
368
+ """
369
+ Generate tiles from the input array with specified overlap.
370
+
371
+ The tiles cover the whole array.
372
+
373
+ Parameters
374
+ ----------
375
+ arr : np.ndarray
376
+ Array of shape (S, (Z), Y, X).
377
+ tile_size : Union[List[int], Tuple[int]]
378
+ Tile sizes in each dimension, of length 2 or 3.
379
+ overlaps : Union[List[int], Tuple[int]]
380
+ Overlap values in each dimension, of length 2 or 3.
381
+
382
+ Yields
383
+ ------
384
+ Generator
385
+ Tile generator that yields the tile with corresponding coordinates to stitch
386
+ back the tiles together.
387
+ """
388
+ # Iterate over num samples (S)
389
+ for sample_idx in range(arr.shape[0]):
390
+ sample = arr[sample_idx]
391
+
392
+ # Create an array of coordinates for cropping and stitching all axes.
393
+ # Shape: (axes, type_of_coord, tile_num, start/end coord)
394
+ crop_and_stitch_coords_list = [
395
+ _compute_crop_and_stitch_coords_1d(
396
+ sample.shape[i], tile_size[i], overlaps[i]
397
+ )
398
+ for i in range(len(tile_size))
399
+ ]
400
+
401
+ # Rearrange crop coordinates from a list of coordinate pairs per axis to a list
402
+ # grouped by type.
403
+ # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
404
+ # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)]),
405
+ # where the first list is crop coordinates for 1st axis.
406
+ all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
407
+ *crop_and_stitch_coords_list
408
+ )
409
+
410
+ # Iterate over generated coordinate pairs:
411
+ for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
412
+ zip(
413
+ itertools.product(*all_crop_coords),
414
+ itertools.product(*all_stitch_coords),
415
+ itertools.product(*all_overlap_crop_coords),
416
+ )
417
+ ):
418
+ tile = sample[(..., *[slice(c[0], c[1]) for c in list(crop_coords)])]
419
+
420
+ # Check if we are at the end of the sample.
421
+ # To check that we compute the length of the array that contains all the
422
+ # tiles
423
+ if tile_idx == np.prod([len(axis) for axis in all_crop_coords]) - 1:
424
+ last_tile = True
425
+ else:
426
+ last_tile = False
427
+ yield (
428
+ np.expand_dims(tile.astype(np.float32), 0),
429
+ last_tile,
430
+ arr.shape[1:],
431
+ overlap_crop_coords,
432
+ stitch_coords,
433
+ )
434
+
435
+
436
+ def generate_patches(
437
+ sample: np.ndarray,
438
+ patch_extraction_method: ExtractionStrategy,
439
+ patch_size: Optional[Union[List[int], Tuple[int]]] = None,
440
+ patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
441
+ ) -> Generator[np.ndarray, None, None]:
442
+ """
443
+ Generate patches from a sample.
444
+
445
+ Parameters
446
+ ----------
447
+ sample : np.ndarray
448
+ Input array.
449
+ patch_extraction_method : ExtractionStrategies
450
+ Patch extraction method, as defined in extraction_strategy.ExtractionStrategy.
451
+ patch_size : Optional[Union[List[int], Tuple[int]]]
452
+ Size of the patches along each dimension of the array, except the first.
453
+ patch_overlap : Optional[Union[List[int], Tuple[int]]]
454
+ Overlap between patches.
455
+
456
+ Returns
457
+ -------
458
+ Generator[np.ndarray, None, None]
459
+ Generator yielding patches/tiles.
460
+
461
+ Raises
462
+ ------
463
+ ValueError
464
+ If overlap is not specified when using tiling.
465
+ ValueError
466
+ If patches is None.
467
+ """
468
+ patches = None
469
+
470
+ if patch_size is not None:
471
+ patches = None
472
+
473
+ if patch_extraction_method == ExtractionStrategy.TILED:
474
+ if patch_overlap is None:
475
+ raise ValueError(
476
+ "Overlaps must be specified when using tiling (got None)."
477
+ )
478
+ patches = _extract_tiles(
479
+ arr=sample, tile_size=patch_size, overlaps=patch_overlap
480
+ )
481
+
482
+ elif patch_extraction_method == ExtractionStrategy.SEQUENTIAL:
483
+ patches = _extract_patches_sequential(sample, patch_size=patch_size)
484
+
485
+ else:
486
+ # random patching
487
+ patches = _extract_patches_random(sample, patch_size=patch_size)
488
+
489
+ return patches
490
+ else:
491
+ # no patching, return a generator for the sample
492
+ return (sample for _ in range(1))