careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""Patching functions."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Callable, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
from ...utils.logging import get_logger
|
|
11
|
+
from ..dataset_utils import reshape_array
|
|
12
|
+
from ..dataset_utils.running_stats import compute_normalization_stats
|
|
13
|
+
from .sequential_patching import extract_patches_sequential
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class Stats:
|
|
20
|
+
"""Dataclass to store statistics."""
|
|
21
|
+
|
|
22
|
+
means: Union[NDArray, tuple, list, None]
|
|
23
|
+
"""Mean of the data across channels."""
|
|
24
|
+
|
|
25
|
+
stds: Union[NDArray, tuple, list, None]
|
|
26
|
+
"""Standard deviation of the data across channels."""
|
|
27
|
+
|
|
28
|
+
def get_statistics(self) -> tuple[list[float], list[float]]:
|
|
29
|
+
"""Return the means and standard deviations.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
tuple of two lists of floats
|
|
34
|
+
Means and standard deviations.
|
|
35
|
+
"""
|
|
36
|
+
if self.means is None or self.stds is None:
|
|
37
|
+
return [], []
|
|
38
|
+
|
|
39
|
+
return list(self.means), list(self.stds)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class PatchedOutput:
|
|
44
|
+
"""Dataclass to store patches and statistics."""
|
|
45
|
+
|
|
46
|
+
patches: Union[NDArray]
|
|
47
|
+
"""Image patches."""
|
|
48
|
+
|
|
49
|
+
targets: Union[NDArray, None]
|
|
50
|
+
"""Target patches."""
|
|
51
|
+
|
|
52
|
+
image_stats: Stats
|
|
53
|
+
"""Statistics of the image patches."""
|
|
54
|
+
|
|
55
|
+
target_stats: Stats
|
|
56
|
+
"""Statistics of the target patches."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# called by in memory dataset
|
|
60
|
+
def prepare_patches_supervised(
|
|
61
|
+
train_files: list[Path],
|
|
62
|
+
target_files: list[Path],
|
|
63
|
+
axes: str,
|
|
64
|
+
patch_size: Union[list[int], tuple[int, ...]],
|
|
65
|
+
read_source_func: Callable,
|
|
66
|
+
) -> PatchedOutput:
|
|
67
|
+
"""
|
|
68
|
+
Iterate over data source and create an array of patches and corresponding targets.
|
|
69
|
+
|
|
70
|
+
The lists of Paths should be pre-sorted.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
train_files : list of pathlib.Path
|
|
75
|
+
List of paths to training data.
|
|
76
|
+
target_files : list of pathlib.Path
|
|
77
|
+
List of paths to target data.
|
|
78
|
+
axes : str
|
|
79
|
+
Axes of the data.
|
|
80
|
+
patch_size : list or tuple of int
|
|
81
|
+
Size of the patches.
|
|
82
|
+
read_source_func : Callable
|
|
83
|
+
Function to read the data.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
np.ndarray
|
|
88
|
+
Array of patches.
|
|
89
|
+
"""
|
|
90
|
+
means, stds, num_samples = 0, 0, 0
|
|
91
|
+
all_patches, all_targets = [], []
|
|
92
|
+
for train_filename, target_filename in zip(train_files, target_files):
|
|
93
|
+
try:
|
|
94
|
+
sample: np.ndarray = read_source_func(train_filename, axes)
|
|
95
|
+
target: np.ndarray = read_source_func(target_filename, axes)
|
|
96
|
+
means += sample.mean()
|
|
97
|
+
stds += sample.std()
|
|
98
|
+
num_samples += 1
|
|
99
|
+
|
|
100
|
+
# reshape array
|
|
101
|
+
sample = reshape_array(sample, axes)
|
|
102
|
+
target = reshape_array(target, axes)
|
|
103
|
+
|
|
104
|
+
# generate patches, return a generator
|
|
105
|
+
patches, targets = extract_patches_sequential(
|
|
106
|
+
sample, patch_size=patch_size, target=target
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# convert generator to list and add to all_patches
|
|
110
|
+
all_patches.append(patches)
|
|
111
|
+
|
|
112
|
+
# ensure targets are not None (type checking)
|
|
113
|
+
if targets is not None:
|
|
114
|
+
all_targets.append(targets)
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(f"No target found for {target_filename}.")
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
# emit warning and continue
|
|
120
|
+
logger.error(f"Failed to read {train_filename} or {target_filename}: {e}")
|
|
121
|
+
|
|
122
|
+
# raise error if no valid samples found
|
|
123
|
+
if num_samples == 0 or len(all_patches) == 0:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"No valid samples found in the input data: {train_files} and "
|
|
126
|
+
f"{target_files}."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
130
|
+
target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
|
|
131
|
+
|
|
132
|
+
patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
|
|
133
|
+
target_array: np.ndarray = np.concatenate(all_targets, axis=0)
|
|
134
|
+
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
135
|
+
|
|
136
|
+
return PatchedOutput(
|
|
137
|
+
patch_array,
|
|
138
|
+
target_array,
|
|
139
|
+
Stats(image_means, image_stds),
|
|
140
|
+
Stats(target_means, target_stds),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# called by in_memory_dataset
|
|
145
|
+
def prepare_patches_unsupervised(
|
|
146
|
+
train_files: list[Path],
|
|
147
|
+
axes: str,
|
|
148
|
+
patch_size: Union[list[int], tuple[int]],
|
|
149
|
+
read_source_func: Callable,
|
|
150
|
+
) -> PatchedOutput:
|
|
151
|
+
"""Iterate over data source and create an array of patches.
|
|
152
|
+
|
|
153
|
+
This method returns the mean and standard deviation of the image.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
train_files : list of pathlib.Path
|
|
158
|
+
List of paths to training data.
|
|
159
|
+
axes : str
|
|
160
|
+
Axes of the data.
|
|
161
|
+
patch_size : list or tuple of int
|
|
162
|
+
Size of the patches.
|
|
163
|
+
read_source_func : Callable
|
|
164
|
+
Function to read the data.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
PatchedOutput
|
|
169
|
+
Dataclass holding patches and their statistics.
|
|
170
|
+
"""
|
|
171
|
+
means, stds, num_samples = 0, 0, 0
|
|
172
|
+
all_patches = []
|
|
173
|
+
for filename in train_files:
|
|
174
|
+
try:
|
|
175
|
+
sample: np.ndarray = read_source_func(filename, axes)
|
|
176
|
+
means += sample.mean()
|
|
177
|
+
stds += sample.std()
|
|
178
|
+
num_samples += 1
|
|
179
|
+
|
|
180
|
+
# reshape array
|
|
181
|
+
sample = reshape_array(sample, axes)
|
|
182
|
+
|
|
183
|
+
# generate patches, return a generator
|
|
184
|
+
patches, _ = extract_patches_sequential(sample, patch_size=patch_size)
|
|
185
|
+
|
|
186
|
+
# convert generator to list and add to all_patches
|
|
187
|
+
all_patches.append(patches)
|
|
188
|
+
except Exception as e:
|
|
189
|
+
# emit warning and continue
|
|
190
|
+
logger.error(f"Failed to read {filename}: {e}")
|
|
191
|
+
|
|
192
|
+
# raise error if no valid samples found
|
|
193
|
+
if num_samples == 0:
|
|
194
|
+
raise ValueError(f"No valid samples found in the input data: {train_files}.")
|
|
195
|
+
|
|
196
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
197
|
+
|
|
198
|
+
patch_array: np.ndarray = np.concatenate(all_patches)
|
|
199
|
+
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
200
|
+
|
|
201
|
+
return PatchedOutput(
|
|
202
|
+
patch_array, None, Stats(image_means, image_stds), Stats((), ())
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# called on arrays by in memory dataset
|
|
207
|
+
def prepare_patches_supervised_array(
|
|
208
|
+
data: NDArray,
|
|
209
|
+
axes: str,
|
|
210
|
+
data_target: NDArray,
|
|
211
|
+
patch_size: Union[list[int], tuple[int]],
|
|
212
|
+
) -> PatchedOutput:
|
|
213
|
+
"""Iterate over data source and create an array of patches.
|
|
214
|
+
|
|
215
|
+
This method expects an array of shape SC(Z)YX, where S and C can be singleton
|
|
216
|
+
dimensions.
|
|
217
|
+
|
|
218
|
+
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
data : numpy.ndarray
|
|
223
|
+
Input data array.
|
|
224
|
+
axes : str
|
|
225
|
+
Axes of the data.
|
|
226
|
+
data_target : numpy.ndarray
|
|
227
|
+
Target data array.
|
|
228
|
+
patch_size : list or tuple of int
|
|
229
|
+
Size of the patches.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
PatchedOutput
|
|
234
|
+
Dataclass holding the source and target patches, with their statistics.
|
|
235
|
+
"""
|
|
236
|
+
# reshape array
|
|
237
|
+
reshaped_sample = reshape_array(data, axes)
|
|
238
|
+
reshaped_target = reshape_array(data_target, axes)
|
|
239
|
+
|
|
240
|
+
# compute statistics
|
|
241
|
+
image_means, image_stds = compute_normalization_stats(reshaped_sample)
|
|
242
|
+
target_means, target_stds = compute_normalization_stats(reshaped_target)
|
|
243
|
+
|
|
244
|
+
# generate patches, return a generator
|
|
245
|
+
patches, patch_targets = extract_patches_sequential(
|
|
246
|
+
reshaped_sample, patch_size=patch_size, target=reshaped_target
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if patch_targets is None:
|
|
250
|
+
raise ValueError("No target extracted.")
|
|
251
|
+
|
|
252
|
+
logger.info(f"Extracted {patches.shape[0]} patches from input array.")
|
|
253
|
+
|
|
254
|
+
return PatchedOutput(
|
|
255
|
+
patches,
|
|
256
|
+
patch_targets,
|
|
257
|
+
Stats(image_means, image_stds),
|
|
258
|
+
Stats(target_means, target_stds),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
# called by in memory dataset
|
|
263
|
+
def prepare_patches_unsupervised_array(
|
|
264
|
+
data: NDArray,
|
|
265
|
+
axes: str,
|
|
266
|
+
patch_size: Union[list[int], tuple[int]],
|
|
267
|
+
) -> PatchedOutput:
|
|
268
|
+
"""
|
|
269
|
+
Iterate over data source and create an array of patches.
|
|
270
|
+
|
|
271
|
+
This method expects an array of shape SC(Z)YX, where S and C can be singleton
|
|
272
|
+
dimensions.
|
|
273
|
+
|
|
274
|
+
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
275
|
+
|
|
276
|
+
Parameters
|
|
277
|
+
----------
|
|
278
|
+
data : numpy.ndarray
|
|
279
|
+
Input data array.
|
|
280
|
+
axes : str
|
|
281
|
+
Axes of the data.
|
|
282
|
+
patch_size : list or tuple of int
|
|
283
|
+
Size of the patches.
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
PatchedOutput
|
|
288
|
+
Dataclass holding the patches and their statistics.
|
|
289
|
+
"""
|
|
290
|
+
# reshape array
|
|
291
|
+
reshaped_sample = reshape_array(data, axes)
|
|
292
|
+
|
|
293
|
+
# calculate mean and std
|
|
294
|
+
means, stds = compute_normalization_stats(reshaped_sample)
|
|
295
|
+
|
|
296
|
+
# generate patches, return a generator
|
|
297
|
+
patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
|
|
298
|
+
|
|
299
|
+
return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""Random patching utilities."""
|
|
2
|
+
|
|
3
|
+
from typing import Generator, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import zarr
|
|
7
|
+
|
|
8
|
+
from .validate_patch_dimension import validate_patch_dimensions
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# TOOD split in testable functions
|
|
12
|
+
def extract_patches_random(
|
|
13
|
+
arr: np.ndarray,
|
|
14
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
15
|
+
target: Optional[np.ndarray] = None,
|
|
16
|
+
seed: Optional[int] = None,
|
|
17
|
+
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
18
|
+
"""
|
|
19
|
+
Generate patches from an array in a random manner.
|
|
20
|
+
|
|
21
|
+
The method calculates how many patches the image can be divided into and then
|
|
22
|
+
extracts an equal number of random patches.
|
|
23
|
+
|
|
24
|
+
It returns a generator that yields the following:
|
|
25
|
+
|
|
26
|
+
- patch: np.ndarray, dimension C(Z)YX.
|
|
27
|
+
- target_patch: np.ndarray, dimension C(Z)YX, if the target is present, None
|
|
28
|
+
otherwise.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
arr : np.ndarray
|
|
33
|
+
Input image array.
|
|
34
|
+
patch_size : Tuple[int]
|
|
35
|
+
Patch sizes in each dimension.
|
|
36
|
+
target : Optional[np.ndarray], optional
|
|
37
|
+
Target array, by default None.
|
|
38
|
+
seed : Optional[int], optional
|
|
39
|
+
Random seed, by default None.
|
|
40
|
+
|
|
41
|
+
Yields
|
|
42
|
+
------
|
|
43
|
+
Generator[np.ndarray, None, None]
|
|
44
|
+
Generator of patches.
|
|
45
|
+
"""
|
|
46
|
+
rng = np.random.default_rng(seed=seed)
|
|
47
|
+
|
|
48
|
+
is_3d_patch = len(patch_size) == 3
|
|
49
|
+
|
|
50
|
+
# patches sanity check
|
|
51
|
+
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
52
|
+
|
|
53
|
+
# Update patch size to encompass S and C dimensions
|
|
54
|
+
patch_size = [1, arr.shape[1], *patch_size]
|
|
55
|
+
|
|
56
|
+
# iterate over the number of samples (S or T)
|
|
57
|
+
for sample_idx in range(arr.shape[0]):
|
|
58
|
+
# get sample array
|
|
59
|
+
sample: np.ndarray = arr[sample_idx, ...]
|
|
60
|
+
|
|
61
|
+
# same for target
|
|
62
|
+
if target is not None:
|
|
63
|
+
target_sample: np.ndarray = target[sample_idx, ...]
|
|
64
|
+
|
|
65
|
+
# calculate the number of patches
|
|
66
|
+
n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
|
|
67
|
+
|
|
68
|
+
# iterate over the number of patches
|
|
69
|
+
for _ in range(n_patches):
|
|
70
|
+
# get crop coordinates
|
|
71
|
+
crop_coords = [
|
|
72
|
+
rng.integers(0, sample.shape[i] - patch_size[1:][i], endpoint=True)
|
|
73
|
+
for i in range(len(patch_size[1:]))
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
# extract patch
|
|
77
|
+
patch = (
|
|
78
|
+
sample[
|
|
79
|
+
(
|
|
80
|
+
..., # type: ignore
|
|
81
|
+
*[ # type: ignore
|
|
82
|
+
slice(c, c + patch_size[1:][i])
|
|
83
|
+
for i, c in enumerate(crop_coords)
|
|
84
|
+
],
|
|
85
|
+
)
|
|
86
|
+
]
|
|
87
|
+
.copy()
|
|
88
|
+
.astype(np.float32)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# same for target
|
|
92
|
+
if target is not None:
|
|
93
|
+
target_patch = (
|
|
94
|
+
target_sample[
|
|
95
|
+
(
|
|
96
|
+
..., # type: ignore
|
|
97
|
+
*[ # type: ignore
|
|
98
|
+
slice(c, c + patch_size[1:][i])
|
|
99
|
+
for i, c in enumerate(crop_coords)
|
|
100
|
+
],
|
|
101
|
+
)
|
|
102
|
+
]
|
|
103
|
+
.copy()
|
|
104
|
+
.astype(np.float32)
|
|
105
|
+
)
|
|
106
|
+
# return patch and target patch
|
|
107
|
+
yield patch, target_patch
|
|
108
|
+
else:
|
|
109
|
+
# return patch
|
|
110
|
+
yield patch, None
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def extract_patches_random_from_chunks(
|
|
114
|
+
arr: zarr.Array,
|
|
115
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
116
|
+
chunk_size: Union[List[int], Tuple[int, ...]],
|
|
117
|
+
chunk_limit: Optional[int] = None,
|
|
118
|
+
seed: Optional[int] = None,
|
|
119
|
+
) -> Generator[np.ndarray, None, None]:
|
|
120
|
+
"""
|
|
121
|
+
Generate patches from an array in a random manner.
|
|
122
|
+
|
|
123
|
+
The method calculates how many patches the image can be divided into and then
|
|
124
|
+
extracts an equal number of random patches.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
arr : np.ndarray
|
|
129
|
+
Input image array.
|
|
130
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
131
|
+
Patch sizes in each dimension.
|
|
132
|
+
chunk_size : Union[List[int], Tuple[int, ...]]
|
|
133
|
+
Chunk sizes to load from the.
|
|
134
|
+
chunk_limit : Optional[int], optional
|
|
135
|
+
Number of chunks to load, by default None.
|
|
136
|
+
seed : Optional[int], optional
|
|
137
|
+
Random seed, by default None.
|
|
138
|
+
|
|
139
|
+
Yields
|
|
140
|
+
------
|
|
141
|
+
Generator[np.ndarray, None, None]
|
|
142
|
+
Generator of patches.
|
|
143
|
+
"""
|
|
144
|
+
is_3d_patch = len(patch_size) == 3
|
|
145
|
+
|
|
146
|
+
# Patches sanity check
|
|
147
|
+
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
148
|
+
|
|
149
|
+
rng = np.random.default_rng(seed=seed)
|
|
150
|
+
num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
|
|
151
|
+
|
|
152
|
+
# Iterate over num chunks in the array
|
|
153
|
+
for _ in range(num_chunks):
|
|
154
|
+
chunk_crop_coords = [
|
|
155
|
+
rng.integers(0, max(0, arr.shape[i] - chunk_size[i]), endpoint=True)
|
|
156
|
+
for i in range(len(chunk_size))
|
|
157
|
+
]
|
|
158
|
+
chunk = arr[
|
|
159
|
+
(
|
|
160
|
+
...,
|
|
161
|
+
*[slice(c, c + chunk_size[i]) for i, c in enumerate(chunk_crop_coords)],
|
|
162
|
+
)
|
|
163
|
+
].squeeze()
|
|
164
|
+
|
|
165
|
+
# Add a singleton dimension if the chunk does not have a sample dimension
|
|
166
|
+
if len(chunk.shape) == len(patch_size):
|
|
167
|
+
chunk = np.expand_dims(chunk, axis=0)
|
|
168
|
+
|
|
169
|
+
# Iterate over num samples (S)
|
|
170
|
+
for sample_idx in range(chunk.shape[0]):
|
|
171
|
+
spatial_chunk = chunk[sample_idx]
|
|
172
|
+
assert len(spatial_chunk.shape) == len(
|
|
173
|
+
patch_size
|
|
174
|
+
), "Requested chunk shape is not equal to patch size"
|
|
175
|
+
|
|
176
|
+
n_patches = np.ceil(
|
|
177
|
+
np.prod(spatial_chunk.shape) / np.prod(patch_size)
|
|
178
|
+
).astype(int)
|
|
179
|
+
|
|
180
|
+
# Iterate over the number of patches
|
|
181
|
+
for _ in range(n_patches):
|
|
182
|
+
patch_crop_coords = [
|
|
183
|
+
rng.integers(
|
|
184
|
+
0, spatial_chunk.shape[i] - patch_size[i], endpoint=True
|
|
185
|
+
)
|
|
186
|
+
for i in range(len(patch_size))
|
|
187
|
+
]
|
|
188
|
+
patch = (
|
|
189
|
+
spatial_chunk[
|
|
190
|
+
(
|
|
191
|
+
...,
|
|
192
|
+
*[
|
|
193
|
+
slice(c, c + patch_size[i])
|
|
194
|
+
for i, c in enumerate(patch_crop_coords)
|
|
195
|
+
],
|
|
196
|
+
)
|
|
197
|
+
]
|
|
198
|
+
.copy()
|
|
199
|
+
.astype(np.float32)
|
|
200
|
+
)
|
|
201
|
+
yield patch
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Sequential patching functions."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from skimage.util import view_as_windows
|
|
7
|
+
|
|
8
|
+
from .validate_patch_dimension import validate_patch_dimensions
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _compute_number_of_patches(
|
|
12
|
+
arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
13
|
+
) -> Tuple[int, ...]:
|
|
14
|
+
"""
|
|
15
|
+
Compute the number of patches that fit in each dimension.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
arr_shape : Tuple[int, ...]
|
|
20
|
+
Shape of the input array.
|
|
21
|
+
patch_sizes : Union[List[int], Tuple[int, ...]
|
|
22
|
+
Shape of the patches.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Tuple[int, ...]
|
|
27
|
+
Number of patches in each dimension.
|
|
28
|
+
"""
|
|
29
|
+
if len(arr_shape) != len(patch_sizes):
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Array shape {arr_shape} and patch size {patch_sizes} should have the "
|
|
32
|
+
f"same dimension, including singleton dimension for S and equal dimension "
|
|
33
|
+
f"for C."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
n_patches = [
|
|
38
|
+
np.ceil(arr_shape[i] / patch_sizes[i]).astype(int)
|
|
39
|
+
for i in range(len(patch_sizes))
|
|
40
|
+
]
|
|
41
|
+
except IndexError as e:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"Patch size {patch_sizes} is not compatible with array shape {arr_shape}"
|
|
44
|
+
) from e
|
|
45
|
+
|
|
46
|
+
return tuple(n_patches)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _compute_overlap(
|
|
50
|
+
arr_shape: Tuple[int, ...], patch_sizes: Union[List[int], Tuple[int, ...]]
|
|
51
|
+
) -> Tuple[int, ...]:
|
|
52
|
+
"""
|
|
53
|
+
Compute the overlap between patches in each dimension.
|
|
54
|
+
|
|
55
|
+
If the array dimensions are divisible by the patch sizes, then the overlap is
|
|
56
|
+
0. Otherwise, it is the result of the division rounded to the upper value.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
arr_shape : Tuple[int, ...]
|
|
61
|
+
Input array shape.
|
|
62
|
+
patch_sizes : Union[List[int], Tuple[int, ...]]
|
|
63
|
+
Size of the patches.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
Tuple[int, ...]
|
|
68
|
+
Overlap between patches in each dimension.
|
|
69
|
+
"""
|
|
70
|
+
n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
|
|
71
|
+
|
|
72
|
+
overlap = [
|
|
73
|
+
np.ceil(
|
|
74
|
+
np.clip(n_patches[i] * patch_sizes[i] - arr_shape[i], 0, None)
|
|
75
|
+
/ max(1, (n_patches[i] - 1))
|
|
76
|
+
).astype(int)
|
|
77
|
+
for i in range(len(patch_sizes))
|
|
78
|
+
]
|
|
79
|
+
return tuple(overlap)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _compute_patch_steps(
|
|
83
|
+
patch_sizes: Union[List[int], Tuple[int, ...]], overlaps: Tuple[int, ...]
|
|
84
|
+
) -> Tuple[int, ...]:
|
|
85
|
+
"""
|
|
86
|
+
Compute steps between patches.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
patch_sizes : Tuple[int]
|
|
91
|
+
Size of the patches.
|
|
92
|
+
overlaps : Tuple[int]
|
|
93
|
+
Overlap between patches.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
Tuple[int]
|
|
98
|
+
Steps between patches.
|
|
99
|
+
"""
|
|
100
|
+
steps = [
|
|
101
|
+
min(patch_sizes[i] - overlaps[i], patch_sizes[i])
|
|
102
|
+
for i in range(len(patch_sizes))
|
|
103
|
+
]
|
|
104
|
+
return tuple(steps)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# TODO why stack the target here and not on a different dimension before this function?
|
|
108
|
+
def _compute_patch_views(
|
|
109
|
+
arr: np.ndarray,
|
|
110
|
+
window_shape: List[int],
|
|
111
|
+
step: Tuple[int, ...],
|
|
112
|
+
output_shape: List[int],
|
|
113
|
+
target: Optional[np.ndarray] = None,
|
|
114
|
+
) -> np.ndarray:
|
|
115
|
+
"""
|
|
116
|
+
Compute views of an array corresponding to patches.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
arr : np.ndarray
|
|
121
|
+
Array from which the views are extracted.
|
|
122
|
+
window_shape : Tuple[int]
|
|
123
|
+
Shape of the views.
|
|
124
|
+
step : Tuple[int]
|
|
125
|
+
Steps between views.
|
|
126
|
+
output_shape : Tuple[int]
|
|
127
|
+
Shape of the output array.
|
|
128
|
+
target : Optional[np.ndarray], optional
|
|
129
|
+
Target array, by default None.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
np.ndarray
|
|
134
|
+
Array with views dimension.
|
|
135
|
+
"""
|
|
136
|
+
rng = np.random.default_rng()
|
|
137
|
+
|
|
138
|
+
if target is not None:
|
|
139
|
+
arr = np.stack([arr, target], axis=0)
|
|
140
|
+
window_shape = [arr.shape[0], *window_shape]
|
|
141
|
+
step = (arr.shape[0], *step)
|
|
142
|
+
output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
|
|
143
|
+
|
|
144
|
+
patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
|
|
145
|
+
*output_shape
|
|
146
|
+
)
|
|
147
|
+
rng.shuffle(patches, axis=0)
|
|
148
|
+
return patches
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def extract_patches_sequential(
|
|
152
|
+
arr: np.ndarray,
|
|
153
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
154
|
+
target: Optional[np.ndarray] = None,
|
|
155
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
156
|
+
"""
|
|
157
|
+
Generate patches from an array in a sequential manner.
|
|
158
|
+
|
|
159
|
+
Array dimensions should be SC(Z)YX, where S and C can be singleton dimensions. The
|
|
160
|
+
patches are generated sequentially and cover the whole array.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
arr : np.ndarray
|
|
165
|
+
Input image array.
|
|
166
|
+
patch_size : Tuple[int]
|
|
167
|
+
Patch sizes in each dimension.
|
|
168
|
+
target : Optional[np.ndarray], optional
|
|
169
|
+
Target array, by default None.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
174
|
+
Patches.
|
|
175
|
+
"""
|
|
176
|
+
is_3d_patch = len(patch_size) == 3
|
|
177
|
+
|
|
178
|
+
# Patches sanity check
|
|
179
|
+
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
180
|
+
|
|
181
|
+
# Update patch size to encompass S and C dimensions
|
|
182
|
+
patch_size = [1, arr.shape[1], *patch_size]
|
|
183
|
+
|
|
184
|
+
# Compute overlap
|
|
185
|
+
overlaps = _compute_overlap(arr_shape=arr.shape, patch_sizes=patch_size)
|
|
186
|
+
|
|
187
|
+
# Create view window and overlaps
|
|
188
|
+
window_steps = _compute_patch_steps(patch_sizes=patch_size, overlaps=overlaps)
|
|
189
|
+
|
|
190
|
+
output_shape = [
|
|
191
|
+
-1,
|
|
192
|
+
] + patch_size[1:]
|
|
193
|
+
|
|
194
|
+
# Generate a view of the input array containing pre-calculated number of patches
|
|
195
|
+
# in each dimension with overlap.
|
|
196
|
+
# Resulting array is resized to (n_patches, C, Z, Y, X) or (n_patches, C, Y, X)
|
|
197
|
+
patches = _compute_patch_views(
|
|
198
|
+
arr,
|
|
199
|
+
window_shape=patch_size,
|
|
200
|
+
step=window_steps,
|
|
201
|
+
output_shape=output_shape,
|
|
202
|
+
target=target,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if target is not None:
|
|
206
|
+
# target was concatenated to patches in _compute_reshaped_view
|
|
207
|
+
return (
|
|
208
|
+
patches[:, 0, ...],
|
|
209
|
+
patches[:, 1, ...],
|
|
210
|
+
) # TODO in _compute_reshaped_view?
|
|
211
|
+
else:
|
|
212
|
+
return patches, None
|