careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +1 -14
- careamics/careamist.py +212 -294
- careamics/config/__init__.py +0 -3
- careamics/config/algorithm_model.py +8 -15
- careamics/config/architectures/architecture_model.py +1 -0
- careamics/config/architectures/custom_model.py +5 -3
- careamics/config/architectures/unet_model.py +19 -0
- careamics/config/architectures/vae_model.py +1 -0
- careamics/config/callback_model.py +76 -34
- careamics/config/configuration_factory.py +18 -98
- careamics/config/configuration_model.py +23 -18
- careamics/config/data_model.py +103 -54
- careamics/config/inference_model.py +41 -19
- careamics/config/optimizer_models.py +13 -7
- careamics/config/support/supported_data.py +29 -4
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +36 -58
- careamics/config/training_model.py +5 -1
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -7
- careamics/dataset/dataset_utils/file_utils.py +2 -2
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +84 -173
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +97 -250
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/patching.py +97 -52
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/file_io/__init__.py +7 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
- careamics/file_io/write/__init__.py +9 -0
- careamics/file_io/write/get_func.py +59 -0
- careamics/file_io/write/tiff.py +39 -0
- careamics/lightning/__init__.py +17 -0
- careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
- careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
- careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +6 -3
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/prediction_outputs.py +137 -0
- careamics/prediction_utils/stitch_prediction.py +103 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/__init__.py +2 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
- careamics-0.1.0rc8.dist-info/RECORD +135 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
- careamics/config/configuration_example.py +0 -89
- careamics/dataset/dataset_utils/read_utils.py +0 -27
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- careamics-0.1.0rc6.dist-info/RECORD +0 -107
- /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
- /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
- /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,86 +3,27 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
|
+
from collections.abc import Generator
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
8
|
+
from typing import Callable, Optional
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
|
-
from torch.utils.data import IterableDataset
|
|
11
|
+
from torch.utils.data import IterableDataset
|
|
11
12
|
|
|
13
|
+
from careamics.config import DataConfig
|
|
14
|
+
from careamics.config.transformations import NormalizeModel
|
|
15
|
+
from careamics.file_io.read import read_tiff
|
|
12
16
|
from careamics.transforms import Compose
|
|
13
17
|
|
|
14
|
-
from ..config import DataConfig, InferenceConfig
|
|
15
|
-
from ..config.tile_information import TileInformation
|
|
16
|
-
from ..config.transformations import NormalizeModel
|
|
17
18
|
from ..utils.logging import get_logger
|
|
18
|
-
from .dataset_utils import
|
|
19
|
+
from .dataset_utils import iterate_over_files
|
|
20
|
+
from .dataset_utils.running_stats import WelfordStatistics
|
|
21
|
+
from .patching.patching import Stats
|
|
19
22
|
from .patching.random_patching import extract_patches_random
|
|
20
|
-
from .patching.tiled_patching import extract_tiles
|
|
21
23
|
|
|
22
24
|
logger = get_logger(__name__)
|
|
23
25
|
|
|
24
26
|
|
|
25
|
-
def _iterate_over_files(
|
|
26
|
-
data_config: Union[DataConfig, InferenceConfig],
|
|
27
|
-
data_files: List[Path],
|
|
28
|
-
target_files: Optional[List[Path]] = None,
|
|
29
|
-
read_source_func: Callable = read_tiff,
|
|
30
|
-
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
31
|
-
"""
|
|
32
|
-
Iterate over data source and yield whole image.
|
|
33
|
-
|
|
34
|
-
Parameters
|
|
35
|
-
----------
|
|
36
|
-
data_config : Union[DataConfig, InferenceConfig]
|
|
37
|
-
Data configuration.
|
|
38
|
-
data_files : List[Path]
|
|
39
|
-
List of data files.
|
|
40
|
-
target_files : Optional[List[Path]]
|
|
41
|
-
List of target files, by default None.
|
|
42
|
-
read_source_func : Optional[Callable]
|
|
43
|
-
Function to read the source, by default read_tiff.
|
|
44
|
-
|
|
45
|
-
Yields
|
|
46
|
-
------
|
|
47
|
-
np.ndarray
|
|
48
|
-
Image.
|
|
49
|
-
"""
|
|
50
|
-
# When num_workers > 0, each worker process will have a different copy of the
|
|
51
|
-
# dataset object
|
|
52
|
-
# Configuring each copy independently to avoid having duplicate data returned
|
|
53
|
-
# from the workers
|
|
54
|
-
worker_info = get_worker_info()
|
|
55
|
-
worker_id = worker_info.id if worker_info is not None else 0
|
|
56
|
-
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
57
|
-
|
|
58
|
-
# iterate over the files
|
|
59
|
-
for i, filename in enumerate(data_files):
|
|
60
|
-
# retrieve file corresponding to the worker id
|
|
61
|
-
if i % num_workers == worker_id:
|
|
62
|
-
try:
|
|
63
|
-
# read data
|
|
64
|
-
sample = read_source_func(filename, data_config.axes)
|
|
65
|
-
|
|
66
|
-
# read target, if available
|
|
67
|
-
if target_files is not None:
|
|
68
|
-
if filename.name != target_files[i].name:
|
|
69
|
-
raise ValueError(
|
|
70
|
-
f"File {filename} does not match target file "
|
|
71
|
-
f"{target_files[i]}. Have you passed sorted "
|
|
72
|
-
f"arrays?"
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# read target
|
|
76
|
-
target = read_source_func(target_files[i], data_config.axes)
|
|
77
|
-
|
|
78
|
-
yield sample, target
|
|
79
|
-
else:
|
|
80
|
-
yield sample, None
|
|
81
|
-
|
|
82
|
-
except Exception as e:
|
|
83
|
-
logger.error(f"Error reading file {filename}: {e}")
|
|
84
|
-
|
|
85
|
-
|
|
86
27
|
class PathIterableDataset(IterableDataset):
|
|
87
28
|
"""
|
|
88
29
|
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
@@ -91,38 +32,26 @@ class PathIterableDataset(IterableDataset):
|
|
|
91
32
|
----------
|
|
92
33
|
data_config : DataConfig
|
|
93
34
|
Data configuration.
|
|
94
|
-
src_files :
|
|
35
|
+
src_files : list of pathlib.Path
|
|
95
36
|
List of data files.
|
|
96
|
-
target_files :
|
|
37
|
+
target_files : list of pathlib.Path, optional
|
|
97
38
|
Optional list of target files, by default None.
|
|
98
39
|
read_source_func : Callable, optional
|
|
99
40
|
Read source function for custom types, by default read_tiff.
|
|
100
41
|
|
|
101
42
|
Attributes
|
|
102
43
|
----------
|
|
103
|
-
data_path :
|
|
44
|
+
data_path : list of pathlib.Path
|
|
104
45
|
Path to the data, must be a directory.
|
|
105
46
|
axes : str
|
|
106
47
|
Description of axes in format STCZYX.
|
|
107
|
-
patch_extraction_method : Union[ExtractionStrategies, None]
|
|
108
|
-
Patch extraction strategy, as defined in extraction_strategy.
|
|
109
|
-
patch_size : Optional[Union[List[int], Tuple[int]]], optional
|
|
110
|
-
Size of the patches in each dimension, by default None.
|
|
111
|
-
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
112
|
-
Overlap of the patches in each dimension, by default None.
|
|
113
|
-
mean : Optional[float], optional
|
|
114
|
-
Expected mean of the dataset, by default None.
|
|
115
|
-
std : Optional[float], optional
|
|
116
|
-
Expected standard deviation of the dataset, by default None.
|
|
117
|
-
patch_transform : Optional[Callable], optional
|
|
118
|
-
Patch transform callable, by default None.
|
|
119
48
|
"""
|
|
120
49
|
|
|
121
50
|
def __init__(
|
|
122
51
|
self,
|
|
123
52
|
data_config: DataConfig,
|
|
124
|
-
src_files:
|
|
125
|
-
target_files: Optional[
|
|
53
|
+
src_files: list[Path],
|
|
54
|
+
target_files: Optional[list[Path]] = None,
|
|
126
55
|
read_source_func: Callable = read_tiff,
|
|
127
56
|
) -> None:
|
|
128
57
|
"""Constructors.
|
|
@@ -131,9 +60,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
131
60
|
----------
|
|
132
61
|
data_config : DataConfig
|
|
133
62
|
Data configuration.
|
|
134
|
-
src_files :
|
|
63
|
+
src_files : list[Path]
|
|
135
64
|
List of data files.
|
|
136
|
-
target_files :
|
|
65
|
+
target_files : list[Path] or None, optional
|
|
137
66
|
Optional list of target files, by default None.
|
|
138
67
|
read_source_func : Callable, optional
|
|
139
68
|
Read source function for custom types, by default read_tiff.
|
|
@@ -141,55 +70,99 @@ class PathIterableDataset(IterableDataset):
|
|
|
141
70
|
self.data_config = data_config
|
|
142
71
|
self.data_files = src_files
|
|
143
72
|
self.target_files = target_files
|
|
144
|
-
self.data_config = data_config
|
|
145
73
|
self.read_source_func = read_source_func
|
|
146
74
|
|
|
147
75
|
# compute mean and std over the dataset
|
|
148
|
-
|
|
149
|
-
|
|
76
|
+
# only checking the image_mean because the DataConfig class ensures that
|
|
77
|
+
# if image_mean is provided, image_std is also provided
|
|
78
|
+
if not self.data_config.image_means:
|
|
79
|
+
self.image_stats, self.target_stats = self._calculate_mean_and_std()
|
|
80
|
+
logger.info(
|
|
81
|
+
f"Computed dataset mean: {self.image_stats.means},"
|
|
82
|
+
f"std: {self.image_stats.stds}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# update the mean in the config
|
|
86
|
+
self.data_config.set_means_and_stds(
|
|
87
|
+
image_means=self.image_stats.means,
|
|
88
|
+
image_stds=self.image_stats.stds,
|
|
89
|
+
target_means=(
|
|
90
|
+
list(self.target_stats.means)
|
|
91
|
+
if self.target_stats.means is not None
|
|
92
|
+
else None
|
|
93
|
+
),
|
|
94
|
+
target_stds=(
|
|
95
|
+
list(self.target_stats.stds)
|
|
96
|
+
if self.target_stats.stds is not None
|
|
97
|
+
else None
|
|
98
|
+
),
|
|
99
|
+
)
|
|
150
100
|
|
|
151
|
-
# update mean and std in configuration
|
|
152
|
-
# the object is mutable and should then be recorded in the CAREamist
|
|
153
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
154
101
|
else:
|
|
155
|
-
|
|
156
|
-
self.
|
|
102
|
+
# if mean and std are provided in the config, use them
|
|
103
|
+
self.image_stats, self.target_stats = (
|
|
104
|
+
Stats(self.data_config.image_means, self.data_config.image_stds),
|
|
105
|
+
Stats(self.data_config.target_means, self.data_config.target_stds),
|
|
106
|
+
)
|
|
157
107
|
|
|
158
|
-
#
|
|
159
|
-
self.patch_transform = Compose(
|
|
108
|
+
# create transform composed of normalization and other transforms
|
|
109
|
+
self.patch_transform = Compose(
|
|
110
|
+
transform_list=[
|
|
111
|
+
NormalizeModel(
|
|
112
|
+
image_means=self.image_stats.means,
|
|
113
|
+
image_stds=self.image_stats.stds,
|
|
114
|
+
target_means=self.target_stats.means,
|
|
115
|
+
target_stds=self.target_stats.stds,
|
|
116
|
+
)
|
|
117
|
+
]
|
|
118
|
+
+ data_config.transforms
|
|
119
|
+
)
|
|
160
120
|
|
|
161
|
-
def _calculate_mean_and_std(self) ->
|
|
121
|
+
def _calculate_mean_and_std(self) -> tuple[Stats, Stats]:
|
|
162
122
|
"""
|
|
163
123
|
Calculate mean and std of the dataset.
|
|
164
124
|
|
|
165
125
|
Returns
|
|
166
126
|
-------
|
|
167
|
-
|
|
168
|
-
|
|
127
|
+
tuple of Stats and optional Stats
|
|
128
|
+
Data classes containing the image and target statistics.
|
|
169
129
|
"""
|
|
170
|
-
means, stds = 0, 0
|
|
171
130
|
num_samples = 0
|
|
131
|
+
image_stats = WelfordStatistics()
|
|
132
|
+
if self.target_files is not None:
|
|
133
|
+
target_stats = WelfordStatistics()
|
|
172
134
|
|
|
173
|
-
for sample,
|
|
135
|
+
for sample, target in iterate_over_files(
|
|
174
136
|
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
175
137
|
):
|
|
176
|
-
|
|
177
|
-
|
|
138
|
+
# update the image statistics
|
|
139
|
+
image_stats.update(sample, num_samples)
|
|
140
|
+
|
|
141
|
+
# update the target statistics if target is available
|
|
142
|
+
if target is not None:
|
|
143
|
+
target_stats.update(target, num_samples)
|
|
144
|
+
|
|
178
145
|
num_samples += 1
|
|
179
146
|
|
|
180
147
|
if num_samples == 0:
|
|
181
148
|
raise ValueError("No samples found in the dataset.")
|
|
182
149
|
|
|
183
|
-
|
|
184
|
-
|
|
150
|
+
# Average the means and stds per sample
|
|
151
|
+
image_means, image_stds = image_stats.finalize()
|
|
185
152
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
153
|
+
if target is not None:
|
|
154
|
+
target_means, target_stds = target_stats.finalize()
|
|
155
|
+
|
|
156
|
+
return (
|
|
157
|
+
Stats(image_means, image_stds),
|
|
158
|
+
Stats(np.array(target_means), np.array(target_stds)),
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
return Stats(image_means, image_stds), Stats(None, None)
|
|
189
162
|
|
|
190
163
|
def __iter__(
|
|
191
164
|
self,
|
|
192
|
-
) -> Generator[
|
|
165
|
+
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
|
193
166
|
"""
|
|
194
167
|
Iterate over data source and yield single patch.
|
|
195
168
|
|
|
@@ -199,24 +172,17 @@ class PathIterableDataset(IterableDataset):
|
|
|
199
172
|
Single patch.
|
|
200
173
|
"""
|
|
201
174
|
assert (
|
|
202
|
-
self.
|
|
175
|
+
self.image_stats.means is not None and self.image_stats.stds is not None
|
|
203
176
|
), "Mean and std must be provided"
|
|
204
177
|
|
|
205
178
|
# iterate over files
|
|
206
|
-
for sample_input, sample_target in
|
|
179
|
+
for sample_input, sample_target in iterate_over_files(
|
|
207
180
|
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
208
181
|
):
|
|
209
|
-
reshaped_sample = reshape_array(sample_input, self.data_config.axes)
|
|
210
|
-
reshaped_target = (
|
|
211
|
-
None
|
|
212
|
-
if sample_target is None
|
|
213
|
-
else reshape_array(sample_target, self.data_config.axes)
|
|
214
|
-
)
|
|
215
|
-
|
|
216
182
|
patches = extract_patches_random(
|
|
217
|
-
arr=
|
|
183
|
+
arr=sample_input,
|
|
218
184
|
patch_size=self.data_config.patch_size,
|
|
219
|
-
target=
|
|
185
|
+
target=sample_target,
|
|
220
186
|
)
|
|
221
187
|
|
|
222
188
|
# iterate over patches
|
|
@@ -229,6 +195,16 @@ class PathIterableDataset(IterableDataset):
|
|
|
229
195
|
target=patch_data[1],
|
|
230
196
|
)
|
|
231
197
|
|
|
198
|
+
def get_data_statistics(self) -> tuple[list[float], list[float]]:
|
|
199
|
+
"""Return training data statistics.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
tuple of list of floats
|
|
204
|
+
Means and standard deviations across channels of the training data.
|
|
205
|
+
"""
|
|
206
|
+
return self.image_stats.get_statistics()
|
|
207
|
+
|
|
232
208
|
def get_number_of_files(self) -> int:
|
|
233
209
|
"""
|
|
234
210
|
Return the number of files in the dataset.
|
|
@@ -317,132 +293,3 @@ class PathIterableDataset(IterableDataset):
|
|
|
317
293
|
dataset.target_files = val_target_files
|
|
318
294
|
|
|
319
295
|
return dataset
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
class IterablePredictionDataset(IterableDataset):
|
|
323
|
-
"""
|
|
324
|
-
Prediction dataset.
|
|
325
|
-
|
|
326
|
-
Parameters
|
|
327
|
-
----------
|
|
328
|
-
prediction_config : InferenceConfig
|
|
329
|
-
Inference configuration.
|
|
330
|
-
src_files : List[Path]
|
|
331
|
-
List of data files.
|
|
332
|
-
read_source_func : Callable, optional
|
|
333
|
-
Read source function for custom types, by default read_tiff.
|
|
334
|
-
**kwargs : Any
|
|
335
|
-
Additional keyword arguments, unused.
|
|
336
|
-
|
|
337
|
-
Attributes
|
|
338
|
-
----------
|
|
339
|
-
data_path : Union[str, Path]
|
|
340
|
-
Path to the data, must be a directory.
|
|
341
|
-
axes : str
|
|
342
|
-
Description of axes in format STCZYX.
|
|
343
|
-
mean : Optional[float], optional
|
|
344
|
-
Expected mean of the dataset, by default None.
|
|
345
|
-
std : Optional[float], optional
|
|
346
|
-
Expected standard deviation of the dataset, by default None.
|
|
347
|
-
patch_transform : Optional[Callable], optional
|
|
348
|
-
Patch transform callable, by default None.
|
|
349
|
-
"""
|
|
350
|
-
|
|
351
|
-
def __init__(
|
|
352
|
-
self,
|
|
353
|
-
prediction_config: InferenceConfig,
|
|
354
|
-
src_files: List[Path],
|
|
355
|
-
read_source_func: Callable = read_tiff,
|
|
356
|
-
**kwargs: Any,
|
|
357
|
-
) -> None:
|
|
358
|
-
"""Constructor.
|
|
359
|
-
|
|
360
|
-
Parameters
|
|
361
|
-
----------
|
|
362
|
-
prediction_config : InferenceConfig
|
|
363
|
-
Inference configuration.
|
|
364
|
-
src_files : List[Path]
|
|
365
|
-
List of data files.
|
|
366
|
-
read_source_func : Callable, optional
|
|
367
|
-
Read source function for custom types, by default read_tiff.
|
|
368
|
-
**kwargs : Any
|
|
369
|
-
Additional keyword arguments, unused.
|
|
370
|
-
|
|
371
|
-
Raises
|
|
372
|
-
------
|
|
373
|
-
ValueError
|
|
374
|
-
If mean and std are not provided in the inference configuration.
|
|
375
|
-
"""
|
|
376
|
-
self.prediction_config = prediction_config
|
|
377
|
-
self.data_files = src_files
|
|
378
|
-
self.axes = prediction_config.axes
|
|
379
|
-
self.tile_size = self.prediction_config.tile_size
|
|
380
|
-
self.tile_overlap = self.prediction_config.tile_overlap
|
|
381
|
-
self.read_source_func = read_source_func
|
|
382
|
-
|
|
383
|
-
# tile only if both tile size and overlaps are provided
|
|
384
|
-
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
385
|
-
|
|
386
|
-
# check mean and std and create normalize transform
|
|
387
|
-
if self.prediction_config.mean is None or self.prediction_config.std is None:
|
|
388
|
-
raise ValueError("Mean and std must be provided for prediction.")
|
|
389
|
-
else:
|
|
390
|
-
self.mean = self.prediction_config.mean
|
|
391
|
-
self.std = self.prediction_config.std
|
|
392
|
-
|
|
393
|
-
# instantiate normalize transform
|
|
394
|
-
self.patch_transform = Compose(
|
|
395
|
-
transform_list=[
|
|
396
|
-
NormalizeModel(
|
|
397
|
-
mean=prediction_config.mean, std=prediction_config.std
|
|
398
|
-
)
|
|
399
|
-
],
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
def __iter__(
|
|
403
|
-
self,
|
|
404
|
-
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
405
|
-
"""
|
|
406
|
-
Iterate over data source and yield single patch.
|
|
407
|
-
|
|
408
|
-
Yields
|
|
409
|
-
------
|
|
410
|
-
np.ndarray
|
|
411
|
-
Single patch.
|
|
412
|
-
"""
|
|
413
|
-
assert (
|
|
414
|
-
self.mean is not None and self.std is not None
|
|
415
|
-
), "Mean and std must be provided"
|
|
416
|
-
|
|
417
|
-
for sample, _ in _iterate_over_files(
|
|
418
|
-
self.prediction_config,
|
|
419
|
-
self.data_files,
|
|
420
|
-
read_source_func=self.read_source_func,
|
|
421
|
-
):
|
|
422
|
-
# reshape array
|
|
423
|
-
reshaped_sample = reshape_array(sample, self.axes)
|
|
424
|
-
|
|
425
|
-
if (
|
|
426
|
-
self.tile
|
|
427
|
-
and self.tile_size is not None
|
|
428
|
-
and self.tile_overlap is not None
|
|
429
|
-
):
|
|
430
|
-
# generate patches, return a generator
|
|
431
|
-
patch_gen = extract_tiles(
|
|
432
|
-
arr=reshaped_sample,
|
|
433
|
-
tile_size=self.tile_size,
|
|
434
|
-
overlaps=self.tile_overlap,
|
|
435
|
-
)
|
|
436
|
-
else:
|
|
437
|
-
# just wrap the sample in a generator with default tiling info
|
|
438
|
-
array_shape = reshaped_sample.squeeze().shape
|
|
439
|
-
patch_gen = (
|
|
440
|
-
(reshaped_sample, TileInformation(array_shape=array_shape))
|
|
441
|
-
for _ in range(1)
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
# apply transform to patches
|
|
445
|
-
for patch_array, tile_info in patch_gen:
|
|
446
|
-
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
447
|
-
|
|
448
|
-
yield transformed_patch, tile_info
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Iterable prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.file_io.read import read_tiff
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
14
|
+
from ..config import InferenceConfig
|
|
15
|
+
from ..config.transformations import NormalizeModel
|
|
16
|
+
from .dataset_utils import iterate_over_files
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class IterablePredDataset(IterableDataset):
|
|
20
|
+
"""Simple iterable prediction dataset.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
prediction_config : InferenceConfig
|
|
25
|
+
Inference configuration.
|
|
26
|
+
src_files : List[Path]
|
|
27
|
+
List of data files.
|
|
28
|
+
read_source_func : Callable, optional
|
|
29
|
+
Read source function for custom types, by default read_tiff.
|
|
30
|
+
**kwargs : Any
|
|
31
|
+
Additional keyword arguments, unused.
|
|
32
|
+
|
|
33
|
+
Attributes
|
|
34
|
+
----------
|
|
35
|
+
data_path : Union[str, Path]
|
|
36
|
+
Path to the data, must be a directory.
|
|
37
|
+
axes : str
|
|
38
|
+
Description of axes in format STCZYX.
|
|
39
|
+
mean : Optional[float], optional
|
|
40
|
+
Expected mean of the dataset, by default None.
|
|
41
|
+
std : Optional[float], optional
|
|
42
|
+
Expected standard deviation of the dataset, by default None.
|
|
43
|
+
patch_transform : Optional[Callable], optional
|
|
44
|
+
Patch transform callable, by default None.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
prediction_config: InferenceConfig,
|
|
50
|
+
src_files: list[Path],
|
|
51
|
+
read_source_func: Callable = read_tiff,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Constructor.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
prediction_config : InferenceConfig
|
|
59
|
+
Inference configuration.
|
|
60
|
+
src_files : list of pathlib.Path
|
|
61
|
+
List of data files.
|
|
62
|
+
read_source_func : Callable, optional
|
|
63
|
+
Read source function for custom types, by default read_tiff.
|
|
64
|
+
**kwargs : Any
|
|
65
|
+
Additional keyword arguments, unused.
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
ValueError
|
|
70
|
+
If mean and std are not provided in the inference configuration.
|
|
71
|
+
"""
|
|
72
|
+
self.prediction_config = prediction_config
|
|
73
|
+
self.data_files = src_files
|
|
74
|
+
self.axes = prediction_config.axes
|
|
75
|
+
self.read_source_func = read_source_func
|
|
76
|
+
|
|
77
|
+
# check mean and std and create normalize transform
|
|
78
|
+
if (
|
|
79
|
+
self.prediction_config.image_means is None
|
|
80
|
+
or self.prediction_config.image_stds is None
|
|
81
|
+
):
|
|
82
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
83
|
+
else:
|
|
84
|
+
self.image_means = self.prediction_config.image_means
|
|
85
|
+
self.image_stds = self.prediction_config.image_stds
|
|
86
|
+
|
|
87
|
+
# instantiate normalize transform
|
|
88
|
+
self.patch_transform = Compose(
|
|
89
|
+
transform_list=[
|
|
90
|
+
NormalizeModel(
|
|
91
|
+
image_means=self.image_means,
|
|
92
|
+
image_stds=self.image_stds,
|
|
93
|
+
)
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
def __iter__(
|
|
98
|
+
self,
|
|
99
|
+
) -> Generator[NDArray, None, None]:
|
|
100
|
+
"""
|
|
101
|
+
Iterate over data source and yield single patch.
|
|
102
|
+
|
|
103
|
+
Yields
|
|
104
|
+
------
|
|
105
|
+
NDArray
|
|
106
|
+
Single patch.
|
|
107
|
+
"""
|
|
108
|
+
assert (
|
|
109
|
+
self.image_means is not None and self.image_stds is not None
|
|
110
|
+
), "Mean and std must be provided"
|
|
111
|
+
|
|
112
|
+
for sample, _ in iterate_over_files(
|
|
113
|
+
self.prediction_config,
|
|
114
|
+
self.data_files,
|
|
115
|
+
read_source_func=self.read_source_func,
|
|
116
|
+
):
|
|
117
|
+
# sample has S dimension
|
|
118
|
+
for i in range(sample.shape[0]):
|
|
119
|
+
|
|
120
|
+
transformed_sample, _ = self.patch_transform(patch=sample[i])
|
|
121
|
+
|
|
122
|
+
yield transformed_sample
|