careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Iterable tiled prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.file_io.read import read_tiff
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
14
|
+
from ..config import InferenceConfig
|
|
15
|
+
from ..config.tile_information import TileInformation
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
17
|
+
from .dataset_utils import iterate_over_files
|
|
18
|
+
from .tiling import extract_tiles
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class IterableTiledPredDataset(IterableDataset):
|
|
22
|
+
"""Tiled prediction dataset.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
prediction_config : InferenceConfig
|
|
27
|
+
Inference configuration.
|
|
28
|
+
src_files : list of pathlib.Path
|
|
29
|
+
List of data files.
|
|
30
|
+
read_source_func : Callable, optional
|
|
31
|
+
Read source function for custom types, by default read_tiff.
|
|
32
|
+
**kwargs : Any
|
|
33
|
+
Additional keyword arguments, unused.
|
|
34
|
+
|
|
35
|
+
Attributes
|
|
36
|
+
----------
|
|
37
|
+
data_path : str or pathlib.Path
|
|
38
|
+
Path to the data, must be a directory.
|
|
39
|
+
axes : str
|
|
40
|
+
Description of axes in format STCZYX.
|
|
41
|
+
mean : float, optional
|
|
42
|
+
Expected mean of the dataset, by default None.
|
|
43
|
+
std : float, optional
|
|
44
|
+
Expected standard deviation of the dataset, by default None.
|
|
45
|
+
patch_transform : Callable, optional
|
|
46
|
+
Patch transform callable, by default None.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
prediction_config: InferenceConfig,
|
|
52
|
+
src_files: list[Path],
|
|
53
|
+
read_source_func: Callable = read_tiff,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Constructor.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
prediction_config : InferenceConfig
|
|
61
|
+
Inference configuration.
|
|
62
|
+
src_files : List[Path]
|
|
63
|
+
List of data files.
|
|
64
|
+
read_source_func : Callable, optional
|
|
65
|
+
Read source function for custom types, by default read_tiff.
|
|
66
|
+
**kwargs : Any
|
|
67
|
+
Additional keyword arguments, unused.
|
|
68
|
+
|
|
69
|
+
Raises
|
|
70
|
+
------
|
|
71
|
+
ValueError
|
|
72
|
+
If mean and std are not provided in the inference configuration.
|
|
73
|
+
"""
|
|
74
|
+
if (
|
|
75
|
+
prediction_config.tile_size is None
|
|
76
|
+
or prediction_config.tile_overlap is None
|
|
77
|
+
):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Tile size and overlap must be provided for tiled prediction."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.prediction_config = prediction_config
|
|
83
|
+
self.data_files = src_files
|
|
84
|
+
self.axes = prediction_config.axes
|
|
85
|
+
self.tile_size = prediction_config.tile_size
|
|
86
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
87
|
+
self.read_source_func = read_source_func
|
|
88
|
+
|
|
89
|
+
# check mean and std and create normalize transform
|
|
90
|
+
if (
|
|
91
|
+
self.prediction_config.image_means is None
|
|
92
|
+
or self.prediction_config.image_stds is None
|
|
93
|
+
):
|
|
94
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
95
|
+
else:
|
|
96
|
+
self.image_means = self.prediction_config.image_means
|
|
97
|
+
self.image_stds = self.prediction_config.image_stds
|
|
98
|
+
|
|
99
|
+
# instantiate normalize transform
|
|
100
|
+
self.patch_transform = Compose(
|
|
101
|
+
transform_list=[
|
|
102
|
+
NormalizeModel(
|
|
103
|
+
image_means=self.image_means,
|
|
104
|
+
image_stds=self.image_stds,
|
|
105
|
+
)
|
|
106
|
+
],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def __iter__(
|
|
110
|
+
self,
|
|
111
|
+
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
112
|
+
"""
|
|
113
|
+
Iterate over data source and yield single patch.
|
|
114
|
+
|
|
115
|
+
Yields
|
|
116
|
+
------
|
|
117
|
+
Generator of NDArray and TileInformation tuple
|
|
118
|
+
Generator of single tiles.
|
|
119
|
+
"""
|
|
120
|
+
assert (
|
|
121
|
+
self.image_means is not None and self.image_stds is not None
|
|
122
|
+
), "Mean and std must be provided"
|
|
123
|
+
|
|
124
|
+
for sample, _ in iterate_over_files(
|
|
125
|
+
self.prediction_config,
|
|
126
|
+
self.data_files,
|
|
127
|
+
read_source_func=self.read_source_func,
|
|
128
|
+
):
|
|
129
|
+
# generate patches, return a generator of single tiles
|
|
130
|
+
patch_gen = extract_tiles(
|
|
131
|
+
arr=sample,
|
|
132
|
+
tile_size=self.tile_size,
|
|
133
|
+
overlaps=self.tile_overlap,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# apply transform to patches
|
|
137
|
+
for patch_array, tile_info in patch_gen:
|
|
138
|
+
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
139
|
+
|
|
140
|
+
yield transformed_patch, tile_info
|
|
@@ -1,37 +1,83 @@
|
|
|
1
1
|
"""Patching functions."""
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Callable,
|
|
5
|
+
from typing import Callable, Union
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
7
9
|
|
|
8
10
|
from ...utils.logging import get_logger
|
|
9
11
|
from ..dataset_utils import reshape_array
|
|
12
|
+
from ..dataset_utils.running_stats import compute_normalization_stats
|
|
10
13
|
from .sequential_patching import extract_patches_sequential
|
|
11
14
|
|
|
12
15
|
logger = get_logger(__name__)
|
|
13
16
|
|
|
14
17
|
|
|
18
|
+
@dataclass
|
|
19
|
+
class Stats:
|
|
20
|
+
"""Dataclass to store statistics."""
|
|
21
|
+
|
|
22
|
+
means: Union[NDArray, tuple, list, None]
|
|
23
|
+
"""Mean of the data across channels."""
|
|
24
|
+
|
|
25
|
+
stds: Union[NDArray, tuple, list, None]
|
|
26
|
+
"""Standard deviation of the data across channels."""
|
|
27
|
+
|
|
28
|
+
def get_statistics(self) -> tuple[list[float], list[float]]:
|
|
29
|
+
"""Return the means and standard deviations.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
tuple of two lists of floats
|
|
34
|
+
Means and standard deviations.
|
|
35
|
+
"""
|
|
36
|
+
if self.means is None or self.stds is None:
|
|
37
|
+
return [], []
|
|
38
|
+
|
|
39
|
+
return list(self.means), list(self.stds)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class PatchedOutput:
|
|
44
|
+
"""Dataclass to store patches and statistics."""
|
|
45
|
+
|
|
46
|
+
patches: Union[NDArray]
|
|
47
|
+
"""Image patches."""
|
|
48
|
+
|
|
49
|
+
targets: Union[NDArray, None]
|
|
50
|
+
"""Target patches."""
|
|
51
|
+
|
|
52
|
+
image_stats: Stats
|
|
53
|
+
"""Statistics of the image patches."""
|
|
54
|
+
|
|
55
|
+
target_stats: Stats
|
|
56
|
+
"""Statistics of the target patches."""
|
|
57
|
+
|
|
58
|
+
|
|
15
59
|
# called by in memory dataset
|
|
16
60
|
def prepare_patches_supervised(
|
|
17
|
-
train_files:
|
|
18
|
-
target_files:
|
|
61
|
+
train_files: list[Path],
|
|
62
|
+
target_files: list[Path],
|
|
19
63
|
axes: str,
|
|
20
|
-
patch_size: Union[
|
|
64
|
+
patch_size: Union[list[int], tuple[int, ...]],
|
|
21
65
|
read_source_func: Callable,
|
|
22
|
-
) ->
|
|
66
|
+
) -> PatchedOutput:
|
|
23
67
|
"""
|
|
24
68
|
Iterate over data source and create an array of patches and corresponding targets.
|
|
25
69
|
|
|
70
|
+
The lists of Paths should be pre-sorted.
|
|
71
|
+
|
|
26
72
|
Parameters
|
|
27
73
|
----------
|
|
28
|
-
train_files :
|
|
74
|
+
train_files : list of pathlib.Path
|
|
29
75
|
List of paths to training data.
|
|
30
|
-
target_files :
|
|
76
|
+
target_files : list of pathlib.Path
|
|
31
77
|
List of paths to target data.
|
|
32
78
|
axes : str
|
|
33
79
|
Axes of the data.
|
|
34
|
-
patch_size :
|
|
80
|
+
patch_size : list or tuple of int
|
|
35
81
|
Size of the patches.
|
|
36
82
|
read_source_func : Callable
|
|
37
83
|
Function to read the data.
|
|
@@ -41,9 +87,6 @@ def prepare_patches_supervised(
|
|
|
41
87
|
np.ndarray
|
|
42
88
|
Array of patches.
|
|
43
89
|
"""
|
|
44
|
-
train_files.sort()
|
|
45
|
-
target_files.sort()
|
|
46
|
-
|
|
47
90
|
means, stds, num_samples = 0, 0, 0
|
|
48
91
|
all_patches, all_targets = [], []
|
|
49
92
|
for train_filename, target_filename in zip(train_files, target_files):
|
|
@@ -83,46 +126,47 @@ def prepare_patches_supervised(
|
|
|
83
126
|
f"{target_files}."
|
|
84
127
|
)
|
|
85
128
|
|
|
86
|
-
|
|
129
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
130
|
+
target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
|
|
87
131
|
|
|
88
132
|
patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
|
|
89
133
|
target_array: np.ndarray = np.concatenate(all_targets, axis=0)
|
|
90
134
|
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
91
135
|
|
|
92
|
-
return (
|
|
136
|
+
return PatchedOutput(
|
|
93
137
|
patch_array,
|
|
94
138
|
target_array,
|
|
95
|
-
|
|
96
|
-
|
|
139
|
+
Stats(image_means, image_stds),
|
|
140
|
+
Stats(target_means, target_stds),
|
|
97
141
|
)
|
|
98
142
|
|
|
99
143
|
|
|
100
144
|
# called by in_memory_dataset
|
|
101
145
|
def prepare_patches_unsupervised(
|
|
102
|
-
train_files:
|
|
146
|
+
train_files: list[Path],
|
|
103
147
|
axes: str,
|
|
104
|
-
patch_size: Union[
|
|
148
|
+
patch_size: Union[list[int], tuple[int]],
|
|
105
149
|
read_source_func: Callable,
|
|
106
|
-
) ->
|
|
150
|
+
) -> PatchedOutput:
|
|
107
151
|
"""Iterate over data source and create an array of patches.
|
|
108
152
|
|
|
109
153
|
This method returns the mean and standard deviation of the image.
|
|
110
154
|
|
|
111
155
|
Parameters
|
|
112
156
|
----------
|
|
113
|
-
train_files :
|
|
157
|
+
train_files : list of pathlib.Path
|
|
114
158
|
List of paths to training data.
|
|
115
159
|
axes : str
|
|
116
160
|
Axes of the data.
|
|
117
|
-
patch_size :
|
|
161
|
+
patch_size : list or tuple of int
|
|
118
162
|
Size of the patches.
|
|
119
163
|
read_source_func : Callable
|
|
120
164
|
Function to read the data.
|
|
121
165
|
|
|
122
166
|
Returns
|
|
123
167
|
-------
|
|
124
|
-
|
|
125
|
-
|
|
168
|
+
PatchedOutput
|
|
169
|
+
Dataclass holding patches and their statistics.
|
|
126
170
|
"""
|
|
127
171
|
means, stds, num_samples = 0, 0, 0
|
|
128
172
|
all_patches = []
|
|
@@ -149,21 +193,23 @@ def prepare_patches_unsupervised(
|
|
|
149
193
|
if num_samples == 0:
|
|
150
194
|
raise ValueError(f"No valid samples found in the input data: {train_files}.")
|
|
151
195
|
|
|
152
|
-
|
|
196
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
153
197
|
|
|
154
198
|
patch_array: np.ndarray = np.concatenate(all_patches)
|
|
155
199
|
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
156
200
|
|
|
157
|
-
return
|
|
201
|
+
return PatchedOutput(
|
|
202
|
+
patch_array, None, Stats(image_means, image_stds), Stats((), ())
|
|
203
|
+
)
|
|
158
204
|
|
|
159
205
|
|
|
160
206
|
# called on arrays by in memory dataset
|
|
161
207
|
def prepare_patches_supervised_array(
|
|
162
|
-
data:
|
|
208
|
+
data: NDArray,
|
|
163
209
|
axes: str,
|
|
164
|
-
data_target:
|
|
165
|
-
patch_size: Union[
|
|
166
|
-
) ->
|
|
210
|
+
data_target: NDArray,
|
|
211
|
+
patch_size: Union[list[int], tuple[int]],
|
|
212
|
+
) -> PatchedOutput:
|
|
167
213
|
"""Iterate over data source and create an array of patches.
|
|
168
214
|
|
|
169
215
|
This method expects an array of shape SC(Z)YX, where S and C can be singleton
|
|
@@ -173,28 +219,28 @@ def prepare_patches_supervised_array(
|
|
|
173
219
|
|
|
174
220
|
Parameters
|
|
175
221
|
----------
|
|
176
|
-
data :
|
|
222
|
+
data : numpy.ndarray
|
|
177
223
|
Input data array.
|
|
178
224
|
axes : str
|
|
179
225
|
Axes of the data.
|
|
180
|
-
data_target :
|
|
226
|
+
data_target : numpy.ndarray
|
|
181
227
|
Target data array.
|
|
182
|
-
patch_size :
|
|
228
|
+
patch_size : list or tuple of int
|
|
183
229
|
Size of the patches.
|
|
184
230
|
|
|
185
231
|
Returns
|
|
186
232
|
-------
|
|
187
|
-
|
|
188
|
-
|
|
233
|
+
PatchedOutput
|
|
234
|
+
Dataclass holding the source and target patches, with their statistics.
|
|
189
235
|
"""
|
|
190
|
-
# compute statistics
|
|
191
|
-
mean = data.mean()
|
|
192
|
-
std = data.std()
|
|
193
|
-
|
|
194
236
|
# reshape array
|
|
195
237
|
reshaped_sample = reshape_array(data, axes)
|
|
196
238
|
reshaped_target = reshape_array(data_target, axes)
|
|
197
239
|
|
|
240
|
+
# compute statistics
|
|
241
|
+
image_means, image_stds = compute_normalization_stats(reshaped_sample)
|
|
242
|
+
target_means, target_stds = compute_normalization_stats(reshaped_target)
|
|
243
|
+
|
|
198
244
|
# generate patches, return a generator
|
|
199
245
|
patches, patch_targets = extract_patches_sequential(
|
|
200
246
|
reshaped_sample, patch_size=patch_size, target=reshaped_target
|
|
@@ -205,20 +251,20 @@ def prepare_patches_supervised_array(
|
|
|
205
251
|
|
|
206
252
|
logger.info(f"Extracted {patches.shape[0]} patches from input array.")
|
|
207
253
|
|
|
208
|
-
return (
|
|
254
|
+
return PatchedOutput(
|
|
209
255
|
patches,
|
|
210
256
|
patch_targets,
|
|
211
|
-
|
|
212
|
-
|
|
257
|
+
Stats(image_means, image_stds),
|
|
258
|
+
Stats(target_means, target_stds),
|
|
213
259
|
)
|
|
214
260
|
|
|
215
261
|
|
|
216
262
|
# called by in memory dataset
|
|
217
263
|
def prepare_patches_unsupervised_array(
|
|
218
|
-
data:
|
|
264
|
+
data: NDArray,
|
|
219
265
|
axes: str,
|
|
220
|
-
patch_size: Union[
|
|
221
|
-
) ->
|
|
266
|
+
patch_size: Union[list[int], tuple[int]],
|
|
267
|
+
) -> PatchedOutput:
|
|
222
268
|
"""
|
|
223
269
|
Iterate over data source and create an array of patches.
|
|
224
270
|
|
|
@@ -229,26 +275,25 @@ def prepare_patches_unsupervised_array(
|
|
|
229
275
|
|
|
230
276
|
Parameters
|
|
231
277
|
----------
|
|
232
|
-
data :
|
|
278
|
+
data : numpy.ndarray
|
|
233
279
|
Input data array.
|
|
234
280
|
axes : str
|
|
235
281
|
Axes of the data.
|
|
236
|
-
patch_size :
|
|
282
|
+
patch_size : list or tuple of int
|
|
237
283
|
Size of the patches.
|
|
238
284
|
|
|
239
285
|
Returns
|
|
240
286
|
-------
|
|
241
|
-
|
|
242
|
-
|
|
287
|
+
PatchedOutput
|
|
288
|
+
Dataclass holding the patches and their statistics.
|
|
243
289
|
"""
|
|
244
|
-
# calculate mean and std
|
|
245
|
-
mean = data.mean()
|
|
246
|
-
std = data.std()
|
|
247
|
-
|
|
248
290
|
# reshape array
|
|
249
291
|
reshaped_sample = reshape_array(data, axes)
|
|
250
292
|
|
|
293
|
+
# calculate mean and std
|
|
294
|
+
means, stds = compute_normalization_stats(reshaped_sample)
|
|
295
|
+
|
|
251
296
|
# generate patches, return a generator
|
|
252
297
|
patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
|
|
253
298
|
|
|
254
|
-
return patches,
|
|
299
|
+
return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
|
|
@@ -13,6 +13,7 @@ def extract_patches_random(
|
|
|
13
13
|
arr: np.ndarray,
|
|
14
14
|
patch_size: Union[List[int], Tuple[int, ...]],
|
|
15
15
|
target: Optional[np.ndarray] = None,
|
|
16
|
+
seed: Optional[int] = None,
|
|
16
17
|
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
17
18
|
"""
|
|
18
19
|
Generate patches from an array in a random manner.
|
|
@@ -34,12 +35,16 @@ def extract_patches_random(
|
|
|
34
35
|
Patch sizes in each dimension.
|
|
35
36
|
target : Optional[np.ndarray], optional
|
|
36
37
|
Target array, by default None.
|
|
38
|
+
seed : Optional[int], optional
|
|
39
|
+
Random seed, by default None.
|
|
37
40
|
|
|
38
41
|
Yields
|
|
39
42
|
------
|
|
40
43
|
Generator[np.ndarray, None, None]
|
|
41
44
|
Generator of patches.
|
|
42
45
|
"""
|
|
46
|
+
rng = np.random.default_rng(seed=seed)
|
|
47
|
+
|
|
43
48
|
is_3d_patch = len(patch_size) == 3
|
|
44
49
|
|
|
45
50
|
# patches sanity check
|
|
@@ -48,9 +53,6 @@ def extract_patches_random(
|
|
|
48
53
|
# Update patch size to encompass S and C dimensions
|
|
49
54
|
patch_size = [1, arr.shape[1], *patch_size]
|
|
50
55
|
|
|
51
|
-
# random generator
|
|
52
|
-
rng = np.random.default_rng()
|
|
53
|
-
|
|
54
56
|
# iterate over the number of samples (S or T)
|
|
55
57
|
for sample_idx in range(arr.shape[0]):
|
|
56
58
|
# get sample array
|
|
@@ -113,6 +115,7 @@ def extract_patches_random_from_chunks(
|
|
|
113
115
|
patch_size: Union[List[int], Tuple[int, ...]],
|
|
114
116
|
chunk_size: Union[List[int], Tuple[int, ...]],
|
|
115
117
|
chunk_limit: Optional[int] = None,
|
|
118
|
+
seed: Optional[int] = None,
|
|
116
119
|
) -> Generator[np.ndarray, None, None]:
|
|
117
120
|
"""
|
|
118
121
|
Generate patches from an array in a random manner.
|
|
@@ -130,6 +133,8 @@ def extract_patches_random_from_chunks(
|
|
|
130
133
|
Chunk sizes to load from the.
|
|
131
134
|
chunk_limit : Optional[int], optional
|
|
132
135
|
Number of chunks to load, by default None.
|
|
136
|
+
seed : Optional[int], optional
|
|
137
|
+
Random seed, by default None.
|
|
133
138
|
|
|
134
139
|
Yields
|
|
135
140
|
------
|
|
@@ -141,7 +146,7 @@ def extract_patches_random_from_chunks(
|
|
|
141
146
|
# Patches sanity check
|
|
142
147
|
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
143
148
|
|
|
144
|
-
rng = np.random.default_rng()
|
|
149
|
+
rng = np.random.default_rng(seed=seed)
|
|
145
150
|
num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
|
|
146
151
|
|
|
147
152
|
# Iterate over num chunks in the array
|
|
@@ -45,18 +45,20 @@ def validate_patch_dimensions(
|
|
|
45
45
|
if len(patch_size) != len(arr.shape[2:]):
|
|
46
46
|
raise ValueError(
|
|
47
47
|
f"There must be a patch size for each spatial dimensions "
|
|
48
|
-
f"(got {patch_size} patches for dims {arr.shape})."
|
|
48
|
+
f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
# Sanity checks on patch sizes versus array dimension
|
|
52
52
|
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
53
53
|
raise ValueError(
|
|
54
54
|
f"Z patch size is inconsistent with image shape "
|
|
55
|
-
f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
|
|
55
|
+
f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
|
|
56
|
+
f"order."
|
|
56
57
|
)
|
|
57
58
|
|
|
58
59
|
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
59
60
|
raise ValueError(
|
|
60
61
|
f"At least one of YX patch dimensions is larger than the corresponding "
|
|
61
|
-
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
|
|
62
|
+
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
|
|
63
|
+
f"Check the axes order."
|
|
62
64
|
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Collate function for tiling."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch.utils.data.dataloader import default_collate
|
|
7
|
+
|
|
8
|
+
from careamics.config.tile_information import TileInformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Collate tiles received from CAREamics prediction dataloader.
|
|
14
|
+
|
|
15
|
+
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
16
|
+
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
17
|
+
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
18
|
+
stitch coordinates.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
23
|
+
Batch of tiles.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Any
|
|
28
|
+
Collated batch.
|
|
29
|
+
"""
|
|
30
|
+
new_batch = [tile for tile, _ in batch]
|
|
31
|
+
tiles_batch = [tile_info for _, tile_info in batch]
|
|
32
|
+
|
|
33
|
+
return default_collate(new_batch), tiles_batch
|
|
@@ -84,15 +84,15 @@ def extract_tiles(
|
|
|
84
84
|
tile_size: Union[List[int], Tuple[int, ...]],
|
|
85
85
|
overlaps: Union[List[int], Tuple[int, ...]],
|
|
86
86
|
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
87
|
-
"""
|
|
88
|
-
Generate tiles from the input array with specified overlap.
|
|
87
|
+
"""Generate tiles from the input array with specified overlap.
|
|
89
88
|
|
|
90
89
|
The tiles cover the whole array. The method returns a generator that yields
|
|
91
90
|
tuples of array and tile information, the latter includes whether
|
|
92
91
|
the tile is the last one, the coordinates of the overlap crop, and the coordinates
|
|
93
92
|
of the stitched tile.
|
|
94
93
|
|
|
95
|
-
|
|
94
|
+
Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
|
|
95
|
+
where C can be a singleton.
|
|
96
96
|
|
|
97
97
|
Parameters
|
|
98
98
|
----------
|
|
@@ -155,10 +155,10 @@ def extract_tiles(
|
|
|
155
155
|
# create tile information
|
|
156
156
|
tile_info = TileInformation(
|
|
157
157
|
array_shape=sample.squeeze().shape,
|
|
158
|
-
tiled=True,
|
|
159
158
|
last_tile=last_tile,
|
|
160
159
|
overlap_crop_coords=overlap_crop_coords,
|
|
161
160
|
stitch_coords=stitch_coords,
|
|
161
|
+
sample_id=sample_idx,
|
|
162
162
|
)
|
|
163
163
|
|
|
164
164
|
yield tile, tile_info
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Module to get read functions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Callable, Dict, Protocol, Union
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
|
|
10
|
+
from .tiff import read_tiff
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# This is very strict, function signature has to match including arg names
|
|
14
|
+
# See WriteFunc notes
|
|
15
|
+
class ReadFunc(Protocol):
|
|
16
|
+
"""Protocol for type hinting read functions."""
|
|
17
|
+
|
|
18
|
+
def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
|
|
19
|
+
"""
|
|
20
|
+
Type hinted callables must match this function signature (not including self).
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
file_path : pathlib.Path
|
|
25
|
+
Path to file.
|
|
26
|
+
*args
|
|
27
|
+
Other positional arguments.
|
|
28
|
+
**kwargs
|
|
29
|
+
Other keyword arguments.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
READ_FUNCS: Dict[SupportedData, ReadFunc] = {
|
|
34
|
+
SupportedData.TIFF: read_tiff,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
|
|
39
|
+
"""
|
|
40
|
+
Get the read function for the data type.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
data_type : SupportedData
|
|
45
|
+
Data type.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
callable
|
|
50
|
+
Read function.
|
|
51
|
+
"""
|
|
52
|
+
if data_type in READ_FUNCS:
|
|
53
|
+
data_type = SupportedData(data_type) # mypy complaining about dict key type
|
|
54
|
+
return READ_FUNCS[data_type]
|
|
55
|
+
else:
|
|
56
|
+
raise NotImplementedError(f"Data type '{data_type}' is not supported.")
|