careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""Iterable dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
from collections.abc import Generator
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Callable, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from torch.utils.data import IterableDataset
|
|
12
|
+
|
|
13
|
+
from careamics.config import DataConfig
|
|
14
|
+
from careamics.config.transformations import NormalizeModel
|
|
15
|
+
from careamics.file_io.read import read_tiff
|
|
16
|
+
from careamics.transforms import Compose
|
|
17
|
+
|
|
18
|
+
from ..utils.logging import get_logger
|
|
19
|
+
from .dataset_utils import iterate_over_files
|
|
20
|
+
from .dataset_utils.running_stats import WelfordStatistics
|
|
21
|
+
from .patching.patching import Stats
|
|
22
|
+
from .patching.random_patching import extract_patches_random
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PathIterableDataset(IterableDataset):
|
|
28
|
+
"""
|
|
29
|
+
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
data_config : DataConfig
|
|
34
|
+
Data configuration.
|
|
35
|
+
src_files : list of pathlib.Path
|
|
36
|
+
List of data files.
|
|
37
|
+
target_files : list of pathlib.Path, optional
|
|
38
|
+
Optional list of target files, by default None.
|
|
39
|
+
read_source_func : Callable, optional
|
|
40
|
+
Read source function for custom types, by default read_tiff.
|
|
41
|
+
|
|
42
|
+
Attributes
|
|
43
|
+
----------
|
|
44
|
+
data_path : list of pathlib.Path
|
|
45
|
+
Path to the data, must be a directory.
|
|
46
|
+
axes : str
|
|
47
|
+
Description of axes in format STCZYX.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
data_config: DataConfig,
|
|
53
|
+
src_files: list[Path],
|
|
54
|
+
target_files: Optional[list[Path]] = None,
|
|
55
|
+
read_source_func: Callable = read_tiff,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Constructors.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
data_config : DataConfig
|
|
62
|
+
Data configuration.
|
|
63
|
+
src_files : list[Path]
|
|
64
|
+
List of data files.
|
|
65
|
+
target_files : list[Path] or None, optional
|
|
66
|
+
Optional list of target files, by default None.
|
|
67
|
+
read_source_func : Callable, optional
|
|
68
|
+
Read source function for custom types, by default read_tiff.
|
|
69
|
+
"""
|
|
70
|
+
self.data_config = data_config
|
|
71
|
+
self.data_files = src_files
|
|
72
|
+
self.target_files = target_files
|
|
73
|
+
self.read_source_func = read_source_func
|
|
74
|
+
|
|
75
|
+
# compute mean and std over the dataset
|
|
76
|
+
# only checking the image_mean because the DataConfig class ensures that
|
|
77
|
+
# if image_mean is provided, image_std is also provided
|
|
78
|
+
if not self.data_config.image_means:
|
|
79
|
+
self.image_stats, self.target_stats = self._calculate_mean_and_std()
|
|
80
|
+
logger.info(
|
|
81
|
+
f"Computed dataset mean: {self.image_stats.means},"
|
|
82
|
+
f"std: {self.image_stats.stds}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# update the mean in the config
|
|
86
|
+
self.data_config.set_means_and_stds(
|
|
87
|
+
image_means=self.image_stats.means,
|
|
88
|
+
image_stds=self.image_stats.stds,
|
|
89
|
+
target_means=(
|
|
90
|
+
list(self.target_stats.means)
|
|
91
|
+
if self.target_stats.means is not None
|
|
92
|
+
else None
|
|
93
|
+
),
|
|
94
|
+
target_stds=(
|
|
95
|
+
list(self.target_stats.stds)
|
|
96
|
+
if self.target_stats.stds is not None
|
|
97
|
+
else None
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
# if mean and std are provided in the config, use them
|
|
103
|
+
self.image_stats, self.target_stats = (
|
|
104
|
+
Stats(self.data_config.image_means, self.data_config.image_stds),
|
|
105
|
+
Stats(self.data_config.target_means, self.data_config.target_stds),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# create transform composed of normalization and other transforms
|
|
109
|
+
self.patch_transform = Compose(
|
|
110
|
+
transform_list=[
|
|
111
|
+
NormalizeModel(
|
|
112
|
+
image_means=self.image_stats.means,
|
|
113
|
+
image_stds=self.image_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
116
|
+
)
|
|
117
|
+
]
|
|
118
|
+
+ data_config.transforms
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
|
|
122
|
+
"""
|
|
123
|
+
Calculate mean and std of the dataset.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
tuple of Stats and optional Stats
|
|
128
|
+
Data classes containing the image and target statistics.
|
|
129
|
+
"""
|
|
130
|
+
num_samples = 0
|
|
131
|
+
image_stats = WelfordStatistics()
|
|
132
|
+
if self.target_files is not None:
|
|
133
|
+
target_stats = WelfordStatistics()
|
|
134
|
+
|
|
135
|
+
for sample, target in iterate_over_files(
|
|
136
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
137
|
+
):
|
|
138
|
+
# update the image statistics
|
|
139
|
+
image_stats.update(sample, num_samples)
|
|
140
|
+
|
|
141
|
+
# update the target statistics if target is available
|
|
142
|
+
if target is not None:
|
|
143
|
+
target_stats.update(target, num_samples)
|
|
144
|
+
|
|
145
|
+
num_samples += 1
|
|
146
|
+
|
|
147
|
+
if num_samples == 0:
|
|
148
|
+
raise ValueError("No samples found in the dataset.")
|
|
149
|
+
|
|
150
|
+
# Average the means and stds per sample
|
|
151
|
+
image_means, image_stds = image_stats.finalize()
|
|
152
|
+
|
|
153
|
+
if target is not None:
|
|
154
|
+
target_means, target_stds = target_stats.finalize()
|
|
155
|
+
|
|
156
|
+
return (
|
|
157
|
+
Stats(image_means, image_stds),
|
|
158
|
+
Stats(np.array(target_means), np.array(target_stds)),
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
return Stats(image_means, image_stds), Stats(None, None)
|
|
162
|
+
|
|
163
|
+
def __iter__(
|
|
164
|
+
self,
|
|
165
|
+
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
|
166
|
+
"""
|
|
167
|
+
Iterate over data source and yield single patch.
|
|
168
|
+
|
|
169
|
+
Yields
|
|
170
|
+
------
|
|
171
|
+
np.ndarray
|
|
172
|
+
Single patch.
|
|
173
|
+
"""
|
|
174
|
+
assert (
|
|
175
|
+
self.image_stats.means is not None and self.image_stats.stds is not None
|
|
176
|
+
), "Mean and std must be provided"
|
|
177
|
+
|
|
178
|
+
# iterate over files
|
|
179
|
+
for sample_input, sample_target in iterate_over_files(
|
|
180
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
181
|
+
):
|
|
182
|
+
patches = extract_patches_random(
|
|
183
|
+
arr=sample_input,
|
|
184
|
+
patch_size=self.data_config.patch_size,
|
|
185
|
+
target=sample_target,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# iterate over patches
|
|
189
|
+
# patches are tuples of (patch, target) if target is available
|
|
190
|
+
# or (patch, None) only if no target is available
|
|
191
|
+
# patch is of dimensions (C)ZYX
|
|
192
|
+
for patch_data in patches:
|
|
193
|
+
yield self.patch_transform(
|
|
194
|
+
patch=patch_data[0],
|
|
195
|
+
target=patch_data[1],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
199
|
+
"""Return training data statistics.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
tuple of list of floats
|
|
204
|
+
Means and standard deviations across channels of the training data.
|
|
205
|
+
"""
|
|
206
|
+
return self.image_stats.get_statistics()
|
|
207
|
+
|
|
208
|
+
def get_number_of_files(self) -> int:
|
|
209
|
+
"""
|
|
210
|
+
Return the number of files in the dataset.
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
int
|
|
215
|
+
Number of files in the dataset.
|
|
216
|
+
"""
|
|
217
|
+
return len(self.data_files)
|
|
218
|
+
|
|
219
|
+
def split_dataset(
|
|
220
|
+
self,
|
|
221
|
+
percentage: float = 0.1,
|
|
222
|
+
minimum_number: int = 5,
|
|
223
|
+
) -> PathIterableDataset:
|
|
224
|
+
"""Split up dataset in two.
|
|
225
|
+
|
|
226
|
+
Splits the datest sing a percentage of the data (files) to extract, or the
|
|
227
|
+
minimum number of the percentage is less than the minimum number.
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
percentage : float, optional
|
|
232
|
+
Percentage of files to split up, by default 0.1.
|
|
233
|
+
minimum_number : int, optional
|
|
234
|
+
Minimum number of files to split up, by default 5.
|
|
235
|
+
|
|
236
|
+
Returns
|
|
237
|
+
-------
|
|
238
|
+
IterableDataset
|
|
239
|
+
Dataset containing the split data.
|
|
240
|
+
|
|
241
|
+
Raises
|
|
242
|
+
------
|
|
243
|
+
ValueError
|
|
244
|
+
If the percentage is smaller than 0 or larger than 1.
|
|
245
|
+
ValueError
|
|
246
|
+
If the minimum number is smaller than 1 or larger than the number of files.
|
|
247
|
+
"""
|
|
248
|
+
if percentage < 0 or percentage > 1:
|
|
249
|
+
raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
|
|
250
|
+
|
|
251
|
+
if minimum_number < 1 or minimum_number > self.get_number_of_files():
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"Minimum number of files must be between 1 and "
|
|
254
|
+
f"{self.get_number_of_files()} (number of files), got "
|
|
255
|
+
f"{minimum_number}."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# compute number of files
|
|
259
|
+
total_files = self.get_number_of_files()
|
|
260
|
+
n_files = max(round(percentage * total_files), minimum_number)
|
|
261
|
+
|
|
262
|
+
# get random indices
|
|
263
|
+
indices = np.random.choice(total_files, n_files, replace=False)
|
|
264
|
+
|
|
265
|
+
# extract files
|
|
266
|
+
val_files = [self.data_files[i] for i in indices]
|
|
267
|
+
|
|
268
|
+
# remove patches from self.patch
|
|
269
|
+
data_files = []
|
|
270
|
+
for i, file in enumerate(self.data_files):
|
|
271
|
+
if i not in indices:
|
|
272
|
+
data_files.append(file)
|
|
273
|
+
self.data_files = data_files
|
|
274
|
+
|
|
275
|
+
# same for targets
|
|
276
|
+
if self.target_files is not None:
|
|
277
|
+
val_target_files = [self.target_files[i] for i in indices]
|
|
278
|
+
|
|
279
|
+
data_target_files = []
|
|
280
|
+
for i, file in enumerate(self.target_files):
|
|
281
|
+
if i not in indices:
|
|
282
|
+
data_target_files.append(file)
|
|
283
|
+
self.target_files = data_target_files
|
|
284
|
+
|
|
285
|
+
# clone the dataset
|
|
286
|
+
dataset = copy.deepcopy(self)
|
|
287
|
+
|
|
288
|
+
# reassign patches
|
|
289
|
+
dataset.data_files = val_files
|
|
290
|
+
|
|
291
|
+
# reassign targets
|
|
292
|
+
if self.target_files is not None:
|
|
293
|
+
dataset.target_files = val_target_files
|
|
294
|
+
|
|
295
|
+
return dataset
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Iterable prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.file_io.read import read_tiff
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
14
|
+
from ..config import InferenceConfig
|
|
15
|
+
from ..config.transformations import NormalizeModel
|
|
16
|
+
from .dataset_utils import iterate_over_files
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class IterablePredDataset(IterableDataset):
|
|
20
|
+
"""Simple iterable prediction dataset.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
prediction_config : InferenceConfig
|
|
25
|
+
Inference configuration.
|
|
26
|
+
src_files : List[Path]
|
|
27
|
+
List of data files.
|
|
28
|
+
read_source_func : Callable, optional
|
|
29
|
+
Read source function for custom types, by default read_tiff.
|
|
30
|
+
**kwargs : Any
|
|
31
|
+
Additional keyword arguments, unused.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
data_path : Union[str, Path]
|
|
36
|
+
Path to the data, must be a directory.
|
|
37
|
+
axes : str
|
|
38
|
+
Description of axes in format STCZYX.
|
|
39
|
+
mean : Optional[float], optional
|
|
40
|
+
Expected mean of the dataset, by default None.
|
|
41
|
+
std : Optional[float], optional
|
|
42
|
+
Expected standard deviation of the dataset, by default None.
|
|
43
|
+
patch_transform : Optional[Callable], optional
|
|
44
|
+
Patch transform callable, by default None.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
prediction_config: InferenceConfig,
|
|
50
|
+
src_files: list[Path],
|
|
51
|
+
read_source_func: Callable = read_tiff,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Constructor.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
prediction_config : InferenceConfig
|
|
59
|
+
Inference configuration.
|
|
60
|
+
src_files : list of pathlib.Path
|
|
61
|
+
List of data files.
|
|
62
|
+
read_source_func : Callable, optional
|
|
63
|
+
Read source function for custom types, by default read_tiff.
|
|
64
|
+
**kwargs : Any
|
|
65
|
+
Additional keyword arguments, unused.
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
ValueError
|
|
70
|
+
If mean and std are not provided in the inference configuration.
|
|
71
|
+
"""
|
|
72
|
+
self.prediction_config = prediction_config
|
|
73
|
+
self.data_files = src_files
|
|
74
|
+
self.axes = prediction_config.axes
|
|
75
|
+
self.read_source_func = read_source_func
|
|
76
|
+
|
|
77
|
+
# check mean and std and create normalize transform
|
|
78
|
+
if (
|
|
79
|
+
self.prediction_config.image_means is None
|
|
80
|
+
or self.prediction_config.image_stds is None
|
|
81
|
+
):
|
|
82
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
83
|
+
else:
|
|
84
|
+
self.image_means = self.prediction_config.image_means
|
|
85
|
+
self.image_stds = self.prediction_config.image_stds
|
|
86
|
+
|
|
87
|
+
# instantiate normalize transform
|
|
88
|
+
self.patch_transform = Compose(
|
|
89
|
+
transform_list=[
|
|
90
|
+
NormalizeModel(
|
|
91
|
+
image_means=self.image_means,
|
|
92
|
+
image_stds=self.image_stds,
|
|
93
|
+
)
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def __iter__(
|
|
98
|
+
self,
|
|
99
|
+
) -> Generator[NDArray, None, None]:
|
|
100
|
+
"""
|
|
101
|
+
Iterate over data source and yield single patch.
|
|
102
|
+
|
|
103
|
+
Yields
|
|
104
|
+
------
|
|
105
|
+
NDArray
|
|
106
|
+
Single patch.
|
|
107
|
+
"""
|
|
108
|
+
assert (
|
|
109
|
+
self.image_means is not None and self.image_stds is not None
|
|
110
|
+
), "Mean and std must be provided"
|
|
111
|
+
|
|
112
|
+
for sample, _ in iterate_over_files(
|
|
113
|
+
self.prediction_config,
|
|
114
|
+
self.data_files,
|
|
115
|
+
read_source_func=self.read_source_func,
|
|
116
|
+
):
|
|
117
|
+
# sample has S dimension
|
|
118
|
+
for i in range(sample.shape[0]):
|
|
119
|
+
|
|
120
|
+
transformed_sample, _ = self.patch_transform(patch=sample[i])
|
|
121
|
+
|
|
122
|
+
yield transformed_sample
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Iterable tiled prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.file_io.read import read_tiff
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
14
|
+
from ..config import InferenceConfig
|
|
15
|
+
from ..config.tile_information import TileInformation
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
17
|
+
from .dataset_utils import iterate_over_files
|
|
18
|
+
from .tiling import extract_tiles
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class IterableTiledPredDataset(IterableDataset):
|
|
22
|
+
"""Tiled prediction dataset.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
prediction_config : InferenceConfig
|
|
27
|
+
Inference configuration.
|
|
28
|
+
src_files : list of pathlib.Path
|
|
29
|
+
List of data files.
|
|
30
|
+
read_source_func : Callable, optional
|
|
31
|
+
Read source function for custom types, by default read_tiff.
|
|
32
|
+
**kwargs : Any
|
|
33
|
+
Additional keyword arguments, unused.
|
|
34
|
+
|
|
35
|
+
Attributes
|
|
36
|
+
----------
|
|
37
|
+
data_path : str or pathlib.Path
|
|
38
|
+
Path to the data, must be a directory.
|
|
39
|
+
axes : str
|
|
40
|
+
Description of axes in format STCZYX.
|
|
41
|
+
mean : float, optional
|
|
42
|
+
Expected mean of the dataset, by default None.
|
|
43
|
+
std : float, optional
|
|
44
|
+
Expected standard deviation of the dataset, by default None.
|
|
45
|
+
patch_transform : Callable, optional
|
|
46
|
+
Patch transform callable, by default None.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
prediction_config: InferenceConfig,
|
|
52
|
+
src_files: list[Path],
|
|
53
|
+
read_source_func: Callable = read_tiff,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Constructor.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
prediction_config : InferenceConfig
|
|
61
|
+
Inference configuration.
|
|
62
|
+
src_files : List[Path]
|
|
63
|
+
List of data files.
|
|
64
|
+
read_source_func : Callable, optional
|
|
65
|
+
Read source function for custom types, by default read_tiff.
|
|
66
|
+
**kwargs : Any
|
|
67
|
+
Additional keyword arguments, unused.
|
|
68
|
+
|
|
69
|
+
Raises
|
|
70
|
+
------
|
|
71
|
+
ValueError
|
|
72
|
+
If mean and std are not provided in the inference configuration.
|
|
73
|
+
"""
|
|
74
|
+
if (
|
|
75
|
+
prediction_config.tile_size is None
|
|
76
|
+
or prediction_config.tile_overlap is None
|
|
77
|
+
):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Tile size and overlap must be provided for tiled prediction."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.prediction_config = prediction_config
|
|
83
|
+
self.data_files = src_files
|
|
84
|
+
self.axes = prediction_config.axes
|
|
85
|
+
self.tile_size = prediction_config.tile_size
|
|
86
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
87
|
+
self.read_source_func = read_source_func
|
|
88
|
+
|
|
89
|
+
# check mean and std and create normalize transform
|
|
90
|
+
if (
|
|
91
|
+
self.prediction_config.image_means is None
|
|
92
|
+
or self.prediction_config.image_stds is None
|
|
93
|
+
):
|
|
94
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
95
|
+
else:
|
|
96
|
+
self.image_means = self.prediction_config.image_means
|
|
97
|
+
self.image_stds = self.prediction_config.image_stds
|
|
98
|
+
|
|
99
|
+
# instantiate normalize transform
|
|
100
|
+
self.patch_transform = Compose(
|
|
101
|
+
transform_list=[
|
|
102
|
+
NormalizeModel(
|
|
103
|
+
image_means=self.image_means,
|
|
104
|
+
image_stds=self.image_stds,
|
|
105
|
+
)
|
|
106
|
+
],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def __iter__(
|
|
110
|
+
self,
|
|
111
|
+
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
112
|
+
"""
|
|
113
|
+
Iterate over data source and yield single patch.
|
|
114
|
+
|
|
115
|
+
Yields
|
|
116
|
+
------
|
|
117
|
+
Generator of NDArray and TileInformation tuple
|
|
118
|
+
Generator of single tiles.
|
|
119
|
+
"""
|
|
120
|
+
assert (
|
|
121
|
+
self.image_means is not None and self.image_stds is not None
|
|
122
|
+
), "Mean and std must be provided"
|
|
123
|
+
|
|
124
|
+
for sample, _ in iterate_over_files(
|
|
125
|
+
self.prediction_config,
|
|
126
|
+
self.data_files,
|
|
127
|
+
read_source_func=self.read_source_func,
|
|
128
|
+
):
|
|
129
|
+
# generate patches, return a generator of single tiles
|
|
130
|
+
patch_gen = extract_tiles(
|
|
131
|
+
arr=sample,
|
|
132
|
+
tile_size=self.tile_size,
|
|
133
|
+
overlaps=self.tile_overlap,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# apply transform to patches
|
|
137
|
+
for patch_array, tile_info in patch_gen:
|
|
138
|
+
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
139
|
+
|
|
140
|
+
yield transformed_patch, tile_info
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Patching and tiling functions."""
|