careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,25 +4,25 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Callable,
|
|
7
|
+
from typing import Any, Callable, Optional, Union
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from torch.utils.data import Dataset
|
|
11
11
|
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
12
13
|
from careamics.transforms import Compose
|
|
13
14
|
|
|
14
|
-
from ..config import DataConfig
|
|
15
|
-
from ..config.tile_information import TileInformation
|
|
15
|
+
from ..config import DataConfig
|
|
16
16
|
from ..config.transformations import NormalizeModel
|
|
17
17
|
from ..utils.logging import get_logger
|
|
18
|
-
from .dataset_utils import read_tiff, reshape_array
|
|
19
18
|
from .patching.patching import (
|
|
19
|
+
PatchedOutput,
|
|
20
|
+
Stats,
|
|
20
21
|
prepare_patches_supervised,
|
|
21
22
|
prepare_patches_supervised_array,
|
|
22
23
|
prepare_patches_unsupervised,
|
|
23
24
|
prepare_patches_unsupervised_array,
|
|
24
25
|
)
|
|
25
|
-
from .patching.tiled_patching import extract_tiles
|
|
26
26
|
|
|
27
27
|
logger = get_logger(__name__)
|
|
28
28
|
|
|
@@ -32,11 +32,12 @@ class InMemoryDataset(Dataset):
|
|
|
32
32
|
|
|
33
33
|
Parameters
|
|
34
34
|
----------
|
|
35
|
-
data_config : DataConfig
|
|
35
|
+
data_config : CAREamics DataConfig
|
|
36
|
+
(see careamics.config.data_model.DataConfig)
|
|
36
37
|
Data configuration.
|
|
37
|
-
inputs :
|
|
38
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
38
39
|
Input data.
|
|
39
|
-
input_target :
|
|
40
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
40
41
|
Target data, by default None.
|
|
41
42
|
read_source_func : Callable, optional
|
|
42
43
|
Read source function for custom types, by default read_tiff.
|
|
@@ -47,8 +48,8 @@ class InMemoryDataset(Dataset):
|
|
|
47
48
|
def __init__(
|
|
48
49
|
self,
|
|
49
50
|
data_config: DataConfig,
|
|
50
|
-
inputs: Union[np.ndarray,
|
|
51
|
-
input_target: Optional[Union[np.ndarray,
|
|
51
|
+
inputs: Union[np.ndarray, list[Path]],
|
|
52
|
+
input_target: Optional[Union[np.ndarray, list[Path]]] = None,
|
|
52
53
|
read_source_func: Callable = read_tiff,
|
|
53
54
|
**kwargs: Any,
|
|
54
55
|
) -> None:
|
|
@@ -59,9 +60,9 @@ class InMemoryDataset(Dataset):
|
|
|
59
60
|
----------
|
|
60
61
|
data_config : DataConfig
|
|
61
62
|
Data configuration.
|
|
62
|
-
inputs :
|
|
63
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
63
64
|
Input data.
|
|
64
|
-
input_target :
|
|
65
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
65
66
|
Target data, by default None.
|
|
66
67
|
read_source_func : Callable, optional
|
|
67
68
|
Read source function for custom types, by default read_tiff.
|
|
@@ -77,31 +78,56 @@ class InMemoryDataset(Dataset):
|
|
|
77
78
|
# read function
|
|
78
79
|
self.read_source_func = read_source_func
|
|
79
80
|
|
|
80
|
-
#
|
|
81
|
+
# generate patches
|
|
81
82
|
supervised = self.input_targets is not None
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
#
|
|
85
|
-
self.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
83
|
+
patches_data = self._prepare_patches(supervised)
|
|
84
|
+
|
|
85
|
+
# unpack the dataclass
|
|
86
|
+
self.data = patches_data.patches
|
|
87
|
+
self.data_targets = patches_data.targets
|
|
88
|
+
|
|
89
|
+
# set image statistics
|
|
90
|
+
if self.data_config.image_means is None:
|
|
91
|
+
self.image_stats = patches_data.image_stats
|
|
92
|
+
logger.info(
|
|
93
|
+
f"Computed dataset mean: {self.image_stats.means}, "
|
|
94
|
+
f"std: {self.image_stats.stds}"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
self.image_stats = Stats(
|
|
98
|
+
self.data_config.image_means, self.data_config.image_stds
|
|
99
|
+
)
|
|
90
100
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
self.
|
|
101
|
+
# set target statistics
|
|
102
|
+
if self.data_config.target_means is None:
|
|
103
|
+
self.target_stats = patches_data.target_stats
|
|
94
104
|
else:
|
|
95
|
-
self.
|
|
105
|
+
self.target_stats = Stats(
|
|
106
|
+
self.data_config.target_means, self.data_config.target_stds
|
|
107
|
+
)
|
|
96
108
|
|
|
109
|
+
# update mean and std in configuration
|
|
110
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
111
|
+
self.data_config.set_means_and_stds(
|
|
112
|
+
image_means=self.image_stats.means,
|
|
113
|
+
image_stds=self.image_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
116
|
+
)
|
|
97
117
|
# get transforms
|
|
98
118
|
self.patch_transform = Compose(
|
|
99
|
-
transform_list=
|
|
119
|
+
transform_list=[
|
|
120
|
+
NormalizeModel(
|
|
121
|
+
image_means=self.image_stats.means,
|
|
122
|
+
image_stds=self.image_stats.stds,
|
|
123
|
+
target_means=self.target_stats.means,
|
|
124
|
+
target_stds=self.target_stats.stds,
|
|
125
|
+
)
|
|
126
|
+
]
|
|
127
|
+
+ self.data_config.transforms,
|
|
100
128
|
)
|
|
101
129
|
|
|
102
|
-
def _prepare_patches(
|
|
103
|
-
self, supervised: bool
|
|
104
|
-
) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
|
|
130
|
+
def _prepare_patches(self, supervised: bool) -> PatchedOutput:
|
|
105
131
|
"""
|
|
106
132
|
Iterate over data source and create an array of patches.
|
|
107
133
|
|
|
@@ -112,7 +138,7 @@ class InMemoryDataset(Dataset):
|
|
|
112
138
|
|
|
113
139
|
Returns
|
|
114
140
|
-------
|
|
115
|
-
|
|
141
|
+
numpy.ndarray
|
|
116
142
|
Array of patches.
|
|
117
143
|
"""
|
|
118
144
|
if supervised:
|
|
@@ -163,9 +189,9 @@ class InMemoryDataset(Dataset):
|
|
|
163
189
|
int
|
|
164
190
|
Length of the dataset.
|
|
165
191
|
"""
|
|
166
|
-
return
|
|
192
|
+
return self.data.shape[0]
|
|
167
193
|
|
|
168
|
-
def __getitem__(self, index: int) ->
|
|
194
|
+
def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
|
|
169
195
|
"""
|
|
170
196
|
Return the patch corresponding to the provided index.
|
|
171
197
|
|
|
@@ -176,7 +202,7 @@ class InMemoryDataset(Dataset):
|
|
|
176
202
|
|
|
177
203
|
Returns
|
|
178
204
|
-------
|
|
179
|
-
|
|
205
|
+
tuple of numpy.ndarray
|
|
180
206
|
Patch.
|
|
181
207
|
|
|
182
208
|
Raises
|
|
@@ -184,16 +210,16 @@ class InMemoryDataset(Dataset):
|
|
|
184
210
|
ValueError
|
|
185
211
|
If dataset mean and std are not set.
|
|
186
212
|
"""
|
|
187
|
-
patch = self.
|
|
213
|
+
patch = self.data[index]
|
|
188
214
|
|
|
189
215
|
# if there is a target
|
|
190
|
-
if self.
|
|
216
|
+
if self.data_targets is not None:
|
|
191
217
|
# get target
|
|
192
|
-
target = self.
|
|
218
|
+
target = self.data_targets[index]
|
|
193
219
|
|
|
194
220
|
return self.patch_transform(patch=patch, target=target)
|
|
195
221
|
|
|
196
|
-
elif self.data_config.has_n2v_manipulate():
|
|
222
|
+
elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
|
|
197
223
|
return self.patch_transform(patch=patch)
|
|
198
224
|
else:
|
|
199
225
|
raise ValueError(
|
|
@@ -201,6 +227,18 @@ class InMemoryDataset(Dataset):
|
|
|
201
227
|
"and no N2V manipulation (no N2V training)."
|
|
202
228
|
)
|
|
203
229
|
|
|
230
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
231
|
+
"""Return training data statistics.
|
|
232
|
+
|
|
233
|
+
This does not return the target data statistics, only those of the input.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
tuple of list of floats
|
|
238
|
+
Means and standard deviations across channels of the training data.
|
|
239
|
+
"""
|
|
240
|
+
return self.image_stats.get_statistics()
|
|
241
|
+
|
|
204
242
|
def split_dataset(
|
|
205
243
|
self,
|
|
206
244
|
percentage: float = 0.1,
|
|
@@ -219,7 +257,7 @@ class InMemoryDataset(Dataset):
|
|
|
219
257
|
|
|
220
258
|
Returns
|
|
221
259
|
-------
|
|
222
|
-
InMemoryDataset
|
|
260
|
+
CAREamics InMemoryDataset
|
|
223
261
|
New dataset with the extracted patches.
|
|
224
262
|
|
|
225
263
|
Raises
|
|
@@ -249,151 +287,24 @@ class InMemoryDataset(Dataset):
|
|
|
249
287
|
indices = np.random.choice(total_patches, n_patches, replace=False)
|
|
250
288
|
|
|
251
289
|
# extract patches
|
|
252
|
-
val_patches = self.
|
|
290
|
+
val_patches = self.data[indices]
|
|
253
291
|
|
|
254
292
|
# remove patches from self.patch
|
|
255
|
-
self.
|
|
293
|
+
self.data = np.delete(self.data, indices, axis=0)
|
|
256
294
|
|
|
257
295
|
# same for targets
|
|
258
|
-
if self.
|
|
259
|
-
val_targets = self.
|
|
260
|
-
self.
|
|
296
|
+
if self.data_targets is not None:
|
|
297
|
+
val_targets = self.data_targets[indices]
|
|
298
|
+
self.data_targets = np.delete(self.data_targets, indices, axis=0)
|
|
261
299
|
|
|
262
300
|
# clone the dataset
|
|
263
301
|
dataset = copy.deepcopy(self)
|
|
264
302
|
|
|
265
303
|
# reassign patches
|
|
266
|
-
dataset.
|
|
304
|
+
dataset.data = val_patches
|
|
267
305
|
|
|
268
306
|
# reassign targets
|
|
269
|
-
if self.
|
|
270
|
-
dataset.
|
|
307
|
+
if self.data_targets is not None:
|
|
308
|
+
dataset.data_targets = val_targets
|
|
271
309
|
|
|
272
310
|
return dataset
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
class InMemoryPredictionDataset(Dataset):
|
|
276
|
-
"""
|
|
277
|
-
Dataset storing data in memory and allowing generating patches from it.
|
|
278
|
-
|
|
279
|
-
Parameters
|
|
280
|
-
----------
|
|
281
|
-
prediction_config : InferenceConfig
|
|
282
|
-
Prediction configuration.
|
|
283
|
-
inputs : np.ndarray
|
|
284
|
-
Input data.
|
|
285
|
-
data_target : Optional[np.ndarray], optional
|
|
286
|
-
Target data, by default None.
|
|
287
|
-
read_source_func : Optional[Callable], optional
|
|
288
|
-
Read source function for custom types, by default read_tiff.
|
|
289
|
-
"""
|
|
290
|
-
|
|
291
|
-
def __init__(
|
|
292
|
-
self,
|
|
293
|
-
prediction_config: InferenceConfig,
|
|
294
|
-
inputs: np.ndarray,
|
|
295
|
-
data_target: Optional[np.ndarray] = None,
|
|
296
|
-
read_source_func: Optional[Callable] = read_tiff,
|
|
297
|
-
) -> None:
|
|
298
|
-
"""Constructor.
|
|
299
|
-
|
|
300
|
-
Parameters
|
|
301
|
-
----------
|
|
302
|
-
prediction_config : InferenceConfig
|
|
303
|
-
Prediction configuration.
|
|
304
|
-
inputs : np.ndarray
|
|
305
|
-
Input data.
|
|
306
|
-
data_target : Optional[np.ndarray], optional
|
|
307
|
-
Target data, by default None.
|
|
308
|
-
read_source_func : Optional[Callable], optional
|
|
309
|
-
Read source function for custom types, by default read_tiff.
|
|
310
|
-
|
|
311
|
-
Raises
|
|
312
|
-
------
|
|
313
|
-
ValueError
|
|
314
|
-
If data_path is not a directory.
|
|
315
|
-
"""
|
|
316
|
-
self.pred_config = prediction_config
|
|
317
|
-
self.input_array = inputs
|
|
318
|
-
self.axes = self.pred_config.axes
|
|
319
|
-
self.tile_size = self.pred_config.tile_size
|
|
320
|
-
self.tile_overlap = self.pred_config.tile_overlap
|
|
321
|
-
self.mean = self.pred_config.mean
|
|
322
|
-
self.std = self.pred_config.std
|
|
323
|
-
self.data_target = data_target
|
|
324
|
-
|
|
325
|
-
# tiling only if both tile size and overlap are provided
|
|
326
|
-
self.tiling = self.tile_size is not None and self.tile_overlap is not None
|
|
327
|
-
|
|
328
|
-
# read function
|
|
329
|
-
self.read_source_func = read_source_func
|
|
330
|
-
|
|
331
|
-
# Generate patches
|
|
332
|
-
self.data = self._prepare_tiles()
|
|
333
|
-
self.mean, self.std = self.pred_config.mean, self.pred_config.std
|
|
334
|
-
|
|
335
|
-
# get transforms
|
|
336
|
-
self.patch_transform = Compose(
|
|
337
|
-
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
341
|
-
"""
|
|
342
|
-
Iterate over data source and create an array of patches.
|
|
343
|
-
|
|
344
|
-
Returns
|
|
345
|
-
-------
|
|
346
|
-
List[XArrayTile]
|
|
347
|
-
List of tiles.
|
|
348
|
-
"""
|
|
349
|
-
# reshape array
|
|
350
|
-
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
351
|
-
|
|
352
|
-
if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
|
|
353
|
-
# generate patches, which returns a generator
|
|
354
|
-
patch_generator = extract_tiles(
|
|
355
|
-
arr=reshaped_sample,
|
|
356
|
-
tile_size=self.tile_size,
|
|
357
|
-
overlaps=self.tile_overlap,
|
|
358
|
-
)
|
|
359
|
-
patches_list = list(patch_generator)
|
|
360
|
-
|
|
361
|
-
if len(patches_list) == 0:
|
|
362
|
-
raise ValueError("No tiles generated, ")
|
|
363
|
-
|
|
364
|
-
return patches_list
|
|
365
|
-
else:
|
|
366
|
-
array_shape = reshaped_sample.squeeze().shape
|
|
367
|
-
return [(reshaped_sample, TileInformation(array_shape=array_shape))]
|
|
368
|
-
|
|
369
|
-
def __len__(self) -> int:
|
|
370
|
-
"""
|
|
371
|
-
Return the length of the dataset.
|
|
372
|
-
|
|
373
|
-
Returns
|
|
374
|
-
-------
|
|
375
|
-
int
|
|
376
|
-
Length of the dataset.
|
|
377
|
-
"""
|
|
378
|
-
return len(self.data)
|
|
379
|
-
|
|
380
|
-
def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
|
|
381
|
-
"""
|
|
382
|
-
Return the patch corresponding to the provided index.
|
|
383
|
-
|
|
384
|
-
Parameters
|
|
385
|
-
----------
|
|
386
|
-
index : int
|
|
387
|
-
Index of the patch to return.
|
|
388
|
-
|
|
389
|
-
Returns
|
|
390
|
-
-------
|
|
391
|
-
Tuple[np.ndarray, TileInformation]
|
|
392
|
-
Transformed patch.
|
|
393
|
-
"""
|
|
394
|
-
tile_array, tile_info = self.data[index]
|
|
395
|
-
|
|
396
|
-
# Apply transforms
|
|
397
|
-
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
398
|
-
|
|
399
|
-
return transformed_tile, tile_info
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""In-memory prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.transformations import NormalizeModel
|
|
12
|
+
from .dataset_utils import reshape_array
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InMemoryPredDataset(Dataset):
|
|
16
|
+
"""Simple prediction dataset returning images along the sample axis.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
prediction_config : InferenceConfig
|
|
21
|
+
Prediction configuration.
|
|
22
|
+
inputs : NDArray
|
|
23
|
+
Input data.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
prediction_config: InferenceConfig,
|
|
29
|
+
inputs: NDArray,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Constructor.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
prediction_config : InferenceConfig
|
|
36
|
+
Prediction configuration.
|
|
37
|
+
inputs : NDArray
|
|
38
|
+
Input data.
|
|
39
|
+
|
|
40
|
+
Raises
|
|
41
|
+
------
|
|
42
|
+
ValueError
|
|
43
|
+
If data_path is not a directory.
|
|
44
|
+
"""
|
|
45
|
+
self.pred_config = prediction_config
|
|
46
|
+
self.input_array = inputs
|
|
47
|
+
self.axes = self.pred_config.axes
|
|
48
|
+
self.image_means = self.pred_config.image_means
|
|
49
|
+
self.image_stds = self.pred_config.image_stds
|
|
50
|
+
|
|
51
|
+
# Reshape data
|
|
52
|
+
self.data = reshape_array(self.input_array, self.axes)
|
|
53
|
+
|
|
54
|
+
# get transforms
|
|
55
|
+
self.patch_transform = Compose(
|
|
56
|
+
transform_list=[
|
|
57
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __len__(self) -> int:
|
|
62
|
+
"""
|
|
63
|
+
Return the length of the dataset.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
int
|
|
68
|
+
Length of the dataset.
|
|
69
|
+
"""
|
|
70
|
+
return len(self.data)
|
|
71
|
+
|
|
72
|
+
def __getitem__(self, index: int) -> NDArray:
|
|
73
|
+
"""
|
|
74
|
+
Return the patch corresponding to the provided index.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
index : int
|
|
79
|
+
Index of the patch to return.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
NDArray
|
|
84
|
+
Transformed patch.
|
|
85
|
+
"""
|
|
86
|
+
transformed_patch, _ = self.patch_transform(patch=self.data[index])
|
|
87
|
+
|
|
88
|
+
return transformed_patch
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""In-memory tiled prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.tile_information import TileInformation
|
|
12
|
+
from ..config.transformations import NormalizeModel
|
|
13
|
+
from .dataset_utils import reshape_array
|
|
14
|
+
from .tiling import extract_tiles
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InMemoryTiledPredDataset(Dataset):
|
|
18
|
+
"""Prediction dataset storing data in memory and returning tiles of each image.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
prediction_config : InferenceConfig
|
|
23
|
+
Prediction configuration.
|
|
24
|
+
inputs : NDArray
|
|
25
|
+
Input data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
prediction_config: InferenceConfig,
|
|
31
|
+
inputs: NDArray,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
prediction_config : InferenceConfig
|
|
38
|
+
Prediction configuration.
|
|
39
|
+
inputs : NDArray
|
|
40
|
+
Input data.
|
|
41
|
+
|
|
42
|
+
Raises
|
|
43
|
+
------
|
|
44
|
+
ValueError
|
|
45
|
+
If data_path is not a directory.
|
|
46
|
+
"""
|
|
47
|
+
if (
|
|
48
|
+
prediction_config.tile_size is None
|
|
49
|
+
or prediction_config.tile_overlap is None
|
|
50
|
+
):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Tile size and overlap must be provided to use the tiled prediction "
|
|
53
|
+
"dataset."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.pred_config = prediction_config
|
|
57
|
+
self.input_array = inputs
|
|
58
|
+
self.axes = self.pred_config.axes
|
|
59
|
+
self.tile_size = prediction_config.tile_size
|
|
60
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
61
|
+
self.image_means = self.pred_config.image_means
|
|
62
|
+
self.image_stds = self.pred_config.image_stds
|
|
63
|
+
|
|
64
|
+
# Generate patches
|
|
65
|
+
self.data = self._prepare_tiles()
|
|
66
|
+
|
|
67
|
+
# get transforms
|
|
68
|
+
self.patch_transform = Compose(
|
|
69
|
+
transform_list=[
|
|
70
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
71
|
+
],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
|
|
75
|
+
"""
|
|
76
|
+
Iterate over data source and create an array of patches.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
list of tuples of NDArray and TileInformation
|
|
81
|
+
List of tiles and tile information.
|
|
82
|
+
"""
|
|
83
|
+
# reshape array
|
|
84
|
+
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
85
|
+
|
|
86
|
+
# generate patches, which returns a generator
|
|
87
|
+
patch_generator = extract_tiles(
|
|
88
|
+
arr=reshaped_sample,
|
|
89
|
+
tile_size=self.tile_size,
|
|
90
|
+
overlaps=self.tile_overlap,
|
|
91
|
+
)
|
|
92
|
+
patches_list = list(patch_generator)
|
|
93
|
+
|
|
94
|
+
if len(patches_list) == 0:
|
|
95
|
+
raise ValueError("No tiles generated, ")
|
|
96
|
+
|
|
97
|
+
return patches_list
|
|
98
|
+
|
|
99
|
+
def __len__(self) -> int:
|
|
100
|
+
"""
|
|
101
|
+
Return the length of the dataset.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
int
|
|
106
|
+
Length of the dataset.
|
|
107
|
+
"""
|
|
108
|
+
return len(self.data)
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
|
|
111
|
+
"""
|
|
112
|
+
Return the patch corresponding to the provided index.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
index : int
|
|
117
|
+
Index of the patch to return.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
tuple of NDArray and TileInformation
|
|
122
|
+
Transformed patch.
|
|
123
|
+
"""
|
|
124
|
+
tile_array, tile_info = self.data[index]
|
|
125
|
+
|
|
126
|
+
# Apply transforms
|
|
127
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
128
|
+
|
|
129
|
+
return transformed_tile, tile_info
|