careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,86 +3,29 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import copy
|
|
6
|
+
from collections.abc import Generator
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
8
|
+
from typing import Callable, Optional
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
|
-
from torch.utils.data import IterableDataset
|
|
11
|
+
from torch.utils.data import IterableDataset
|
|
11
12
|
|
|
13
|
+
from careamics.config import DataConfig
|
|
14
|
+
from careamics.config.transformations import NormalizeModel
|
|
12
15
|
from careamics.transforms import Compose
|
|
13
16
|
|
|
14
|
-
from ..config import DataConfig, InferenceConfig
|
|
15
|
-
from ..config.tile_information import TileInformation
|
|
16
|
-
from ..config.transformations import NormalizeModel
|
|
17
17
|
from ..utils.logging import get_logger
|
|
18
|
-
from .dataset_utils import
|
|
18
|
+
from .dataset_utils import (
|
|
19
|
+
iterate_over_files,
|
|
20
|
+
read_tiff,
|
|
21
|
+
)
|
|
22
|
+
from .dataset_utils.running_stats import WelfordStatistics
|
|
23
|
+
from .patching.patching import Stats, StatsOutput
|
|
19
24
|
from .patching.random_patching import extract_patches_random
|
|
20
|
-
from .patching.tiled_patching import extract_tiles
|
|
21
25
|
|
|
22
26
|
logger = get_logger(__name__)
|
|
23
27
|
|
|
24
28
|
|
|
25
|
-
def _iterate_over_files(
|
|
26
|
-
data_config: Union[DataConfig, InferenceConfig],
|
|
27
|
-
data_files: List[Path],
|
|
28
|
-
target_files: Optional[List[Path]] = None,
|
|
29
|
-
read_source_func: Callable = read_tiff,
|
|
30
|
-
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
31
|
-
"""
|
|
32
|
-
Iterate over data source and yield whole image.
|
|
33
|
-
|
|
34
|
-
Parameters
|
|
35
|
-
----------
|
|
36
|
-
data_config : Union[DataConfig, InferenceConfig]
|
|
37
|
-
Data configuration.
|
|
38
|
-
data_files : List[Path]
|
|
39
|
-
List of data files.
|
|
40
|
-
target_files : Optional[List[Path]]
|
|
41
|
-
List of target files, by default None.
|
|
42
|
-
read_source_func : Optional[Callable]
|
|
43
|
-
Function to read the source, by default read_tiff.
|
|
44
|
-
|
|
45
|
-
Yields
|
|
46
|
-
------
|
|
47
|
-
np.ndarray
|
|
48
|
-
Image.
|
|
49
|
-
"""
|
|
50
|
-
# When num_workers > 0, each worker process will have a different copy of the
|
|
51
|
-
# dataset object
|
|
52
|
-
# Configuring each copy independently to avoid having duplicate data returned
|
|
53
|
-
# from the workers
|
|
54
|
-
worker_info = get_worker_info()
|
|
55
|
-
worker_id = worker_info.id if worker_info is not None else 0
|
|
56
|
-
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
57
|
-
|
|
58
|
-
# iterate over the files
|
|
59
|
-
for i, filename in enumerate(data_files):
|
|
60
|
-
# retrieve file corresponding to the worker id
|
|
61
|
-
if i % num_workers == worker_id:
|
|
62
|
-
try:
|
|
63
|
-
# read data
|
|
64
|
-
sample = read_source_func(filename, data_config.axes)
|
|
65
|
-
|
|
66
|
-
# read target, if available
|
|
67
|
-
if target_files is not None:
|
|
68
|
-
if filename.name != target_files[i].name:
|
|
69
|
-
raise ValueError(
|
|
70
|
-
f"File {filename} does not match target file "
|
|
71
|
-
f"{target_files[i]}. Have you passed sorted "
|
|
72
|
-
f"arrays?"
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# read target
|
|
76
|
-
target = read_source_func(target_files[i], data_config.axes)
|
|
77
|
-
|
|
78
|
-
yield sample, target
|
|
79
|
-
else:
|
|
80
|
-
yield sample, None
|
|
81
|
-
|
|
82
|
-
except Exception as e:
|
|
83
|
-
logger.error(f"Error reading file {filename}: {e}")
|
|
84
|
-
|
|
85
|
-
|
|
86
29
|
class PathIterableDataset(IterableDataset):
|
|
87
30
|
"""
|
|
88
31
|
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
@@ -91,38 +34,26 @@ class PathIterableDataset(IterableDataset):
|
|
|
91
34
|
----------
|
|
92
35
|
data_config : DataConfig
|
|
93
36
|
Data configuration.
|
|
94
|
-
src_files :
|
|
37
|
+
src_files : list of pathlib.Path
|
|
95
38
|
List of data files.
|
|
96
|
-
target_files :
|
|
39
|
+
target_files : list of pathlib.Path, optional
|
|
97
40
|
Optional list of target files, by default None.
|
|
98
41
|
read_source_func : Callable, optional
|
|
99
42
|
Read source function for custom types, by default read_tiff.
|
|
100
43
|
|
|
101
44
|
Attributes
|
|
102
45
|
----------
|
|
103
|
-
data_path :
|
|
46
|
+
data_path : list of pathlib.Path
|
|
104
47
|
Path to the data, must be a directory.
|
|
105
48
|
axes : str
|
|
106
49
|
Description of axes in format STCZYX.
|
|
107
|
-
patch_extraction_method : Union[ExtractionStrategies, None]
|
|
108
|
-
Patch extraction strategy, as defined in extraction_strategy.
|
|
109
|
-
patch_size : Optional[Union[List[int], Tuple[int]]], optional
|
|
110
|
-
Size of the patches in each dimension, by default None.
|
|
111
|
-
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
112
|
-
Overlap of the patches in each dimension, by default None.
|
|
113
|
-
mean : Optional[float], optional
|
|
114
|
-
Expected mean of the dataset, by default None.
|
|
115
|
-
std : Optional[float], optional
|
|
116
|
-
Expected standard deviation of the dataset, by default None.
|
|
117
|
-
patch_transform : Optional[Callable], optional
|
|
118
|
-
Patch transform callable, by default None.
|
|
119
50
|
"""
|
|
120
51
|
|
|
121
52
|
def __init__(
|
|
122
53
|
self,
|
|
123
54
|
data_config: DataConfig,
|
|
124
|
-
src_files:
|
|
125
|
-
target_files: Optional[
|
|
55
|
+
src_files: list[Path],
|
|
56
|
+
target_files: Optional[list[Path]] = None,
|
|
126
57
|
read_source_func: Callable = read_tiff,
|
|
127
58
|
) -> None:
|
|
128
59
|
"""Constructors.
|
|
@@ -131,9 +62,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
131
62
|
----------
|
|
132
63
|
data_config : DataConfig
|
|
133
64
|
Data configuration.
|
|
134
|
-
src_files :
|
|
65
|
+
src_files : list[Path]
|
|
135
66
|
List of data files.
|
|
136
|
-
target_files :
|
|
67
|
+
target_files : list[Path] or None, optional
|
|
137
68
|
Optional list of target files, by default None.
|
|
138
69
|
read_source_func : Callable, optional
|
|
139
70
|
Read source function for custom types, by default read_tiff.
|
|
@@ -141,55 +72,102 @@ class PathIterableDataset(IterableDataset):
|
|
|
141
72
|
self.data_config = data_config
|
|
142
73
|
self.data_files = src_files
|
|
143
74
|
self.target_files = target_files
|
|
144
|
-
self.data_config = data_config
|
|
145
75
|
self.read_source_func = read_source_func
|
|
146
76
|
|
|
147
77
|
# compute mean and std over the dataset
|
|
148
|
-
|
|
149
|
-
|
|
78
|
+
# only checking the image_mean because the DataConfig class ensures that
|
|
79
|
+
# if image_mean is provided, image_std is also provided
|
|
80
|
+
if not self.data_config.image_means:
|
|
81
|
+
self.data_stats = self._calculate_mean_and_std()
|
|
82
|
+
logger.info(
|
|
83
|
+
f"Computed dataset mean: {self.data_stats.image_stats.means},"
|
|
84
|
+
f"std: {self.data_stats.image_stats.stds}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# update the mean in the config
|
|
88
|
+
self.data_config.set_mean_and_std(
|
|
89
|
+
image_means=self.data_stats.image_stats.means,
|
|
90
|
+
image_stds=self.data_stats.image_stats.stds,
|
|
91
|
+
target_means=(
|
|
92
|
+
list(self.data_stats.target_stats.means)
|
|
93
|
+
if self.data_stats.target_stats.means is not None
|
|
94
|
+
else None
|
|
95
|
+
),
|
|
96
|
+
target_stds=(
|
|
97
|
+
list(self.data_stats.target_stats.stds)
|
|
98
|
+
if self.data_stats.target_stats.stds is not None
|
|
99
|
+
else None
|
|
100
|
+
),
|
|
101
|
+
)
|
|
150
102
|
|
|
151
|
-
# update mean and std in configuration
|
|
152
|
-
# the object is mutable and should then be recorded in the CAREamist
|
|
153
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
154
103
|
else:
|
|
155
|
-
|
|
156
|
-
self.
|
|
104
|
+
# if mean and std are provided in the config, use them
|
|
105
|
+
self.data_stats = StatsOutput(
|
|
106
|
+
Stats(self.data_config.image_means, self.data_config.image_stds),
|
|
107
|
+
Stats(self.data_config.target_means, self.data_config.target_stds),
|
|
108
|
+
)
|
|
157
109
|
|
|
158
|
-
#
|
|
159
|
-
self.patch_transform = Compose(
|
|
110
|
+
# create transform composed of normalization and other transforms
|
|
111
|
+
self.patch_transform = Compose(
|
|
112
|
+
transform_list=[
|
|
113
|
+
NormalizeModel(
|
|
114
|
+
image_means=self.data_stats.image_stats.means,
|
|
115
|
+
image_stds=self.data_stats.image_stats.stds,
|
|
116
|
+
target_means=self.data_stats.target_stats.means,
|
|
117
|
+
target_stds=self.data_stats.target_stats.stds,
|
|
118
|
+
)
|
|
119
|
+
]
|
|
120
|
+
+ data_config.transforms
|
|
121
|
+
)
|
|
160
122
|
|
|
161
|
-
def _calculate_mean_and_std(self) ->
|
|
123
|
+
def _calculate_mean_and_std(self) -> StatsOutput:
|
|
162
124
|
"""
|
|
163
125
|
Calculate mean and std of the dataset.
|
|
164
126
|
|
|
165
127
|
Returns
|
|
166
128
|
-------
|
|
167
|
-
|
|
168
|
-
|
|
129
|
+
PatchedOutput
|
|
130
|
+
Data class containing the image statistics.
|
|
169
131
|
"""
|
|
170
|
-
means, stds = 0, 0
|
|
171
132
|
num_samples = 0
|
|
133
|
+
image_stats = WelfordStatistics()
|
|
134
|
+
if self.target_files is not None:
|
|
135
|
+
target_stats = WelfordStatistics()
|
|
172
136
|
|
|
173
|
-
for sample,
|
|
137
|
+
for sample, target in iterate_over_files(
|
|
174
138
|
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
175
139
|
):
|
|
176
|
-
|
|
177
|
-
|
|
140
|
+
# update the image statistics
|
|
141
|
+
image_stats.update(sample, num_samples)
|
|
142
|
+
|
|
143
|
+
# update the target statistics if target is available
|
|
144
|
+
if target is not None:
|
|
145
|
+
target_stats.update(target, num_samples)
|
|
146
|
+
|
|
178
147
|
num_samples += 1
|
|
179
148
|
|
|
180
149
|
if num_samples == 0:
|
|
181
150
|
raise ValueError("No samples found in the dataset.")
|
|
182
151
|
|
|
183
|
-
|
|
184
|
-
|
|
152
|
+
# Average the means and stds per sample
|
|
153
|
+
image_means, image_stds = image_stats.finalize()
|
|
154
|
+
|
|
155
|
+
if target is not None:
|
|
156
|
+
target_means, target_stds = target_stats.finalize()
|
|
185
157
|
|
|
186
158
|
logger.info(f"Calculated mean and std for {num_samples} images")
|
|
187
|
-
logger.info(f"Mean: {
|
|
188
|
-
return
|
|
159
|
+
logger.info(f"Mean: {image_means}, std: {image_stds}")
|
|
160
|
+
return StatsOutput(
|
|
161
|
+
Stats(image_means, image_stds),
|
|
162
|
+
Stats(
|
|
163
|
+
np.array(target_means) if target is not None else None,
|
|
164
|
+
np.array(target_stds) if target is not None else None,
|
|
165
|
+
),
|
|
166
|
+
)
|
|
189
167
|
|
|
190
168
|
def __iter__(
|
|
191
169
|
self,
|
|
192
|
-
) -> Generator[
|
|
170
|
+
) -> Generator[tuple[np.ndarray, ...], None, None]:
|
|
193
171
|
"""
|
|
194
172
|
Iterate over data source and yield single patch.
|
|
195
173
|
|
|
@@ -199,24 +177,18 @@ class PathIterableDataset(IterableDataset):
|
|
|
199
177
|
Single patch.
|
|
200
178
|
"""
|
|
201
179
|
assert (
|
|
202
|
-
self.
|
|
180
|
+
self.data_stats.image_stats.means is not None
|
|
181
|
+
and self.data_stats.image_stats.stds is not None
|
|
203
182
|
), "Mean and std must be provided"
|
|
204
183
|
|
|
205
184
|
# iterate over files
|
|
206
|
-
for sample_input, sample_target in
|
|
185
|
+
for sample_input, sample_target in iterate_over_files(
|
|
207
186
|
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
208
187
|
):
|
|
209
|
-
reshaped_sample = reshape_array(sample_input, self.data_config.axes)
|
|
210
|
-
reshaped_target = (
|
|
211
|
-
None
|
|
212
|
-
if sample_target is None
|
|
213
|
-
else reshape_array(sample_target, self.data_config.axes)
|
|
214
|
-
)
|
|
215
|
-
|
|
216
188
|
patches = extract_patches_random(
|
|
217
|
-
arr=
|
|
189
|
+
arr=sample_input,
|
|
218
190
|
patch_size=self.data_config.patch_size,
|
|
219
|
-
target=
|
|
191
|
+
target=sample_target,
|
|
220
192
|
)
|
|
221
193
|
|
|
222
194
|
# iterate over patches
|
|
@@ -317,132 +289,3 @@ class PathIterableDataset(IterableDataset):
|
|
|
317
289
|
dataset.target_files = val_target_files
|
|
318
290
|
|
|
319
291
|
return dataset
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
class IterablePredictionDataset(IterableDataset):
|
|
323
|
-
"""
|
|
324
|
-
Prediction dataset.
|
|
325
|
-
|
|
326
|
-
Parameters
|
|
327
|
-
----------
|
|
328
|
-
prediction_config : InferenceConfig
|
|
329
|
-
Inference configuration.
|
|
330
|
-
src_files : List[Path]
|
|
331
|
-
List of data files.
|
|
332
|
-
read_source_func : Callable, optional
|
|
333
|
-
Read source function for custom types, by default read_tiff.
|
|
334
|
-
**kwargs : Any
|
|
335
|
-
Additional keyword arguments, unused.
|
|
336
|
-
|
|
337
|
-
Attributes
|
|
338
|
-
----------
|
|
339
|
-
data_path : Union[str, Path]
|
|
340
|
-
Path to the data, must be a directory.
|
|
341
|
-
axes : str
|
|
342
|
-
Description of axes in format STCZYX.
|
|
343
|
-
mean : Optional[float], optional
|
|
344
|
-
Expected mean of the dataset, by default None.
|
|
345
|
-
std : Optional[float], optional
|
|
346
|
-
Expected standard deviation of the dataset, by default None.
|
|
347
|
-
patch_transform : Optional[Callable], optional
|
|
348
|
-
Patch transform callable, by default None.
|
|
349
|
-
"""
|
|
350
|
-
|
|
351
|
-
def __init__(
|
|
352
|
-
self,
|
|
353
|
-
prediction_config: InferenceConfig,
|
|
354
|
-
src_files: List[Path],
|
|
355
|
-
read_source_func: Callable = read_tiff,
|
|
356
|
-
**kwargs: Any,
|
|
357
|
-
) -> None:
|
|
358
|
-
"""Constructor.
|
|
359
|
-
|
|
360
|
-
Parameters
|
|
361
|
-
----------
|
|
362
|
-
prediction_config : InferenceConfig
|
|
363
|
-
Inference configuration.
|
|
364
|
-
src_files : List[Path]
|
|
365
|
-
List of data files.
|
|
366
|
-
read_source_func : Callable, optional
|
|
367
|
-
Read source function for custom types, by default read_tiff.
|
|
368
|
-
**kwargs : Any
|
|
369
|
-
Additional keyword arguments, unused.
|
|
370
|
-
|
|
371
|
-
Raises
|
|
372
|
-
------
|
|
373
|
-
ValueError
|
|
374
|
-
If mean and std are not provided in the inference configuration.
|
|
375
|
-
"""
|
|
376
|
-
self.prediction_config = prediction_config
|
|
377
|
-
self.data_files = src_files
|
|
378
|
-
self.axes = prediction_config.axes
|
|
379
|
-
self.tile_size = self.prediction_config.tile_size
|
|
380
|
-
self.tile_overlap = self.prediction_config.tile_overlap
|
|
381
|
-
self.read_source_func = read_source_func
|
|
382
|
-
|
|
383
|
-
# tile only if both tile size and overlaps are provided
|
|
384
|
-
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
385
|
-
|
|
386
|
-
# check mean and std and create normalize transform
|
|
387
|
-
if self.prediction_config.mean is None or self.prediction_config.std is None:
|
|
388
|
-
raise ValueError("Mean and std must be provided for prediction.")
|
|
389
|
-
else:
|
|
390
|
-
self.mean = self.prediction_config.mean
|
|
391
|
-
self.std = self.prediction_config.std
|
|
392
|
-
|
|
393
|
-
# instantiate normalize transform
|
|
394
|
-
self.patch_transform = Compose(
|
|
395
|
-
transform_list=[
|
|
396
|
-
NormalizeModel(
|
|
397
|
-
mean=prediction_config.mean, std=prediction_config.std
|
|
398
|
-
)
|
|
399
|
-
],
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
def __iter__(
|
|
403
|
-
self,
|
|
404
|
-
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
405
|
-
"""
|
|
406
|
-
Iterate over data source and yield single patch.
|
|
407
|
-
|
|
408
|
-
Yields
|
|
409
|
-
------
|
|
410
|
-
np.ndarray
|
|
411
|
-
Single patch.
|
|
412
|
-
"""
|
|
413
|
-
assert (
|
|
414
|
-
self.mean is not None and self.std is not None
|
|
415
|
-
), "Mean and std must be provided"
|
|
416
|
-
|
|
417
|
-
for sample, _ in _iterate_over_files(
|
|
418
|
-
self.prediction_config,
|
|
419
|
-
self.data_files,
|
|
420
|
-
read_source_func=self.read_source_func,
|
|
421
|
-
):
|
|
422
|
-
# reshape array
|
|
423
|
-
reshaped_sample = reshape_array(sample, self.axes)
|
|
424
|
-
|
|
425
|
-
if (
|
|
426
|
-
self.tile
|
|
427
|
-
and self.tile_size is not None
|
|
428
|
-
and self.tile_overlap is not None
|
|
429
|
-
):
|
|
430
|
-
# generate patches, return a generator
|
|
431
|
-
patch_gen = extract_tiles(
|
|
432
|
-
arr=reshaped_sample,
|
|
433
|
-
tile_size=self.tile_size,
|
|
434
|
-
overlaps=self.tile_overlap,
|
|
435
|
-
)
|
|
436
|
-
else:
|
|
437
|
-
# just wrap the sample in a generator with default tiling info
|
|
438
|
-
array_shape = reshaped_sample.squeeze().shape
|
|
439
|
-
patch_gen = (
|
|
440
|
-
(reshaped_sample, TileInformation(array_shape=array_shape))
|
|
441
|
-
for _ in range(1)
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
# apply transform to patches
|
|
445
|
-
for patch_array, tile_info in patch_gen:
|
|
446
|
-
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
447
|
-
|
|
448
|
-
yield transformed_patch, tile_info
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Iterable prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.transforms import Compose
|
|
12
|
+
|
|
13
|
+
from ..config import InferenceConfig
|
|
14
|
+
from ..config.transformations import NormalizeModel
|
|
15
|
+
from .dataset_utils import iterate_over_files, read_tiff
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class IterablePredDataset(IterableDataset):
|
|
19
|
+
"""Simple iterable prediction dataset.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
prediction_config : InferenceConfig
|
|
24
|
+
Inference configuration.
|
|
25
|
+
src_files : List[Path]
|
|
26
|
+
List of data files.
|
|
27
|
+
read_source_func : Callable, optional
|
|
28
|
+
Read source function for custom types, by default read_tiff.
|
|
29
|
+
**kwargs : Any
|
|
30
|
+
Additional keyword arguments, unused.
|
|
31
|
+
|
|
32
|
+
Attributes
|
|
33
|
+
----------
|
|
34
|
+
data_path : Union[str, Path]
|
|
35
|
+
Path to the data, must be a directory.
|
|
36
|
+
axes : str
|
|
37
|
+
Description of axes in format STCZYX.
|
|
38
|
+
mean : Optional[float], optional
|
|
39
|
+
Expected mean of the dataset, by default None.
|
|
40
|
+
std : Optional[float], optional
|
|
41
|
+
Expected standard deviation of the dataset, by default None.
|
|
42
|
+
patch_transform : Optional[Callable], optional
|
|
43
|
+
Patch transform callable, by default None.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
prediction_config: InferenceConfig,
|
|
49
|
+
src_files: list[Path],
|
|
50
|
+
read_source_func: Callable = read_tiff,
|
|
51
|
+
**kwargs: Any,
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Constructor.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
prediction_config : InferenceConfig
|
|
58
|
+
Inference configuration.
|
|
59
|
+
src_files : list of pathlib.Path
|
|
60
|
+
List of data files.
|
|
61
|
+
read_source_func : Callable, optional
|
|
62
|
+
Read source function for custom types, by default read_tiff.
|
|
63
|
+
**kwargs : Any
|
|
64
|
+
Additional keyword arguments, unused.
|
|
65
|
+
|
|
66
|
+
Raises
|
|
67
|
+
------
|
|
68
|
+
ValueError
|
|
69
|
+
If mean and std are not provided in the inference configuration.
|
|
70
|
+
"""
|
|
71
|
+
self.prediction_config = prediction_config
|
|
72
|
+
self.data_files = src_files
|
|
73
|
+
self.axes = prediction_config.axes
|
|
74
|
+
self.read_source_func = read_source_func
|
|
75
|
+
|
|
76
|
+
# check mean and std and create normalize transform
|
|
77
|
+
if (
|
|
78
|
+
self.prediction_config.image_means is None
|
|
79
|
+
or self.prediction_config.image_stds is None
|
|
80
|
+
):
|
|
81
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
82
|
+
else:
|
|
83
|
+
self.image_means = self.prediction_config.image_means
|
|
84
|
+
self.image_stds = self.prediction_config.image_stds
|
|
85
|
+
|
|
86
|
+
# instantiate normalize transform
|
|
87
|
+
self.patch_transform = Compose(
|
|
88
|
+
transform_list=[
|
|
89
|
+
NormalizeModel(
|
|
90
|
+
image_means=self.image_means,
|
|
91
|
+
image_stds=self.image_stds,
|
|
92
|
+
)
|
|
93
|
+
],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def __iter__(
|
|
97
|
+
self,
|
|
98
|
+
) -> Generator[NDArray, None, None]:
|
|
99
|
+
"""
|
|
100
|
+
Iterate over data source and yield single patch.
|
|
101
|
+
|
|
102
|
+
Yields
|
|
103
|
+
------
|
|
104
|
+
NDArray
|
|
105
|
+
Single patch.
|
|
106
|
+
"""
|
|
107
|
+
assert (
|
|
108
|
+
self.image_means is not None and self.image_stds is not None
|
|
109
|
+
), "Mean and std must be provided"
|
|
110
|
+
|
|
111
|
+
for sample, _ in iterate_over_files(
|
|
112
|
+
self.prediction_config,
|
|
113
|
+
self.data_files,
|
|
114
|
+
read_source_func=self.read_source_func,
|
|
115
|
+
):
|
|
116
|
+
# sample has S dimension
|
|
117
|
+
for i in range(sample.shape[0]):
|
|
118
|
+
|
|
119
|
+
transformed_sample, _ = self.patch_transform(patch=sample[i])
|
|
120
|
+
|
|
121
|
+
yield transformed_sample
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Iterable tiled prediction dataset used to load data file by file."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Generator
|
|
7
|
+
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from torch.utils.data import IterableDataset
|
|
10
|
+
|
|
11
|
+
from careamics.transforms import Compose
|
|
12
|
+
|
|
13
|
+
from ..config import InferenceConfig
|
|
14
|
+
from ..config.tile_information import TileInformation
|
|
15
|
+
from ..config.transformations import NormalizeModel
|
|
16
|
+
from .dataset_utils import iterate_over_files, read_tiff
|
|
17
|
+
from .tiling import extract_tiles
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class IterableTiledPredDataset(IterableDataset):
|
|
21
|
+
"""Tiled prediction dataset.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
prediction_config : InferenceConfig
|
|
26
|
+
Inference configuration.
|
|
27
|
+
src_files : list of pathlib.Path
|
|
28
|
+
List of data files.
|
|
29
|
+
read_source_func : Callable, optional
|
|
30
|
+
Read source function for custom types, by default read_tiff.
|
|
31
|
+
**kwargs : Any
|
|
32
|
+
Additional keyword arguments, unused.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
data_path : str or pathlib.Path
|
|
37
|
+
Path to the data, must be a directory.
|
|
38
|
+
axes : str
|
|
39
|
+
Description of axes in format STCZYX.
|
|
40
|
+
mean : float, optional
|
|
41
|
+
Expected mean of the dataset, by default None.
|
|
42
|
+
std : float, optional
|
|
43
|
+
Expected standard deviation of the dataset, by default None.
|
|
44
|
+
patch_transform : Callable, optional
|
|
45
|
+
Patch transform callable, by default None.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
prediction_config: InferenceConfig,
|
|
51
|
+
src_files: list[Path],
|
|
52
|
+
read_source_func: Callable = read_tiff,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""Constructor.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
prediction_config : InferenceConfig
|
|
60
|
+
Inference configuration.
|
|
61
|
+
src_files : List[Path]
|
|
62
|
+
List of data files.
|
|
63
|
+
read_source_func : Callable, optional
|
|
64
|
+
Read source function for custom types, by default read_tiff.
|
|
65
|
+
**kwargs : Any
|
|
66
|
+
Additional keyword arguments, unused.
|
|
67
|
+
|
|
68
|
+
Raises
|
|
69
|
+
------
|
|
70
|
+
ValueError
|
|
71
|
+
If mean and std are not provided in the inference configuration.
|
|
72
|
+
"""
|
|
73
|
+
if (
|
|
74
|
+
prediction_config.tile_size is None
|
|
75
|
+
or prediction_config.tile_overlap is None
|
|
76
|
+
):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Tile size and overlap must be provided for tiled prediction."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.prediction_config = prediction_config
|
|
82
|
+
self.data_files = src_files
|
|
83
|
+
self.axes = prediction_config.axes
|
|
84
|
+
self.tile_size = prediction_config.tile_size
|
|
85
|
+
self.tile_overlap = prediction_config.tile_overlap
|
|
86
|
+
self.read_source_func = read_source_func
|
|
87
|
+
|
|
88
|
+
# check mean and std and create normalize transform
|
|
89
|
+
if (
|
|
90
|
+
self.prediction_config.image_means is None
|
|
91
|
+
or self.prediction_config.image_stds is None
|
|
92
|
+
):
|
|
93
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
94
|
+
else:
|
|
95
|
+
self.image_means = self.prediction_config.image_means
|
|
96
|
+
self.image_stds = self.prediction_config.image_stds
|
|
97
|
+
|
|
98
|
+
# instantiate normalize transform
|
|
99
|
+
self.patch_transform = Compose(
|
|
100
|
+
transform_list=[
|
|
101
|
+
NormalizeModel(
|
|
102
|
+
image_means=self.image_means,
|
|
103
|
+
image_stds=self.image_stds,
|
|
104
|
+
)
|
|
105
|
+
],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def __iter__(
|
|
109
|
+
self,
|
|
110
|
+
) -> Generator[tuple[NDArray, TileInformation], None, None]:
|
|
111
|
+
"""
|
|
112
|
+
Iterate over data source and yield single patch.
|
|
113
|
+
|
|
114
|
+
Yields
|
|
115
|
+
------
|
|
116
|
+
Generator of NDArray and TileInformation tuple
|
|
117
|
+
Generator of single tiles.
|
|
118
|
+
"""
|
|
119
|
+
assert (
|
|
120
|
+
self.image_means is not None and self.image_stds is not None
|
|
121
|
+
), "Mean and std must be provided"
|
|
122
|
+
|
|
123
|
+
for sample, _ in iterate_over_files(
|
|
124
|
+
self.prediction_config,
|
|
125
|
+
self.data_files,
|
|
126
|
+
read_source_func=self.read_source_func,
|
|
127
|
+
):
|
|
128
|
+
# generate patches, return a generator of single tiles
|
|
129
|
+
patch_gen = extract_tiles(
|
|
130
|
+
arr=sample,
|
|
131
|
+
tile_size=self.tile_size,
|
|
132
|
+
overlaps=self.tile_overlap,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# apply transform to patches
|
|
136
|
+
for patch_array, tile_info in patch_gen:
|
|
137
|
+
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
138
|
+
|
|
139
|
+
yield transformed_patch, tile_info
|