careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +164 -231
- careamics/config/algorithm_model.py +5 -18
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +11 -4
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -5
- careamics/config/configuration_factory.py +27 -41
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +89 -63
- careamics/config/inference_model.py +28 -81
- careamics/config/optimizer_models.py +11 -11
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -16
- careamics/config/tile_information.py +28 -58
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/config/validators/validator_utils.py +1 -1
- careamics/conftest.py +12 -0
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +6 -11
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +88 -154
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +121 -191
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +109 -39
- careamics/dataset/patching/random_patching.py +17 -6
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/validate_patch_dimension.py +7 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +46 -25
- careamics/lightning_module.py +19 -9
- careamics/lightning_prediction_datamodule.py +54 -84
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +3 -3
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +52 -14
- careamics/transforms/normalize.py +171 -48
- careamics/transforms/pixel_manipulation.py +35 -11
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/tta.py +43 -29
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +4 -2
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
- careamics-0.1.0rc7.dist-info/RECORD +130 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/lightning_prediction_loop.py +0 -116
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -74
- careamics/transforms/nd_flip.py +0 -67
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Iterable tiled prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.transforms import Compose
|
|
12
|
+
|
|
13
|
+
from ..config import InferenceConfig
|
|
14
|
+
from ..config.tile_information import TileInformation
|
|
15
|
+
from ..config.transformations import NormalizeModel
|
|
16
|
+
from .dataset_utils import iterate_over_files, read_tiff
|
|
17
|
+
from .tiling import extract_tiles
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class IterableTiledPredDataset(IterableDataset):
|
|
21
|
+
"""Tiled prediction dataset.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
prediction_config : InferenceConfig
|
|
26
|
+
Inference configuration.
|
|
27
|
+
src_files : list of pathlib.Path
|
|
28
|
+
List of data files.
|
|
29
|
+
read_source_func : Callable, optional
|
|
30
|
+
Read source function for custom types, by default read_tiff.
|
|
31
|
+
**kwargs : Any
|
|
32
|
+
Additional keyword arguments, unused.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
data_path : str or pathlib.Path
|
|
37
|
+
Path to the data, must be a directory.
|
|
38
|
+
axes : str
|
|
39
|
+
Description of axes in format STCZYX.
|
|
40
|
+
mean : float, optional
|
|
41
|
+
Expected mean of the dataset, by default None.
|
|
42
|
+
std : float, optional
|
|
43
|
+
Expected standard deviation of the dataset, by default None.
|
|
44
|
+
patch_transform : Callable, optional
|
|
45
|
+
Patch transform callable, by default None.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
prediction_config: InferenceConfig,
|
|
51
|
+
src_files: list[Path],
|
|
52
|
+
read_source_func: Callable = read_tiff,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""Constructor.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
prediction_config : InferenceConfig
|
|
60
|
+
Inference configuration.
|
|
61
|
+
src_files : List[Path]
|
|
62
|
+
List of data files.
|
|
63
|
+
read_source_func : Callable, optional
|
|
64
|
+
Read source function for custom types, by default read_tiff.
|
|
65
|
+
**kwargs : Any
|
|
66
|
+
Additional keyword arguments, unused.
|
|
67
|
+
|
|
68
|
+
Raises
|
|
69
|
+
------
|
|
70
|
+
ValueError
|
|
71
|
+
If mean and std are not provided in the inference configuration.
|
|
72
|
+
"""
|
|
73
|
+
if (
|
|
74
|
+
prediction_config.tile_size is None
|
|
75
|
+
or prediction_config.tile_overlap is None
|
|
76
|
+
):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Tile size and overlap must be provided for tiled prediction."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.prediction_config = prediction_config
|
|
82
|
+
self.data_files = src_files
|
|
83
|
+
self.axes = prediction_config.axes
|
|
84
|
+
self.tile_size = prediction_config.tile_size
|
|
85
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
86
|
+
self.read_source_func = read_source_func
|
|
87
|
+
|
|
88
|
+
# check mean and std and create normalize transform
|
|
89
|
+
if (
|
|
90
|
+
self.prediction_config.image_means is None
|
|
91
|
+
or self.prediction_config.image_stds is None
|
|
92
|
+
):
|
|
93
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
94
|
+
else:
|
|
95
|
+
self.image_means = self.prediction_config.image_means
|
|
96
|
+
self.image_stds = self.prediction_config.image_stds
|
|
97
|
+
|
|
98
|
+
# instantiate normalize transform
|
|
99
|
+
self.patch_transform = Compose(
|
|
100
|
+
transform_list=[
|
|
101
|
+
NormalizeModel(
|
|
102
|
+
image_means=self.image_means,
|
|
103
|
+
image_stds=self.image_stds,
|
|
104
|
+
)
|
|
105
|
+
],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def __iter__(
|
|
109
|
+
self,
|
|
110
|
+
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
111
|
+
"""
|
|
112
|
+
Iterate over data source and yield single patch.
|
|
113
|
+
|
|
114
|
+
Yields
|
|
115
|
+
------
|
|
116
|
+
Generator of NDArray and TileInformation tuple
|
|
117
|
+
Generator of single tiles.
|
|
118
|
+
"""
|
|
119
|
+
assert (
|
|
120
|
+
self.image_means is not None and self.image_stds is not None
|
|
121
|
+
), "Mean and std must be provided"
|
|
122
|
+
|
|
123
|
+
for sample, _ in iterate_over_files(
|
|
124
|
+
self.prediction_config,
|
|
125
|
+
self.data_files,
|
|
126
|
+
read_source_func=self.read_source_func,
|
|
127
|
+
):
|
|
128
|
+
# generate patches, return a generator of single tiles
|
|
129
|
+
patch_gen = extract_tiles(
|
|
130
|
+
arr=sample,
|
|
131
|
+
tile_size=self.tile_size,
|
|
132
|
+
overlaps=self.tile_overlap,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# apply transform to patches
|
|
136
|
+
for patch_array, tile_info in patch_gen:
|
|
137
|
+
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
138
|
+
|
|
139
|
+
yield transformed_patch, tile_info
|
|
@@ -1,9 +1,6 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tiling submodule.
|
|
3
|
-
|
|
4
|
-
These functions are used to tile images into patches or tiles.
|
|
5
|
-
"""
|
|
1
|
+
"""Patching functions."""
|
|
6
2
|
|
|
3
|
+
from dataclasses import dataclass
|
|
7
4
|
from pathlib import Path
|
|
8
5
|
from typing import Callable, List, Tuple, Union
|
|
9
6
|
|
|
@@ -11,30 +8,69 @@ import numpy as np
|
|
|
11
8
|
|
|
12
9
|
from ...utils.logging import get_logger
|
|
13
10
|
from ..dataset_utils import reshape_array
|
|
11
|
+
from ..dataset_utils.running_stats import compute_normalization_stats
|
|
14
12
|
from .sequential_patching import extract_patches_sequential
|
|
15
13
|
|
|
16
14
|
logger = get_logger(__name__)
|
|
17
15
|
|
|
18
16
|
|
|
17
|
+
@dataclass
|
|
18
|
+
class Stats:
|
|
19
|
+
"""Dataclass to store statistics."""
|
|
20
|
+
|
|
21
|
+
means: Union[np.ndarray, tuple, list, None]
|
|
22
|
+
stds: Union[np.ndarray, tuple, list, None]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class PatchedOutput:
|
|
27
|
+
"""Dataclass to store patches and statistics."""
|
|
28
|
+
|
|
29
|
+
patches: Union[np.ndarray]
|
|
30
|
+
targets: Union[np.ndarray, None]
|
|
31
|
+
image_stats: Stats
|
|
32
|
+
target_stats: Stats
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class StatsOutput:
|
|
37
|
+
"""Dataclass to store patches and statistics."""
|
|
38
|
+
|
|
39
|
+
image_stats: Stats
|
|
40
|
+
target_stats: Stats
|
|
41
|
+
|
|
42
|
+
|
|
19
43
|
# called by in memory dataset
|
|
20
44
|
def prepare_patches_supervised(
|
|
21
45
|
train_files: List[Path],
|
|
22
46
|
target_files: List[Path],
|
|
23
47
|
axes: str,
|
|
24
|
-
patch_size: Union[List[int], Tuple[int]],
|
|
48
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
25
49
|
read_source_func: Callable,
|
|
26
|
-
) ->
|
|
50
|
+
) -> PatchedOutput:
|
|
27
51
|
"""
|
|
28
52
|
Iterate over data source and create an array of patches and corresponding targets.
|
|
29
53
|
|
|
54
|
+
The lists of Paths should be pre-sorted.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
train_files : List[Path]
|
|
59
|
+
List of paths to training data.
|
|
60
|
+
target_files : List[Path]
|
|
61
|
+
List of paths to target data.
|
|
62
|
+
axes : str
|
|
63
|
+
Axes of the data.
|
|
64
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
65
|
+
Size of the patches.
|
|
66
|
+
read_source_func : Callable
|
|
67
|
+
Function to read the data.
|
|
68
|
+
|
|
30
69
|
Returns
|
|
31
70
|
-------
|
|
32
71
|
np.ndarray
|
|
33
72
|
Array of patches.
|
|
34
73
|
"""
|
|
35
|
-
train_files.sort()
|
|
36
|
-
target_files.sort()
|
|
37
|
-
|
|
38
74
|
means, stds, num_samples = 0, 0, 0
|
|
39
75
|
all_patches, all_targets = [], []
|
|
40
76
|
for train_filename, target_filename in zip(train_files, target_files):
|
|
@@ -74,17 +110,18 @@ def prepare_patches_supervised(
|
|
|
74
110
|
f"{target_files}."
|
|
75
111
|
)
|
|
76
112
|
|
|
77
|
-
|
|
113
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
114
|
+
target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
|
|
78
115
|
|
|
79
116
|
patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
|
|
80
117
|
target_array: np.ndarray = np.concatenate(all_targets, axis=0)
|
|
81
118
|
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
82
119
|
|
|
83
|
-
return (
|
|
120
|
+
return PatchedOutput(
|
|
84
121
|
patch_array,
|
|
85
122
|
target_array,
|
|
86
|
-
|
|
87
|
-
|
|
123
|
+
Stats(image_means, image_stds),
|
|
124
|
+
Stats(target_means, target_stds),
|
|
88
125
|
)
|
|
89
126
|
|
|
90
127
|
|
|
@@ -94,14 +131,26 @@ def prepare_patches_unsupervised(
|
|
|
94
131
|
axes: str,
|
|
95
132
|
patch_size: Union[List[int], Tuple[int]],
|
|
96
133
|
read_source_func: Callable,
|
|
97
|
-
) ->
|
|
98
|
-
"""
|
|
99
|
-
|
|
134
|
+
) -> PatchedOutput:
|
|
135
|
+
"""Iterate over data source and create an array of patches.
|
|
136
|
+
|
|
137
|
+
This method returns the mean and standard deviation of the image.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
train_files : List[Path]
|
|
142
|
+
List of paths to training data.
|
|
143
|
+
axes : str
|
|
144
|
+
Axes of the data.
|
|
145
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
146
|
+
Size of the patches.
|
|
147
|
+
read_source_func : Callable
|
|
148
|
+
Function to read the data.
|
|
100
149
|
|
|
101
150
|
Returns
|
|
102
151
|
-------
|
|
103
|
-
np.ndarray
|
|
104
|
-
|
|
152
|
+
Tuple[np.ndarray, None, float, float]
|
|
153
|
+
Source and target patches, mean and standard deviation.
|
|
105
154
|
"""
|
|
106
155
|
means, stds, num_samples = 0, 0, 0
|
|
107
156
|
all_patches = []
|
|
@@ -128,12 +177,14 @@ def prepare_patches_unsupervised(
|
|
|
128
177
|
if num_samples == 0:
|
|
129
178
|
raise ValueError(f"No valid samples found in the input data: {train_files}.")
|
|
130
179
|
|
|
131
|
-
|
|
180
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
132
181
|
|
|
133
182
|
patch_array: np.ndarray = np.concatenate(all_patches)
|
|
134
183
|
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
135
184
|
|
|
136
|
-
return
|
|
185
|
+
return PatchedOutput(
|
|
186
|
+
patch_array, None, Stats(image_means, image_stds), Stats((), ())
|
|
187
|
+
)
|
|
137
188
|
|
|
138
189
|
|
|
139
190
|
# called on arrays by in memory dataset
|
|
@@ -142,7 +193,7 @@ def prepare_patches_supervised_array(
|
|
|
142
193
|
axes: str,
|
|
143
194
|
data_target: np.ndarray,
|
|
144
195
|
patch_size: Union[List[int], Tuple[int]],
|
|
145
|
-
) ->
|
|
196
|
+
) -> PatchedOutput:
|
|
146
197
|
"""Iterate over data source and create an array of patches.
|
|
147
198
|
|
|
148
199
|
This method expects an array of shape SC(Z)YX, where S and C can be singleton
|
|
@@ -150,19 +201,30 @@ def prepare_patches_supervised_array(
|
|
|
150
201
|
|
|
151
202
|
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
152
203
|
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
data : np.ndarray
|
|
207
|
+
Input data array.
|
|
208
|
+
axes : str
|
|
209
|
+
Axes of the data.
|
|
210
|
+
data_target : np.ndarray
|
|
211
|
+
Target data array.
|
|
212
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
213
|
+
Size of the patches.
|
|
214
|
+
|
|
153
215
|
Returns
|
|
154
216
|
-------
|
|
155
|
-
np.ndarray
|
|
156
|
-
|
|
217
|
+
Tuple[np.ndarray, np.ndarray, float, float]
|
|
218
|
+
Source and target patches, mean and standard deviation.
|
|
157
219
|
"""
|
|
158
|
-
# compute statistics
|
|
159
|
-
mean = data.mean()
|
|
160
|
-
std = data.std()
|
|
161
|
-
|
|
162
220
|
# reshape array
|
|
163
221
|
reshaped_sample = reshape_array(data, axes)
|
|
164
222
|
reshaped_target = reshape_array(data_target, axes)
|
|
165
223
|
|
|
224
|
+
# compute statistics
|
|
225
|
+
image_means, image_stds = compute_normalization_stats(reshaped_sample)
|
|
226
|
+
target_means, target_stds = compute_normalization_stats(reshaped_target)
|
|
227
|
+
|
|
166
228
|
# generate patches, return a generator
|
|
167
229
|
patches, patch_targets = extract_patches_sequential(
|
|
168
230
|
reshaped_sample, patch_size=patch_size, target=reshaped_target
|
|
@@ -173,11 +235,11 @@ def prepare_patches_supervised_array(
|
|
|
173
235
|
|
|
174
236
|
logger.info(f"Extracted {patches.shape[0]} patches from input array.")
|
|
175
237
|
|
|
176
|
-
return (
|
|
238
|
+
return PatchedOutput(
|
|
177
239
|
patches,
|
|
178
240
|
patch_targets,
|
|
179
|
-
|
|
180
|
-
|
|
241
|
+
Stats(image_means, image_stds),
|
|
242
|
+
Stats(target_means, target_stds),
|
|
181
243
|
)
|
|
182
244
|
|
|
183
245
|
|
|
@@ -186,7 +248,7 @@ def prepare_patches_unsupervised_array(
|
|
|
186
248
|
data: np.ndarray,
|
|
187
249
|
axes: str,
|
|
188
250
|
patch_size: Union[List[int], Tuple[int]],
|
|
189
|
-
) ->
|
|
251
|
+
) -> PatchedOutput:
|
|
190
252
|
"""
|
|
191
253
|
Iterate over data source and create an array of patches.
|
|
192
254
|
|
|
@@ -195,19 +257,27 @@ def prepare_patches_unsupervised_array(
|
|
|
195
257
|
|
|
196
258
|
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
197
259
|
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
data : np.ndarray
|
|
263
|
+
Input data array.
|
|
264
|
+
axes : str
|
|
265
|
+
Axes of the data.
|
|
266
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
267
|
+
Size of the patches.
|
|
268
|
+
|
|
198
269
|
Returns
|
|
199
270
|
-------
|
|
200
|
-
np.ndarray
|
|
201
|
-
|
|
271
|
+
Tuple[np.ndarray, None, float, float]
|
|
272
|
+
Source patches, mean and standard deviation.
|
|
202
273
|
"""
|
|
203
|
-
# calculate mean and std
|
|
204
|
-
mean = data.mean()
|
|
205
|
-
std = data.std()
|
|
206
|
-
|
|
207
274
|
# reshape array
|
|
208
275
|
reshaped_sample = reshape_array(data, axes)
|
|
209
276
|
|
|
277
|
+
# calculate mean and std
|
|
278
|
+
means, stds = compute_normalization_stats(reshaped_sample)
|
|
279
|
+
|
|
210
280
|
# generate patches, return a generator
|
|
211
281
|
patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
|
|
212
282
|
|
|
213
|
-
return patches,
|
|
283
|
+
return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Random patching utilities."""
|
|
2
|
+
|
|
1
3
|
from typing import Generator, List, Optional, Tuple, Union
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -11,6 +13,7 @@ def extract_patches_random(
|
|
|
11
13
|
arr: np.ndarray,
|
|
12
14
|
patch_size: Union[List[int], Tuple[int, ...]],
|
|
13
15
|
target: Optional[np.ndarray] = None,
|
|
16
|
+
seed: Optional[int] = None,
|
|
14
17
|
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
15
18
|
"""
|
|
16
19
|
Generate patches from an array in a random manner.
|
|
@@ -30,12 +33,18 @@ def extract_patches_random(
|
|
|
30
33
|
Input image array.
|
|
31
34
|
patch_size : Tuple[int]
|
|
32
35
|
Patch sizes in each dimension.
|
|
36
|
+
target : Optional[np.ndarray], optional
|
|
37
|
+
Target array, by default None.
|
|
38
|
+
seed : Optional[int], optional
|
|
39
|
+
Random seed, by default None.
|
|
33
40
|
|
|
34
41
|
Yields
|
|
35
42
|
------
|
|
36
43
|
Generator[np.ndarray, None, None]
|
|
37
44
|
Generator of patches.
|
|
38
45
|
"""
|
|
46
|
+
rng = np.random.default_rng(seed=seed)
|
|
47
|
+
|
|
39
48
|
is_3d_patch = len(patch_size) == 3
|
|
40
49
|
|
|
41
50
|
# patches sanity check
|
|
@@ -44,9 +53,6 @@ def extract_patches_random(
|
|
|
44
53
|
# Update patch size to encompass S and C dimensions
|
|
45
54
|
patch_size = [1, arr.shape[1], *patch_size]
|
|
46
55
|
|
|
47
|
-
# random generator
|
|
48
|
-
rng = np.random.default_rng()
|
|
49
|
-
|
|
50
56
|
# iterate over the number of samples (S or T)
|
|
51
57
|
for sample_idx in range(arr.shape[0]):
|
|
52
58
|
# get sample array
|
|
@@ -109,6 +115,7 @@ def extract_patches_random_from_chunks(
|
|
|
109
115
|
patch_size: Union[List[int], Tuple[int, ...]],
|
|
110
116
|
chunk_size: Union[List[int], Tuple[int, ...]],
|
|
111
117
|
chunk_limit: Optional[int] = None,
|
|
118
|
+
seed: Optional[int] = None,
|
|
112
119
|
) -> Generator[np.ndarray, None, None]:
|
|
113
120
|
"""
|
|
114
121
|
Generate patches from an array in a random manner.
|
|
@@ -120,10 +127,14 @@ def extract_patches_random_from_chunks(
|
|
|
120
127
|
----------
|
|
121
128
|
arr : np.ndarray
|
|
122
129
|
Input image array.
|
|
123
|
-
patch_size : Tuple[int]
|
|
130
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
124
131
|
Patch sizes in each dimension.
|
|
125
|
-
chunk_size : Tuple[int]
|
|
132
|
+
chunk_size : Union[List[int], Tuple[int, ...]]
|
|
126
133
|
Chunk sizes to load from the.
|
|
134
|
+
chunk_limit : Optional[int], optional
|
|
135
|
+
Number of chunks to load, by default None.
|
|
136
|
+
seed : Optional[int], optional
|
|
137
|
+
Random seed, by default None.
|
|
127
138
|
|
|
128
139
|
Yields
|
|
129
140
|
------
|
|
@@ -135,7 +146,7 @@ def extract_patches_random_from_chunks(
|
|
|
135
146
|
# Patches sanity check
|
|
136
147
|
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
137
148
|
|
|
138
|
-
rng = np.random.default_rng()
|
|
149
|
+
rng = np.random.default_rng(seed=seed)
|
|
139
150
|
num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
|
|
140
151
|
|
|
141
152
|
# Iterate over num chunks in the array
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Sequential patching functions."""
|
|
2
|
+
|
|
1
3
|
from typing import List, Optional, Tuple, Union
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -14,14 +16,14 @@ def _compute_number_of_patches(
|
|
|
14
16
|
|
|
15
17
|
Parameters
|
|
16
18
|
----------
|
|
17
|
-
|
|
19
|
+
arr_shape : Tuple[int, ...]
|
|
18
20
|
Shape of the input array.
|
|
19
|
-
patch_sizes : Tuple[int]
|
|
21
|
+
patch_sizes : Union[List[int], Tuple[int, ...]
|
|
20
22
|
Shape of the patches.
|
|
21
23
|
|
|
22
24
|
Returns
|
|
23
25
|
-------
|
|
24
|
-
Tuple[int]
|
|
26
|
+
Tuple[int, ...]
|
|
25
27
|
Number of patches in each dimension.
|
|
26
28
|
"""
|
|
27
29
|
if len(arr_shape) != len(patch_sizes):
|
|
@@ -55,14 +57,14 @@ def _compute_overlap(
|
|
|
55
57
|
|
|
56
58
|
Parameters
|
|
57
59
|
----------
|
|
58
|
-
|
|
60
|
+
arr_shape : Tuple[int, ...]
|
|
59
61
|
Input array shape.
|
|
60
|
-
patch_sizes : Tuple[int]
|
|
62
|
+
patch_sizes : Union[List[int], Tuple[int, ...]]
|
|
61
63
|
Size of the patches.
|
|
62
64
|
|
|
63
65
|
Returns
|
|
64
66
|
-------
|
|
65
|
-
Tuple[int]
|
|
67
|
+
Tuple[int, ...]
|
|
66
68
|
Overlap between patches in each dimension.
|
|
67
69
|
"""
|
|
68
70
|
n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
|
|
@@ -123,6 +125,8 @@ def _compute_patch_views(
|
|
|
123
125
|
Steps between views.
|
|
124
126
|
output_shape : Tuple[int]
|
|
125
127
|
Shape of the output array.
|
|
128
|
+
target : Optional[np.ndarray], optional
|
|
129
|
+
Target array, by default None.
|
|
126
130
|
|
|
127
131
|
Returns
|
|
128
132
|
-------
|
|
@@ -161,11 +165,13 @@ def extract_patches_sequential(
|
|
|
161
165
|
Input image array.
|
|
162
166
|
patch_size : Tuple[int]
|
|
163
167
|
Patch sizes in each dimension.
|
|
168
|
+
target : Optional[np.ndarray], optional
|
|
169
|
+
Target array, by default None.
|
|
164
170
|
|
|
165
171
|
Returns
|
|
166
172
|
-------
|
|
167
|
-
|
|
168
|
-
|
|
173
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
174
|
+
Patches.
|
|
169
175
|
"""
|
|
170
176
|
is_3d_patch = len(patch_size) == 3
|
|
171
177
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Patch validation functions."""
|
|
2
|
+
|
|
1
3
|
from typing import List, Tuple, Union
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -43,18 +45,20 @@ def validate_patch_dimensions(
|
|
|
43
45
|
if len(patch_size) != len(arr.shape[2:]):
|
|
44
46
|
raise ValueError(
|
|
45
47
|
f"There must be a patch size for each spatial dimensions "
|
|
46
|
-
f"(got {patch_size} patches for dims {arr.shape})."
|
|
48
|
+
f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
|
|
47
49
|
)
|
|
48
50
|
|
|
49
51
|
# Sanity checks on patch sizes versus array dimension
|
|
50
52
|
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
51
53
|
raise ValueError(
|
|
52
54
|
f"Z patch size is inconsistent with image shape "
|
|
53
|
-
f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
|
|
55
|
+
f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
|
|
56
|
+
f"order."
|
|
54
57
|
)
|
|
55
58
|
|
|
56
59
|
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
57
60
|
raise ValueError(
|
|
58
61
|
f"At least one of YX patch dimensions is larger than the corresponding "
|
|
59
|
-
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
|
|
62
|
+
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
|
|
63
|
+
f"Check the axes order."
|
|
60
64
|
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Collate function for tiling."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch.utils.data.dataloader import default_collate
|
|
7
|
+
|
|
8
|
+
from careamics.config.tile_information import TileInformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Collate tiles received from CAREamics prediction dataloader.
|
|
14
|
+
|
|
15
|
+
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
16
|
+
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
17
|
+
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
18
|
+
stitch coordinates.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
23
|
+
Batch of tiles.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Any
|
|
28
|
+
Collated batch.
|
|
29
|
+
"""
|
|
30
|
+
new_batch = [tile for tile, _ in batch]
|
|
31
|
+
tiles_batch = [tile_info for _, tile_info in batch]
|
|
32
|
+
|
|
33
|
+
return default_collate(new_batch), tiles_batch
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Tiled patching utilities."""
|
|
2
|
+
|
|
1
3
|
import itertools
|
|
2
4
|
from typing import Generator, List, Tuple, Union
|
|
3
5
|
|
|
@@ -8,7 +10,7 @@ from careamics.config.tile_information import TileInformation
|
|
|
8
10
|
|
|
9
11
|
def _compute_crop_and_stitch_coords_1d(
|
|
10
12
|
axis_size: int, tile_size: int, overlap: int
|
|
11
|
-
) -> Tuple[List[Tuple[int,
|
|
13
|
+
) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]:
|
|
12
14
|
"""
|
|
13
15
|
Compute the coordinates of each tile along an axis, given the overlap.
|
|
14
16
|
|
|
@@ -82,15 +84,15 @@ def extract_tiles(
|
|
|
82
84
|
tile_size: Union[List[int], Tuple[int, ...]],
|
|
83
85
|
overlaps: Union[List[int], Tuple[int, ...]],
|
|
84
86
|
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
85
|
-
"""
|
|
86
|
-
Generate tiles from the input array with specified overlap.
|
|
87
|
+
"""Generate tiles from the input array with specified overlap.
|
|
87
88
|
|
|
88
89
|
The tiles cover the whole array. The method returns a generator that yields
|
|
89
90
|
tuples of array and tile information, the latter includes whether
|
|
90
91
|
the tile is the last one, the coordinates of the overlap crop, and the coordinates
|
|
91
92
|
of the stitched tile.
|
|
92
93
|
|
|
93
|
-
|
|
94
|
+
Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
|
|
95
|
+
where C can be a singleton.
|
|
94
96
|
|
|
95
97
|
Parameters
|
|
96
98
|
----------
|
|
@@ -153,10 +155,10 @@ def extract_tiles(
|
|
|
153
155
|
# create tile information
|
|
154
156
|
tile_info = TileInformation(
|
|
155
157
|
array_shape=sample.squeeze().shape,
|
|
156
|
-
tiled=True,
|
|
157
158
|
last_tile=last_tile,
|
|
158
159
|
overlap_crop_coords=overlap_crop_coords,
|
|
159
160
|
stitch_coords=stitch_coords,
|
|
161
|
+
sample_id=sample_idx,
|
|
160
162
|
)
|
|
161
163
|
|
|
162
164
|
yield tile, tile_info
|