careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,25 +4,24 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Callable,
|
|
7
|
+
from typing import Any, Callable, Optional, Union
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from torch.utils.data import Dataset
|
|
11
11
|
|
|
12
12
|
from careamics.transforms import Compose
|
|
13
13
|
|
|
14
|
-
from ..config import DataConfig
|
|
15
|
-
from ..config.tile_information import TileInformation
|
|
14
|
+
from ..config import DataConfig
|
|
16
15
|
from ..config.transformations import NormalizeModel
|
|
17
16
|
from ..utils.logging import get_logger
|
|
18
|
-
from .dataset_utils import read_tiff
|
|
17
|
+
from .dataset_utils import read_tiff
|
|
19
18
|
from .patching.patching import (
|
|
19
|
+
PatchedOutput,
|
|
20
20
|
prepare_patches_supervised,
|
|
21
21
|
prepare_patches_supervised_array,
|
|
22
22
|
prepare_patches_unsupervised,
|
|
23
23
|
prepare_patches_unsupervised_array,
|
|
24
24
|
)
|
|
25
|
-
from .patching.tiled_patching import extract_tiles
|
|
26
25
|
|
|
27
26
|
logger = get_logger(__name__)
|
|
28
27
|
|
|
@@ -32,11 +31,12 @@ class InMemoryDataset(Dataset):
|
|
|
32
31
|
|
|
33
32
|
Parameters
|
|
34
33
|
----------
|
|
35
|
-
data_config : DataConfig
|
|
34
|
+
data_config : CAREamics DataConfig
|
|
35
|
+
(see careamics.config.data_model.DataConfig)
|
|
36
36
|
Data configuration.
|
|
37
|
-
inputs :
|
|
37
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
38
38
|
Input data.
|
|
39
|
-
input_target :
|
|
39
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
40
40
|
Target data, by default None.
|
|
41
41
|
read_source_func : Callable, optional
|
|
42
42
|
Read source function for custom types, by default read_tiff.
|
|
@@ -47,8 +47,8 @@ class InMemoryDataset(Dataset):
|
|
|
47
47
|
def __init__(
|
|
48
48
|
self,
|
|
49
49
|
data_config: DataConfig,
|
|
50
|
-
inputs: Union[np.ndarray,
|
|
51
|
-
input_target: Optional[Union[np.ndarray,
|
|
50
|
+
inputs: Union[np.ndarray, list[Path]],
|
|
51
|
+
input_target: Optional[Union[np.ndarray, list[Path]]] = None,
|
|
52
52
|
read_source_func: Callable = read_tiff,
|
|
53
53
|
**kwargs: Any,
|
|
54
54
|
) -> None:
|
|
@@ -59,9 +59,9 @@ class InMemoryDataset(Dataset):
|
|
|
59
59
|
----------
|
|
60
60
|
data_config : DataConfig
|
|
61
61
|
Data configuration.
|
|
62
|
-
inputs :
|
|
62
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
63
63
|
Input data.
|
|
64
|
-
input_target :
|
|
64
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
65
65
|
Target data, by default None.
|
|
66
66
|
read_source_func : Callable, optional
|
|
67
67
|
Read source function for custom types, by default read_tiff.
|
|
@@ -79,29 +79,51 @@ class InMemoryDataset(Dataset):
|
|
|
79
79
|
|
|
80
80
|
# Generate patches
|
|
81
81
|
supervised = self.input_targets is not None
|
|
82
|
-
|
|
82
|
+
patches_data = self._prepare_patches(supervised)
|
|
83
83
|
|
|
84
|
-
#
|
|
85
|
-
self.
|
|
84
|
+
# Unpack the dataclass
|
|
85
|
+
self.data = patches_data.patches
|
|
86
|
+
self.data_targets = patches_data.targets
|
|
86
87
|
|
|
87
|
-
if
|
|
88
|
-
self.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
88
|
+
if self.data_config.image_means is None:
|
|
89
|
+
self.image_means = patches_data.image_stats.means
|
|
90
|
+
self.image_stds = patches_data.image_stats.stds
|
|
91
|
+
logger.info(
|
|
92
|
+
f"Computed dataset mean: {self.image_means}, std: {self.image_stds}"
|
|
93
|
+
)
|
|
94
94
|
else:
|
|
95
|
-
self.
|
|
95
|
+
self.image_means = self.data_config.image_means
|
|
96
|
+
self.image_stds = self.data_config.image_stds
|
|
96
97
|
|
|
98
|
+
if self.data_config.target_means is None:
|
|
99
|
+
self.target_means = patches_data.target_stats.means
|
|
100
|
+
self.target_stds = patches_data.target_stats.stds
|
|
101
|
+
else:
|
|
102
|
+
self.target_means = self.data_config.target_means
|
|
103
|
+
self.target_stds = self.data_config.target_stds
|
|
104
|
+
|
|
105
|
+
# update mean and std in configuration
|
|
106
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
107
|
+
self.data_config.set_mean_and_std(
|
|
108
|
+
image_means=self.image_means,
|
|
109
|
+
image_stds=self.image_stds,
|
|
110
|
+
target_means=self.target_means,
|
|
111
|
+
target_stds=self.target_stds,
|
|
112
|
+
)
|
|
97
113
|
# get transforms
|
|
98
114
|
self.patch_transform = Compose(
|
|
99
|
-
transform_list=
|
|
115
|
+
transform_list=[
|
|
116
|
+
NormalizeModel(
|
|
117
|
+
image_means=self.image_means,
|
|
118
|
+
image_stds=self.image_stds,
|
|
119
|
+
target_means=self.target_means,
|
|
120
|
+
target_stds=self.target_stds,
|
|
121
|
+
)
|
|
122
|
+
]
|
|
123
|
+
+ self.data_config.transforms,
|
|
100
124
|
)
|
|
101
125
|
|
|
102
|
-
def _prepare_patches(
|
|
103
|
-
self, supervised: bool
|
|
104
|
-
) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
|
|
126
|
+
def _prepare_patches(self, supervised: bool) -> PatchedOutput:
|
|
105
127
|
"""
|
|
106
128
|
Iterate over data source and create an array of patches.
|
|
107
129
|
|
|
@@ -112,7 +134,7 @@ class InMemoryDataset(Dataset):
|
|
|
112
134
|
|
|
113
135
|
Returns
|
|
114
136
|
-------
|
|
115
|
-
|
|
137
|
+
numpy.ndarray
|
|
116
138
|
Array of patches.
|
|
117
139
|
"""
|
|
118
140
|
if supervised:
|
|
@@ -163,9 +185,9 @@ class InMemoryDataset(Dataset):
|
|
|
163
185
|
int
|
|
164
186
|
Length of the dataset.
|
|
165
187
|
"""
|
|
166
|
-
return
|
|
188
|
+
return self.data.shape[0]
|
|
167
189
|
|
|
168
|
-
def __getitem__(self, index: int) ->
|
|
190
|
+
def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
|
|
169
191
|
"""
|
|
170
192
|
Return the patch corresponding to the provided index.
|
|
171
193
|
|
|
@@ -176,7 +198,7 @@ class InMemoryDataset(Dataset):
|
|
|
176
198
|
|
|
177
199
|
Returns
|
|
178
200
|
-------
|
|
179
|
-
|
|
201
|
+
tuple of numpy.ndarray
|
|
180
202
|
Patch.
|
|
181
203
|
|
|
182
204
|
Raises
|
|
@@ -184,16 +206,16 @@ class InMemoryDataset(Dataset):
|
|
|
184
206
|
ValueError
|
|
185
207
|
If dataset mean and std are not set.
|
|
186
208
|
"""
|
|
187
|
-
patch = self.
|
|
209
|
+
patch = self.data[index]
|
|
188
210
|
|
|
189
211
|
# if there is a target
|
|
190
|
-
if self.
|
|
212
|
+
if self.data_targets is not None:
|
|
191
213
|
# get target
|
|
192
|
-
target = self.
|
|
214
|
+
target = self.data_targets[index]
|
|
193
215
|
|
|
194
216
|
return self.patch_transform(patch=patch, target=target)
|
|
195
217
|
|
|
196
|
-
elif self.data_config.has_n2v_manipulate():
|
|
218
|
+
elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
|
|
197
219
|
return self.patch_transform(patch=patch)
|
|
198
220
|
else:
|
|
199
221
|
raise ValueError(
|
|
@@ -219,7 +241,7 @@ class InMemoryDataset(Dataset):
|
|
|
219
241
|
|
|
220
242
|
Returns
|
|
221
243
|
-------
|
|
222
|
-
InMemoryDataset
|
|
244
|
+
CAREamics InMemoryDataset
|
|
223
245
|
New dataset with the extracted patches.
|
|
224
246
|
|
|
225
247
|
Raises
|
|
@@ -249,151 +271,24 @@ class InMemoryDataset(Dataset):
|
|
|
249
271
|
indices = np.random.choice(total_patches, n_patches, replace=False)
|
|
250
272
|
|
|
251
273
|
# extract patches
|
|
252
|
-
val_patches = self.
|
|
274
|
+
val_patches = self.data[indices]
|
|
253
275
|
|
|
254
276
|
# remove patches from self.patch
|
|
255
|
-
self.
|
|
277
|
+
self.data = np.delete(self.data, indices, axis=0)
|
|
256
278
|
|
|
257
279
|
# same for targets
|
|
258
|
-
if self.
|
|
259
|
-
val_targets = self.
|
|
260
|
-
self.
|
|
280
|
+
if self.data_targets is not None:
|
|
281
|
+
val_targets = self.data_targets[indices]
|
|
282
|
+
self.data_targets = np.delete(self.data_targets, indices, axis=0)
|
|
261
283
|
|
|
262
284
|
# clone the dataset
|
|
263
285
|
dataset = copy.deepcopy(self)
|
|
264
286
|
|
|
265
287
|
# reassign patches
|
|
266
|
-
dataset.
|
|
288
|
+
dataset.data = val_patches
|
|
267
289
|
|
|
268
290
|
# reassign targets
|
|
269
|
-
if self.
|
|
270
|
-
dataset.
|
|
291
|
+
if self.data_targets is not None:
|
|
292
|
+
dataset.data_targets = val_targets
|
|
271
293
|
|
|
272
294
|
return dataset
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
class InMemoryPredictionDataset(Dataset):
|
|
276
|
-
"""
|
|
277
|
-
Dataset storing data in memory and allowing generating patches from it.
|
|
278
|
-
|
|
279
|
-
Parameters
|
|
280
|
-
----------
|
|
281
|
-
prediction_config : InferenceConfig
|
|
282
|
-
Prediction configuration.
|
|
283
|
-
inputs : np.ndarray
|
|
284
|
-
Input data.
|
|
285
|
-
data_target : Optional[np.ndarray], optional
|
|
286
|
-
Target data, by default None.
|
|
287
|
-
read_source_func : Optional[Callable], optional
|
|
288
|
-
Read source function for custom types, by default read_tiff.
|
|
289
|
-
"""
|
|
290
|
-
|
|
291
|
-
def __init__(
|
|
292
|
-
self,
|
|
293
|
-
prediction_config: InferenceConfig,
|
|
294
|
-
inputs: np.ndarray,
|
|
295
|
-
data_target: Optional[np.ndarray] = None,
|
|
296
|
-
read_source_func: Optional[Callable] = read_tiff,
|
|
297
|
-
) -> None:
|
|
298
|
-
"""Constructor.
|
|
299
|
-
|
|
300
|
-
Parameters
|
|
301
|
-
----------
|
|
302
|
-
prediction_config : InferenceConfig
|
|
303
|
-
Prediction configuration.
|
|
304
|
-
inputs : np.ndarray
|
|
305
|
-
Input data.
|
|
306
|
-
data_target : Optional[np.ndarray], optional
|
|
307
|
-
Target data, by default None.
|
|
308
|
-
read_source_func : Optional[Callable], optional
|
|
309
|
-
Read source function for custom types, by default read_tiff.
|
|
310
|
-
|
|
311
|
-
Raises
|
|
312
|
-
------
|
|
313
|
-
ValueError
|
|
314
|
-
If data_path is not a directory.
|
|
315
|
-
"""
|
|
316
|
-
self.pred_config = prediction_config
|
|
317
|
-
self.input_array = inputs
|
|
318
|
-
self.axes = self.pred_config.axes
|
|
319
|
-
self.tile_size = self.pred_config.tile_size
|
|
320
|
-
self.tile_overlap = self.pred_config.tile_overlap
|
|
321
|
-
self.mean = self.pred_config.mean
|
|
322
|
-
self.std = self.pred_config.std
|
|
323
|
-
self.data_target = data_target
|
|
324
|
-
|
|
325
|
-
# tiling only if both tile size and overlap are provided
|
|
326
|
-
self.tiling = self.tile_size is not None and self.tile_overlap is not None
|
|
327
|
-
|
|
328
|
-
# read function
|
|
329
|
-
self.read_source_func = read_source_func
|
|
330
|
-
|
|
331
|
-
# Generate patches
|
|
332
|
-
self.data = self._prepare_tiles()
|
|
333
|
-
self.mean, self.std = self.pred_config.mean, self.pred_config.std
|
|
334
|
-
|
|
335
|
-
# get transforms
|
|
336
|
-
self.patch_transform = Compose(
|
|
337
|
-
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
341
|
-
"""
|
|
342
|
-
Iterate over data source and create an array of patches.
|
|
343
|
-
|
|
344
|
-
Returns
|
|
345
|
-
-------
|
|
346
|
-
List[XArrayTile]
|
|
347
|
-
List of tiles.
|
|
348
|
-
"""
|
|
349
|
-
# reshape array
|
|
350
|
-
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
351
|
-
|
|
352
|
-
if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
|
|
353
|
-
# generate patches, which returns a generator
|
|
354
|
-
patch_generator = extract_tiles(
|
|
355
|
-
arr=reshaped_sample,
|
|
356
|
-
tile_size=self.tile_size,
|
|
357
|
-
overlaps=self.tile_overlap,
|
|
358
|
-
)
|
|
359
|
-
patches_list = list(patch_generator)
|
|
360
|
-
|
|
361
|
-
if len(patches_list) == 0:
|
|
362
|
-
raise ValueError("No tiles generated, ")
|
|
363
|
-
|
|
364
|
-
return patches_list
|
|
365
|
-
else:
|
|
366
|
-
array_shape = reshaped_sample.squeeze().shape
|
|
367
|
-
return [(reshaped_sample, TileInformation(array_shape=array_shape))]
|
|
368
|
-
|
|
369
|
-
def __len__(self) -> int:
|
|
370
|
-
"""
|
|
371
|
-
Return the length of the dataset.
|
|
372
|
-
|
|
373
|
-
Returns
|
|
374
|
-
-------
|
|
375
|
-
int
|
|
376
|
-
Length of the dataset.
|
|
377
|
-
"""
|
|
378
|
-
return len(self.data)
|
|
379
|
-
|
|
380
|
-
def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
|
|
381
|
-
"""
|
|
382
|
-
Return the patch corresponding to the provided index.
|
|
383
|
-
|
|
384
|
-
Parameters
|
|
385
|
-
----------
|
|
386
|
-
index : int
|
|
387
|
-
Index of the patch to return.
|
|
388
|
-
|
|
389
|
-
Returns
|
|
390
|
-
-------
|
|
391
|
-
Tuple[np.ndarray, TileInformation]
|
|
392
|
-
Transformed patch.
|
|
393
|
-
"""
|
|
394
|
-
tile_array, tile_info = self.data[index]
|
|
395
|
-
|
|
396
|
-
# Apply transforms
|
|
397
|
-
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
398
|
-
|
|
399
|
-
return transformed_tile, tile_info
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""In-memory prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.transformations import NormalizeModel
|
|
12
|
+
from .dataset_utils import reshape_array
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InMemoryPredDataset(Dataset):
|
|
16
|
+
"""Simple prediction dataset returning images along the sample axis.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
prediction_config : InferenceConfig
|
|
21
|
+
Prediction configuration.
|
|
22
|
+
inputs : NDArray
|
|
23
|
+
Input data.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
prediction_config: InferenceConfig,
|
|
29
|
+
inputs: NDArray,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Constructor.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
prediction_config : InferenceConfig
|
|
36
|
+
Prediction configuration.
|
|
37
|
+
inputs : NDArray
|
|
38
|
+
Input data.
|
|
39
|
+
|
|
40
|
+
Raises
|
|
41
|
+
------
|
|
42
|
+
ValueError
|
|
43
|
+
If data_path is not a directory.
|
|
44
|
+
"""
|
|
45
|
+
self.pred_config = prediction_config
|
|
46
|
+
self.input_array = inputs
|
|
47
|
+
self.axes = self.pred_config.axes
|
|
48
|
+
self.image_means = self.pred_config.image_means
|
|
49
|
+
self.image_stds = self.pred_config.image_stds
|
|
50
|
+
|
|
51
|
+
# Reshape data
|
|
52
|
+
self.data = reshape_array(self.input_array, self.axes)
|
|
53
|
+
|
|
54
|
+
# get transforms
|
|
55
|
+
self.patch_transform = Compose(
|
|
56
|
+
transform_list=[
|
|
57
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __len__(self) -> int:
|
|
62
|
+
"""
|
|
63
|
+
Return the length of the dataset.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
int
|
|
68
|
+
Length of the dataset.
|
|
69
|
+
"""
|
|
70
|
+
return len(self.data)
|
|
71
|
+
|
|
72
|
+
def __getitem__(self, index: int) -> NDArray:
|
|
73
|
+
"""
|
|
74
|
+
Return the patch corresponding to the provided index.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
index : int
|
|
79
|
+
Index of the patch to return.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
NDArray
|
|
84
|
+
Transformed patch.
|
|
85
|
+
"""
|
|
86
|
+
transformed_patch, _ = self.patch_transform(patch=self.data[index])
|
|
87
|
+
|
|
88
|
+
return transformed_patch
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""In-memory tiled prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.tile_information import TileInformation
|
|
12
|
+
from ..config.transformations import NormalizeModel
|
|
13
|
+
from .dataset_utils import reshape_array
|
|
14
|
+
from .tiling import extract_tiles
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InMemoryTiledPredDataset(Dataset):
|
|
18
|
+
"""Prediction dataset storing data in memory and returning tiles of each image.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
prediction_config : InferenceConfig
|
|
23
|
+
Prediction configuration.
|
|
24
|
+
inputs : NDArray
|
|
25
|
+
Input data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
prediction_config: InferenceConfig,
|
|
31
|
+
inputs: NDArray,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
prediction_config : InferenceConfig
|
|
38
|
+
Prediction configuration.
|
|
39
|
+
inputs : NDArray
|
|
40
|
+
Input data.
|
|
41
|
+
|
|
42
|
+
Raises
|
|
43
|
+
------
|
|
44
|
+
ValueError
|
|
45
|
+
If data_path is not a directory.
|
|
46
|
+
"""
|
|
47
|
+
if (
|
|
48
|
+
prediction_config.tile_size is None
|
|
49
|
+
or prediction_config.tile_overlap is None
|
|
50
|
+
):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Tile size and overlap must be provided to use the tiled prediction "
|
|
53
|
+
"dataset."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.pred_config = prediction_config
|
|
57
|
+
self.input_array = inputs
|
|
58
|
+
self.axes = self.pred_config.axes
|
|
59
|
+
self.tile_size = prediction_config.tile_size
|
|
60
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
61
|
+
self.image_means = self.pred_config.image_means
|
|
62
|
+
self.image_stds = self.pred_config.image_stds
|
|
63
|
+
|
|
64
|
+
# Generate patches
|
|
65
|
+
self.data = self._prepare_tiles()
|
|
66
|
+
|
|
67
|
+
# get transforms
|
|
68
|
+
self.patch_transform = Compose(
|
|
69
|
+
transform_list=[
|
|
70
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
71
|
+
],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
|
|
75
|
+
"""
|
|
76
|
+
Iterate over data source and create an array of patches.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
list of tuples of NDArray and TileInformation
|
|
81
|
+
List of tiles and tile information.
|
|
82
|
+
"""
|
|
83
|
+
# reshape array
|
|
84
|
+
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
85
|
+
|
|
86
|
+
# generate patches, which returns a generator
|
|
87
|
+
patch_generator = extract_tiles(
|
|
88
|
+
arr=reshaped_sample,
|
|
89
|
+
tile_size=self.tile_size,
|
|
90
|
+
overlaps=self.tile_overlap,
|
|
91
|
+
)
|
|
92
|
+
patches_list = list(patch_generator)
|
|
93
|
+
|
|
94
|
+
if len(patches_list) == 0:
|
|
95
|
+
raise ValueError("No tiles generated, ")
|
|
96
|
+
|
|
97
|
+
return patches_list
|
|
98
|
+
|
|
99
|
+
def __len__(self) -> int:
|
|
100
|
+
"""
|
|
101
|
+
Return the length of the dataset.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
int
|
|
106
|
+
Length of the dataset.
|
|
107
|
+
"""
|
|
108
|
+
return len(self.data)
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
|
|
111
|
+
"""
|
|
112
|
+
Return the patch corresponding to the provided index.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
index : int
|
|
117
|
+
Index of the patch to return.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
tuple of NDArray and TileInformation
|
|
122
|
+
Transformed patch.
|
|
123
|
+
"""
|
|
124
|
+
tile_array, tile_info = self.data[index]
|
|
125
|
+
|
|
126
|
+
# Apply transforms
|
|
127
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
128
|
+
|
|
129
|
+
return transformed_tile, tile_info
|