careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Patching functions."""
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Callable, List, Tuple, Union
|
|
5
6
|
|
|
@@ -7,11 +8,38 @@ import numpy as np
|
|
|
7
8
|
|
|
8
9
|
from ...utils.logging import get_logger
|
|
9
10
|
from ..dataset_utils import reshape_array
|
|
11
|
+
from ..dataset_utils.running_stats import compute_normalization_stats
|
|
10
12
|
from .sequential_patching import extract_patches_sequential
|
|
11
13
|
|
|
12
14
|
logger = get_logger(__name__)
|
|
13
15
|
|
|
14
16
|
|
|
17
|
+
@dataclass
|
|
18
|
+
class Stats:
|
|
19
|
+
"""Dataclass to store statistics."""
|
|
20
|
+
|
|
21
|
+
means: Union[np.ndarray, tuple, list, None]
|
|
22
|
+
stds: Union[np.ndarray, tuple, list, None]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class PatchedOutput:
|
|
27
|
+
"""Dataclass to store patches and statistics."""
|
|
28
|
+
|
|
29
|
+
patches: Union[np.ndarray]
|
|
30
|
+
targets: Union[np.ndarray, None]
|
|
31
|
+
image_stats: Stats
|
|
32
|
+
target_stats: Stats
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class StatsOutput:
|
|
37
|
+
"""Dataclass to store patches and statistics."""
|
|
38
|
+
|
|
39
|
+
image_stats: Stats
|
|
40
|
+
target_stats: Stats
|
|
41
|
+
|
|
42
|
+
|
|
15
43
|
# called by in memory dataset
|
|
16
44
|
def prepare_patches_supervised(
|
|
17
45
|
train_files: List[Path],
|
|
@@ -19,10 +47,12 @@ def prepare_patches_supervised(
|
|
|
19
47
|
axes: str,
|
|
20
48
|
patch_size: Union[List[int], Tuple[int, ...]],
|
|
21
49
|
read_source_func: Callable,
|
|
22
|
-
) ->
|
|
50
|
+
) -> PatchedOutput:
|
|
23
51
|
"""
|
|
24
52
|
Iterate over data source and create an array of patches and corresponding targets.
|
|
25
53
|
|
|
54
|
+
The lists of Paths should be pre-sorted.
|
|
55
|
+
|
|
26
56
|
Parameters
|
|
27
57
|
----------
|
|
28
58
|
train_files : List[Path]
|
|
@@ -41,9 +71,6 @@ def prepare_patches_supervised(
|
|
|
41
71
|
np.ndarray
|
|
42
72
|
Array of patches.
|
|
43
73
|
"""
|
|
44
|
-
train_files.sort()
|
|
45
|
-
target_files.sort()
|
|
46
|
-
|
|
47
74
|
means, stds, num_samples = 0, 0, 0
|
|
48
75
|
all_patches, all_targets = [], []
|
|
49
76
|
for train_filename, target_filename in zip(train_files, target_files):
|
|
@@ -83,17 +110,18 @@ def prepare_patches_supervised(
|
|
|
83
110
|
f"{target_files}."
|
|
84
111
|
)
|
|
85
112
|
|
|
86
|
-
|
|
113
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
114
|
+
target_means, target_stds = compute_normalization_stats(np.concatenate(all_targets))
|
|
87
115
|
|
|
88
116
|
patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
|
|
89
117
|
target_array: np.ndarray = np.concatenate(all_targets, axis=0)
|
|
90
118
|
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
91
119
|
|
|
92
|
-
return (
|
|
120
|
+
return PatchedOutput(
|
|
93
121
|
patch_array,
|
|
94
122
|
target_array,
|
|
95
|
-
|
|
96
|
-
|
|
123
|
+
Stats(image_means, image_stds),
|
|
124
|
+
Stats(target_means, target_stds),
|
|
97
125
|
)
|
|
98
126
|
|
|
99
127
|
|
|
@@ -103,7 +131,7 @@ def prepare_patches_unsupervised(
|
|
|
103
131
|
axes: str,
|
|
104
132
|
patch_size: Union[List[int], Tuple[int]],
|
|
105
133
|
read_source_func: Callable,
|
|
106
|
-
) ->
|
|
134
|
+
) -> PatchedOutput:
|
|
107
135
|
"""Iterate over data source and create an array of patches.
|
|
108
136
|
|
|
109
137
|
This method returns the mean and standard deviation of the image.
|
|
@@ -149,12 +177,14 @@ def prepare_patches_unsupervised(
|
|
|
149
177
|
if num_samples == 0:
|
|
150
178
|
raise ValueError(f"No valid samples found in the input data: {train_files}.")
|
|
151
179
|
|
|
152
|
-
|
|
180
|
+
image_means, image_stds = compute_normalization_stats(np.concatenate(all_patches))
|
|
153
181
|
|
|
154
182
|
patch_array: np.ndarray = np.concatenate(all_patches)
|
|
155
183
|
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
156
184
|
|
|
157
|
-
return
|
|
185
|
+
return PatchedOutput(
|
|
186
|
+
patch_array, None, Stats(image_means, image_stds), Stats((), ())
|
|
187
|
+
)
|
|
158
188
|
|
|
159
189
|
|
|
160
190
|
# called on arrays by in memory dataset
|
|
@@ -163,7 +193,7 @@ def prepare_patches_supervised_array(
|
|
|
163
193
|
axes: str,
|
|
164
194
|
data_target: np.ndarray,
|
|
165
195
|
patch_size: Union[List[int], Tuple[int]],
|
|
166
|
-
) ->
|
|
196
|
+
) -> PatchedOutput:
|
|
167
197
|
"""Iterate over data source and create an array of patches.
|
|
168
198
|
|
|
169
199
|
This method expects an array of shape SC(Z)YX, where S and C can be singleton
|
|
@@ -187,14 +217,14 @@ def prepare_patches_supervised_array(
|
|
|
187
217
|
Tuple[np.ndarray, np.ndarray, float, float]
|
|
188
218
|
Source and target patches, mean and standard deviation.
|
|
189
219
|
"""
|
|
190
|
-
# compute statistics
|
|
191
|
-
mean = data.mean()
|
|
192
|
-
std = data.std()
|
|
193
|
-
|
|
194
220
|
# reshape array
|
|
195
221
|
reshaped_sample = reshape_array(data, axes)
|
|
196
222
|
reshaped_target = reshape_array(data_target, axes)
|
|
197
223
|
|
|
224
|
+
# compute statistics
|
|
225
|
+
image_means, image_stds = compute_normalization_stats(reshaped_sample)
|
|
226
|
+
target_means, target_stds = compute_normalization_stats(reshaped_target)
|
|
227
|
+
|
|
198
228
|
# generate patches, return a generator
|
|
199
229
|
patches, patch_targets = extract_patches_sequential(
|
|
200
230
|
reshaped_sample, patch_size=patch_size, target=reshaped_target
|
|
@@ -205,11 +235,11 @@ def prepare_patches_supervised_array(
|
|
|
205
235
|
|
|
206
236
|
logger.info(f"Extracted {patches.shape[0]} patches from input array.")
|
|
207
237
|
|
|
208
|
-
return (
|
|
238
|
+
return PatchedOutput(
|
|
209
239
|
patches,
|
|
210
240
|
patch_targets,
|
|
211
|
-
|
|
212
|
-
|
|
241
|
+
Stats(image_means, image_stds),
|
|
242
|
+
Stats(target_means, target_stds),
|
|
213
243
|
)
|
|
214
244
|
|
|
215
245
|
|
|
@@ -218,7 +248,7 @@ def prepare_patches_unsupervised_array(
|
|
|
218
248
|
data: np.ndarray,
|
|
219
249
|
axes: str,
|
|
220
250
|
patch_size: Union[List[int], Tuple[int]],
|
|
221
|
-
) ->
|
|
251
|
+
) -> PatchedOutput:
|
|
222
252
|
"""
|
|
223
253
|
Iterate over data source and create an array of patches.
|
|
224
254
|
|
|
@@ -241,14 +271,13 @@ def prepare_patches_unsupervised_array(
|
|
|
241
271
|
Tuple[np.ndarray, None, float, float]
|
|
242
272
|
Source patches, mean and standard deviation.
|
|
243
273
|
"""
|
|
244
|
-
# calculate mean and std
|
|
245
|
-
mean = data.mean()
|
|
246
|
-
std = data.std()
|
|
247
|
-
|
|
248
274
|
# reshape array
|
|
249
275
|
reshaped_sample = reshape_array(data, axes)
|
|
250
276
|
|
|
277
|
+
# calculate mean and std
|
|
278
|
+
means, stds = compute_normalization_stats(reshaped_sample)
|
|
279
|
+
|
|
251
280
|
# generate patches, return a generator
|
|
252
281
|
patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
|
|
253
282
|
|
|
254
|
-
return patches,
|
|
283
|
+
return PatchedOutput(patches, None, Stats(means, stds), Stats((), ()))
|
|
@@ -13,6 +13,7 @@ def extract_patches_random(
|
|
|
13
13
|
arr: np.ndarray,
|
|
14
14
|
patch_size: Union[List[int], Tuple[int, ...]],
|
|
15
15
|
target: Optional[np.ndarray] = None,
|
|
16
|
+
seed: Optional[int] = None,
|
|
16
17
|
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
17
18
|
"""
|
|
18
19
|
Generate patches from an array in a random manner.
|
|
@@ -34,12 +35,16 @@ def extract_patches_random(
|
|
|
34
35
|
Patch sizes in each dimension.
|
|
35
36
|
target : Optional[np.ndarray], optional
|
|
36
37
|
Target array, by default None.
|
|
38
|
+
seed : Optional[int], optional
|
|
39
|
+
Random seed, by default None.
|
|
37
40
|
|
|
38
41
|
Yields
|
|
39
42
|
------
|
|
40
43
|
Generator[np.ndarray, None, None]
|
|
41
44
|
Generator of patches.
|
|
42
45
|
"""
|
|
46
|
+
rng = np.random.default_rng(seed=seed)
|
|
47
|
+
|
|
43
48
|
is_3d_patch = len(patch_size) == 3
|
|
44
49
|
|
|
45
50
|
# patches sanity check
|
|
@@ -48,9 +53,6 @@ def extract_patches_random(
|
|
|
48
53
|
# Update patch size to encompass S and C dimensions
|
|
49
54
|
patch_size = [1, arr.shape[1], *patch_size]
|
|
50
55
|
|
|
51
|
-
# random generator
|
|
52
|
-
rng = np.random.default_rng()
|
|
53
|
-
|
|
54
56
|
# iterate over the number of samples (S or T)
|
|
55
57
|
for sample_idx in range(arr.shape[0]):
|
|
56
58
|
# get sample array
|
|
@@ -113,6 +115,7 @@ def extract_patches_random_from_chunks(
|
|
|
113
115
|
patch_size: Union[List[int], Tuple[int, ...]],
|
|
114
116
|
chunk_size: Union[List[int], Tuple[int, ...]],
|
|
115
117
|
chunk_limit: Optional[int] = None,
|
|
118
|
+
seed: Optional[int] = None,
|
|
116
119
|
) -> Generator[np.ndarray, None, None]:
|
|
117
120
|
"""
|
|
118
121
|
Generate patches from an array in a random manner.
|
|
@@ -130,6 +133,8 @@ def extract_patches_random_from_chunks(
|
|
|
130
133
|
Chunk sizes to load from the.
|
|
131
134
|
chunk_limit : Optional[int], optional
|
|
132
135
|
Number of chunks to load, by default None.
|
|
136
|
+
seed : Optional[int], optional
|
|
137
|
+
Random seed, by default None.
|
|
133
138
|
|
|
134
139
|
Yields
|
|
135
140
|
------
|
|
@@ -141,7 +146,7 @@ def extract_patches_random_from_chunks(
|
|
|
141
146
|
# Patches sanity check
|
|
142
147
|
validate_patch_dimensions(arr, patch_size, is_3d_patch)
|
|
143
148
|
|
|
144
|
-
rng = np.random.default_rng()
|
|
149
|
+
rng = np.random.default_rng(seed=seed)
|
|
145
150
|
num_chunks = chunk_limit if chunk_limit else np.prod(arr._cdata_shape)
|
|
146
151
|
|
|
147
152
|
# Iterate over num chunks in the array
|
|
@@ -45,18 +45,20 @@ def validate_patch_dimensions(
|
|
|
45
45
|
if len(patch_size) != len(arr.shape[2:]):
|
|
46
46
|
raise ValueError(
|
|
47
47
|
f"There must be a patch size for each spatial dimensions "
|
|
48
|
-
f"(got {patch_size} patches for dims {arr.shape})."
|
|
48
|
+
f"(got {patch_size} patches for dims {arr.shape}). Check the axes order."
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
# Sanity checks on patch sizes versus array dimension
|
|
52
52
|
if is_3d_patch and patch_size[0] > arr.shape[-3]:
|
|
53
53
|
raise ValueError(
|
|
54
54
|
f"Z patch size is inconsistent with image shape "
|
|
55
|
-
f"(got {patch_size[0]} patches for dim {arr.shape[1]})."
|
|
55
|
+
f"(got {patch_size[0]} patches for dim {arr.shape[1]}). Check the axes "
|
|
56
|
+
f"order."
|
|
56
57
|
)
|
|
57
58
|
|
|
58
59
|
if patch_size[-2] > arr.shape[-2] or patch_size[-1] > arr.shape[-1]:
|
|
59
60
|
raise ValueError(
|
|
60
61
|
f"At least one of YX patch dimensions is larger than the corresponding "
|
|
61
|
-
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]})."
|
|
62
|
+
f"image dimension (got {patch_size} patches for dims {arr.shape[-2:]}). "
|
|
63
|
+
f"Check the axes order."
|
|
62
64
|
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Collate function for tiling."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torch.utils.data.dataloader import default_collate
|
|
7
|
+
|
|
8
|
+
from careamics.config.tile_information import TileInformation
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
12
|
+
"""
|
|
13
|
+
Collate tiles received from CAREamics prediction dataloader.
|
|
14
|
+
|
|
15
|
+
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
16
|
+
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
17
|
+
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
18
|
+
stitch coordinates.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
23
|
+
Batch of tiles.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
Any
|
|
28
|
+
Collated batch.
|
|
29
|
+
"""
|
|
30
|
+
new_batch = [tile for tile, _ in batch]
|
|
31
|
+
tiles_batch = [tile_info for _, tile_info in batch]
|
|
32
|
+
|
|
33
|
+
return default_collate(new_batch), tiles_batch
|
|
@@ -84,15 +84,15 @@ def extract_tiles(
|
|
|
84
84
|
tile_size: Union[List[int], Tuple[int, ...]],
|
|
85
85
|
overlaps: Union[List[int], Tuple[int, ...]],
|
|
86
86
|
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
87
|
-
"""
|
|
88
|
-
Generate tiles from the input array with specified overlap.
|
|
87
|
+
"""Generate tiles from the input array with specified overlap.
|
|
89
88
|
|
|
90
89
|
The tiles cover the whole array. The method returns a generator that yields
|
|
91
90
|
tuples of array and tile information, the latter includes whether
|
|
92
91
|
the tile is the last one, the coordinates of the overlap crop, and the coordinates
|
|
93
92
|
of the stitched tile.
|
|
94
93
|
|
|
95
|
-
|
|
94
|
+
Input array should have shape SC(Z)YX, while the returned tiles have shape C(Z)YX,
|
|
95
|
+
where C can be a singleton.
|
|
96
96
|
|
|
97
97
|
Parameters
|
|
98
98
|
----------
|
|
@@ -155,10 +155,10 @@ def extract_tiles(
|
|
|
155
155
|
# create tile information
|
|
156
156
|
tile_info = TileInformation(
|
|
157
157
|
array_shape=sample.squeeze().shape,
|
|
158
|
-
tiled=True,
|
|
159
158
|
last_tile=last_tile,
|
|
160
159
|
overlap_crop_coords=overlap_crop_coords,
|
|
161
160
|
stitch_coords=stitch_coords,
|
|
161
|
+
sample_id=sample_idx,
|
|
162
162
|
)
|
|
163
163
|
|
|
164
164
|
yield tile, tile_info
|
|
@@ -583,12 +583,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
|
|
|
583
583
|
>>> my_array = np.arange(256).reshape(16, 16)
|
|
584
584
|
>>> my_transforms = [
|
|
585
585
|
... {
|
|
586
|
-
... "name": SupportedTransform.
|
|
587
|
-
... "mean": 0,
|
|
588
|
-
... "std": 1,
|
|
589
|
-
... },
|
|
590
|
-
... {
|
|
591
|
-
... "name": SupportedTransform.N2V_MANIPULATE.value,
|
|
586
|
+
... "name": SupportedTransform.XY_FLIP.value,
|
|
592
587
|
... }
|
|
593
588
|
... ]
|
|
594
589
|
>>> data_module = TrainingDataWrapper(
|
careamics/lightning_module.py
CHANGED
|
@@ -148,13 +148,17 @@ class CAREamicsModule(L.LightningModule):
|
|
|
148
148
|
Any
|
|
149
149
|
Model output.
|
|
150
150
|
"""
|
|
151
|
-
|
|
151
|
+
if self._trainer.datamodule.tiled:
|
|
152
|
+
x, *aux = batch
|
|
153
|
+
else:
|
|
154
|
+
x = batch
|
|
155
|
+
aux = []
|
|
152
156
|
|
|
153
157
|
# apply test-time augmentation if available
|
|
154
158
|
# TODO: probably wont work with batch size > 1
|
|
155
159
|
if self._trainer.datamodule.prediction_config.tta_transforms:
|
|
156
160
|
tta = ImageRestorationTTA()
|
|
157
|
-
augmented_batch = tta.forward(
|
|
161
|
+
augmented_batch = tta.forward(x) # list of augmented tensors
|
|
158
162
|
augmented_output = []
|
|
159
163
|
for augmented in augmented_batch:
|
|
160
164
|
augmented_pred = self.model(augmented)
|
|
@@ -165,13 +169,13 @@ class CAREamicsModule(L.LightningModule):
|
|
|
165
169
|
|
|
166
170
|
# Denormalize the output
|
|
167
171
|
denorm = Denormalize(
|
|
168
|
-
|
|
169
|
-
|
|
172
|
+
image_means=self._trainer.datamodule.predict_dataset.image_means,
|
|
173
|
+
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
|
|
170
174
|
)
|
|
171
|
-
denormalized_output
|
|
175
|
+
denormalized_output = denorm(patch=output.cpu().numpy())
|
|
172
176
|
|
|
173
|
-
if len(aux) > 0:
|
|
174
|
-
return denormalized_output, aux
|
|
177
|
+
if len(aux) > 0: # aux can be tiling information
|
|
178
|
+
return denormalized_output, *aux
|
|
175
179
|
else:
|
|
176
180
|
return denormalized_output
|
|
177
181
|
|
|
@@ -1,68 +1,37 @@
|
|
|
1
1
|
"""Prediction Lightning data modules."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Callable, Dict,
|
|
4
|
+
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pytorch_lightning as L
|
|
8
8
|
from torch.utils.data import DataLoader
|
|
9
|
-
from torch.utils.data.dataloader import default_collate
|
|
10
9
|
|
|
11
10
|
from careamics.config import InferenceConfig
|
|
12
11
|
from careamics.config.support import SupportedData
|
|
13
|
-
from careamics.
|
|
12
|
+
from careamics.dataset import (
|
|
13
|
+
InMemoryPredDataset,
|
|
14
|
+
InMemoryTiledPredDataset,
|
|
15
|
+
IterablePredDataset,
|
|
16
|
+
IterableTiledPredDataset,
|
|
17
|
+
)
|
|
14
18
|
from careamics.dataset.dataset_utils import (
|
|
15
19
|
get_read_func,
|
|
16
20
|
list_files,
|
|
17
21
|
)
|
|
18
|
-
from careamics.dataset.
|
|
19
|
-
InMemoryPredictionDataset,
|
|
20
|
-
)
|
|
21
|
-
from careamics.dataset.iterable_dataset import (
|
|
22
|
-
IterablePredictionDataset,
|
|
23
|
-
)
|
|
22
|
+
from careamics.dataset.tiling.collate_tiles import collate_tiles
|
|
24
23
|
from careamics.utils import get_logger
|
|
25
24
|
|
|
26
|
-
PredictDatasetType = Union[
|
|
25
|
+
PredictDatasetType = Union[
|
|
26
|
+
InMemoryPredDataset,
|
|
27
|
+
InMemoryTiledPredDataset,
|
|
28
|
+
IterablePredDataset,
|
|
29
|
+
IterableTiledPredDataset,
|
|
30
|
+
]
|
|
27
31
|
|
|
28
32
|
logger = get_logger(__name__)
|
|
29
33
|
|
|
30
34
|
|
|
31
|
-
def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
|
|
32
|
-
"""
|
|
33
|
-
Collate tiles received from CAREamics prediction dataloader.
|
|
34
|
-
|
|
35
|
-
CAREamics prediction dataloader returns tuples of arrays and TileInformation. In
|
|
36
|
-
case of non-tiled data, this function will return the arrays. In case of tiled data,
|
|
37
|
-
it will return the arrays, the last tile flag, the overlap crop coordinates and the
|
|
38
|
-
stitch coordinates.
|
|
39
|
-
|
|
40
|
-
Parameters
|
|
41
|
-
----------
|
|
42
|
-
batch : List[Tuple[np.ndarray, TileInformation], ...]
|
|
43
|
-
Batch of tiles.
|
|
44
|
-
|
|
45
|
-
Returns
|
|
46
|
-
-------
|
|
47
|
-
Any
|
|
48
|
-
Collated batch.
|
|
49
|
-
"""
|
|
50
|
-
first_tile_info: TileInformation = batch[0][1]
|
|
51
|
-
# if not tiled, then return arrays
|
|
52
|
-
if not first_tile_info.tiled:
|
|
53
|
-
arrays, _ = zip(*batch)
|
|
54
|
-
|
|
55
|
-
return default_collate(arrays)
|
|
56
|
-
# else we explicit the last_tile flag and coordinates
|
|
57
|
-
else:
|
|
58
|
-
new_batch = [
|
|
59
|
-
(tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords)
|
|
60
|
-
for tile, t in batch
|
|
61
|
-
]
|
|
62
|
-
|
|
63
|
-
return default_collate(new_batch)
|
|
64
|
-
|
|
65
|
-
|
|
66
35
|
class CAREamicsPredictData(L.LightningDataModule):
|
|
67
36
|
"""
|
|
68
37
|
CAREamics Lightning prediction data module.
|
|
@@ -182,6 +151,9 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
182
151
|
self.tile_size = pred_config.tile_size
|
|
183
152
|
self.tile_overlap = pred_config.tile_overlap
|
|
184
153
|
|
|
154
|
+
# check if it is tiled
|
|
155
|
+
self.tiled = self.tile_size is not None and self.tile_overlap is not None
|
|
156
|
+
|
|
185
157
|
# read source function
|
|
186
158
|
if pred_config.data_type == SupportedData.CUSTOM:
|
|
187
159
|
# mypy check
|
|
@@ -212,17 +184,29 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
212
184
|
"""
|
|
213
185
|
# if numpy array
|
|
214
186
|
if self.data_type == SupportedData.ARRAY:
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
187
|
+
if self.tiled:
|
|
188
|
+
self.predict_dataset: PredictDatasetType = InMemoryTiledPredDataset(
|
|
189
|
+
prediction_config=self.prediction_config,
|
|
190
|
+
inputs=self.pred_data,
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
self.predict_dataset = InMemoryPredDataset(
|
|
194
|
+
prediction_config=self.prediction_config,
|
|
195
|
+
inputs=self.pred_data,
|
|
196
|
+
)
|
|
220
197
|
else:
|
|
221
|
-
self.
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
198
|
+
if self.tiled:
|
|
199
|
+
self.predict_dataset = IterableTiledPredDataset(
|
|
200
|
+
prediction_config=self.prediction_config,
|
|
201
|
+
src_files=self.pred_files,
|
|
202
|
+
read_source_func=self.read_source_func,
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
self.predict_dataset = IterablePredDataset(
|
|
206
|
+
prediction_config=self.prediction_config,
|
|
207
|
+
src_files=self.pred_files,
|
|
208
|
+
read_source_func=self.read_source_func,
|
|
209
|
+
)
|
|
226
210
|
|
|
227
211
|
def predict_dataloader(self) -> DataLoader:
|
|
228
212
|
"""
|
|
@@ -236,7 +220,7 @@ class CAREamicsPredictData(L.LightningDataModule):
|
|
|
236
220
|
return DataLoader(
|
|
237
221
|
self.predict_dataset,
|
|
238
222
|
batch_size=self.batch_size,
|
|
239
|
-
collate_fn=
|
|
223
|
+
collate_fn=collate_tiles if self.tiled else None,
|
|
240
224
|
**self.dataloader_params,
|
|
241
225
|
) # TODO check workers are used
|
|
242
226
|
|
|
@@ -287,12 +271,10 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
287
271
|
Prediction data.
|
|
288
272
|
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
289
273
|
Data type, see `SupportedData` for available options.
|
|
290
|
-
|
|
291
|
-
Mean
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
Standard deviation value for normalization, only used if Normalization is
|
|
295
|
-
defined in the transform.
|
|
274
|
+
image_means : list of float
|
|
275
|
+
Mean values for normalization, only used if Normalization is defined.
|
|
276
|
+
image_stds : list of float
|
|
277
|
+
Std values for normalization, only used if Normalization is defined.
|
|
296
278
|
tile_size : Tuple[int, ...]
|
|
297
279
|
Tile size, 2D or 3D tile size.
|
|
298
280
|
tile_overlap : Tuple[int, ...]
|
|
@@ -316,8 +298,8 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
316
298
|
self,
|
|
317
299
|
pred_data: Union[str, Path, np.ndarray],
|
|
318
300
|
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
|
|
319
|
-
|
|
320
|
-
|
|
301
|
+
image_means=list[float],
|
|
302
|
+
image_stds=list[float],
|
|
321
303
|
tile_size: Optional[Tuple[int, ...]] = None,
|
|
322
304
|
tile_overlap: Optional[Tuple[int, ...]] = None,
|
|
323
305
|
axes: str = "YX",
|
|
@@ -336,12 +318,10 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
336
318
|
Prediction data.
|
|
337
319
|
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
|
|
338
320
|
Data type, see `SupportedData` for available options.
|
|
339
|
-
|
|
340
|
-
Mean
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
Standard deviation value for normalization, only used if Normalization is
|
|
344
|
-
defined in the transform.
|
|
321
|
+
image_means : list of float
|
|
322
|
+
Mean values for normalization, only used if Normalization is defined.
|
|
323
|
+
image_stds : list of float
|
|
324
|
+
Std values for normalization, only used if Normalization is defined.
|
|
345
325
|
tile_size : List[int]
|
|
346
326
|
Tile size, 2D or 3D tile size.
|
|
347
327
|
tile_overlap : List[int]
|
|
@@ -367,8 +347,8 @@ class PredictDataWrapper(CAREamicsPredictData):
|
|
|
367
347
|
"tile_size": tile_size,
|
|
368
348
|
"tile_overlap": tile_overlap,
|
|
369
349
|
"axes": axes,
|
|
370
|
-
"
|
|
371
|
-
"
|
|
350
|
+
"image_means": image_means,
|
|
351
|
+
"image_stds": image_stds,
|
|
372
352
|
"tta": tta_transforms,
|
|
373
353
|
"batch_size": batch_size,
|
|
374
354
|
"transforms": [],
|
|
File without changes
|