careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +321 -131
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -13
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -202
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -13
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +89 -56
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -271
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -296
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -115
- careamics/dataset/patching.py +0 -493
- careamics/dataset/prepare_dataset.py +0 -174
- careamics/dataset/tiff_dataset.py +0 -211
- careamics/engine.py +0 -954
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -102
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -156
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc1.dist-info/METADATA +0 -80
- careamics-0.1.0rc1.dist-info/RECORD +0 -46
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
careamics/dataset/patching.py
DELETED
|
@@ -1,493 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tiling submodule.
|
|
3
|
-
|
|
4
|
-
These functions are used to tile images into patches or tiles.
|
|
5
|
-
"""
|
|
6
|
-
import itertools
|
|
7
|
-
from typing import Generator, List, Optional, Tuple, Union
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
from skimage.util import view_as_windows
|
|
11
|
-
|
|
12
|
-
from ..utils.logging import get_logger
|
|
13
|
-
from .extraction_strategy import ExtractionStrategy
|
|
14
|
-
|
|
15
|
-
logger = get_logger(__name__)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def _compute_number_of_patches(
|
|
19
|
-
arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
20
|
-
) -> Tuple[int, ...]:
|
|
21
|
-
"""
|
|
22
|
-
Compute the number of patches that fit in each dimension.
|
|
23
|
-
|
|
24
|
-
Array must have one dimension more than the patches (C dimension).
|
|
25
|
-
|
|
26
|
-
Parameters
|
|
27
|
-
----------
|
|
28
|
-
arr : np.ndarray
|
|
29
|
-
Input array.
|
|
30
|
-
patch_sizes : Tuple[int]
|
|
31
|
-
Size of the patches.
|
|
32
|
-
|
|
33
|
-
Returns
|
|
34
|
-
-------
|
|
35
|
-
Tuple[int]
|
|
36
|
-
Number of patches in each dimension.
|
|
37
|
-
"""
|
|
38
|
-
n_patches = [
|
|
39
|
-
np.ceil(arr.shape[i + 1] / patch_sizes[i]).astype(int)
|
|
40
|
-
for i in range(len(patch_sizes))
|
|
41
|
-
]
|
|
42
|
-
return tuple(n_patches)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
def _compute_overlap(
|
|
46
|
-
arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
47
|
-
) -> Tuple[int, ...]:
|
|
48
|
-
"""
|
|
49
|
-
Compute the overlap between patches in each dimension.
|
|
50
|
-
|
|
51
|
-
Array must be of dimensions C(Z)YX, and patches must be of dimensions YX or ZYX.
|
|
52
|
-
If the array dimensions are divisible by the patch sizes, then the overlap is 0.
|
|
53
|
-
Otherwise, it is the result of the division rounded to the upper value.
|
|
54
|
-
|
|
55
|
-
Parameters
|
|
56
|
-
----------
|
|
57
|
-
arr : np.ndarray
|
|
58
|
-
Input array 3 or 4 dimensions.
|
|
59
|
-
patch_sizes : Tuple[int]
|
|
60
|
-
Size of the patches.
|
|
61
|
-
|
|
62
|
-
Returns
|
|
63
|
-
-------
|
|
64
|
-
Tuple[int]
|
|
65
|
-
Overlap between patches in each dimension.
|
|
66
|
-
"""
|
|
67
|
-
n_patches = _compute_number_of_patches(arr, patch_sizes)
|
|
68
|
-
|
|
69
|
-
overlap = [
|
|
70
|
-
np.ceil(
|
|
71
|
-
np.clip(n_patches[i] * patch_sizes[i] - arr.shape[i + 1], 0, None)
|
|
72
|
-
/ max(1, (n_patches[i] - 1))
|
|
73
|
-
).astype(int)
|
|
74
|
-
for i in range(len(patch_sizes))
|
|
75
|
-
]
|
|
76
|
-
return tuple(overlap)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _compute_crop_and_stitch_coords_1d(
|
|
80
|
-
axis_size: int, tile_size: int, overlap: int
|
|
81
|
-
) -> Tuple[List[Tuple[int, int]], ...]:
|
|
82
|
-
"""
|
|
83
|
-
Compute the coordinates of each tile along an axis, given the overlap.
|
|
84
|
-
|
|
85
|
-
Parameters
|
|
86
|
-
----------
|
|
87
|
-
axis_size : int
|
|
88
|
-
Length of the axis.
|
|
89
|
-
tile_size : int
|
|
90
|
-
Size of the tile for the given axis.
|
|
91
|
-
overlap : int
|
|
92
|
-
Size of the overlap for the given axis.
|
|
93
|
-
|
|
94
|
-
Returns
|
|
95
|
-
-------
|
|
96
|
-
Tuple[Tuple[int]]
|
|
97
|
-
Tuple of all coordinates for given axis.
|
|
98
|
-
"""
|
|
99
|
-
# Compute the step between tiles
|
|
100
|
-
step = tile_size - overlap
|
|
101
|
-
crop_coords = []
|
|
102
|
-
stitch_coords = []
|
|
103
|
-
overlap_crop_coords = []
|
|
104
|
-
# Iterate over the axis with a certain step
|
|
105
|
-
for i in range(0, axis_size - overlap, step):
|
|
106
|
-
# Check if the tile fits within the axis
|
|
107
|
-
if i + tile_size <= axis_size:
|
|
108
|
-
# Add the coordinates to crop one tile
|
|
109
|
-
crop_coords.append((i, i + tile_size))
|
|
110
|
-
# Add the pixel coordinates of the cropped tile in the original image space
|
|
111
|
-
stitch_coords.append(
|
|
112
|
-
(
|
|
113
|
-
i + overlap // 2 if i > 0 else 0,
|
|
114
|
-
i + tile_size - overlap // 2
|
|
115
|
-
if crop_coords[-1][1] < axis_size
|
|
116
|
-
else axis_size,
|
|
117
|
-
)
|
|
118
|
-
)
|
|
119
|
-
# Add the coordinates to crop the overlap from the prediction.
|
|
120
|
-
overlap_crop_coords.append(
|
|
121
|
-
(
|
|
122
|
-
overlap // 2 if i > 0 else 0,
|
|
123
|
-
tile_size - overlap // 2
|
|
124
|
-
if crop_coords[-1][1] < axis_size
|
|
125
|
-
else tile_size,
|
|
126
|
-
)
|
|
127
|
-
)
|
|
128
|
-
# If the tile does not fit within the axis, perform the abovementioned
|
|
129
|
-
# operations starting from the end of the axis
|
|
130
|
-
else:
|
|
131
|
-
# if (axis_size - tile_size, axis_size) not in crop_coords:
|
|
132
|
-
crop_coords.append((axis_size - tile_size, axis_size))
|
|
133
|
-
last_tile_end_coord = stitch_coords[-1][1]
|
|
134
|
-
stitch_coords.append((last_tile_end_coord, axis_size))
|
|
135
|
-
overlap_crop_coords.append(
|
|
136
|
-
(tile_size - (axis_size - last_tile_end_coord), tile_size)
|
|
137
|
-
)
|
|
138
|
-
break
|
|
139
|
-
return crop_coords, stitch_coords, overlap_crop_coords
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def _compute_patch_steps(
|
|
143
|
-
patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
|
|
144
|
-
) -> Tuple[int, ...]:
|
|
145
|
-
"""
|
|
146
|
-
Compute steps between patches.
|
|
147
|
-
|
|
148
|
-
Parameters
|
|
149
|
-
----------
|
|
150
|
-
patch_sizes : Tuple[int]
|
|
151
|
-
Size of the patches.
|
|
152
|
-
overlaps : Tuple[int]
|
|
153
|
-
Overlap between patches.
|
|
154
|
-
|
|
155
|
-
Returns
|
|
156
|
-
-------
|
|
157
|
-
Tuple[int]
|
|
158
|
-
Steps between patches.
|
|
159
|
-
"""
|
|
160
|
-
steps = [
|
|
161
|
-
min(patch_sizes[i] - overlaps[i], patch_sizes[i])
|
|
162
|
-
for i in range(len(patch_sizes))
|
|
163
|
-
]
|
|
164
|
-
return tuple(steps)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def _compute_reshaped_view(
|
|
168
|
-
arr: np.ndarray,
|
|
169
|
-
window_shape: Tuple[int, ...],
|
|
170
|
-
step: Tuple[int, ...],
|
|
171
|
-
output_shape: Tuple[int, ...],
|
|
172
|
-
) -> np.ndarray:
|
|
173
|
-
"""
|
|
174
|
-
Compute reshaped views of an array, where views correspond to patches.
|
|
175
|
-
|
|
176
|
-
Parameters
|
|
177
|
-
----------
|
|
178
|
-
arr : np.ndarray
|
|
179
|
-
Array from which the views are extracted.
|
|
180
|
-
window_shape : Tuple[int]
|
|
181
|
-
Shape of the views.
|
|
182
|
-
step : Tuple[int]
|
|
183
|
-
Steps between views.
|
|
184
|
-
output_shape : Tuple[int]
|
|
185
|
-
Shape of the output array.
|
|
186
|
-
|
|
187
|
-
Returns
|
|
188
|
-
-------
|
|
189
|
-
np.ndarray
|
|
190
|
-
Array with views dimension.
|
|
191
|
-
"""
|
|
192
|
-
rng = np.random.default_rng()
|
|
193
|
-
patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
|
|
194
|
-
*output_shape
|
|
195
|
-
)
|
|
196
|
-
rng.shuffle(patches, axis=0)
|
|
197
|
-
return patches
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def _patches_sanity_check(
|
|
201
|
-
arr: np.ndarray,
|
|
202
|
-
patch_size: Union[List[int], Tuple[int, ...]],
|
|
203
|
-
is_3d_patch: bool,
|
|
204
|
-
) -> None:
|
|
205
|
-
"""
|
|
206
|
-
Check patch size and array compatibility.
|
|
207
|
-
|
|
208
|
-
This method validates the patch sizes with respect to the array dimensions:
|
|
209
|
-
- The patch sizes must have one dimension fewer than the array (C dimension).
|
|
210
|
-
- Chack that patch sizes are smaller than array dimensions.
|
|
211
|
-
|
|
212
|
-
Parameters
|
|
213
|
-
----------
|
|
214
|
-
arr : np.ndarray
|
|
215
|
-
Input array.
|
|
216
|
-
patch_size : Union[List[int], Tuple[int, ...]]
|
|
217
|
-
Size of the patches along each dimension of the array, except the first.
|
|
218
|
-
is_3d_patch : bool
|
|
219
|
-
Whether the patch is 3D or not.
|
|
220
|
-
|
|
221
|
-
Raises
|
|
222
|
-
------
|
|
223
|
-
ValueError
|
|
224
|
-
If the patch size is not consistent with the array shape (one more array
|
|
225
|
-
dimension).
|
|
226
|
-
ValueError
|
|
227
|
-
If the patch size in Z is larger than the array dimension.
|
|
228
|
-
ValueError
|
|
229
|
-
If either of the patch sizes in X or Y is larger than the corresponding array
|
|
230
|
-
dimension.
|
|
231
|
-
"""
|
|
232
|
-
if len(patch_size) != len(arr.shape[1:]):
|
|
233
|
-
raise ValueError(
|
|
234
|
-
f"There must be a patch size for each spatial dimensions "
|
|
235
|
-
f"(got {patch_size} patches for dims {arr.shape})."
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
# Sanity checks on patch sizes versus array dimension
|
|
239
|
-
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
240
|
-
raise ValueError(
|
|
241
|
-
f"Z patch size is inconsistent with image shape "
|
|
242
|
-
f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
|
|
243
|
-
)
|
|
244
|
-
|
|
245
|
-
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
246
|
-
raise ValueError(
|
|
247
|
-
f"At least one of YX patch dimensions is inconsistent with image shape "
|
|
248
|
-
f"(got {patch_size} patches for dims {arr.shape[-2:]})."
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
# formerly :
|
|
253
|
-
# in dataloader.py#L52, 00d536c
|
|
254
|
-
def _extract_patches_sequential(
|
|
255
|
-
arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
|
|
256
|
-
) -> Generator[np.ndarray, None, None]:
|
|
257
|
-
"""
|
|
258
|
-
Generate patches from an array in a sequential manner.
|
|
259
|
-
|
|
260
|
-
Array dimensions should be C(Z)YX, where C can be a singleton dimension. The patches
|
|
261
|
-
are generated sequentially and cover the whole array.
|
|
262
|
-
|
|
263
|
-
Parameters
|
|
264
|
-
----------
|
|
265
|
-
arr : np.ndarray
|
|
266
|
-
Input image array.
|
|
267
|
-
patch_size : Tuple[int]
|
|
268
|
-
Patch sizes in each dimension.
|
|
269
|
-
|
|
270
|
-
Returns
|
|
271
|
-
-------
|
|
272
|
-
Generator[np.ndarray, None, None]
|
|
273
|
-
Generator of patches.
|
|
274
|
-
"""
|
|
275
|
-
# Patches sanity check
|
|
276
|
-
is_3d_patch = len(patch_size) == 3
|
|
277
|
-
|
|
278
|
-
_patches_sanity_check(arr, patch_size, is_3d_patch)
|
|
279
|
-
|
|
280
|
-
# Compute overlap
|
|
281
|
-
overlaps = _compute_overlap(arr=arr, patch_sizes=patch_size)
|
|
282
|
-
|
|
283
|
-
# Create view window and overlaps
|
|
284
|
-
window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
|
|
285
|
-
|
|
286
|
-
# Correct for first dimension for computing windowed views
|
|
287
|
-
window_shape = (1, *patch_size)
|
|
288
|
-
window_steps = (1, *window_steps)
|
|
289
|
-
|
|
290
|
-
if is_3d_patch and patch_size[0] == 1:
|
|
291
|
-
output_shape = (-1,) + window_shape[1:]
|
|
292
|
-
else:
|
|
293
|
-
output_shape = (-1, *window_shape)
|
|
294
|
-
|
|
295
|
-
# Generate a view of the input array containing pre-calculated number of patches
|
|
296
|
-
# in each dimension with overlap.
|
|
297
|
-
# Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches,C, Y, X)
|
|
298
|
-
patches = _compute_reshaped_view(
|
|
299
|
-
arr, window_shape=window_shape, step=window_steps, output_shape=output_shape
|
|
300
|
-
)
|
|
301
|
-
logger.info(f"Extracted {patches.shape[0]} patches from input array.")
|
|
302
|
-
|
|
303
|
-
# return a generator of patches
|
|
304
|
-
return (patches[i, ...] for i in range(patches.shape[0]))
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
def _extract_patches_random(
|
|
308
|
-
arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
|
|
309
|
-
) -> Generator[np.ndarray, None, None]:
|
|
310
|
-
"""
|
|
311
|
-
Generate patches from an array in a random manner.
|
|
312
|
-
|
|
313
|
-
The method calculates how many patches the image can be divided into and then
|
|
314
|
-
extracts an equal number of random patches.
|
|
315
|
-
|
|
316
|
-
Parameters
|
|
317
|
-
----------
|
|
318
|
-
arr : np.ndarray
|
|
319
|
-
Input image array.
|
|
320
|
-
patch_size : Tuple[int]
|
|
321
|
-
Patch sizes in each dimension.
|
|
322
|
-
|
|
323
|
-
Yields
|
|
324
|
-
------
|
|
325
|
-
Generator[np.ndarray, None, None]
|
|
326
|
-
Generator of patches.
|
|
327
|
-
"""
|
|
328
|
-
is_3d_patch = len(patch_size) == 3
|
|
329
|
-
|
|
330
|
-
# Patches sanity check
|
|
331
|
-
_patches_sanity_check(arr, patch_size, is_3d_patch)
|
|
332
|
-
|
|
333
|
-
rng = np.random.default_rng()
|
|
334
|
-
# shuffle the array along the first axis TODO do we need shuffling?
|
|
335
|
-
rng.shuffle(arr, axis=0)
|
|
336
|
-
|
|
337
|
-
for sample_idx in range(arr.shape[0]):
|
|
338
|
-
sample = arr[sample_idx]
|
|
339
|
-
# calculate how many number of patches can image area be divided into
|
|
340
|
-
n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
|
|
341
|
-
for _ in range(n_patches):
|
|
342
|
-
crop_coords = [
|
|
343
|
-
rng.integers(0, arr.shape[i + 1] - patch_size[i])
|
|
344
|
-
for i in range(len(patch_size))
|
|
345
|
-
]
|
|
346
|
-
patch = (
|
|
347
|
-
sample[
|
|
348
|
-
(
|
|
349
|
-
...,
|
|
350
|
-
*[
|
|
351
|
-
slice(c, c + patch_size[i])
|
|
352
|
-
for i, c in enumerate(crop_coords)
|
|
353
|
-
],
|
|
354
|
-
)
|
|
355
|
-
]
|
|
356
|
-
.copy()
|
|
357
|
-
.astype(np.float32)
|
|
358
|
-
)
|
|
359
|
-
yield patch
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
def _extract_tiles(
|
|
363
|
-
arr: np.ndarray,
|
|
364
|
-
tile_size: Union[List[int], Tuple[int]],
|
|
365
|
-
overlaps: Union[List[int], Tuple[int]],
|
|
366
|
-
) -> Generator:
|
|
367
|
-
"""
|
|
368
|
-
Generate tiles from the input array with specified overlap.
|
|
369
|
-
|
|
370
|
-
The tiles cover the whole array.
|
|
371
|
-
|
|
372
|
-
Parameters
|
|
373
|
-
----------
|
|
374
|
-
arr : np.ndarray
|
|
375
|
-
Array of shape (S, (Z), Y, X).
|
|
376
|
-
tile_size : Union[List[int], Tuple[int]]
|
|
377
|
-
Tile sizes in each dimension, of length 2 or 3.
|
|
378
|
-
overlaps : Union[List[int], Tuple[int]]
|
|
379
|
-
Overlap values in each dimension, of length 2 or 3.
|
|
380
|
-
|
|
381
|
-
Yields
|
|
382
|
-
------
|
|
383
|
-
Generator
|
|
384
|
-
Tile generator that yields the tile with corresponding coordinates to stitch
|
|
385
|
-
back the tiles together.
|
|
386
|
-
"""
|
|
387
|
-
# Iterate over num samples (S)
|
|
388
|
-
for sample_idx in range(arr.shape[0]):
|
|
389
|
-
sample = arr[sample_idx]
|
|
390
|
-
|
|
391
|
-
# Create an array of coordinates for cropping and stitching all axes.
|
|
392
|
-
# Shape: (axes, type_of_coord, tile_num, start/end coord)
|
|
393
|
-
crop_and_stitch_coords_list = [
|
|
394
|
-
_compute_crop_and_stitch_coords_1d(
|
|
395
|
-
sample.shape[i], tile_size[i], overlaps[i]
|
|
396
|
-
)
|
|
397
|
-
for i in range(len(tile_size))
|
|
398
|
-
]
|
|
399
|
-
|
|
400
|
-
# Rearrange crop coordinates from a list of coordinate pairs per axis to a list
|
|
401
|
-
# grouped by type.
|
|
402
|
-
# For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
|
|
403
|
-
# will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)]),
|
|
404
|
-
# where the first list is crop coordinates for 1st axis.
|
|
405
|
-
all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
|
|
406
|
-
*crop_and_stitch_coords_list
|
|
407
|
-
)
|
|
408
|
-
|
|
409
|
-
# Iterate over generated coordinate pairs:
|
|
410
|
-
for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
|
|
411
|
-
zip(
|
|
412
|
-
itertools.product(*all_crop_coords),
|
|
413
|
-
itertools.product(*all_stitch_coords),
|
|
414
|
-
itertools.product(*all_overlap_crop_coords),
|
|
415
|
-
)
|
|
416
|
-
):
|
|
417
|
-
tile = sample[(..., *[slice(c[0], c[1]) for c in list(crop_coords)])]
|
|
418
|
-
|
|
419
|
-
# Check if we are at the end of the sample.
|
|
420
|
-
# To check that we compute the length of the array that contains all the
|
|
421
|
-
# tiles
|
|
422
|
-
if tile_idx == np.prod([len(axis) for axis in all_crop_coords]) - 1:
|
|
423
|
-
last_tile = True
|
|
424
|
-
else:
|
|
425
|
-
last_tile = False
|
|
426
|
-
yield (
|
|
427
|
-
np.expand_dims(tile.astype(np.float32), 0),
|
|
428
|
-
last_tile,
|
|
429
|
-
arr.shape[1:],
|
|
430
|
-
overlap_crop_coords,
|
|
431
|
-
stitch_coords,
|
|
432
|
-
)
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
def generate_patches(
|
|
436
|
-
sample: np.ndarray,
|
|
437
|
-
patch_extraction_method: ExtractionStrategy,
|
|
438
|
-
patch_size: Optional[Union[List[int], Tuple[int]]] = None,
|
|
439
|
-
patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
|
|
440
|
-
) -> Generator[np.ndarray, None, None]:
|
|
441
|
-
"""
|
|
442
|
-
Generate patches from a sample.
|
|
443
|
-
|
|
444
|
-
Parameters
|
|
445
|
-
----------
|
|
446
|
-
sample : np.ndarray
|
|
447
|
-
Input array.
|
|
448
|
-
patch_extraction_method : ExtractionStrategies
|
|
449
|
-
Patch extraction method, as defined in extraction_strategy.ExtractionStrategy.
|
|
450
|
-
patch_size : Optional[Union[List[int], Tuple[int]]]
|
|
451
|
-
Size of the patches along each dimension of the array, except the first.
|
|
452
|
-
patch_overlap : Optional[Union[List[int], Tuple[int]]]
|
|
453
|
-
Overlap between patches.
|
|
454
|
-
|
|
455
|
-
Returns
|
|
456
|
-
-------
|
|
457
|
-
Generator[np.ndarray, None, None]
|
|
458
|
-
Generator yielding patches/tiles.
|
|
459
|
-
|
|
460
|
-
Raises
|
|
461
|
-
------
|
|
462
|
-
ValueError
|
|
463
|
-
If overlap is not specified when using tiling.
|
|
464
|
-
ValueError
|
|
465
|
-
If patches is None.
|
|
466
|
-
"""
|
|
467
|
-
patches = None
|
|
468
|
-
|
|
469
|
-
if patch_size is not None:
|
|
470
|
-
patches = None
|
|
471
|
-
|
|
472
|
-
if patch_extraction_method == ExtractionStrategy.TILED:
|
|
473
|
-
if patch_overlap is None:
|
|
474
|
-
raise ValueError(
|
|
475
|
-
"Overlaps must be specified when using tiling (got None)."
|
|
476
|
-
)
|
|
477
|
-
patches = _extract_tiles(
|
|
478
|
-
arr=sample, tile_size=patch_size, overlaps=patch_overlap
|
|
479
|
-
)
|
|
480
|
-
|
|
481
|
-
elif patch_extraction_method == ExtractionStrategy.SEQUENTIAL:
|
|
482
|
-
patches = _extract_patches_sequential(sample, patch_size=patch_size)
|
|
483
|
-
|
|
484
|
-
elif patch_extraction_method == ExtractionStrategy.RANDOM:
|
|
485
|
-
patches = _extract_patches_random(sample, patch_size=patch_size)
|
|
486
|
-
|
|
487
|
-
if patches is None:
|
|
488
|
-
raise ValueError("No patch generated")
|
|
489
|
-
|
|
490
|
-
return patches
|
|
491
|
-
else:
|
|
492
|
-
# no patching
|
|
493
|
-
return (sample for _ in range(1))
|
|
@@ -1,174 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Dataset preparation module.
|
|
3
|
-
|
|
4
|
-
Methods to set up the datasets for training, validation and prediction.
|
|
5
|
-
"""
|
|
6
|
-
from pathlib import Path
|
|
7
|
-
from typing import List, Optional, Union
|
|
8
|
-
|
|
9
|
-
from ..config import Configuration
|
|
10
|
-
from ..manipulation import default_manipulate
|
|
11
|
-
from ..utils import check_tiling_validity
|
|
12
|
-
from .extraction_strategy import ExtractionStrategy
|
|
13
|
-
from .in_memory_dataset import InMemoryDataset
|
|
14
|
-
from .tiff_dataset import TiffDataset
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def get_train_dataset(
|
|
18
|
-
config: Configuration, train_path: str
|
|
19
|
-
) -> Union[TiffDataset, InMemoryDataset]:
|
|
20
|
-
"""
|
|
21
|
-
Create training dataset.
|
|
22
|
-
|
|
23
|
-
Depending on the configuration, this methods return either a TiffDataset or an
|
|
24
|
-
InMemoryDataset.
|
|
25
|
-
|
|
26
|
-
Parameters
|
|
27
|
-
----------
|
|
28
|
-
config : Configuration
|
|
29
|
-
Configuration.
|
|
30
|
-
train_path : Union[str, Path]
|
|
31
|
-
Path to training data.
|
|
32
|
-
|
|
33
|
-
Returns
|
|
34
|
-
-------
|
|
35
|
-
Union[TiffDataset, InMemoryDataset]
|
|
36
|
-
Dataset.
|
|
37
|
-
"""
|
|
38
|
-
if config.data.in_memory:
|
|
39
|
-
dataset = InMemoryDataset(
|
|
40
|
-
data_path=train_path,
|
|
41
|
-
data_format=config.data.data_format,
|
|
42
|
-
axes=config.data.axes,
|
|
43
|
-
mean=config.data.mean,
|
|
44
|
-
std=config.data.std,
|
|
45
|
-
patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
|
|
46
|
-
patch_size=config.training.patch_size,
|
|
47
|
-
patch_transform=default_manipulate,
|
|
48
|
-
patch_transform_params={
|
|
49
|
-
"mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
|
|
50
|
-
"roi_size": config.algorithm.roi_size,
|
|
51
|
-
},
|
|
52
|
-
)
|
|
53
|
-
else:
|
|
54
|
-
dataset = TiffDataset(
|
|
55
|
-
data_path=train_path,
|
|
56
|
-
data_format=config.data.data_format,
|
|
57
|
-
axes=config.data.axes,
|
|
58
|
-
mean=config.data.mean,
|
|
59
|
-
std=config.data.std,
|
|
60
|
-
patch_extraction_method=ExtractionStrategy.RANDOM,
|
|
61
|
-
patch_size=config.training.patch_size,
|
|
62
|
-
patch_transform=default_manipulate,
|
|
63
|
-
patch_transform_params={
|
|
64
|
-
"mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
|
|
65
|
-
"roi_size": config.algorithm.roi_size,
|
|
66
|
-
},
|
|
67
|
-
)
|
|
68
|
-
return dataset
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def get_validation_dataset(config: Configuration, val_path: str) -> InMemoryDataset:
|
|
72
|
-
"""
|
|
73
|
-
Create validation dataset.
|
|
74
|
-
|
|
75
|
-
Validation dataset is kept in memory.
|
|
76
|
-
|
|
77
|
-
Parameters
|
|
78
|
-
----------
|
|
79
|
-
config : Configuration
|
|
80
|
-
Configuration.
|
|
81
|
-
val_path : Union[str, Path]
|
|
82
|
-
Path to validation data.
|
|
83
|
-
|
|
84
|
-
Returns
|
|
85
|
-
-------
|
|
86
|
-
TiffDataset
|
|
87
|
-
In memory dataset.
|
|
88
|
-
"""
|
|
89
|
-
data_path = val_path
|
|
90
|
-
|
|
91
|
-
dataset = InMemoryDataset(
|
|
92
|
-
data_path=data_path,
|
|
93
|
-
data_format=config.data.data_format,
|
|
94
|
-
axes=config.data.axes,
|
|
95
|
-
mean=config.data.mean,
|
|
96
|
-
std=config.data.std,
|
|
97
|
-
patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
|
|
98
|
-
patch_size=config.training.patch_size,
|
|
99
|
-
patch_transform=default_manipulate,
|
|
100
|
-
patch_transform_params={
|
|
101
|
-
"mask_pixel_percentage": config.algorithm.masked_pixel_percentage
|
|
102
|
-
},
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
return dataset
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def get_prediction_dataset(
|
|
109
|
-
config: Configuration,
|
|
110
|
-
pred_path: Union[str, Path],
|
|
111
|
-
*,
|
|
112
|
-
tile_shape: Optional[List[int]] = None,
|
|
113
|
-
overlaps: Optional[List[int]] = None,
|
|
114
|
-
axes: Optional[str] = None,
|
|
115
|
-
) -> TiffDataset:
|
|
116
|
-
"""
|
|
117
|
-
Create prediction dataset.
|
|
118
|
-
|
|
119
|
-
To use tiling, both `tile_shape` and `overlaps` must be specified, have same
|
|
120
|
-
length, be divisible by 2 and greater than 0. Finally, the overlaps must be
|
|
121
|
-
smaller than the tiles.
|
|
122
|
-
|
|
123
|
-
By default, axes are extracted from the configuration. To use images with
|
|
124
|
-
different axes, set the `axes` parameter. Note that the difference between
|
|
125
|
-
configuration and parameter axes must be S or T, but not any of the spatial
|
|
126
|
-
dimensions (e.g. 2D vs 3D).
|
|
127
|
-
|
|
128
|
-
Parameters
|
|
129
|
-
----------
|
|
130
|
-
config : Configuration
|
|
131
|
-
Configuration.
|
|
132
|
-
pred_path : Union[str, Path]
|
|
133
|
-
Path to prediction data.
|
|
134
|
-
tile_shape : Optional[List[int]], optional
|
|
135
|
-
2D or 3D shape of the tiles, by default None.
|
|
136
|
-
overlaps : Optional[List[int]], optional
|
|
137
|
-
2D or 3D overlaps between tiles, by default None.
|
|
138
|
-
axes : Optional[str], optional
|
|
139
|
-
Axes of the data, by default None.
|
|
140
|
-
|
|
141
|
-
Returns
|
|
142
|
-
-------
|
|
143
|
-
TiffDataset
|
|
144
|
-
Dataset.
|
|
145
|
-
"""
|
|
146
|
-
use_tiling = False # default value
|
|
147
|
-
|
|
148
|
-
# Validate tiles and overlaps
|
|
149
|
-
if tile_shape is not None and overlaps is not None:
|
|
150
|
-
check_tiling_validity(tile_shape, overlaps)
|
|
151
|
-
|
|
152
|
-
# Use tiling
|
|
153
|
-
use_tiling = True
|
|
154
|
-
|
|
155
|
-
# Extraction strategy
|
|
156
|
-
if use_tiling:
|
|
157
|
-
patch_extraction_method = ExtractionStrategy.TILED
|
|
158
|
-
else:
|
|
159
|
-
patch_extraction_method = None
|
|
160
|
-
|
|
161
|
-
# Create dataset
|
|
162
|
-
dataset = TiffDataset(
|
|
163
|
-
data_path=pred_path,
|
|
164
|
-
data_format=config.data.data_format,
|
|
165
|
-
axes=config.data.axes if axes is None else axes, # supersede axes
|
|
166
|
-
mean=config.data.mean,
|
|
167
|
-
std=config.data.std,
|
|
168
|
-
patch_size=tile_shape,
|
|
169
|
-
patch_overlap=overlaps,
|
|
170
|
-
patch_extraction_method=patch_extraction_method,
|
|
171
|
-
patch_transform=None,
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
return dataset
|