careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from typing import Generator, List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import zarr
|
|
5
|
+
|
|
6
|
+
from .validate_patch_dimension import validate_patch_dimensions
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# TOOD split in testable functions
|
|
10
|
+
def extract_patches_random(
|
|
11
|
+
arr: np.ndarray,
|
|
12
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
13
|
+
target: Optional[np.ndarray] = None,
|
|
14
|
+
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
15
|
+
"""
|
|
16
|
+
Generate patches from an array in a random manner.
|
|
17
|
+
|
|
18
|
+
The method calculates how many patches the image can be divided into and then
|
|
19
|
+
extracts an equal number of random patches.
|
|
20
|
+
|
|
21
|
+
It returns a generator that yields the following:
|
|
22
|
+
|
|
23
|
+
- patch: np.ndarray, dimension C(Z)YX.
|
|
24
|
+
- target_patch: np.ndarray, dimension C(Z)YX, if the target is present, None
|
|
25
|
+
otherwise.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
arr : np.ndarray
|
|
30
|
+
Input image array.
|
|
31
|
+
patch_size : Tuple[int]
|
|
32
|
+
Patch sizes in each dimension.
|
|
33
|
+
|
|
34
|
+
Yields
|
|
35
|
+
------
|
|
36
|
+
Generator[np.ndarray, None, None]
|
|
37
|
+
Generator of patches.
|
|
38
|
+
"""
|
|
39
|
+
is_3d_patch = len(patch_size) == 3
|
|
40
|
+
|
|
41
|
+
# patches sanity check
|
|
42
|
+
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
43
|
+
|
|
44
|
+
# Update patch size to encompass S and C dimensions
|
|
45
|
+
patch_size = [1, arr.shape[1], *patch_size]
|
|
46
|
+
|
|
47
|
+
# random generator
|
|
48
|
+
rng = np.random.default_rng()
|
|
49
|
+
|
|
50
|
+
# iterate over the number of samples (S or T)
|
|
51
|
+
for sample_idx in range(arr.shape[0]):
|
|
52
|
+
# get sample array
|
|
53
|
+
sample: np.ndarray = arr[sample_idx, ...]
|
|
54
|
+
|
|
55
|
+
# same for target
|
|
56
|
+
if target is not None:
|
|
57
|
+
target_sample: np.ndarray = target[sample_idx, ...]
|
|
58
|
+
|
|
59
|
+
# calculate the number of patches
|
|
60
|
+
n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
|
|
61
|
+
|
|
62
|
+
# iterate over the number of patches
|
|
63
|
+
for _ in range(n_patches):
|
|
64
|
+
# get crop coordinates
|
|
65
|
+
crop_coords = [
|
|
66
|
+
rng.integers(0, sample.shape[i] - patch_size[1:][i], endpoint=True)
|
|
67
|
+
for i in range(len(patch_size[1:]))
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
# extract patch
|
|
71
|
+
patch = (
|
|
72
|
+
sample[
|
|
73
|
+
(
|
|
74
|
+
..., # type: ignore
|
|
75
|
+
*[ # type: ignore
|
|
76
|
+
slice(c, c + patch_size[1:][i])
|
|
77
|
+
for i, c in enumerate(crop_coords)
|
|
78
|
+
],
|
|
79
|
+
)
|
|
80
|
+
]
|
|
81
|
+
.copy()
|
|
82
|
+
.astype(np.float32)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# same for target
|
|
86
|
+
if target is not None:
|
|
87
|
+
target_patch = (
|
|
88
|
+
target_sample[
|
|
89
|
+
(
|
|
90
|
+
..., # type: ignore
|
|
91
|
+
*[ # type: ignore
|
|
92
|
+
slice(c, c + patch_size[1:][i])
|
|
93
|
+
for i, c in enumerate(crop_coords)
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
]
|
|
97
|
+
.copy()
|
|
98
|
+
.astype(np.float32)
|
|
99
|
+
)
|
|
100
|
+
# return patch and target patch
|
|
101
|
+
yield patch, target_patch
|
|
102
|
+
else:
|
|
103
|
+
# return patch
|
|
104
|
+
yield patch, None
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def extract_patches_random_from_chunks(
|
|
108
|
+
arr: zarr.Array,
|
|
109
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
110
|
+
chunk_size: Union[List[int], Tuple[int, ...]],
|
|
111
|
+
chunk_limit: Optional[int] = None,
|
|
112
|
+
) -> Generator[np.ndarray, None, None]:
|
|
113
|
+
"""
|
|
114
|
+
Generate patches from an array in a random manner.
|
|
115
|
+
|
|
116
|
+
The method calculates how many patches the image can be divided into and then
|
|
117
|
+
extracts an equal number of random patches.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
arr : np.ndarray
|
|
122
|
+
Input image array.
|
|
123
|
+
patch_size : Tuple[int]
|
|
124
|
+
Patch sizes in each dimension.
|
|
125
|
+
chunk_size : Tuple[int]
|
|
126
|
+
Chunk sizes to load from the.
|
|
127
|
+
|
|
128
|
+
Yields
|
|
129
|
+
------
|
|
130
|
+
Generator[np.ndarray, None, None]
|
|
131
|
+
Generator of patches.
|
|
132
|
+
"""
|
|
133
|
+
is_3d_patch = len(patch_size) == 3
|
|
134
|
+
|
|
135
|
+
# Patches sanity check
|
|
136
|
+
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
137
|
+
|
|
138
|
+
rng = np.random.default_rng()
|
|
139
|
+
num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
|
|
140
|
+
|
|
141
|
+
# Iterate over num chunks in the array
|
|
142
|
+
for _ in range(num_chunks):
|
|
143
|
+
chunk_crop_coords = [
|
|
144
|
+
rng.integers(0, max(0, arr.shape[i] - chunk_size[i]), endpoint=True)
|
|
145
|
+
for i in range(len(chunk_size))
|
|
146
|
+
]
|
|
147
|
+
chunk = arr[
|
|
148
|
+
(
|
|
149
|
+
...,
|
|
150
|
+
*[slice(c, c + chunk_size[i]) for i, c in enumerate(chunk_crop_coords)],
|
|
151
|
+
)
|
|
152
|
+
].squeeze()
|
|
153
|
+
|
|
154
|
+
# Add a singleton dimension if the chunk does not have a sample dimension
|
|
155
|
+
if len(chunk.shape) == len(patch_size):
|
|
156
|
+
chunk = np.expand_dims(chunk, axis=0)
|
|
157
|
+
|
|
158
|
+
# Iterate over num samples (S)
|
|
159
|
+
for sample_idx in range(chunk.shape[0]):
|
|
160
|
+
spatial_chunk = chunk[sample_idx]
|
|
161
|
+
assert len(spatial_chunk.shape) == len(
|
|
162
|
+
patch_size
|
|
163
|
+
), "Requested chunk shape is not equal to patch size"
|
|
164
|
+
|
|
165
|
+
n_patches = np.ceil(
|
|
166
|
+
np.prod(spatial_chunk.shape) / np.prod(patch_size)
|
|
167
|
+
).astype(int)
|
|
168
|
+
|
|
169
|
+
# Iterate over the number of patches
|
|
170
|
+
for _ in range(n_patches):
|
|
171
|
+
patch_crop_coords = [
|
|
172
|
+
rng.integers(
|
|
173
|
+
0, spatial_chunk.shape[i] - patch_size[i], endpoint=True
|
|
174
|
+
)
|
|
175
|
+
for i in range(len(patch_size))
|
|
176
|
+
]
|
|
177
|
+
patch = (
|
|
178
|
+
spatial_chunk[
|
|
179
|
+
(
|
|
180
|
+
...,
|
|
181
|
+
*[
|
|
182
|
+
slice(c, c + patch_size[i])
|
|
183
|
+
for i, c in enumerate(patch_crop_coords)
|
|
184
|
+
],
|
|
185
|
+
)
|
|
186
|
+
]
|
|
187
|
+
.copy()
|
|
188
|
+
.astype(np.float32)
|
|
189
|
+
)
|
|
190
|
+
yield patch
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from skimage.util import view_as_windows
|
|
5
|
+
|
|
6
|
+
from .validate_patch_dimension import validate_patch_dimensions
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _compute_number_of_patches(
|
|
10
|
+
arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
11
|
+
) -> Tuple[int, ...]:
|
|
12
|
+
"""
|
|
13
|
+
Compute the number of patches that fit in each dimension.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
arr : Tuple[int, ...]
|
|
18
|
+
Shape of the input array.
|
|
19
|
+
patch_sizes : Tuple[int]
|
|
20
|
+
Shape of the patches.
|
|
21
|
+
|
|
22
|
+
Returns
|
|
23
|
+
-------
|
|
24
|
+
Tuple[int]
|
|
25
|
+
Number of patches in each dimension.
|
|
26
|
+
"""
|
|
27
|
+
if len(arr_shape) != len(patch_sizes):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Array shape {arr_shape} and patch size {patch_sizes} should have the "
|
|
30
|
+
f"same dimension, including singleton dimension for S and equal dimension "
|
|
31
|
+
f"for C."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
n_patches = [
|
|
36
|
+
np.ceil(arr_shape[i] / patch_sizes[i]).astype(int)
|
|
37
|
+
for i in range(len(patch_sizes))
|
|
38
|
+
]
|
|
39
|
+
except IndexError as e:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"Patch size {patch_sizes} is not compatible with array shape {arr_shape}"
|
|
42
|
+
) from e
|
|
43
|
+
|
|
44
|
+
return tuple(n_patches)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _compute_overlap(
|
|
48
|
+
arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
49
|
+
) -> Tuple[int, ...]:
|
|
50
|
+
"""
|
|
51
|
+
Compute the overlap between patches in each dimension.
|
|
52
|
+
|
|
53
|
+
If the array dimensions are divisible by the patch sizes, then the overlap is
|
|
54
|
+
0. Otherwise, it is the result of the division rounded to the upper value.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
arr : Tuple[int, ...]
|
|
59
|
+
Input array shape.
|
|
60
|
+
patch_sizes : Tuple[int]
|
|
61
|
+
Size of the patches.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
Tuple[int]
|
|
66
|
+
Overlap between patches in each dimension.
|
|
67
|
+
"""
|
|
68
|
+
n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
|
|
69
|
+
|
|
70
|
+
overlap = [
|
|
71
|
+
np.ceil(
|
|
72
|
+
np.clip(n_patches[i] * patch_sizes[i] - arr_shape[i], 0, None)
|
|
73
|
+
/ max(1, (n_patches[i] - 1))
|
|
74
|
+
).astype(int)
|
|
75
|
+
for i in range(len(patch_sizes))
|
|
76
|
+
]
|
|
77
|
+
return tuple(overlap)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _compute_patch_steps(
|
|
81
|
+
patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
|
|
82
|
+
) -> Tuple[int, ...]:
|
|
83
|
+
"""
|
|
84
|
+
Compute steps between patches.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
patch_sizes : Tuple[int]
|
|
89
|
+
Size of the patches.
|
|
90
|
+
overlaps : Tuple[int]
|
|
91
|
+
Overlap between patches.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
Tuple[int]
|
|
96
|
+
Steps between patches.
|
|
97
|
+
"""
|
|
98
|
+
steps = [
|
|
99
|
+
min(patch_sizes[i] - overlaps[i], patch_sizes[i])
|
|
100
|
+
for i in range(len(patch_sizes))
|
|
101
|
+
]
|
|
102
|
+
return tuple(steps)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# TODO why stack the target here and not on a different dimension before this function?
|
|
106
|
+
def _compute_patch_views(
|
|
107
|
+
arr: np.ndarray,
|
|
108
|
+
window_shape: List[int],
|
|
109
|
+
step: Tuple[int, ...],
|
|
110
|
+
output_shape: List[int],
|
|
111
|
+
target: Optional[np.ndarray] = None,
|
|
112
|
+
) -> np.ndarray:
|
|
113
|
+
"""
|
|
114
|
+
Compute views of an array corresponding to patches.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
arr : np.ndarray
|
|
119
|
+
Array from which the views are extracted.
|
|
120
|
+
window_shape : Tuple[int]
|
|
121
|
+
Shape of the views.
|
|
122
|
+
step : Tuple[int]
|
|
123
|
+
Steps between views.
|
|
124
|
+
output_shape : Tuple[int]
|
|
125
|
+
Shape of the output array.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
np.ndarray
|
|
130
|
+
Array with views dimension.
|
|
131
|
+
"""
|
|
132
|
+
rng = np.random.default_rng()
|
|
133
|
+
|
|
134
|
+
if target is not None:
|
|
135
|
+
arr = np.stack([arr, target], axis=0)
|
|
136
|
+
window_shape = [arr.shape[0], *window_shape]
|
|
137
|
+
step = (arr.shape[0], *step)
|
|
138
|
+
output_shape = [arr.shape[0], -1, arr.shape[2], *output_shape[2:]]
|
|
139
|
+
|
|
140
|
+
patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
|
|
141
|
+
*output_shape
|
|
142
|
+
)
|
|
143
|
+
if target is not None:
|
|
144
|
+
rng.shuffle(patches, axis=1)
|
|
145
|
+
else:
|
|
146
|
+
rng.shuffle(patches, axis=0)
|
|
147
|
+
return patches
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def extract_patches_sequential(
|
|
151
|
+
arr: np.ndarray,
|
|
152
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
153
|
+
target: Optional[np.ndarray] = None,
|
|
154
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
155
|
+
"""
|
|
156
|
+
Generate patches from an array in a sequential manner.
|
|
157
|
+
|
|
158
|
+
Array dimensions should be SC(Z)YX, where S and C can be singleton dimensions. The
|
|
159
|
+
patches are generated sequentially and cover the whole array.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
arr : np.ndarray
|
|
164
|
+
Input image array.
|
|
165
|
+
patch_size : Tuple[int]
|
|
166
|
+
Patch sizes in each dimension.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
Generator[Tuple[np.ndarray, ...], None, None]
|
|
171
|
+
Generator of patches.
|
|
172
|
+
"""
|
|
173
|
+
is_3d_patch = len(patch_size) == 3
|
|
174
|
+
|
|
175
|
+
# Patches sanity check
|
|
176
|
+
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
177
|
+
|
|
178
|
+
# Update patch size to encompass S and C dimensions
|
|
179
|
+
patch_size = [1, arr.shape[1], *patch_size]
|
|
180
|
+
|
|
181
|
+
# Compute overlap
|
|
182
|
+
overlaps = _compute_overlap(arr_shape=arr.shape, patch_sizes=patch_size)
|
|
183
|
+
|
|
184
|
+
# Create view window and overlaps
|
|
185
|
+
window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
|
|
186
|
+
|
|
187
|
+
output_shape = [
|
|
188
|
+
-1,
|
|
189
|
+
] + patch_size[1:]
|
|
190
|
+
|
|
191
|
+
# Generate a view of the input array containing pre-calculated number of patches
|
|
192
|
+
# in each dimension with overlap.
|
|
193
|
+
# Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches, C, Y, X)
|
|
194
|
+
patches = _compute_patch_views(
|
|
195
|
+
arr,
|
|
196
|
+
window_shape=patch_size,
|
|
197
|
+
step=window_steps,
|
|
198
|
+
output_shape=output_shape,
|
|
199
|
+
target=target,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if target is not None:
|
|
203
|
+
# target was concatenated to patches in _compute_reshaped_view
|
|
204
|
+
return (patches[0, ...], patches[1, ...]) # TODO in _compute_reshaped_view?
|
|
205
|
+
else:
|
|
206
|
+
return patches, None
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import Generator, List, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from careamics.config.tile_information import TileInformation
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _compute_crop_and_stitch_coords_1d(
|
|
10
|
+
axis_size: int, tile_size: int, overlap: int
|
|
11
|
+
) -> Tuple[List[Tuple[int, ...]], ...]:
|
|
12
|
+
"""
|
|
13
|
+
Compute the coordinates of each tile along an axis, given the overlap.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
axis_size : int
|
|
18
|
+
Length of the axis.
|
|
19
|
+
tile_size : int
|
|
20
|
+
Size of the tile for the given axis.
|
|
21
|
+
overlap : int
|
|
22
|
+
Size of the overlap for the given axis.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Tuple[Tuple[int, ...], ...]
|
|
27
|
+
Tuple of all coordinates for given axis.
|
|
28
|
+
"""
|
|
29
|
+
# Compute the step between tiles
|
|
30
|
+
step = tile_size - overlap
|
|
31
|
+
crop_coords = []
|
|
32
|
+
stitch_coords = []
|
|
33
|
+
overlap_crop_coords = []
|
|
34
|
+
|
|
35
|
+
# Iterate over the axis with step
|
|
36
|
+
for i in range(0, max(1, axis_size - overlap), step):
|
|
37
|
+
# Check if the tile fits within the axis
|
|
38
|
+
if i + tile_size <= axis_size:
|
|
39
|
+
# Add the coordinates to crop one tile
|
|
40
|
+
crop_coords.append((i, i + tile_size))
|
|
41
|
+
|
|
42
|
+
# Add the pixel coordinates of the cropped tile in the original image space
|
|
43
|
+
stitch_coords.append(
|
|
44
|
+
(
|
|
45
|
+
i + overlap // 2 if i > 0 else 0,
|
|
46
|
+
i + tile_size - overlap // 2
|
|
47
|
+
if crop_coords[-1][1] < axis_size
|
|
48
|
+
else axis_size,
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Add the coordinates to crop the overlap from the prediction.
|
|
53
|
+
overlap_crop_coords.append(
|
|
54
|
+
(
|
|
55
|
+
overlap // 2 if i > 0 else 0,
|
|
56
|
+
tile_size - overlap // 2
|
|
57
|
+
if crop_coords[-1][1] < axis_size
|
|
58
|
+
else tile_size,
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# If the tile does not fit within the axis, perform the abovementioned
|
|
63
|
+
# operations starting from the end of the axis
|
|
64
|
+
else:
|
|
65
|
+
# if (axis_size - tile_size, axis_size) not in crop_coords:
|
|
66
|
+
crop_coords.append((max(0, axis_size - tile_size), axis_size))
|
|
67
|
+
last_tile_end_coord = stitch_coords[-1][1] if stitch_coords else 1
|
|
68
|
+
stitch_coords.append((last_tile_end_coord, axis_size))
|
|
69
|
+
overlap_crop_coords.append(
|
|
70
|
+
(tile_size - (axis_size - last_tile_end_coord), tile_size)
|
|
71
|
+
)
|
|
72
|
+
break
|
|
73
|
+
return crop_coords, stitch_coords, overlap_crop_coords
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def extract_tiles(
|
|
77
|
+
arr: np.ndarray,
|
|
78
|
+
tile_size: Union[List[int], Tuple[int, ...]],
|
|
79
|
+
overlaps: Union[List[int], Tuple[int, ...]],
|
|
80
|
+
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
81
|
+
"""
|
|
82
|
+
Generate tiles from the input array with specified overlap.
|
|
83
|
+
|
|
84
|
+
The tiles cover the whole array. The method returns a generator that yields
|
|
85
|
+
tuples of array and tile information, the latter includes whether
|
|
86
|
+
the tile is the last one, the coordinates of the overlap crop, and the coordinates
|
|
87
|
+
of the stitched tile.
|
|
88
|
+
|
|
89
|
+
The array has shape C(Z)YX, where C can be a singleton.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
arr : np.ndarray
|
|
94
|
+
Array of shape (S, C, (Z), Y, X).
|
|
95
|
+
tile_size : Union[List[int], Tuple[int]]
|
|
96
|
+
Tile sizes in each dimension, of length 2 or 3.
|
|
97
|
+
overlaps : Union[List[int], Tuple[int]]
|
|
98
|
+
Overlap values in each dimension, of length 2 or 3.
|
|
99
|
+
|
|
100
|
+
Yields
|
|
101
|
+
------
|
|
102
|
+
Generator[Tuple[np.ndarray, TileInformation], None, None]
|
|
103
|
+
Tile generator, yields the tile and additional information.
|
|
104
|
+
"""
|
|
105
|
+
# Iterate over num samples (S)
|
|
106
|
+
for sample_idx in range(arr.shape[0]):
|
|
107
|
+
sample: np.ndarray = arr[sample_idx, ...]
|
|
108
|
+
|
|
109
|
+
# Create a list of coordinates for cropping and stitching all axes.
|
|
110
|
+
# [crop coordinates, stitching coordinates, overlap crop coordinates]
|
|
111
|
+
# For axis of size 35 and patch size of 32 compute_crop_and_stitch_coords_1d
|
|
112
|
+
# will output ([(0, 32), (3, 35)], [(0, 20), (20, 35)], [(0, 20), (17, 32)])
|
|
113
|
+
crop_and_stitch_coords_list = [
|
|
114
|
+
_compute_crop_and_stitch_coords_1d(
|
|
115
|
+
sample.shape[i + 1], tile_size[i], overlaps[i]
|
|
116
|
+
)
|
|
117
|
+
for i in range(len(tile_size))
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
# Rearrange crop coordinates from a list of coordinate pairs per axis to a list
|
|
121
|
+
# grouped by type.
|
|
122
|
+
all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
|
|
123
|
+
*crop_and_stitch_coords_list
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Maximum tile index
|
|
127
|
+
max_tile_idx = np.prod([len(axis) for axis in all_crop_coords]) - 1
|
|
128
|
+
|
|
129
|
+
# Iterate over generated coordinate pairs:
|
|
130
|
+
for tile_idx, (crop_coords, stitch_coords, overlap_crop_coords) in enumerate(
|
|
131
|
+
zip(
|
|
132
|
+
itertools.product(*all_crop_coords),
|
|
133
|
+
itertools.product(*all_stitch_coords),
|
|
134
|
+
itertools.product(*all_overlap_crop_coords),
|
|
135
|
+
)
|
|
136
|
+
):
|
|
137
|
+
# Extract tile from the sample
|
|
138
|
+
tile: np.ndarray = sample[
|
|
139
|
+
(..., *[slice(c[0], c[1]) for c in list(crop_coords)]) # type: ignore
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
# Check if we are at the end of the sample by computing the length of the
|
|
143
|
+
# array that contains all the tiles
|
|
144
|
+
if tile_idx == max_tile_idx:
|
|
145
|
+
last_tile = True
|
|
146
|
+
else:
|
|
147
|
+
last_tile = False
|
|
148
|
+
|
|
149
|
+
# create tile information
|
|
150
|
+
tile_info = TileInformation(
|
|
151
|
+
array_shape=sample.squeeze().shape,
|
|
152
|
+
tiled=True,
|
|
153
|
+
last_tile=last_tile,
|
|
154
|
+
overlap_crop_coords=overlap_crop_coords,
|
|
155
|
+
stitch_coords=stitch_coords,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
yield tile, tile_info
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import List, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def validate_patch_dimensions(
|
|
7
|
+
arr: np.ndarray,
|
|
8
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
9
|
+
is_3d_patch: bool,
|
|
10
|
+
) -> None:
|
|
11
|
+
"""
|
|
12
|
+
Check patch size and array compatibility.
|
|
13
|
+
|
|
14
|
+
This method validates the patch sizes with respect to the array dimensions:
|
|
15
|
+
|
|
16
|
+
- Patch must have two dimensions fewer than the array (S and C).
|
|
17
|
+
- Patch sizes are smaller than the corresponding array dimensions.
|
|
18
|
+
|
|
19
|
+
If one of these conditions is not met, a ValueError is raised.
|
|
20
|
+
|
|
21
|
+
This method should be called after inputs have been resized.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
arr : np.ndarray
|
|
26
|
+
Input array.
|
|
27
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
28
|
+
Size of the patches along each dimension of the array, except the first.
|
|
29
|
+
is_3d_patch : bool
|
|
30
|
+
Whether the patch is 3D or not.
|
|
31
|
+
|
|
32
|
+
Raises
|
|
33
|
+
------
|
|
34
|
+
ValueError
|
|
35
|
+
If the patch size is not consistent with the array shape (one more array
|
|
36
|
+
dimension).
|
|
37
|
+
ValueError
|
|
38
|
+
If the patch size in Z is larger than the array dimension.
|
|
39
|
+
ValueError
|
|
40
|
+
If either of the patch sizes in X or Y is larger than the corresponding array
|
|
41
|
+
dimension.
|
|
42
|
+
"""
|
|
43
|
+
if len(patch_size) != len(arr.shape[2:]):
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"There must be a patch size for each spatial dimensions "
|
|
46
|
+
f"(got {patch_size} patches for dims {arr.shape})."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Sanity checks on patch sizes versus array dimension
|
|
50
|
+
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Z patch size is inconsistent with image shape "
|
|
53
|
+
f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"At least one of YX patch dimensions is larger than the corresponding "
|
|
59
|
+
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
|
|
60
|
+
)
|