careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (134) hide show
  1. careamics/__init__.py +16 -4
  2. careamics/callbacks/__init__.py +6 -0
  3. careamics/callbacks/hyperparameters_callback.py +42 -0
  4. careamics/callbacks/progress_bar_callback.py +57 -0
  5. careamics/careamist.py +761 -0
  6. careamics/config/__init__.py +31 -3
  7. careamics/config/algorithm_model.py +167 -0
  8. careamics/config/architectures/__init__.py +17 -0
  9. careamics/config/architectures/architecture_model.py +29 -0
  10. careamics/config/architectures/custom_model.py +150 -0
  11. careamics/config/architectures/register_model.py +101 -0
  12. careamics/config/architectures/unet_model.py +96 -0
  13. careamics/config/architectures/vae_model.py +39 -0
  14. careamics/config/callback_model.py +92 -0
  15. careamics/config/configuration_example.py +89 -0
  16. careamics/config/configuration_factory.py +597 -0
  17. careamics/config/configuration_model.py +597 -0
  18. careamics/config/data_model.py +555 -0
  19. careamics/config/inference_model.py +283 -0
  20. careamics/config/noise_models.py +162 -0
  21. careamics/config/optimizer_models.py +181 -0
  22. careamics/config/references/__init__.py +45 -0
  23. careamics/config/references/algorithm_descriptions.py +131 -0
  24. careamics/config/references/references.py +38 -0
  25. careamics/config/support/__init__.py +33 -0
  26. careamics/config/support/supported_activations.py +24 -0
  27. careamics/config/support/supported_algorithms.py +18 -0
  28. careamics/config/support/supported_architectures.py +18 -0
  29. careamics/config/support/supported_data.py +82 -0
  30. careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
  31. careamics/config/support/supported_loggers.py +8 -0
  32. careamics/config/support/supported_losses.py +25 -0
  33. careamics/config/support/supported_optimizers.py +55 -0
  34. careamics/config/support/supported_pixel_manipulations.py +15 -0
  35. careamics/config/support/supported_struct_axis.py +19 -0
  36. careamics/config/support/supported_transforms.py +23 -0
  37. careamics/config/tile_information.py +104 -0
  38. careamics/config/training_model.py +65 -0
  39. careamics/config/transformations/__init__.py +14 -0
  40. careamics/config/transformations/n2v_manipulate_model.py +63 -0
  41. careamics/config/transformations/nd_flip_model.py +32 -0
  42. careamics/config/transformations/normalize_model.py +31 -0
  43. careamics/config/transformations/transform_model.py +44 -0
  44. careamics/config/transformations/xy_random_rotate90_model.py +29 -0
  45. careamics/config/validators/__init__.py +5 -0
  46. careamics/config/validators/validator_utils.py +100 -0
  47. careamics/conftest.py +26 -0
  48. careamics/dataset/__init__.py +5 -0
  49. careamics/dataset/dataset_utils/__init__.py +19 -0
  50. careamics/dataset/dataset_utils/dataset_utils.py +100 -0
  51. careamics/dataset/dataset_utils/file_utils.py +140 -0
  52. careamics/dataset/dataset_utils/read_tiff.py +61 -0
  53. careamics/dataset/dataset_utils/read_utils.py +25 -0
  54. careamics/dataset/dataset_utils/read_zarr.py +56 -0
  55. careamics/dataset/in_memory_dataset.py +323 -134
  56. careamics/dataset/iterable_dataset.py +416 -0
  57. careamics/dataset/patching/__init__.py +8 -0
  58. careamics/dataset/patching/patch_transform.py +44 -0
  59. careamics/dataset/patching/patching.py +212 -0
  60. careamics/dataset/patching/random_patching.py +190 -0
  61. careamics/dataset/patching/sequential_patching.py +206 -0
  62. careamics/dataset/patching/tiled_patching.py +158 -0
  63. careamics/dataset/patching/validate_patch_dimension.py +60 -0
  64. careamics/dataset/zarr_dataset.py +149 -0
  65. careamics/lightning_datamodule.py +743 -0
  66. careamics/lightning_module.py +292 -0
  67. careamics/lightning_prediction_datamodule.py +396 -0
  68. careamics/lightning_prediction_loop.py +116 -0
  69. careamics/losses/__init__.py +4 -1
  70. careamics/losses/loss_factory.py +24 -14
  71. careamics/losses/losses.py +65 -5
  72. careamics/losses/noise_model_factory.py +40 -0
  73. careamics/losses/noise_models.py +524 -0
  74. careamics/model_io/__init__.py +8 -0
  75. careamics/model_io/bioimage/__init__.py +11 -0
  76. careamics/model_io/bioimage/_readme_factory.py +120 -0
  77. careamics/model_io/bioimage/bioimage_utils.py +48 -0
  78. careamics/model_io/bioimage/model_description.py +318 -0
  79. careamics/model_io/bmz_io.py +231 -0
  80. careamics/model_io/model_io_utils.py +80 -0
  81. careamics/models/__init__.py +4 -1
  82. careamics/models/activation.py +35 -0
  83. careamics/models/layers.py +244 -0
  84. careamics/models/model_factory.py +21 -221
  85. careamics/models/unet.py +46 -20
  86. careamics/prediction/__init__.py +1 -3
  87. careamics/prediction/stitch_prediction.py +73 -0
  88. careamics/transforms/__init__.py +41 -0
  89. careamics/transforms/n2v_manipulate.py +113 -0
  90. careamics/transforms/nd_flip.py +93 -0
  91. careamics/transforms/normalize.py +109 -0
  92. careamics/transforms/pixel_manipulation.py +383 -0
  93. careamics/transforms/struct_mask_parameters.py +18 -0
  94. careamics/transforms/tta.py +74 -0
  95. careamics/transforms/xy_random_rotate90.py +95 -0
  96. careamics/utils/__init__.py +10 -12
  97. careamics/utils/base_enum.py +32 -0
  98. careamics/utils/context.py +22 -2
  99. careamics/utils/metrics.py +0 -46
  100. careamics/utils/path_utils.py +24 -0
  101. careamics/utils/ram.py +13 -0
  102. careamics/utils/receptive_field.py +102 -0
  103. careamics/utils/running_stats.py +43 -0
  104. careamics/utils/torch_utils.py +112 -75
  105. careamics-0.1.0rc4.dist-info/METADATA +122 -0
  106. careamics-0.1.0rc4.dist-info/RECORD +110 -0
  107. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
  108. careamics/bioimage/__init__.py +0 -15
  109. careamics/bioimage/docs/Noise2Void.md +0 -5
  110. careamics/bioimage/docs/__init__.py +0 -1
  111. careamics/bioimage/io.py +0 -182
  112. careamics/bioimage/rdf.py +0 -105
  113. careamics/config/algorithm.py +0 -231
  114. careamics/config/config.py +0 -297
  115. careamics/config/config_filter.py +0 -44
  116. careamics/config/data.py +0 -194
  117. careamics/config/torch_optim.py +0 -118
  118. careamics/config/training.py +0 -534
  119. careamics/dataset/dataset_utils.py +0 -111
  120. careamics/dataset/patching.py +0 -492
  121. careamics/dataset/prepare_dataset.py +0 -175
  122. careamics/dataset/tiff_dataset.py +0 -212
  123. careamics/engine.py +0 -1014
  124. careamics/manipulation/__init__.py +0 -4
  125. careamics/manipulation/pixel_manipulation.py +0 -158
  126. careamics/prediction/prediction_utils.py +0 -106
  127. careamics/utils/ascii_logo.txt +0 -9
  128. careamics/utils/augment.py +0 -65
  129. careamics/utils/normalization.py +0 -55
  130. careamics/utils/validators.py +0 -170
  131. careamics/utils/wandb.py +0 -121
  132. careamics-0.1.0rc2.dist-info/METADATA +0 -81
  133. careamics-0.1.0rc2.dist-info/RECORD +0 -47
  134. {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
@@ -1,492 +0,0 @@
1
- """
2
- Tiling submodule.
3
-
4
- These functions are used to tile images into patches or tiles.
5
- """
6
- import itertools
7
- from typing import Generator, List, Optional, Tuple, Union
8
-
9
- import numpy as np
10
- from skimage.util import view_as_windows
11
-
12
- from careamics.utils.logging import get_logger
13
-
14
- from .extraction_strategy import ExtractionStrategy
15
-
16
- logger = get_logger(__name__)
17
-
18
-
19
- def _compute_number_of_patches(
20
- arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
21
- ) -> Tuple[int, ...]:
22
- """
23
- Compute the number of patches that fit in each dimension.
24
-
25
- Array must have one dimension more than the patches (C dimension).
26
-
27
- Parameters
28
- ----------
29
- arr : np.ndarray
30
- Input array.
31
- patch_sizes : Tuple[int]
32
- Size of the patches.
33
-
34
- Returns
35
- -------
36
- Tuple[int]
37
- Number of patches in each dimension.
38
- """
39
- n_patches = [
40
- np.ceil(arr.shape[i + 1] / patch_sizes[i]).astype(int)
41
- for i in range(len(patch_sizes))
42
- ]
43
- return tuple(n_patches)
44
-
45
-
46
- def _compute_overlap(
47
- arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
48
- ) -> Tuple[int, ...]:
49
- """
50
- Compute the overlap between patches in each dimension.
51
-
52
- Array must be of dimensions C(Z)YX, and patches must be of dimensions YX or ZYX.
53
- If the array dimensions are divisible by the patch sizes, then the overlap is 0.
54
- Otherwise, it is the result of the division rounded to the upper value.
55
-
56
- Parameters
57
- ----------
58
- arr : np.ndarray
59
- Input array 3 or 4 dimensions.
60
- patch_sizes : Tuple[int]
61
- Size of the patches.
62
-
63
- Returns
64
- -------
65
- Tuple[int]
66
- Overlap between patches in each dimension.
67
- """
68
- n_patches = _compute_number_of_patches(arr, patch_sizes)
69
-
70
- overlap = [
71
- np.ceil(
72
- np.clip(n_patches[i] * patch_sizes[i] - arr.shape[i + 1], 0, None)
73
- / max(1, (n_patches[i] - 1))
74
- ).astype(int)
75
- for i in range(len(patch_sizes))
76
- ]
77
- return tuple(overlap)
78
-
79
-
80
- def _compute_crop_and_stitch_coords_1d(
81
- axis_size: int, tile_size: int, overlap: int
82
- ) -> Tuple[List[Tuple[int, int]], ...]:
83
- """
84
- Compute the coordinates of each tile along an axis, given the overlap.
85
-
86
- Parameters
87
- ----------
88
- axis_size : int
89
- Length of the axis.
90
- tile_size : int
91
- Size of the tile for the given axis.
92
- overlap : int
93
- Size of the overlap for the given axis.
94
-
95
- Returns
96
- -------
97
- Tuple[Tuple[int]]
98
- Tuple of all coordinates for given axis.
99
- """
100
- # Compute the step between tiles
101
- step = tile_size - overlap
102
- crop_coords = []
103
- stitch_coords = []
104
- overlap_crop_coords = []
105
- # Iterate over the axis with a certain step
106
- for i in range(0, axis_size - overlap, step):
107
- # Check if the tile fits within the axis
108
- if i + tile_size <= axis_size:
109
- # Add the coordinates to crop one tile
110
- crop_coords.append((i, i + tile_size))
111
- # Add the pixel coordinates of the cropped tile in the original image space
112
- stitch_coords.append(
113
- (
114
- i + overlap // 2 if i > 0 else 0,
115
- i + tile_size - overlap // 2
116
- if crop_coords[-1][1] < axis_size
117
- else axis_size,
118
- )
119
- )
120
- # Add the coordinates to crop the overlap from the prediction.
121
- overlap_crop_coords.append(
122
- (
123
- overlap // 2 if i > 0 else 0,
124
- tile_size - overlap // 2
125
- if crop_coords[-1][1] < axis_size
126
- else tile_size,
127
- )
128
- )
129
- # If the tile does not fit within the axis, perform the abovementioned
130
- # operations starting from the end of the axis
131
- else:
132
- # if (axis_size - tile_size, axis_size) not in crop_coords:
133
- crop_coords.append((axis_size - tile_size, axis_size))
134
- last_tile_end_coord = stitch_coords[-1][1]
135
- stitch_coords.append((last_tile_end_coord, axis_size))
136
- overlap_crop_coords.append(
137
- (tile_size - (axis_size - last_tile_end_coord), tile_size)
138
- )
139
- break
140
- return crop_coords, stitch_coords, overlap_crop_coords
141
-
142
-
143
- def _compute_patch_steps(
144
- patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
145
- ) -> Tuple[int, ...]:
146
- """
147
- Compute steps between patches.
148
-
149
- Parameters
150
- ----------
151
- patch_sizes : Tuple[int]
152
- Size of the patches.
153
- overlaps : Tuple[int]
154
- Overlap between patches.
155
-
156
- Returns
157
- -------
158
- Tuple[int]
159
- Steps between patches.
160
- """
161
- steps = [
162
- min(patch_sizes[i] - overlaps[i], patch_sizes[i])
163
- for i in range(len(patch_sizes))
164
- ]
165
- return tuple(steps)
166
-
167
-
168
- def _compute_reshaped_view(
169
- arr: np.ndarray,
170
- window_shape: Tuple[int, ...],
171
- step: Tuple[int, ...],
172
- output_shape: Tuple[int, ...],
173
- ) -> np.ndarray:
174
- """
175
- Compute reshaped views of an array, where views correspond to patches.
176
-
177
- Parameters
178
- ----------
179
- arr : np.ndarray
180
- Array from which the views are extracted.
181
- window_shape : Tuple[int]
182
- Shape of the views.
183
- step : Tuple[int]
184
- Steps between views.
185
- output_shape : Tuple[int]
186
- Shape of the output array.
187
-
188
- Returns
189
- -------
190
- np.ndarray
191
- Array with views dimension.
192
- """
193
- rng = np.random.default_rng()
194
- patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
195
- *output_shape
196
- )
197
- rng.shuffle(patches, axis=0)
198
- return patches
199
-
200
-
201
- def _patches_sanity_check(
202
- arr: np.ndarray,
203
- patch_size: Union[List[int], Tuple[int, ...]],
204
- is_3d_patch: bool,
205
- ) -> None:
206
- """
207
- Check patch size and array compatibility.
208
-
209
- This method validates the patch sizes with respect to the array dimensions:
210
- - The patch sizes must have one dimension fewer than the array (C dimension).
211
- - Chack that patch sizes are smaller than array dimensions.
212
-
213
- Parameters
214
- ----------
215
- arr : np.ndarray
216
- Input array.
217
- patch_size : Union[List[int], Tuple[int, ...]]
218
- Size of the patches along each dimension of the array, except the first.
219
- is_3d_patch : bool
220
- Whether the patch is 3D or not.
221
-
222
- Raises
223
- ------
224
- ValueError
225
- If the patch size is not consistent with the array shape (one more array
226
- dimension).
227
- ValueError
228
- If the patch size in Z is larger than the array dimension.
229
- ValueError
230
- If either of the patch sizes in X or Y is larger than the corresponding array
231
- dimension.
232
- """
233
- if len(patch_size) != len(arr.shape[1:]):
234
- raise ValueError(
235
- f"There must be a patch size for each spatial dimensions "
236
- f"(got {patch_size} patches for dims {arr.shape})."
237
- )
238
-
239
- # Sanity checks on patch sizes versus array dimension
240
- if is_3d_patch and patch_size[0] > arr.shape[-3]:
241
- raise ValueError(
242
- f"Z patch size is inconsistent with image shape "
243
- f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
244
- )
245
-
246
- if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
247
- raise ValueError(
248
- f"At least one of YX patch dimensions is inconsistent with image shape "
249
- f"(got {patch_size} patches for dims {arr.shape[-2:]})."
250
- )
251
-
252
-
253
- # formerly :
254
- # in dataloader.py#L52, 00d536c
255
- def _extract_patches_sequential(
256
- arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
257
- ) -> Generator[np.ndarray, None, None]:
258
- """
259
- Generate patches from an array in a sequential manner.
260
-
261
- Array dimensions should be C(Z)YX, where C can be a singleton dimension. The patches
262
- are generated sequentially and cover the whole array.
263
-
264
- Parameters
265
- ----------
266
- arr : np.ndarray
267
- Input image array.
268
- patch_size : Tuple[int]
269
- Patch sizes in each dimension.
270
-
271
- Returns
272
- -------
273
- Generator[np.ndarray, None, None]
274
- Generator of patches.
275
- """
276
- # Patches sanity check
277
- is_3d_patch = len(patch_size) == 3
278
-
279
- _patches_sanity_check(arr, patch_size, is_3d_patch)
280
-
281
- # Compute overlap
282
- overlaps = _compute_overlap(arr=arr, patch_sizes=patch_size)
283
-
284
- # Create view window and overlaps
285
- window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
286
-
287
- # Correct for first dimension for computing windowed views
288
- window_shape = (1, *patch_size)
289
- window_steps = (1, *window_steps)
290
-
291
- if is_3d_patch and patch_size[0] == 1:
292
- output_shape = (-1,) + window_shape[1:]
293
- else:
294
- output_shape = (-1, *window_shape)
295
-
296
- # Generate a view of the input array containing pre-calculated number of patches
297
- # in each dimension with overlap.
298
- # Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches,C, Y, X)
299
- patches = _compute_reshaped_view(
300
- arr, window_shape=window_shape, step=window_steps, output_shape=output_shape
301
- )
302
- logger.info(f"Extracted {patches.shape[0]} patches from input array.")
303
-
304
- # return a generator of patches
305
- return (patches[i, ...] for i in range(patches.shape[0]))
306
-
307
-
308
- def _extract_patches_random(
309
- arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
310
- ) -> Generator[np.ndarray, None, None]:
311
- """
312
- Generate patches from an array in a random manner.
313
-
314
- The method calculates how many patches the image can be divided into and then
315
- extracts an equal number of random patches.
316
-
317
- Parameters
318
- ----------
319
- arr : np.ndarray
320
- Input image array.
321
- patch_size : Tuple[int]
322
- Patch sizes in each dimension.
323
-
324
- Yields
325
- ------
326
- Generator[np.ndarray, None, None]
327
- Generator of patches.
328
- """
329
- is_3d_patch = len(patch_size) == 3
330
-
331
- # Patches sanity check
332
- _patches_sanity_check(arr, patch_size, is_3d_patch)
333
-
334
- rng = np.random.default_rng()
335
- # shuffle the array along the first axis TODO do we need shuffling?
336
- rng.shuffle(arr, axis=0)
337
-
338
- for sample_idx in range(arr.shape[0]):
339
- sample = arr[sample_idx]
340
- # calculate how many number of patches can image area be divided into
341
- n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
342
- for _ in range(n_patches):
343
- crop_coords = [
344
- rng.integers(0, arr.shape[i + 1] - patch_size[i])
345
- for i in range(len(patch_size))
346
- ]
347
- patch = (
348
- sample[
349
- (
350
- ...,
351
- *[
352
- slice(c, c + patch_size[i])
353
- for i, c in enumerate(crop_coords)
354
- ],
355
- )
356
- ]
357
- .copy()
358
- .astype(np.float32)
359
- )
360
- yield patch
361
-
362
-
363
- def _extract_tiles(
364
- arr: np.ndarray,
365
- tile_size: Union[List[int], Tuple[int]],
366
- overlaps: Union[List[int], Tuple[int]],
367
- ) -> Generator:
368
- """
369
- Generate tiles from the input array with specified overlap.
370
-
371
- The tiles cover the whole array.
372
-
373
- Parameters
374
- ----------
375
- arr : np.ndarray
376
- Array of shape (S, (Z), Y, X).
377
- tile_size : Union[List[int], Tuple[int]]
378
- Tile sizes in each dimension, of length 2 or 3.
379
- overlaps : Union[List[int], Tuple[int]]
380
- Overlap values in each dimension, of length 2 or 3.
381
-
382
- Yields
383
- ------
384
- Generator
385
- Tile generator that yields the tile with corresponding coordinates to stitch
386
- back the tiles together.
387
- """
388
- # Iterate over num samples (S)
389
- for sample_idx in range(arr.shape[0]):
390
- sample = arr[sample_idx]
391
-
392
- # Create an array of coordinates for cropping and stitching all axes.
393
- # Shape: (axes, type_of_coord, tile_num, start/end coord)
394
- crop_and_stitch_coords_list = [
395
- _compute_crop_and_stitch_coords_1d(
396
- sample.shape[i], tile_size[i], overlaps[i]
397
- )
398
- for i in range(len(tile_size))
399
- ]
400
-
401
- # Rearrange crop coordinates from a list of coordinate pairs per axis to a list
402
- # grouped by type.
403
- # For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
404
- # will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)]),
405
- # where the first list is crop coordinates for 1st axis.
406
- all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
407
- *crop_and_stitch_coords_list
408
- )
409
-
410
- # Iterate over generated coordinate pairs:
411
- for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
412
- zip(
413
- itertools.product(*all_crop_coords),
414
- itertools.product(*all_stitch_coords),
415
- itertools.product(*all_overlap_crop_coords),
416
- )
417
- ):
418
- tile = sample[(..., *[slice(c[0], c[1]) for c in list(crop_coords)])]
419
-
420
- # Check if we are at the end of the sample.
421
- # To check that we compute the length of the array that contains all the
422
- # tiles
423
- if tile_idx == np.prod([len(axis) for axis in all_crop_coords]) - 1:
424
- last_tile = True
425
- else:
426
- last_tile = False
427
- yield (
428
- np.expand_dims(tile.astype(np.float32), 0),
429
- last_tile,
430
- arr.shape[1:],
431
- overlap_crop_coords,
432
- stitch_coords,
433
- )
434
-
435
-
436
- def generate_patches(
437
- sample: np.ndarray,
438
- patch_extraction_method: ExtractionStrategy,
439
- patch_size: Optional[Union[List[int], Tuple[int]]] = None,
440
- patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
441
- ) -> Generator[np.ndarray, None, None]:
442
- """
443
- Generate patches from a sample.
444
-
445
- Parameters
446
- ----------
447
- sample : np.ndarray
448
- Input array.
449
- patch_extraction_method : ExtractionStrategies
450
- Patch extraction method, as defined in extraction_strategy.ExtractionStrategy.
451
- patch_size : Optional[Union[List[int], Tuple[int]]]
452
- Size of the patches along each dimension of the array, except the first.
453
- patch_overlap : Optional[Union[List[int], Tuple[int]]]
454
- Overlap between patches.
455
-
456
- Returns
457
- -------
458
- Generator[np.ndarray, None, None]
459
- Generator yielding patches/tiles.
460
-
461
- Raises
462
- ------
463
- ValueError
464
- If overlap is not specified when using tiling.
465
- ValueError
466
- If patches is None.
467
- """
468
- patches = None
469
-
470
- if patch_size is not None:
471
- patches = None
472
-
473
- if patch_extraction_method == ExtractionStrategy.TILED:
474
- if patch_overlap is None:
475
- raise ValueError(
476
- "Overlaps must be specified when using tiling (got None)."
477
- )
478
- patches = _extract_tiles(
479
- arr=sample, tile_size=patch_size, overlaps=patch_overlap
480
- )
481
-
482
- elif patch_extraction_method == ExtractionStrategy.SEQUENTIAL:
483
- patches = _extract_patches_sequential(sample, patch_size=patch_size)
484
-
485
- else:
486
- # random patching
487
- patches = _extract_patches_random(sample, patch_size=patch_size)
488
-
489
- return patches
490
- else:
491
- # no patching, return a generator for the sample
492
- return (sample for _ in range(1))
@@ -1,175 +0,0 @@
1
- """
2
- Dataset preparation module.
3
-
4
- Methods to set up the datasets for training, validation and prediction.
5
- """
6
- from pathlib import Path
7
- from typing import List, Optional, Union
8
-
9
- from careamics.config import Configuration
10
- from careamics.manipulation import default_manipulate
11
- from careamics.utils import check_tiling_validity
12
-
13
- from .extraction_strategy import ExtractionStrategy
14
- from .in_memory_dataset import InMemoryDataset
15
- from .tiff_dataset import TiffDataset
16
-
17
-
18
- def get_train_dataset(
19
- config: Configuration, train_path: str
20
- ) -> Union[TiffDataset, InMemoryDataset]:
21
- """
22
- Create training dataset.
23
-
24
- Depending on the configuration, this methods return either a TiffDataset or an
25
- InMemoryDataset.
26
-
27
- Parameters
28
- ----------
29
- config : Configuration
30
- Configuration.
31
- train_path : Union[str, Path]
32
- Path to training data.
33
-
34
- Returns
35
- -------
36
- Union[TiffDataset, InMemoryDataset]
37
- Dataset.
38
- """
39
- if config.data.in_memory:
40
- dataset = InMemoryDataset(
41
- data_path=train_path,
42
- data_format=config.data.data_format,
43
- axes=config.data.axes,
44
- mean=config.data.mean,
45
- std=config.data.std,
46
- patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
47
- patch_size=config.training.patch_size,
48
- patch_transform=default_manipulate,
49
- patch_transform_params={
50
- "mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
51
- "roi_size": config.algorithm.roi_size,
52
- },
53
- )
54
- else:
55
- dataset = TiffDataset(
56
- data_path=train_path,
57
- data_format=config.data.data_format,
58
- axes=config.data.axes,
59
- mean=config.data.mean,
60
- std=config.data.std,
61
- patch_extraction_method=ExtractionStrategy.RANDOM,
62
- patch_size=config.training.patch_size,
63
- patch_transform=default_manipulate,
64
- patch_transform_params={
65
- "mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
66
- "roi_size": config.algorithm.roi_size,
67
- },
68
- )
69
- return dataset
70
-
71
-
72
- def get_validation_dataset(config: Configuration, val_path: str) -> InMemoryDataset:
73
- """
74
- Create validation dataset.
75
-
76
- Validation dataset is kept in memory.
77
-
78
- Parameters
79
- ----------
80
- config : Configuration
81
- Configuration.
82
- val_path : Union[str, Path]
83
- Path to validation data.
84
-
85
- Returns
86
- -------
87
- TiffDataset
88
- In memory dataset.
89
- """
90
- data_path = val_path
91
-
92
- dataset = InMemoryDataset(
93
- data_path=data_path,
94
- data_format=config.data.data_format,
95
- axes=config.data.axes,
96
- mean=config.data.mean,
97
- std=config.data.std,
98
- patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
99
- patch_size=config.training.patch_size,
100
- patch_transform=default_manipulate,
101
- patch_transform_params={
102
- "mask_pixel_percentage": config.algorithm.masked_pixel_percentage
103
- },
104
- )
105
-
106
- return dataset
107
-
108
-
109
- def get_prediction_dataset(
110
- config: Configuration,
111
- pred_path: Union[str, Path],
112
- *,
113
- tile_shape: Optional[List[int]] = None,
114
- overlaps: Optional[List[int]] = None,
115
- axes: Optional[str] = None,
116
- ) -> TiffDataset:
117
- """
118
- Create prediction dataset.
119
-
120
- To use tiling, both `tile_shape` and `overlaps` must be specified, have same
121
- length, be divisible by 2 and greater than 0. Finally, the overlaps must be
122
- smaller than the tiles.
123
-
124
- By default, axes are extracted from the configuration. To use images with
125
- different axes, set the `axes` parameter. Note that the difference between
126
- configuration and parameter axes must be S or T, but not any of the spatial
127
- dimensions (e.g. 2D vs 3D).
128
-
129
- Parameters
130
- ----------
131
- config : Configuration
132
- Configuration.
133
- pred_path : Union[str, Path]
134
- Path to prediction data.
135
- tile_shape : Optional[List[int]], optional
136
- 2D or 3D shape of the tiles, by default None.
137
- overlaps : Optional[List[int]], optional
138
- 2D or 3D overlaps between tiles, by default None.
139
- axes : Optional[str], optional
140
- Axes of the data, by default None.
141
-
142
- Returns
143
- -------
144
- TiffDataset
145
- Dataset.
146
- """
147
- use_tiling = False # default value
148
-
149
- # Validate tiles and overlaps
150
- if tile_shape is not None and overlaps is not None:
151
- check_tiling_validity(tile_shape, overlaps)
152
-
153
- # Use tiling
154
- use_tiling = True
155
-
156
- # Extraction strategy
157
- if use_tiling:
158
- patch_extraction_method = ExtractionStrategy.TILED
159
- else:
160
- patch_extraction_method = None
161
-
162
- # Create dataset
163
- dataset = TiffDataset(
164
- data_path=pred_path,
165
- data_format=config.data.data_format,
166
- axes=config.data.axes if axes is None else axes, # supersede axes
167
- mean=config.data.mean,
168
- std=config.data.std,
169
- patch_size=tile_shape,
170
- patch_overlap=overlaps,
171
- patch_extraction_method=patch_extraction_method,
172
- patch_transform=None,
173
- )
174
-
175
- return dataset