careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +7 -1
- careamics/bioimage/__init__.py +15 -0
- careamics/bioimage/docs/Noise2Void.md +5 -0
- careamics/bioimage/docs/__init__.py +1 -0
- careamics/bioimage/io.py +182 -0
- careamics/bioimage/rdf.py +105 -0
- careamics/config/__init__.py +11 -0
- careamics/config/algorithm.py +231 -0
- careamics/config/config.py +297 -0
- careamics/config/config_filter.py +44 -0
- careamics/config/data.py +194 -0
- careamics/config/torch_optim.py +118 -0
- careamics/config/training.py +534 -0
- careamics/dataset/__init__.py +1 -0
- careamics/dataset/dataset_utils.py +111 -0
- careamics/dataset/extraction_strategy.py +21 -0
- careamics/dataset/in_memory_dataset.py +202 -0
- careamics/dataset/patching.py +492 -0
- careamics/dataset/prepare_dataset.py +175 -0
- careamics/dataset/tiff_dataset.py +212 -0
- careamics/engine.py +1014 -0
- careamics/losses/__init__.py +4 -0
- careamics/losses/loss_factory.py +38 -0
- careamics/losses/losses.py +34 -0
- careamics/manipulation/__init__.py +4 -0
- careamics/manipulation/pixel_manipulation.py +158 -0
- careamics/models/__init__.py +4 -0
- careamics/models/layers.py +152 -0
- careamics/models/model_factory.py +251 -0
- careamics/models/unet.py +322 -0
- careamics/prediction/__init__.py +9 -0
- careamics/prediction/prediction_utils.py +106 -0
- careamics/utils/__init__.py +20 -0
- careamics/utils/ascii_logo.txt +9 -0
- careamics/utils/augment.py +65 -0
- careamics/utils/context.py +45 -0
- careamics/utils/logging.py +321 -0
- careamics/utils/metrics.py +160 -0
- careamics/utils/normalization.py +55 -0
- careamics/utils/torch_utils.py +89 -0
- careamics/utils/validators.py +170 -0
- careamics/utils/wandb.py +121 -0
- careamics-0.1.0rc2.dist-info/METADATA +81 -0
- careamics-0.1.0rc2.dist-info/RECORD +47 -0
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""In-memory dataset module."""
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from careamics.utils import normalize
|
|
9
|
+
from careamics.utils.logging import get_logger
|
|
10
|
+
|
|
11
|
+
from .dataset_utils import (
|
|
12
|
+
list_files,
|
|
13
|
+
read_tiff,
|
|
14
|
+
)
|
|
15
|
+
from .extraction_strategy import ExtractionStrategy
|
|
16
|
+
from .patching import generate_patches
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class InMemoryDataset(torch.utils.data.Dataset):
|
|
22
|
+
"""
|
|
23
|
+
Dataset storing data in memory and allowing generating patches from it.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
data_path : Union[str, Path]
|
|
28
|
+
Path to the data, must be a directory.
|
|
29
|
+
data_format : str
|
|
30
|
+
Extension of the data files, without period.
|
|
31
|
+
axes : str
|
|
32
|
+
Description of axes in format STCZYX.
|
|
33
|
+
patch_extraction_method : ExtractionStrategies
|
|
34
|
+
Patch extraction strategy, as defined in extraction_strategy.
|
|
35
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
36
|
+
Size of the patches along each axis, must be of dimension 2 or 3.
|
|
37
|
+
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
38
|
+
Overlap of the patches, must be of dimension 2 or 3, by default None.
|
|
39
|
+
mean : Optional[float], optional
|
|
40
|
+
Expected mean of the dataset, by default None.
|
|
41
|
+
std : Optional[float], optional
|
|
42
|
+
Expected standard deviation of the dataset, by default None.
|
|
43
|
+
patch_transform : Optional[Callable], optional
|
|
44
|
+
Patch transform to apply, by default None.
|
|
45
|
+
patch_transform_params : Optional[Dict], optional
|
|
46
|
+
Patch transform parameters, by default None.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
data_path: Union[str, Path],
|
|
52
|
+
data_format: str,
|
|
53
|
+
axes: str,
|
|
54
|
+
patch_extraction_method: ExtractionStrategy,
|
|
55
|
+
patch_size: Union[List[int], Tuple[int]],
|
|
56
|
+
patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
|
|
57
|
+
mean: Optional[float] = None,
|
|
58
|
+
std: Optional[float] = None,
|
|
59
|
+
patch_transform: Optional[Callable] = None,
|
|
60
|
+
patch_transform_params: Optional[Dict] = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Constructor.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
data_path : Union[str, Path]
|
|
68
|
+
Path to the data, must be a directory.
|
|
69
|
+
data_format : str
|
|
70
|
+
Extension of the data files, without period.
|
|
71
|
+
axes : str
|
|
72
|
+
Description of axes in format STCZYX.
|
|
73
|
+
patch_extraction_method : ExtractionStrategies
|
|
74
|
+
Patch extraction strategy, as defined in extraction_strategy.
|
|
75
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
76
|
+
Size of the patches along each axis, must be of dimension 2 or 3.
|
|
77
|
+
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
78
|
+
Overlap of the patches, must be of dimension 2 or 3, by default None.
|
|
79
|
+
mean : Optional[float], optional
|
|
80
|
+
Expected mean of the dataset, by default None.
|
|
81
|
+
std : Optional[float], optional
|
|
82
|
+
Expected standard deviation of the dataset, by default None.
|
|
83
|
+
patch_transform : Optional[Callable], optional
|
|
84
|
+
Patch transform to apply, by default None.
|
|
85
|
+
patch_transform_params : Optional[Dict], optional
|
|
86
|
+
Patch transform parameters, by default None.
|
|
87
|
+
|
|
88
|
+
Raises
|
|
89
|
+
------
|
|
90
|
+
ValueError
|
|
91
|
+
If data_path is not a directory.
|
|
92
|
+
"""
|
|
93
|
+
self.data_path = Path(data_path)
|
|
94
|
+
if not self.data_path.is_dir():
|
|
95
|
+
raise ValueError("Path to data should be an existing folder.")
|
|
96
|
+
|
|
97
|
+
self.data_format = data_format
|
|
98
|
+
self.axes = axes
|
|
99
|
+
|
|
100
|
+
self.patch_transform = patch_transform
|
|
101
|
+
|
|
102
|
+
self.files = list_files(self.data_path, self.data_format)
|
|
103
|
+
|
|
104
|
+
self.patch_size = patch_size
|
|
105
|
+
self.patch_overlap = patch_overlap
|
|
106
|
+
self.patch_extraction_method = patch_extraction_method
|
|
107
|
+
self.patch_transform = patch_transform
|
|
108
|
+
self.patch_transform_params = patch_transform_params
|
|
109
|
+
|
|
110
|
+
self.mean = mean
|
|
111
|
+
self.std = std
|
|
112
|
+
|
|
113
|
+
# Generate patches
|
|
114
|
+
self.data, computed_mean, computed_std = self._prepare_patches()
|
|
115
|
+
|
|
116
|
+
if not mean or not std:
|
|
117
|
+
self.mean, self.std = computed_mean, computed_std
|
|
118
|
+
logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
|
|
119
|
+
|
|
120
|
+
assert self.mean is not None
|
|
121
|
+
assert self.std is not None
|
|
122
|
+
|
|
123
|
+
def _prepare_patches(self) -> Tuple[np.ndarray, float, float]:
|
|
124
|
+
"""
|
|
125
|
+
Iterate over data source and create an array of patches.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
np.ndarray
|
|
130
|
+
Array of patches.
|
|
131
|
+
"""
|
|
132
|
+
means, stds, num_samples = 0, 0, 0
|
|
133
|
+
self.all_patches = []
|
|
134
|
+
for filename in self.files:
|
|
135
|
+
sample = read_tiff(filename, self.axes)
|
|
136
|
+
means += sample.mean()
|
|
137
|
+
stds += np.std(sample)
|
|
138
|
+
num_samples += 1
|
|
139
|
+
|
|
140
|
+
# generate patches, return a generator
|
|
141
|
+
patches = generate_patches(
|
|
142
|
+
sample,
|
|
143
|
+
self.patch_extraction_method,
|
|
144
|
+
self.patch_size,
|
|
145
|
+
self.patch_overlap,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# convert generator to list and add to all_patches
|
|
149
|
+
self.all_patches.extend(list(patches))
|
|
150
|
+
|
|
151
|
+
result_mean, result_std = means / num_samples, stds / num_samples
|
|
152
|
+
return np.concatenate(self.all_patches), result_mean, result_std
|
|
153
|
+
|
|
154
|
+
def __len__(self) -> int:
|
|
155
|
+
"""
|
|
156
|
+
Return the length of the dataset.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
int
|
|
161
|
+
Length of the dataset.
|
|
162
|
+
"""
|
|
163
|
+
# convert to numpy array to convince mypy that it is not a generator
|
|
164
|
+
return sum(np.array(s).shape[0] for s in self.all_patches)
|
|
165
|
+
|
|
166
|
+
def __getitem__(self, index: int) -> Tuple[np.ndarray]:
|
|
167
|
+
"""
|
|
168
|
+
Return the patch corresponding to the provided index.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
index : int
|
|
173
|
+
Index of the patch to return.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
Tuple[np.ndarray]
|
|
178
|
+
Patch.
|
|
179
|
+
|
|
180
|
+
Raises
|
|
181
|
+
------
|
|
182
|
+
ValueError
|
|
183
|
+
If dataset mean and std are not set.
|
|
184
|
+
"""
|
|
185
|
+
patch = self.data[index].squeeze()
|
|
186
|
+
|
|
187
|
+
if self.mean is not None and self.std is not None:
|
|
188
|
+
if isinstance(patch, tuple):
|
|
189
|
+
patch = normalize(img=patch[0], mean=self.mean, std=self.std)
|
|
190
|
+
patch = (patch, *patch[1:])
|
|
191
|
+
else:
|
|
192
|
+
patch = normalize(img=patch, mean=self.mean, std=self.std)
|
|
193
|
+
|
|
194
|
+
if self.patch_transform is not None:
|
|
195
|
+
# replace None self.patch_transform_params with empty dict
|
|
196
|
+
if self.patch_transform_params is None:
|
|
197
|
+
self.patch_transform_params = {}
|
|
198
|
+
|
|
199
|
+
patch = self.patch_transform(patch, **self.patch_transform_params)
|
|
200
|
+
return patch
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError("Dataset mean and std must be set before using it.")
|
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tiling submodule.
|
|
3
|
+
|
|
4
|
+
These functions are used to tile images into patches or tiles.
|
|
5
|
+
"""
|
|
6
|
+
import itertools
|
|
7
|
+
from typing import Generator, List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from skimage.util import view_as_windows
|
|
11
|
+
|
|
12
|
+
from careamics.utils.logging import get_logger
|
|
13
|
+
|
|
14
|
+
from .extraction_strategy import ExtractionStrategy
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _compute_number_of_patches(
|
|
20
|
+
arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
21
|
+
) -> Tuple[int, ...]:
|
|
22
|
+
"""
|
|
23
|
+
Compute the number of patches that fit in each dimension.
|
|
24
|
+
|
|
25
|
+
Array must have one dimension more than the patches (C dimension).
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
arr : np.ndarray
|
|
30
|
+
Input array.
|
|
31
|
+
patch_sizes : Tuple[int]
|
|
32
|
+
Size of the patches.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
Tuple[int]
|
|
37
|
+
Number of patches in each dimension.
|
|
38
|
+
"""
|
|
39
|
+
n_patches = [
|
|
40
|
+
np.ceil(arr.shape[i + 1] / patch_sizes[i]).astype(int)
|
|
41
|
+
for i in range(len(patch_sizes))
|
|
42
|
+
]
|
|
43
|
+
return tuple(n_patches)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _compute_overlap(
|
|
47
|
+
arr: np.ndarray, patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
48
|
+
) -> Tuple[int, ...]:
|
|
49
|
+
"""
|
|
50
|
+
Compute the overlap between patches in each dimension.
|
|
51
|
+
|
|
52
|
+
Array must be of dimensions C(Z)YX, and patches must be of dimensions YX or ZYX.
|
|
53
|
+
If the array dimensions are divisible by the patch sizes, then the overlap is 0.
|
|
54
|
+
Otherwise, it is the result of the division rounded to the upper value.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
arr : np.ndarray
|
|
59
|
+
Input array 3 or 4 dimensions.
|
|
60
|
+
patch_sizes : Tuple[int]
|
|
61
|
+
Size of the patches.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Tuple[int]
|
|
66
|
+
Overlap between patches in each dimension.
|
|
67
|
+
"""
|
|
68
|
+
n_patches = _compute_number_of_patches(arr, patch_sizes)
|
|
69
|
+
|
|
70
|
+
overlap = [
|
|
71
|
+
np.ceil(
|
|
72
|
+
np.clip(n_patches[i] * patch_sizes[i] - arr.shape[i + 1], 0, None)
|
|
73
|
+
/ max(1, (n_patches[i] - 1))
|
|
74
|
+
).astype(int)
|
|
75
|
+
for i in range(len(patch_sizes))
|
|
76
|
+
]
|
|
77
|
+
return tuple(overlap)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _compute_crop_and_stitch_coords_1d(
|
|
81
|
+
axis_size: int, tile_size: int, overlap: int
|
|
82
|
+
) -> Tuple[List[Tuple[int, int]], ...]:
|
|
83
|
+
"""
|
|
84
|
+
Compute the coordinates of each tile along an axis, given the overlap.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
axis_size : int
|
|
89
|
+
Length of the axis.
|
|
90
|
+
tile_size : int
|
|
91
|
+
Size of the tile for the given axis.
|
|
92
|
+
overlap : int
|
|
93
|
+
Size of the overlap for the given axis.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
Tuple[Tuple[int]]
|
|
98
|
+
Tuple of all coordinates for given axis.
|
|
99
|
+
"""
|
|
100
|
+
# Compute the step between tiles
|
|
101
|
+
step = tile_size - overlap
|
|
102
|
+
crop_coords = []
|
|
103
|
+
stitch_coords = []
|
|
104
|
+
overlap_crop_coords = []
|
|
105
|
+
# Iterate over the axis with a certain step
|
|
106
|
+
for i in range(0, axis_size - overlap, step):
|
|
107
|
+
# Check if the tile fits within the axis
|
|
108
|
+
if i + tile_size <= axis_size:
|
|
109
|
+
# Add the coordinates to crop one tile
|
|
110
|
+
crop_coords.append((i, i + tile_size))
|
|
111
|
+
# Add the pixel coordinates of the cropped tile in the original image space
|
|
112
|
+
stitch_coords.append(
|
|
113
|
+
(
|
|
114
|
+
i + overlap // 2 if i > 0 else 0,
|
|
115
|
+
i + tile_size - overlap // 2
|
|
116
|
+
if crop_coords[-1][1] < axis_size
|
|
117
|
+
else axis_size,
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
# Add the coordinates to crop the overlap from the prediction.
|
|
121
|
+
overlap_crop_coords.append(
|
|
122
|
+
(
|
|
123
|
+
overlap // 2 if i > 0 else 0,
|
|
124
|
+
tile_size - overlap // 2
|
|
125
|
+
if crop_coords[-1][1] < axis_size
|
|
126
|
+
else tile_size,
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
# If the tile does not fit within the axis, perform the abovementioned
|
|
130
|
+
# operations starting from the end of the axis
|
|
131
|
+
else:
|
|
132
|
+
# if (axis_size - tile_size, axis_size) not in crop_coords:
|
|
133
|
+
crop_coords.append((axis_size - tile_size, axis_size))
|
|
134
|
+
last_tile_end_coord = stitch_coords[-1][1]
|
|
135
|
+
stitch_coords.append((last_tile_end_coord, axis_size))
|
|
136
|
+
overlap_crop_coords.append(
|
|
137
|
+
(tile_size - (axis_size - last_tile_end_coord), tile_size)
|
|
138
|
+
)
|
|
139
|
+
break
|
|
140
|
+
return crop_coords, stitch_coords, overlap_crop_coords
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _compute_patch_steps(
|
|
144
|
+
patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
|
|
145
|
+
) -> Tuple[int, ...]:
|
|
146
|
+
"""
|
|
147
|
+
Compute steps between patches.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
patch_sizes : Tuple[int]
|
|
152
|
+
Size of the patches.
|
|
153
|
+
overlaps : Tuple[int]
|
|
154
|
+
Overlap between patches.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
Tuple[int]
|
|
159
|
+
Steps between patches.
|
|
160
|
+
"""
|
|
161
|
+
steps = [
|
|
162
|
+
min(patch_sizes[i] - overlaps[i], patch_sizes[i])
|
|
163
|
+
for i in range(len(patch_sizes))
|
|
164
|
+
]
|
|
165
|
+
return tuple(steps)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _compute_reshaped_view(
|
|
169
|
+
arr: np.ndarray,
|
|
170
|
+
window_shape: Tuple[int, ...],
|
|
171
|
+
step: Tuple[int, ...],
|
|
172
|
+
output_shape: Tuple[int, ...],
|
|
173
|
+
) -> np.ndarray:
|
|
174
|
+
"""
|
|
175
|
+
Compute reshaped views of an array, where views correspond to patches.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
arr : np.ndarray
|
|
180
|
+
Array from which the views are extracted.
|
|
181
|
+
window_shape : Tuple[int]
|
|
182
|
+
Shape of the views.
|
|
183
|
+
step : Tuple[int]
|
|
184
|
+
Steps between views.
|
|
185
|
+
output_shape : Tuple[int]
|
|
186
|
+
Shape of the output array.
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
np.ndarray
|
|
191
|
+
Array with views dimension.
|
|
192
|
+
"""
|
|
193
|
+
rng = np.random.default_rng()
|
|
194
|
+
patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
|
|
195
|
+
*output_shape
|
|
196
|
+
)
|
|
197
|
+
rng.shuffle(patches, axis=0)
|
|
198
|
+
return patches
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _patches_sanity_check(
|
|
202
|
+
arr: np.ndarray,
|
|
203
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
204
|
+
is_3d_patch: bool,
|
|
205
|
+
) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Check patch size and array compatibility.
|
|
208
|
+
|
|
209
|
+
This method validates the patch sizes with respect to the array dimensions:
|
|
210
|
+
- The patch sizes must have one dimension fewer than the array (C dimension).
|
|
211
|
+
- Chack that patch sizes are smaller than array dimensions.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
arr : np.ndarray
|
|
216
|
+
Input array.
|
|
217
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
218
|
+
Size of the patches along each dimension of the array, except the first.
|
|
219
|
+
is_3d_patch : bool
|
|
220
|
+
Whether the patch is 3D or not.
|
|
221
|
+
|
|
222
|
+
Raises
|
|
223
|
+
------
|
|
224
|
+
ValueError
|
|
225
|
+
If the patch size is not consistent with the array shape (one more array
|
|
226
|
+
dimension).
|
|
227
|
+
ValueError
|
|
228
|
+
If the patch size in Z is larger than the array dimension.
|
|
229
|
+
ValueError
|
|
230
|
+
If either of the patch sizes in X or Y is larger than the corresponding array
|
|
231
|
+
dimension.
|
|
232
|
+
"""
|
|
233
|
+
if len(patch_size) != len(arr.shape[1:]):
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"There must be a patch size for each spatial dimensions "
|
|
236
|
+
f"(got {patch_size} patches for dims {arr.shape})."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Sanity checks on patch sizes versus array dimension
|
|
240
|
+
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
f"Z patch size is inconsistent with image shape "
|
|
243
|
+
f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"At least one of YX patch dimensions is inconsistent with image shape "
|
|
249
|
+
f"(got {patch_size} patches for dims {arr.shape[-2:]})."
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# formerly :
|
|
254
|
+
# in dataloader.py#L52, 00d536c
|
|
255
|
+
def _extract_patches_sequential(
|
|
256
|
+
arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
|
|
257
|
+
) -> Generator[np.ndarray, None, None]:
|
|
258
|
+
"""
|
|
259
|
+
Generate patches from an array in a sequential manner.
|
|
260
|
+
|
|
261
|
+
Array dimensions should be C(Z)YX, where C can be a singleton dimension. The patches
|
|
262
|
+
are generated sequentially and cover the whole array.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
arr : np.ndarray
|
|
267
|
+
Input image array.
|
|
268
|
+
patch_size : Tuple[int]
|
|
269
|
+
Patch sizes in each dimension.
|
|
270
|
+
|
|
271
|
+
Returns
|
|
272
|
+
-------
|
|
273
|
+
Generator[np.ndarray, None, None]
|
|
274
|
+
Generator of patches.
|
|
275
|
+
"""
|
|
276
|
+
# Patches sanity check
|
|
277
|
+
is_3d_patch = len(patch_size) == 3
|
|
278
|
+
|
|
279
|
+
_patches_sanity_check(arr, patch_size, is_3d_patch)
|
|
280
|
+
|
|
281
|
+
# Compute overlap
|
|
282
|
+
overlaps = _compute_overlap(arr=arr, patch_sizes=patch_size)
|
|
283
|
+
|
|
284
|
+
# Create view window and overlaps
|
|
285
|
+
window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
|
|
286
|
+
|
|
287
|
+
# Correct for first dimension for computing windowed views
|
|
288
|
+
window_shape = (1, *patch_size)
|
|
289
|
+
window_steps = (1, *window_steps)
|
|
290
|
+
|
|
291
|
+
if is_3d_patch and patch_size[0] == 1:
|
|
292
|
+
output_shape = (-1,) + window_shape[1:]
|
|
293
|
+
else:
|
|
294
|
+
output_shape = (-1, *window_shape)
|
|
295
|
+
|
|
296
|
+
# Generate a view of the input array containing pre-calculated number of patches
|
|
297
|
+
# in each dimension with overlap.
|
|
298
|
+
# Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches,C, Y, X)
|
|
299
|
+
patches = _compute_reshaped_view(
|
|
300
|
+
arr, window_shape=window_shape, step=window_steps, output_shape=output_shape
|
|
301
|
+
)
|
|
302
|
+
logger.info(f"Extracted {patches.shape[0]} patches from input array.")
|
|
303
|
+
|
|
304
|
+
# return a generator of patches
|
|
305
|
+
return (patches[i, ...] for i in range(patches.shape[0]))
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _extract_patches_random(
|
|
309
|
+
arr: np.ndarray, patch_size: Union[List[int], Tuple[int]]
|
|
310
|
+
) -> Generator[np.ndarray, None, None]:
|
|
311
|
+
"""
|
|
312
|
+
Generate patches from an array in a random manner.
|
|
313
|
+
|
|
314
|
+
The method calculates how many patches the image can be divided into and then
|
|
315
|
+
extracts an equal number of random patches.
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
arr : np.ndarray
|
|
320
|
+
Input image array.
|
|
321
|
+
patch_size : Tuple[int]
|
|
322
|
+
Patch sizes in each dimension.
|
|
323
|
+
|
|
324
|
+
Yields
|
|
325
|
+
------
|
|
326
|
+
Generator[np.ndarray, None, None]
|
|
327
|
+
Generator of patches.
|
|
328
|
+
"""
|
|
329
|
+
is_3d_patch = len(patch_size) == 3
|
|
330
|
+
|
|
331
|
+
# Patches sanity check
|
|
332
|
+
_patches_sanity_check(arr, patch_size, is_3d_patch)
|
|
333
|
+
|
|
334
|
+
rng = np.random.default_rng()
|
|
335
|
+
# shuffle the array along the first axis TODO do we need shuffling?
|
|
336
|
+
rng.shuffle(arr, axis=0)
|
|
337
|
+
|
|
338
|
+
for sample_idx in range(arr.shape[0]):
|
|
339
|
+
sample = arr[sample_idx]
|
|
340
|
+
# calculate how many number of patches can image area be divided into
|
|
341
|
+
n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
|
|
342
|
+
for _ in range(n_patches):
|
|
343
|
+
crop_coords = [
|
|
344
|
+
rng.integers(0, arr.shape[i + 1] - patch_size[i])
|
|
345
|
+
for i in range(len(patch_size))
|
|
346
|
+
]
|
|
347
|
+
patch = (
|
|
348
|
+
sample[
|
|
349
|
+
(
|
|
350
|
+
...,
|
|
351
|
+
*[
|
|
352
|
+
slice(c, c + patch_size[i])
|
|
353
|
+
for i, c in enumerate(crop_coords)
|
|
354
|
+
],
|
|
355
|
+
)
|
|
356
|
+
]
|
|
357
|
+
.copy()
|
|
358
|
+
.astype(np.float32)
|
|
359
|
+
)
|
|
360
|
+
yield patch
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _extract_tiles(
|
|
364
|
+
arr: np.ndarray,
|
|
365
|
+
tile_size: Union[List[int], Tuple[int]],
|
|
366
|
+
overlaps: Union[List[int], Tuple[int]],
|
|
367
|
+
) -> Generator:
|
|
368
|
+
"""
|
|
369
|
+
Generate tiles from the input array with specified overlap.
|
|
370
|
+
|
|
371
|
+
The tiles cover the whole array.
|
|
372
|
+
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
arr : np.ndarray
|
|
376
|
+
Array of shape (S, (Z), Y, X).
|
|
377
|
+
tile_size : Union[List[int], Tuple[int]]
|
|
378
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
379
|
+
overlaps : Union[List[int], Tuple[int]]
|
|
380
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
381
|
+
|
|
382
|
+
Yields
|
|
383
|
+
------
|
|
384
|
+
Generator
|
|
385
|
+
Tile generator that yields the tile with corresponding coordinates to stitch
|
|
386
|
+
back the tiles together.
|
|
387
|
+
"""
|
|
388
|
+
# Iterate over num samples (S)
|
|
389
|
+
for sample_idx in range(arr.shape[0]):
|
|
390
|
+
sample = arr[sample_idx]
|
|
391
|
+
|
|
392
|
+
# Create an array of coordinates for cropping and stitching all axes.
|
|
393
|
+
# Shape: (axes, type_of_coord, tile_num, start/end coord)
|
|
394
|
+
crop_and_stitch_coords_list = [
|
|
395
|
+
_compute_crop_and_stitch_coords_1d(
|
|
396
|
+
sample.shape[i], tile_size[i], overlaps[i]
|
|
397
|
+
)
|
|
398
|
+
for i in range(len(tile_size))
|
|
399
|
+
]
|
|
400
|
+
|
|
401
|
+
# Rearrange crop coordinates from a list of coordinate pairs per axis to a list
|
|
402
|
+
# grouped by type.
|
|
403
|
+
# For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
|
|
404
|
+
# will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)]),
|
|
405
|
+
# where the first list is crop coordinates for 1st axis.
|
|
406
|
+
all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
|
|
407
|
+
*crop_and_stitch_coords_list
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Iterate over generated coordinate pairs:
|
|
411
|
+
for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
|
|
412
|
+
zip(
|
|
413
|
+
itertools.product(*all_crop_coords),
|
|
414
|
+
itertools.product(*all_stitch_coords),
|
|
415
|
+
itertools.product(*all_overlap_crop_coords),
|
|
416
|
+
)
|
|
417
|
+
):
|
|
418
|
+
tile = sample[(..., *[slice(c[0], c[1]) for c in list(crop_coords)])]
|
|
419
|
+
|
|
420
|
+
# Check if we are at the end of the sample.
|
|
421
|
+
# To check that we compute the length of the array that contains all the
|
|
422
|
+
# tiles
|
|
423
|
+
if tile_idx == np.prod([len(axis) for axis in all_crop_coords]) - 1:
|
|
424
|
+
last_tile = True
|
|
425
|
+
else:
|
|
426
|
+
last_tile = False
|
|
427
|
+
yield (
|
|
428
|
+
np.expand_dims(tile.astype(np.float32), 0),
|
|
429
|
+
last_tile,
|
|
430
|
+
arr.shape[1:],
|
|
431
|
+
overlap_crop_coords,
|
|
432
|
+
stitch_coords,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def generate_patches(
|
|
437
|
+
sample: np.ndarray,
|
|
438
|
+
patch_extraction_method: ExtractionStrategy,
|
|
439
|
+
patch_size: Optional[Union[List[int], Tuple[int]]] = None,
|
|
440
|
+
patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
|
|
441
|
+
) -> Generator[np.ndarray, None, None]:
|
|
442
|
+
"""
|
|
443
|
+
Generate patches from a sample.
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
sample : np.ndarray
|
|
448
|
+
Input array.
|
|
449
|
+
patch_extraction_method : ExtractionStrategies
|
|
450
|
+
Patch extraction method, as defined in extraction_strategy.ExtractionStrategy.
|
|
451
|
+
patch_size : Optional[Union[List[int], Tuple[int]]]
|
|
452
|
+
Size of the patches along each dimension of the array, except the first.
|
|
453
|
+
patch_overlap : Optional[Union[List[int], Tuple[int]]]
|
|
454
|
+
Overlap between patches.
|
|
455
|
+
|
|
456
|
+
Returns
|
|
457
|
+
-------
|
|
458
|
+
Generator[np.ndarray, None, None]
|
|
459
|
+
Generator yielding patches/tiles.
|
|
460
|
+
|
|
461
|
+
Raises
|
|
462
|
+
------
|
|
463
|
+
ValueError
|
|
464
|
+
If overlap is not specified when using tiling.
|
|
465
|
+
ValueError
|
|
466
|
+
If patches is None.
|
|
467
|
+
"""
|
|
468
|
+
patches = None
|
|
469
|
+
|
|
470
|
+
if patch_size is not None:
|
|
471
|
+
patches = None
|
|
472
|
+
|
|
473
|
+
if patch_extraction_method == ExtractionStrategy.TILED:
|
|
474
|
+
if patch_overlap is None:
|
|
475
|
+
raise ValueError(
|
|
476
|
+
"Overlaps must be specified when using tiling (got None)."
|
|
477
|
+
)
|
|
478
|
+
patches = _extract_tiles(
|
|
479
|
+
arr=sample, tile_size=patch_size, overlaps=patch_overlap
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
elif patch_extraction_method == ExtractionStrategy.SEQUENTIAL:
|
|
483
|
+
patches = _extract_patches_sequential(sample, patch_size=patch_size)
|
|
484
|
+
|
|
485
|
+
else:
|
|
486
|
+
# random patching
|
|
487
|
+
patches = _extract_patches_random(sample, patch_size=patch_size)
|
|
488
|
+
|
|
489
|
+
return patches
|
|
490
|
+
else:
|
|
491
|
+
# no patching, return a generator for the sample
|
|
492
|
+
return (sample for _ in range(1))
|