careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""In-memory dataset module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Callable, Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
from careamics.file_io.read import read_tiff
|
|
13
|
+
from careamics.transforms import Compose
|
|
14
|
+
|
|
15
|
+
from ..config import DataConfig
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
17
|
+
from ..utils.logging import get_logger
|
|
18
|
+
from .patching.patching import (
|
|
19
|
+
PatchedOutput,
|
|
20
|
+
Stats,
|
|
21
|
+
prepare_patches_supervised,
|
|
22
|
+
prepare_patches_supervised_array,
|
|
23
|
+
prepare_patches_unsupervised,
|
|
24
|
+
prepare_patches_unsupervised_array,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class InMemoryDataset(Dataset):
|
|
31
|
+
"""Dataset storing data in memory and allowing generating patches from it.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
data_config : CAREamics DataConfig
|
|
36
|
+
(see careamics.config.data_model.DataConfig)
|
|
37
|
+
Data configuration.
|
|
38
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
39
|
+
Input data.
|
|
40
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
41
|
+
Target data, by default None.
|
|
42
|
+
read_source_func : Callable, optional
|
|
43
|
+
Read source function for custom types, by default read_tiff.
|
|
44
|
+
**kwargs : Any
|
|
45
|
+
Additional keyword arguments, unused.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
data_config: DataConfig,
|
|
51
|
+
inputs: Union[np.ndarray, list[Path]],
|
|
52
|
+
input_target: Optional[Union[np.ndarray, list[Path]]] = None,
|
|
53
|
+
read_source_func: Callable = read_tiff,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Constructor.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
data_config : DataConfig
|
|
62
|
+
Data configuration.
|
|
63
|
+
inputs : numpy.ndarray or list[pathlib.Path]
|
|
64
|
+
Input data.
|
|
65
|
+
input_target : numpy.ndarray or list[pathlib.Path], optional
|
|
66
|
+
Target data, by default None.
|
|
67
|
+
read_source_func : Callable, optional
|
|
68
|
+
Read source function for custom types, by default read_tiff.
|
|
69
|
+
**kwargs : Any
|
|
70
|
+
Additional keyword arguments, unused.
|
|
71
|
+
"""
|
|
72
|
+
self.data_config = data_config
|
|
73
|
+
self.inputs = inputs
|
|
74
|
+
self.input_targets = input_target
|
|
75
|
+
self.axes = self.data_config.axes
|
|
76
|
+
self.patch_size = self.data_config.patch_size
|
|
77
|
+
|
|
78
|
+
# read function
|
|
79
|
+
self.read_source_func = read_source_func
|
|
80
|
+
|
|
81
|
+
# generate patches
|
|
82
|
+
supervised = self.input_targets is not None
|
|
83
|
+
patches_data = self._prepare_patches(supervised)
|
|
84
|
+
|
|
85
|
+
# unpack the dataclass
|
|
86
|
+
self.data = patches_data.patches
|
|
87
|
+
self.data_targets = patches_data.targets
|
|
88
|
+
|
|
89
|
+
# set image statistics
|
|
90
|
+
if self.data_config.image_means is None:
|
|
91
|
+
self.image_stats = patches_data.image_stats
|
|
92
|
+
logger.info(
|
|
93
|
+
f"Computed dataset mean: {self.image_stats.means}, "
|
|
94
|
+
f"std: {self.image_stats.stds}"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
self.image_stats = Stats(
|
|
98
|
+
self.data_config.image_means, self.data_config.image_stds
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# set target statistics
|
|
102
|
+
if self.data_config.target_means is None:
|
|
103
|
+
self.target_stats = patches_data.target_stats
|
|
104
|
+
else:
|
|
105
|
+
self.target_stats = Stats(
|
|
106
|
+
self.data_config.target_means, self.data_config.target_stds
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# update mean and std in configuration
|
|
110
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
111
|
+
self.data_config.set_means_and_stds(
|
|
112
|
+
image_means=self.image_stats.means,
|
|
113
|
+
image_stds=self.image_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
116
|
+
)
|
|
117
|
+
# get transforms
|
|
118
|
+
self.patch_transform = Compose(
|
|
119
|
+
transform_list=[
|
|
120
|
+
NormalizeModel(
|
|
121
|
+
image_means=self.image_stats.means,
|
|
122
|
+
image_stds=self.image_stats.stds,
|
|
123
|
+
target_means=self.target_stats.means,
|
|
124
|
+
target_stds=self.target_stats.stds,
|
|
125
|
+
)
|
|
126
|
+
]
|
|
127
|
+
+ self.data_config.transforms,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _prepare_patches(self, supervised: bool) -> PatchedOutput:
|
|
131
|
+
"""
|
|
132
|
+
Iterate over data source and create an array of patches.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
supervised : bool
|
|
137
|
+
Whether the dataset is supervised or not.
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
numpy.ndarray
|
|
142
|
+
Array of patches.
|
|
143
|
+
"""
|
|
144
|
+
if supervised:
|
|
145
|
+
if isinstance(self.inputs, np.ndarray) and isinstance(
|
|
146
|
+
self.input_targets, np.ndarray
|
|
147
|
+
):
|
|
148
|
+
return prepare_patches_supervised_array(
|
|
149
|
+
self.inputs,
|
|
150
|
+
self.axes,
|
|
151
|
+
self.input_targets,
|
|
152
|
+
self.patch_size,
|
|
153
|
+
)
|
|
154
|
+
elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
|
|
155
|
+
return prepare_patches_supervised(
|
|
156
|
+
self.inputs,
|
|
157
|
+
self.input_targets,
|
|
158
|
+
self.axes,
|
|
159
|
+
self.patch_size,
|
|
160
|
+
self.read_source_func,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Data and target must be of the same type, either both numpy "
|
|
165
|
+
f"arrays or both lists of paths, got {type(self.inputs)} (data) "
|
|
166
|
+
f"and {type(self.input_targets)} (target)."
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
if isinstance(self.inputs, np.ndarray):
|
|
170
|
+
return prepare_patches_unsupervised_array(
|
|
171
|
+
self.inputs,
|
|
172
|
+
self.axes,
|
|
173
|
+
self.patch_size,
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
return prepare_patches_unsupervised(
|
|
177
|
+
self.inputs,
|
|
178
|
+
self.axes,
|
|
179
|
+
self.patch_size,
|
|
180
|
+
self.read_source_func,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def __len__(self) -> int:
|
|
184
|
+
"""
|
|
185
|
+
Return the length of the dataset.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
int
|
|
190
|
+
Length of the dataset.
|
|
191
|
+
"""
|
|
192
|
+
return self.data.shape[0]
|
|
193
|
+
|
|
194
|
+
def __getitem__(self, index: int) -> tuple[np.ndarray, ...]:
|
|
195
|
+
"""
|
|
196
|
+
Return the patch corresponding to the provided index.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
index : int
|
|
201
|
+
Index of the patch to return.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
tuple of numpy.ndarray
|
|
206
|
+
Patch.
|
|
207
|
+
|
|
208
|
+
Raises
|
|
209
|
+
------
|
|
210
|
+
ValueError
|
|
211
|
+
If dataset mean and std are not set.
|
|
212
|
+
"""
|
|
213
|
+
patch = self.data[index]
|
|
214
|
+
|
|
215
|
+
# if there is a target
|
|
216
|
+
if self.data_targets is not None:
|
|
217
|
+
# get target
|
|
218
|
+
target = self.data_targets[index]
|
|
219
|
+
|
|
220
|
+
return self.patch_transform(patch=patch, target=target)
|
|
221
|
+
|
|
222
|
+
elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN
|
|
223
|
+
return self.patch_transform(patch=patch)
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"Something went wrong! No target provided (not supervised training) "
|
|
227
|
+
"and no N2V manipulation (no N2V training)."
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
231
|
+
"""Return training data statistics.
|
|
232
|
+
|
|
233
|
+
This does not return the target data statistics, only those of the input.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
tuple of list of floats
|
|
238
|
+
Means and standard deviations across channels of the training data.
|
|
239
|
+
"""
|
|
240
|
+
return self.image_stats.get_statistics()
|
|
241
|
+
|
|
242
|
+
def split_dataset(
|
|
243
|
+
self,
|
|
244
|
+
percentage: float = 0.1,
|
|
245
|
+
minimum_patches: int = 1,
|
|
246
|
+
) -> InMemoryDataset:
|
|
247
|
+
"""Split a new dataset away from the current one.
|
|
248
|
+
|
|
249
|
+
This method is used to extract random validation patches from the dataset.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
percentage : float, optional
|
|
254
|
+
Percentage of patches to extract, by default 0.1.
|
|
255
|
+
minimum_patches : int, optional
|
|
256
|
+
Minimum number of patches to extract, by default 5.
|
|
257
|
+
|
|
258
|
+
Returns
|
|
259
|
+
-------
|
|
260
|
+
CAREamics InMemoryDataset
|
|
261
|
+
New dataset with the extracted patches.
|
|
262
|
+
|
|
263
|
+
Raises
|
|
264
|
+
------
|
|
265
|
+
ValueError
|
|
266
|
+
If `percentage` is not between 0 and 1.
|
|
267
|
+
ValueError
|
|
268
|
+
If `minimum_number` is not between 1 and the number of patches.
|
|
269
|
+
"""
|
|
270
|
+
if percentage < 0 or percentage > 1:
|
|
271
|
+
raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
|
|
272
|
+
|
|
273
|
+
if minimum_patches < 1 or minimum_patches > len(self):
|
|
274
|
+
raise ValueError(
|
|
275
|
+
f"Minimum number of patches must be between 1 and "
|
|
276
|
+
f"{len(self)} (number of patches), got "
|
|
277
|
+
f"{minimum_patches}. Adjust the patch size or the minimum number of "
|
|
278
|
+
f"patches."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
total_patches = len(self)
|
|
282
|
+
|
|
283
|
+
# number of patches to extract (either percentage rounded or minimum number)
|
|
284
|
+
n_patches = max(round(total_patches * percentage), minimum_patches)
|
|
285
|
+
|
|
286
|
+
# get random indices
|
|
287
|
+
indices = np.random.choice(total_patches, n_patches, replace=False)
|
|
288
|
+
|
|
289
|
+
# extract patches
|
|
290
|
+
val_patches = self.data[indices]
|
|
291
|
+
|
|
292
|
+
# remove patches from self.patch
|
|
293
|
+
self.data = np.delete(self.data, indices, axis=0)
|
|
294
|
+
|
|
295
|
+
# same for targets
|
|
296
|
+
if self.data_targets is not None:
|
|
297
|
+
val_targets = self.data_targets[indices]
|
|
298
|
+
self.data_targets = np.delete(self.data_targets, indices, axis=0)
|
|
299
|
+
|
|
300
|
+
# clone the dataset
|
|
301
|
+
dataset = copy.deepcopy(self)
|
|
302
|
+
|
|
303
|
+
# reassign patches
|
|
304
|
+
dataset.data = val_patches
|
|
305
|
+
|
|
306
|
+
# reassign targets
|
|
307
|
+
if self.data_targets is not None:
|
|
308
|
+
dataset.data_targets = val_targets
|
|
309
|
+
|
|
310
|
+
return dataset
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""In-memory prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.transformations import NormalizeModel
|
|
12
|
+
from .dataset_utils import reshape_array
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InMemoryPredDataset(Dataset):
|
|
16
|
+
"""Simple prediction dataset returning images along the sample axis.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
prediction_config : InferenceConfig
|
|
21
|
+
Prediction configuration.
|
|
22
|
+
inputs : NDArray
|
|
23
|
+
Input data.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
prediction_config: InferenceConfig,
|
|
29
|
+
inputs: NDArray,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Constructor.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
prediction_config : InferenceConfig
|
|
36
|
+
Prediction configuration.
|
|
37
|
+
inputs : NDArray
|
|
38
|
+
Input data.
|
|
39
|
+
|
|
40
|
+
Raises
|
|
41
|
+
------
|
|
42
|
+
ValueError
|
|
43
|
+
If data_path is not a directory.
|
|
44
|
+
"""
|
|
45
|
+
self.pred_config = prediction_config
|
|
46
|
+
self.input_array = inputs
|
|
47
|
+
self.axes = self.pred_config.axes
|
|
48
|
+
self.image_means = self.pred_config.image_means
|
|
49
|
+
self.image_stds = self.pred_config.image_stds
|
|
50
|
+
|
|
51
|
+
# Reshape data
|
|
52
|
+
self.data = reshape_array(self.input_array, self.axes)
|
|
53
|
+
|
|
54
|
+
# get transforms
|
|
55
|
+
self.patch_transform = Compose(
|
|
56
|
+
transform_list=[
|
|
57
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def __len__(self) -> int:
|
|
62
|
+
"""
|
|
63
|
+
Return the length of the dataset.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
int
|
|
68
|
+
Length of the dataset.
|
|
69
|
+
"""
|
|
70
|
+
return len(self.data)
|
|
71
|
+
|
|
72
|
+
def __getitem__(self, index: int) -> NDArray:
|
|
73
|
+
"""
|
|
74
|
+
Return the patch corresponding to the provided index.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
index : int
|
|
79
|
+
Index of the patch to return.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
NDArray
|
|
84
|
+
Transformed patch.
|
|
85
|
+
"""
|
|
86
|
+
transformed_patch, _ = self.patch_transform(patch=self.data[index])
|
|
87
|
+
|
|
88
|
+
return transformed_patch
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""In-memory tiled prediction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
|
|
8
|
+
from careamics.transforms import Compose
|
|
9
|
+
|
|
10
|
+
from ..config import InferenceConfig
|
|
11
|
+
from ..config.tile_information import TileInformation
|
|
12
|
+
from ..config.transformations import NormalizeModel
|
|
13
|
+
from .dataset_utils import reshape_array
|
|
14
|
+
from .tiling import extract_tiles
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InMemoryTiledPredDataset(Dataset):
|
|
18
|
+
"""Prediction dataset storing data in memory and returning tiles of each image.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
prediction_config : InferenceConfig
|
|
23
|
+
Prediction configuration.
|
|
24
|
+
inputs : NDArray
|
|
25
|
+
Input data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
prediction_config: InferenceConfig,
|
|
31
|
+
inputs: NDArray,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
prediction_config : InferenceConfig
|
|
38
|
+
Prediction configuration.
|
|
39
|
+
inputs : NDArray
|
|
40
|
+
Input data.
|
|
41
|
+
|
|
42
|
+
Raises
|
|
43
|
+
------
|
|
44
|
+
ValueError
|
|
45
|
+
If data_path is not a directory.
|
|
46
|
+
"""
|
|
47
|
+
if (
|
|
48
|
+
prediction_config.tile_size is None
|
|
49
|
+
or prediction_config.tile_overlap is None
|
|
50
|
+
):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"Tile size and overlap must be provided to use the tiled prediction "
|
|
53
|
+
"dataset."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
self.pred_config = prediction_config
|
|
57
|
+
self.input_array = inputs
|
|
58
|
+
self.axes = self.pred_config.axes
|
|
59
|
+
self.tile_size = prediction_config.tile_size
|
|
60
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
61
|
+
self.image_means = self.pred_config.image_means
|
|
62
|
+
self.image_stds = self.pred_config.image_stds
|
|
63
|
+
|
|
64
|
+
# Generate patches
|
|
65
|
+
self.data = self._prepare_tiles()
|
|
66
|
+
|
|
67
|
+
# get transforms
|
|
68
|
+
self.patch_transform = Compose(
|
|
69
|
+
transform_list=[
|
|
70
|
+
NormalizeModel(image_means=self.image_means, image_stds=self.image_stds)
|
|
71
|
+
],
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _prepare_tiles(self) -> list[tuple[NDArray, TileInformation]]:
|
|
75
|
+
"""
|
|
76
|
+
Iterate over data source and create an array of patches.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
list of tuples of NDArray and TileInformation
|
|
81
|
+
List of tiles and tile information.
|
|
82
|
+
"""
|
|
83
|
+
# reshape array
|
|
84
|
+
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
85
|
+
|
|
86
|
+
# generate patches, which returns a generator
|
|
87
|
+
patch_generator = extract_tiles(
|
|
88
|
+
arr=reshaped_sample,
|
|
89
|
+
tile_size=self.tile_size,
|
|
90
|
+
overlaps=self.tile_overlap,
|
|
91
|
+
)
|
|
92
|
+
patches_list = list(patch_generator)
|
|
93
|
+
|
|
94
|
+
if len(patches_list) == 0:
|
|
95
|
+
raise ValueError("No tiles generated, ")
|
|
96
|
+
|
|
97
|
+
return patches_list
|
|
98
|
+
|
|
99
|
+
def __len__(self) -> int:
|
|
100
|
+
"""
|
|
101
|
+
Return the length of the dataset.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
int
|
|
106
|
+
Length of the dataset.
|
|
107
|
+
"""
|
|
108
|
+
return len(self.data)
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, index: int) -> tuple[NDArray, TileInformation]:
|
|
111
|
+
"""
|
|
112
|
+
Return the patch corresponding to the provided index.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
index : int
|
|
117
|
+
Index of the patch to return.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
tuple of NDArray and TileInformation
|
|
122
|
+
Transformed patch.
|
|
123
|
+
"""
|
|
124
|
+
tile_array, tile_info = self.data[index]
|
|
125
|
+
|
|
126
|
+
# Apply transforms
|
|
127
|
+
transformed_tile, _ = self.patch_transform(patch=tile_array)
|
|
128
|
+
|
|
129
|
+
return transformed_tile, tile_info
|