careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (132) hide show
  1. careamics/__init__.py +14 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +27 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_factory.py +460 -0
  16. careamics/config/configuration_model.py +596 -0
  17. careamics/config/data_model.py +555 -0
  18. careamics/config/inference_model.py +283 -0
  19. careamics/config/noise_models.py +162 -0
  20. careamics/config/optimizer_models.py +181 -0
  21. careamics/config/references/__init__.py +45 -0
  22. careamics/config/references/algorithm_descriptions.py +131 -0
  23. careamics/config/references/references.py +38 -0
  24. careamics/config/support/__init__.py +33 -0
  25. careamics/config/support/supported_activations.py +24 -0
  26. careamics/config/support/supported_algorithms.py +18 -0
  27. careamics/config/support/supported_architectures.py +18 -0
  28. careamics/config/support/supported_data.py +82 -0
  29. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  30. careamics/config/support/supported_loggers.py +8 -0
  31. careamics/config/support/supported_losses.py +25 -0
  32. careamics/config/support/supported_optimizers.py +55 -0
  33. careamics/config/support/supported_pixel_manipulations.py +15 -0
  34. careamics/config/support/supported_struct_axis.py +19 -0
  35. careamics/config/support/supported_transforms.py +23 -0
  36. careamics/config/tile_information.py +104 -0
  37. careamics/config/training_model.py +65 -0
  38. careamics/config/transformations/__init__.py +14 -0
  39. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  40. careamics/config/transformations/nd_flip_model.py +32 -0
  41. careamics/config/transformations/normalize_model.py +31 -0
  42. careamics/config/transformations/transform_model.py +44 -0
  43. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  44. careamics/config/validators/__init__.py +5 -0
  45. careamics/config/validators/validator_utils.py +100 -0
  46. careamics/conftest.py +26 -0
  47. careamics/dataset/__init__.py +5 -0
  48. careamics/dataset/dataset_utils/__init__.py +19 -0
  49. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  50. careamics/dataset/dataset_utils/file_utils.py +140 -0
  51. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  52. careamics/dataset/dataset_utils/read_utils.py +25 -0
  53. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  54. careamics/dataset/in_memory_dataset.py +321 -131
  55. careamics/dataset/iterable_dataset.py +416 -0
  56. careamics/dataset/patching/__init__.py +8 -0
  57. careamics/dataset/patching/patch_transform.py +44 -0
  58. careamics/dataset/patching/patching.py +212 -0
  59. careamics/dataset/patching/random_patching.py +190 -0
  60. careamics/dataset/patching/sequential_patching.py +206 -0
  61. careamics/dataset/patching/tiled_patching.py +158 -0
  62. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  63. careamics/dataset/zarr_dataset.py +149 -0
  64. careamics/lightning_datamodule.py +665 -0
  65. careamics/lightning_module.py +292 -0
  66. careamics/lightning_prediction_datamodule.py +390 -0
  67. careamics/lightning_prediction_loop.py +116 -0
  68. careamics/losses/__init__.py +4 -1
  69. careamics/losses/loss_factory.py +24 -13
  70. careamics/losses/losses.py +65 -5
  71. careamics/losses/noise_model_factory.py +40 -0
  72. careamics/losses/noise_models.py +524 -0
  73. careamics/model_io/__init__.py +8 -0
  74. careamics/model_io/bioimage/__init__.py +11 -0
  75. careamics/model_io/bioimage/_readme_factory.py +120 -0
  76. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  77. careamics/model_io/bioimage/model_description.py +318 -0
  78. careamics/model_io/bmz_io.py +231 -0
  79. careamics/model_io/model_io_utils.py +80 -0
  80. careamics/models/__init__.py +4 -1
  81. careamics/models/activation.py +35 -0
  82. careamics/models/layers.py +244 -0
  83. careamics/models/model_factory.py +21 -202
  84. careamics/models/unet.py +46 -20
  85. careamics/prediction/__init__.py +1 -3
  86. careamics/prediction/stitch_prediction.py +73 -0
  87. careamics/transforms/__init__.py +41 -0
  88. careamics/transforms/n2v_manipulate.py +113 -0
  89. careamics/transforms/nd_flip.py +93 -0
  90. careamics/transforms/normalize.py +109 -0
  91. careamics/transforms/pixel_manipulation.py +383 -0
  92. careamics/transforms/struct_mask_parameters.py +18 -0
  93. careamics/transforms/tta.py +74 -0
  94. careamics/transforms/xy_random_rotate90.py +95 -0
  95. careamics/utils/__init__.py +10 -13
  96. careamics/utils/base_enum.py +32 -0
  97. careamics/utils/context.py +22 -2
  98. careamics/utils/metrics.py +0 -46
  99. careamics/utils/path_utils.py +24 -0
  100. careamics/utils/ram.py +13 -0
  101. careamics/utils/receptive_field.py +102 -0
  102. careamics/utils/running_stats.py +43 -0
  103. careamics/utils/torch_utils.py +89 -56
  104. careamics-0.1.0rc3.dist-info/METADATA +122 -0
  105. careamics-0.1.0rc3.dist-info/RECORD +109 -0
  106. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
  107. careamics/bioimage/__init__.py +0 -15
  108. careamics/bioimage/docs/Noise2Void.md +0 -5
  109. careamics/bioimage/docs/__init__.py +0 -1
  110. careamics/bioimage/io.py +0 -271
  111. careamics/config/algorithm.py +0 -231
  112. careamics/config/config.py +0 -296
  113. careamics/config/config_filter.py +0 -44
  114. careamics/config/data.py +0 -194
  115. careamics/config/torch_optim.py +0 -118
  116. careamics/config/training.py +0 -534
  117. careamics/dataset/dataset_utils.py +0 -115
  118. careamics/dataset/patching.py +0 -493
  119. careamics/dataset/prepare_dataset.py +0 -174
  120. careamics/dataset/tiff_dataset.py +0 -211
  121. careamics/engine.py +0 -954
  122. careamics/manipulation/__init__.py +0 -4
  123. careamics/manipulation/pixel_manipulation.py +0 -158
  124. careamics/prediction/prediction_utils.py +0 -102
  125. careamics/utils/ascii_logo.txt +0 -9
  126. careamics/utils/augment.py +0 -65
  127. careamics/utils/normalization.py +0 -55
  128. careamics/utils/validators.py +0 -156
  129. careamics/utils/wandb.py +0 -121
  130. careamics-0.1.0rc1.dist-info/METADATA +0 -80
  131. careamics-0.1.0rc1.dist-info/RECORD +0 -46
  132. {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
@@ -1,493 +0,0 @@
1
- """
2
- Tiling submodule.
3
-
4
- These functions are used to tile images into patches or tiles.
5
- """
6
- import itertools
7
- from typing import Generator, List, Optional, Tuple, Union
8
-
9
- import numpy as np
10
- from skimage.util import view_as_windows
11
-
12
- from ..utils.logging import get_logger
13
- from .extraction_strategy import ExtractionStrategy
14
-
15
- logger = get_logger(__name__)
16
-
17
-
18
- def _compute_number_of_patches(
19
- arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
20
- ) -> Tuple[int, ...]:
21
- """
22
- Compute the number of patches that fit in each dimension.
23
-
24
- Array must have one dimension more than the patches (C dimension).
25
-
26
- Parameters
27
- ----------
28
- arr : np.ndarray
29
- Input array.
30
- patch_sizes : Tuple[int]
31
- Size of the patches.
32
-
33
- Returns
34
- -------
35
- Tuple[int]
36
- Number of patches in each dimension.
37
- """
38
- n_patches = [
39
- np.ceil(arr.shape[i + 1] / patch_sizes[i]).astype(int)
40
- for i in range(len(patch_sizes))
41
- ]
42
- return tuple(n_patches)
43
-
44
-
45
- def _compute_overlap(
46
- arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
47
- ) -> Tuple[int, ...]:
48
- """
49
- Compute the overlap between patches in each dimension.
50
-
51
- Array must be of dimensions C(Z)YX, and patches must be of dimensions YX or ZYX.
52
- If the array dimensions are divisible by the patch sizes, then the overlap is 0.
53
- Otherwise, it is the result of the division rounded to the upper value.
54
-
55
- Parameters
56
- ----------
57
- arr : np.ndarray
58
- Input array 3 or 4 dimensions.
59
- patch_sizes : Tuple[int]
60
- Size of the patches.
61
-
62
- Returns
63
- -------
64
- Tuple[int]
65
- Overlap between patches in each dimension.
66
- """
67
- n_patches = _compute_number_of_patches(arr, patch_sizes)
68
-
69
- overlap = [
70
- np.ceil(
71
- np.clip(n_patches[i] * patch_sizes[i] - arr.shape[i + 1], 0, None)
72
- / max(1, (n_patches[i] - 1))
73
- ).astype(int)
74
- for i in range(len(patch_sizes))
75
- ]
76
- return tuple(overlap)
77
-
78
-
79
- def _compute_crop_and_stitch_coords_1d(
80
- axis_size: int, tile_size: int, overlap: int
81
- ) -> Tuple[List[Tuple[int, int]], ...]:
82
- """
83
- Compute the coordinates of each tile along an axis, given the overlap.
84
-
85
- Parameters
86
- ----------
87
- axis_size : int
88
- Length of the axis.
89
- tile_size : int
90
- Size of the tile for the given axis.
91
- overlap : int
92
- Size of the overlap for the given axis.
93
-
94
- Returns
95
- -------
96
- Tuple[Tuple[int]]
97
- Tuple of all coordinates for given axis.
98
- """
99
- # Compute the step between tiles
100
- step = tile_size - overlap
101
- crop_coords = []
102
- stitch_coords = []
103
- overlap_crop_coords = []
104
- # Iterate over the axis with a certain step
105
- for i in range(0, axis_size - overlap, step):
106
- # Check if the tile fits within the axis
107
- if i + tile_size <= axis_size:
108
- # Add the coordinates to crop one tile
109
- crop_coords.append((i, i + tile_size))
110
- # Add the pixel coordinates of the cropped tile in the original image space
111
- stitch_coords.append(
112
- (
113
- i + overlap // 2 if i > 0 else 0,
114
- i + tile_size - overlap // 2
115
- if crop_coords[-1][1] < axis_size
116
- else axis_size,
117
- )
118
- )
119
- # Add the coordinates to crop the overlap from the prediction.
120
- overlap_crop_coords.append(
121
- (
122
- overlap // 2 if i > 0 else 0,
123
- tile_size - overlap // 2
124
- if crop_coords[-1][1] < axis_size
125
- else tile_size,
126
- )
127
- )
128
- # If the tile does not fit within the axis, perform the abovementioned
129
- # operations starting from the end of the axis
130
- else:
131
- # if (axis_size - tile_size, axis_size) not in crop_coords:
132
- crop_coords.append((axis_size - tile_size, axis_size))
133
- last_tile_end_coord = stitch_coords[-1][1]
134
- stitch_coords.append((last_tile_end_coord, axis_size))
135
- overlap_crop_coords.append(
136
- (tile_size - (axis_size - last_tile_end_coord), tile_size)
137
- )
138
- break
139
- return crop_coords, stitch_coords, overlap_crop_coords
140
-
141
-
142
- def _compute_patch_steps(
143
- patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
144
- ) -> Tuple[int, ...]:
145
- """
146
- Compute steps between patches.
147
-
148
- Parameters
149
- ----------
150
- patch_sizes : Tuple[int]
151
- Size of the patches.
152
- overlaps : Tuple[int]
153
- Overlap between patches.
154
-
155
- Returns
156
- -------
157
- Tuple[int]
158
- Steps between patches.
159
- """
160
- steps = [
161
- min(patch_sizes[i] - overlaps[i], patch_sizes[i])
162
- for i in range(len(patch_sizes))
163
- ]
164
- return tuple(steps)
165
-
166
-
167
- def _compute_reshaped_view(
168
- arr: np.ndarray,
169
- window_shape: Tuple[int, ...],
170
- step: Tuple[int, ...],
171
- output_shape: Tuple[int, ...],
172
- ) -> np.ndarray:
173
- """
174
- Compute reshaped views of an array, where views correspond to patches.
175
-
176
- Parameters
177
- ----------
178
- arr : np.ndarray
179
- Array from which the views are extracted.
180
- window_shape : Tuple[int]
181
- Shape of the views.
182
- step : Tuple[int]
183
- Steps between views.
184
- output_shape : Tuple[int]
185
- Shape of the output array.
186
-
187
- Returns
188
- -------
189
- np.ndarray
190
- Array with views dimension.
191
- """
192
- rng = np.random.default_rng()
193
- patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
194
- *output_shape
195
- )
196
- rng.shuffle(patches, axis=0)
197
- return patches
198
-
199
-
200
- def _patches_sanity_check(
201
- arr: np.ndarray,
202
- patch_size: Union[List[int], Tuple[int, ...]],
203
- is_3d_patch: bool,
204
- ) -> None:
205
- """
206
- Check patch size and array compatibility.
207
-
208
- This method validates the patch sizes with respect to the array dimensions:
209
- - The patch sizes must have one dimension fewer than the array (C dimension).
210
- - Chack that patch sizes are smaller than array dimensions.
211
-
212
- Parameters
213
- ----------
214
- arr : np.ndarray
215
- Input array.
216
- patch_size : Union[List[int], Tuple[int, ...]]
217
- Size of the patches along each dimension of the array, except the first.
218
- is_3d_patch : bool
219
- Whether the patch is 3D or not.
220
-
221
- Raises
222
- ------
223
- ValueError
224
- If the patch size is not consistent with the array shape (one more array
225
- dimension).
226
- ValueError
227
- If the patch size in Z is larger than the array dimension.
228
- ValueError
229
- If either of the patch sizes in X or Y is larger than the corresponding array
230
- dimension.
231
- """
232
- if len(patch_size) != len(arr.shape[1:]):
233
- raise ValueError(
234
- f"There must be a patch size for each spatial dimensions "
235
- f"(got {patch_size} patches for dims {arr.shape})."
236
- )
237
-
238
- # Sanity checks on patch sizes versus array dimension
239
- if is_3d_patch and patch_size[0] > arr.shape[-3]:
240
- raise ValueError(
241
- f"Z patch size is inconsistent with image shape "
242
- f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
243
- )
244
-
245
- if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
246
- raise ValueError(
247
- f"At least one of YX patch dimensions is inconsistent with image shape "
248
- f"(got {patch_size} patches for dims {arr.shape[-2:]})."
249
- )
250
-
251
-
252
- # formerly :
253
- # in dataloader.py#L52, 00d536c
254
- def _extract_patches_sequential(
255
- arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
256
- ) -> Generator[np.ndarray, None, None]:
257
- """
258
- Generate patches from an array in a sequential manner.
259
-
260
- Array dimensions should be C(Z)YX, where C can be a singleton dimension. The patches
261
- are generated sequentially and cover the whole array.
262
-
263
- Parameters
264
- ----------
265
- arr : np.ndarray
266
- Input image array.
267
- patch_size : Tuple[int]
268
- Patch sizes in each dimension.
269
-
270
- Returns
271
- -------
272
- Generator[np.ndarray, None, None]
273
- Generator of patches.
274
- """
275
- # Patches sanity check
276
- is_3d_patch = len(patch_size) == 3
277
-
278
- _patches_sanity_check(arr, patch_size, is_3d_patch)
279
-
280
- # Compute overlap
281
- overlaps = _compute_overlap(arr=arr, patch_sizes=patch_size)
282
-
283
- # Create view window and overlaps
284
- window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
285
-
286
- # Correct for first dimension for computing windowed views
287
- window_shape = (1, *patch_size)
288
- window_steps = (1, *window_steps)
289
-
290
- if is_3d_patch and patch_size[0] == 1:
291
- output_shape = (-1,) + window_shape[1:]
292
- else:
293
- output_shape = (-1, *window_shape)
294
-
295
- # Generate a view of the input array containing pre-calculated number of patches
296
- # in each dimension with overlap.
297
- # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches,C, Y, X)
298
- patches = _compute_reshaped_view(
299
- arr, window_shape=window_shape, step=window_steps, output_shape=output_shape
300
- )
301
- logger.info(f"Extracted {patches.shape[0]} patches from input array.")
302
-
303
- # return a generator of patches
304
- return (patches[i, ...] for i in range(patches.shape[0]))
305
-
306
-
307
- def _extract_patches_random(
308
- arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
309
- ) -> Generator[np.ndarray, None, None]:
310
- """
311
- Generate patches from an array in a random manner.
312
-
313
- The method calculates how many patches the image can be divided into and then
314
- extracts an equal number of random patches.
315
-
316
- Parameters
317
- ----------
318
- arr : np.ndarray
319
- Input image array.
320
- patch_size : Tuple[int]
321
- Patch sizes in each dimension.
322
-
323
- Yields
324
- ------
325
- Generator[np.ndarray, None, None]
326
- Generator of patches.
327
- """
328
- is_3d_patch = len(patch_size) == 3
329
-
330
- # Patches sanity check
331
- _patches_sanity_check(arr, patch_size, is_3d_patch)
332
-
333
- rng = np.random.default_rng()
334
- # shuffle the array along the first axis TODO do we need shuffling?
335
- rng.shuffle(arr, axis=0)
336
-
337
- for sample_idx in range(arr.shape[0]):
338
- sample = arr[sample_idx]
339
- # calculate how many number of patches can image area be divided into
340
- n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
341
- for _ in range(n_patches):
342
- crop_coords = [
343
- rng.integers(0, arr.shape[i + 1] - patch_size[i])
344
- for i in range(len(patch_size))
345
- ]
346
- patch = (
347
- sample[
348
- (
349
- ...,
350
- *[
351
- slice(c, c + patch_size[i])
352
- for i, c in enumerate(crop_coords)
353
- ],
354
- )
355
- ]
356
- .copy()
357
- .astype(np.float32)
358
- )
359
- yield patch
360
-
361
-
362
- def _extract_tiles(
363
- arr: np.ndarray,
364
- tile_size: Union[List[int], Tuple[int]],
365
- overlaps: Union[List[int], Tuple[int]],
366
- ) -> Generator:
367
- """
368
- Generate tiles from the input array with specified overlap.
369
-
370
- The tiles cover the whole array.
371
-
372
- Parameters
373
- ----------
374
- arr : np.ndarray
375
- Array of shape (S, (Z), Y, X).
376
- tile_size : Union[List[int], Tuple[int]]
377
- Tile sizes in each dimension, of length 2 or 3.
378
- overlaps : Union[List[int], Tuple[int]]
379
- Overlap values in each dimension, of length 2 or 3.
380
-
381
- Yields
382
- ------
383
- Generator
384
- Tile generator that yields the tile with corresponding coordinates to stitch
385
- back the tiles together.
386
- """
387
- # Iterate over num samples (S)
388
- for sample_idx in range(arr.shape[0]):
389
- sample = arr[sample_idx]
390
-
391
- # Create an array of coordinates for cropping and stitching all axes.
392
- # Shape: (axes, type_of_coord, tile_num, start/end coord)
393
- crop_and_stitch_coords_list = [
394
- _compute_crop_and_stitch_coords_1d(
395
- sample.shape[i], tile_size[i], overlaps[i]
396
- )
397
- for i in range(len(tile_size))
398
- ]
399
-
400
- # Rearrange crop coordinates from a list of coordinate pairs per axis to a list
401
- # grouped by type.
402
- # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
403
- # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)]),
404
- # where the first list is crop coordinates for 1st axis.
405
- all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
406
- *crop_and_stitch_coords_list
407
- )
408
-
409
- # Iterate over generated coordinate pairs:
410
- for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
411
- zip(
412
- itertools.product(*all_crop_coords),
413
- itertools.product(*all_stitch_coords),
414
- itertools.product(*all_overlap_crop_coords),
415
- )
416
- ):
417
- tile = sample[(..., *[slice(c[0], c[1]) for c in list(crop_coords)])]
418
-
419
- # Check if we are at the end of the sample.
420
- # To check that we compute the length of the array that contains all the
421
- # tiles
422
- if tile_idx == np.prod([len(axis) for axis in all_crop_coords]) - 1:
423
- last_tile = True
424
- else:
425
- last_tile = False
426
- yield (
427
- np.expand_dims(tile.astype(np.float32), 0),
428
- last_tile,
429
- arr.shape[1:],
430
- overlap_crop_coords,
431
- stitch_coords,
432
- )
433
-
434
-
435
- def generate_patches(
436
- sample: np.ndarray,
437
- patch_extraction_method: ExtractionStrategy,
438
- patch_size: Optional[Union[List[int], Tuple[int]]] = None,
439
- patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
440
- ) -> Generator[np.ndarray, None, None]:
441
- """
442
- Generate patches from a sample.
443
-
444
- Parameters
445
- ----------
446
- sample : np.ndarray
447
- Input array.
448
- patch_extraction_method : ExtractionStrategies
449
- Patch extraction method, as defined in extraction_strategy.ExtractionStrategy.
450
- patch_size : Optional[Union[List[int], Tuple[int]]]
451
- Size of the patches along each dimension of the array, except the first.
452
- patch_overlap : Optional[Union[List[int], Tuple[int]]]
453
- Overlap between patches.
454
-
455
- Returns
456
- -------
457
- Generator[np.ndarray, None, None]
458
- Generator yielding patches/tiles.
459
-
460
- Raises
461
- ------
462
- ValueError
463
- If overlap is not specified when using tiling.
464
- ValueError
465
- If patches is None.
466
- """
467
- patches = None
468
-
469
- if patch_size is not None:
470
- patches = None
471
-
472
- if patch_extraction_method == ExtractionStrategy.TILED:
473
- if patch_overlap is None:
474
- raise ValueError(
475
- "Overlaps must be specified when using tiling (got None)."
476
- )
477
- patches = _extract_tiles(
478
- arr=sample, tile_size=patch_size, overlaps=patch_overlap
479
- )
480
-
481
- elif patch_extraction_method == ExtractionStrategy.SEQUENTIAL:
482
- patches = _extract_patches_sequential(sample, patch_size=patch_size)
483
-
484
- elif patch_extraction_method == ExtractionStrategy.RANDOM:
485
- patches = _extract_patches_random(sample, patch_size=patch_size)
486
-
487
- if patches is None:
488
- raise ValueError("No patch generated")
489
-
490
- return patches
491
- else:
492
- # no patching
493
- return (sample for _ in range(1))
@@ -1,174 +0,0 @@
1
- """
2
- Dataset preparation module.
3
-
4
- Methods to set up the datasets for training, validation and prediction.
5
- """
6
- from pathlib import Path
7
- from typing import List, Optional, Union
8
-
9
- from ..config import Configuration
10
- from ..manipulation import default_manipulate
11
- from ..utils import check_tiling_validity
12
- from .extraction_strategy import ExtractionStrategy
13
- from .in_memory_dataset import InMemoryDataset
14
- from .tiff_dataset import TiffDataset
15
-
16
-
17
- def get_train_dataset(
18
- config: Configuration, train_path: str
19
- ) -> Union[TiffDataset, InMemoryDataset]:
20
- """
21
- Create training dataset.
22
-
23
- Depending on the configuration, this methods return either a TiffDataset or an
24
- InMemoryDataset.
25
-
26
- Parameters
27
- ----------
28
- config : Configuration
29
- Configuration.
30
- train_path : Union[str, Path]
31
- Path to training data.
32
-
33
- Returns
34
- -------
35
- Union[TiffDataset, InMemoryDataset]
36
- Dataset.
37
- """
38
- if config.data.in_memory:
39
- dataset = InMemoryDataset(
40
- data_path=train_path,
41
- data_format=config.data.data_format,
42
- axes=config.data.axes,
43
- mean=config.data.mean,
44
- std=config.data.std,
45
- patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
46
- patch_size=config.training.patch_size,
47
- patch_transform=default_manipulate,
48
- patch_transform_params={
49
- "mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
50
- "roi_size": config.algorithm.roi_size,
51
- },
52
- )
53
- else:
54
- dataset = TiffDataset(
55
- data_path=train_path,
56
- data_format=config.data.data_format,
57
- axes=config.data.axes,
58
- mean=config.data.mean,
59
- std=config.data.std,
60
- patch_extraction_method=ExtractionStrategy.RANDOM,
61
- patch_size=config.training.patch_size,
62
- patch_transform=default_manipulate,
63
- patch_transform_params={
64
- "mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
65
- "roi_size": config.algorithm.roi_size,
66
- },
67
- )
68
- return dataset
69
-
70
-
71
- def get_validation_dataset(config: Configuration, val_path: str) -> InMemoryDataset:
72
- """
73
- Create validation dataset.
74
-
75
- Validation dataset is kept in memory.
76
-
77
- Parameters
78
- ----------
79
- config : Configuration
80
- Configuration.
81
- val_path : Union[str, Path]
82
- Path to validation data.
83
-
84
- Returns
85
- -------
86
- TiffDataset
87
- In memory dataset.
88
- """
89
- data_path = val_path
90
-
91
- dataset = InMemoryDataset(
92
- data_path=data_path,
93
- data_format=config.data.data_format,
94
- axes=config.data.axes,
95
- mean=config.data.mean,
96
- std=config.data.std,
97
- patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
98
- patch_size=config.training.patch_size,
99
- patch_transform=default_manipulate,
100
- patch_transform_params={
101
- "mask_pixel_percentage": config.algorithm.masked_pixel_percentage
102
- },
103
- )
104
-
105
- return dataset
106
-
107
-
108
- def get_prediction_dataset(
109
- config: Configuration,
110
- pred_path: Union[str, Path],
111
- *,
112
- tile_shape: Optional[List[int]] = None,
113
- overlaps: Optional[List[int]] = None,
114
- axes: Optional[str] = None,
115
- ) -> TiffDataset:
116
- """
117
- Create prediction dataset.
118
-
119
- To use tiling, both `tile_shape` and `overlaps` must be specified, have same
120
- length, be divisible by 2 and greater than 0. Finally, the overlaps must be
121
- smaller than the tiles.
122
-
123
- By default, axes are extracted from the configuration. To use images with
124
- different axes, set the `axes` parameter. Note that the difference between
125
- configuration and parameter axes must be S or T, but not any of the spatial
126
- dimensions (e.g. 2D vs 3D).
127
-
128
- Parameters
129
- ----------
130
- config : Configuration
131
- Configuration.
132
- pred_path : Union[str, Path]
133
- Path to prediction data.
134
- tile_shape : Optional[List[int]], optional
135
- 2D or 3D shape of the tiles, by default None.
136
- overlaps : Optional[List[int]], optional
137
- 2D or 3D overlaps between tiles, by default None.
138
- axes : Optional[str], optional
139
- Axes of the data, by default None.
140
-
141
- Returns
142
- -------
143
- TiffDataset
144
- Dataset.
145
- """
146
- use_tiling = False # default value
147
-
148
- # Validate tiles and overlaps
149
- if tile_shape is not None and overlaps is not None:
150
- check_tiling_validity(tile_shape, overlaps)
151
-
152
- # Use tiling
153
- use_tiling = True
154
-
155
- # Extraction strategy
156
- if use_tiling:
157
- patch_extraction_method = ExtractionStrategy.TILED
158
- else:
159
- patch_extraction_method = None
160
-
161
- # Create dataset
162
- dataset = TiffDataset(
163
- data_path=pred_path,
164
- data_format=config.data.data_format,
165
- axes=config.data.axes if axes is None else axes, # supersede axes
166
- mean=config.data.mean,
167
- std=config.data.std,
168
- patch_size=tile_shape,
169
- patch_overlap=overlaps,
170
- patch_extraction_method=patch_extraction_method,
171
- patch_transform=None,
172
- )
173
-
174
- return dataset