careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +92 -55
- careamics/config/__init__.py +0 -1
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +4 -15
- careamics/config/configuration_example.py +4 -4
- careamics/config/configuration_factory.py +113 -55
- careamics/config/configuration_model.py +14 -16
- careamics/config/data_model.py +63 -165
- careamics/config/inference_model.py +9 -75
- careamics/config/optimizer_models.py +4 -4
- careamics/config/references/algorithm_descriptions.py +1 -0
- careamics/config/references/references.py +1 -0
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/training_model.py +1 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/n2v_manipulate_model.py +1 -0
- careamics/config/transformations/normalize_model.py +1 -0
- careamics/config/transformations/transform_model.py +1 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +13 -7
- careamics/config/validators/validator_utils.py +1 -0
- careamics/conftest.py +13 -0
- careamics/dataset/dataset_utils/__init__.py +0 -1
- careamics/dataset/dataset_utils/dataset_utils.py +5 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +84 -76
- careamics/dataset/iterable_dataset.py +166 -134
- careamics/dataset/patching/__init__.py +0 -7
- careamics/dataset/patching/patching.py +56 -14
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +20 -14
- careamics/dataset/patching/tiled_patching.py +13 -7
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +63 -41
- careamics/lightning_module.py +9 -3
- careamics/lightning_prediction_datamodule.py +15 -20
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +1 -3
- careamics/losses/loss_factory.py +2 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/__init__.py +0 -1
- careamics/model_io/bioimage/_readme_factory.py +2 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -0
- careamics/model_io/bioimage/model_description.py +1 -0
- careamics/model_io/bmz_io.py +4 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +122 -25
- careamics/models/model_factory.py +2 -1
- careamics/models/unet.py +114 -19
- careamics/prediction/stitch_prediction.py +2 -5
- careamics/transforms/__init__.py +4 -25
- careamics/transforms/compose.py +124 -0
- careamics/transforms/n2v_manipulate.py +65 -34
- careamics/transforms/normalize.py +91 -28
- careamics/transforms/pixel_manipulation.py +7 -7
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +2 -2
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +66 -60
- careamics/utils/__init__.py +0 -1
- careamics/utils/base_enum.py +28 -0
- careamics/utils/context.py +1 -0
- careamics/utils/logging.py +1 -0
- careamics/utils/metrics.py +1 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- careamics/utils/torch_utils.py +1 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +17 -61
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -24
- careamics/config/transformations/nd_flip_model.py +0 -32
- careamics/dataset/patching/patch_transform.py +0 -44
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -93
- careamics-0.1.0rc4.dist-info/RECORD +0 -110
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Iterable dataset used to load data file by file."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import copy
|
|
@@ -7,26 +9,98 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
11
|
|
|
12
|
+
from careamics.transforms import Compose
|
|
13
|
+
|
|
10
14
|
from ..config import DataConfig, InferenceConfig
|
|
11
15
|
from ..config.tile_information import TileInformation
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
12
17
|
from ..utils.logging import get_logger
|
|
13
18
|
from .dataset_utils import read_tiff, reshape_array
|
|
14
|
-
from .patching import (
|
|
15
|
-
get_patch_transform,
|
|
16
|
-
)
|
|
17
19
|
from .patching.random_patching import extract_patches_random
|
|
18
20
|
from .patching.tiled_patching import extract_tiles
|
|
19
21
|
|
|
20
22
|
logger = get_logger(__name__)
|
|
21
23
|
|
|
22
24
|
|
|
25
|
+
def _iterate_over_files(
|
|
26
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
27
|
+
data_files: List[Path],
|
|
28
|
+
target_files: Optional[List[Path]] = None,
|
|
29
|
+
read_source_func: Callable = read_tiff,
|
|
30
|
+
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
31
|
+
"""
|
|
32
|
+
Iterate over data source and yield whole image.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
data_config : Union[DataConfig, InferenceConfig]
|
|
37
|
+
Data configuration.
|
|
38
|
+
data_files : List[Path]
|
|
39
|
+
List of data files.
|
|
40
|
+
target_files : Optional[List[Path]]
|
|
41
|
+
List of target files, by default None.
|
|
42
|
+
read_source_func : Optional[Callable]
|
|
43
|
+
Function to read the source, by default read_tiff.
|
|
44
|
+
|
|
45
|
+
Yields
|
|
46
|
+
------
|
|
47
|
+
np.ndarray
|
|
48
|
+
Image.
|
|
49
|
+
"""
|
|
50
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
51
|
+
# dataset object
|
|
52
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
53
|
+
# from the workers
|
|
54
|
+
worker_info = get_worker_info()
|
|
55
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
56
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
57
|
+
|
|
58
|
+
# iterate over the files
|
|
59
|
+
for i, filename in enumerate(data_files):
|
|
60
|
+
# retrieve file corresponding to the worker id
|
|
61
|
+
if i % num_workers == worker_id:
|
|
62
|
+
try:
|
|
63
|
+
# read data
|
|
64
|
+
sample = read_source_func(filename, data_config.axes)
|
|
65
|
+
|
|
66
|
+
# read target, if available
|
|
67
|
+
if target_files is not None:
|
|
68
|
+
if filename.name != target_files[i].name:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"File {filename} does not match target file "
|
|
71
|
+
f"{target_files[i]}. Have you passed sorted "
|
|
72
|
+
f"arrays?"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# read target
|
|
76
|
+
target = read_source_func(target_files[i], data_config.axes)
|
|
77
|
+
|
|
78
|
+
yield sample, target
|
|
79
|
+
else:
|
|
80
|
+
yield sample, None
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error reading file {filename}: {e}")
|
|
84
|
+
|
|
85
|
+
|
|
23
86
|
class PathIterableDataset(IterableDataset):
|
|
24
87
|
"""
|
|
25
88
|
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
26
89
|
|
|
27
90
|
Parameters
|
|
28
91
|
----------
|
|
29
|
-
|
|
92
|
+
data_config : DataConfig
|
|
93
|
+
Data configuration.
|
|
94
|
+
src_files : List[Path]
|
|
95
|
+
List of data files.
|
|
96
|
+
target_files : Optional[List[Path]], optional
|
|
97
|
+
Optional list of target files, by default None.
|
|
98
|
+
read_source_func : Callable, optional
|
|
99
|
+
Read source function for custom types, by default read_tiff.
|
|
100
|
+
|
|
101
|
+
Attributes
|
|
102
|
+
----------
|
|
103
|
+
data_path : List[Path]
|
|
30
104
|
Path to the data, must be a directory.
|
|
31
105
|
axes : str
|
|
32
106
|
Description of axes in format STCZYX.
|
|
@@ -46,11 +120,24 @@ class PathIterableDataset(IterableDataset):
|
|
|
46
120
|
|
|
47
121
|
def __init__(
|
|
48
122
|
self,
|
|
49
|
-
data_config:
|
|
123
|
+
data_config: DataConfig,
|
|
50
124
|
src_files: List[Path],
|
|
51
125
|
target_files: Optional[List[Path]] = None,
|
|
52
126
|
read_source_func: Callable = read_tiff,
|
|
53
127
|
) -> None:
|
|
128
|
+
"""Constructors.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
data_config : DataConfig
|
|
133
|
+
Data configuration.
|
|
134
|
+
src_files : List[Path]
|
|
135
|
+
List of data files.
|
|
136
|
+
target_files : Optional[List[Path]], optional
|
|
137
|
+
Optional list of target files, by default None.
|
|
138
|
+
read_source_func : Callable, optional
|
|
139
|
+
Read source function for custom types, by default read_tiff.
|
|
140
|
+
"""
|
|
54
141
|
self.data_config = data_config
|
|
55
142
|
self.data_files = src_files
|
|
56
143
|
self.target_files = target_files
|
|
@@ -61,26 +148,15 @@ class PathIterableDataset(IterableDataset):
|
|
|
61
148
|
if not data_config.mean or not data_config.std:
|
|
62
149
|
self.mean, self.std = self._calculate_mean_and_std()
|
|
63
150
|
|
|
64
|
-
#
|
|
65
|
-
#
|
|
66
|
-
|
|
67
|
-
if hasattr(data_config, "has_transform_list"):
|
|
68
|
-
if data_config.has_transform_list():
|
|
69
|
-
# update mean and std in configuration
|
|
70
|
-
# the object is mutable and should then be recorded in the CAREamist
|
|
71
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
72
|
-
else:
|
|
73
|
-
data_config.set_mean_and_std(self.mean, self.std)
|
|
74
|
-
|
|
151
|
+
# update mean and std in configuration
|
|
152
|
+
# the object is mutable and should then be recorded in the CAREamist
|
|
153
|
+
data_config.set_mean_and_std(self.mean, self.std)
|
|
75
154
|
else:
|
|
76
155
|
self.mean = data_config.mean
|
|
77
156
|
self.std = data_config.std
|
|
78
157
|
|
|
79
158
|
# get transforms
|
|
80
|
-
self.patch_transform =
|
|
81
|
-
patch_transforms=data_config.transforms,
|
|
82
|
-
with_target=target_files is not None,
|
|
83
|
-
)
|
|
159
|
+
self.patch_transform = Compose(transform_list=data_config.transforms)
|
|
84
160
|
|
|
85
161
|
def _calculate_mean_and_std(self) -> Tuple[float, float]:
|
|
86
162
|
"""
|
|
@@ -94,7 +170,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
94
170
|
means, stds = 0, 0
|
|
95
171
|
num_samples = 0
|
|
96
172
|
|
|
97
|
-
for sample, _ in
|
|
173
|
+
for sample, _ in _iterate_over_files(
|
|
174
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
175
|
+
):
|
|
98
176
|
means += sample.mean()
|
|
99
177
|
stds += sample.std()
|
|
100
178
|
num_samples += 1
|
|
@@ -109,57 +187,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
109
187
|
logger.info(f"Mean: {result_mean}, std: {result_std}")
|
|
110
188
|
return result_mean, result_std
|
|
111
189
|
|
|
112
|
-
def _iterate_over_files(
|
|
113
|
-
self,
|
|
114
|
-
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
115
|
-
"""
|
|
116
|
-
Iterate over data source and yield whole image.
|
|
117
|
-
|
|
118
|
-
Yields
|
|
119
|
-
------
|
|
120
|
-
np.ndarray
|
|
121
|
-
Image.
|
|
122
|
-
"""
|
|
123
|
-
# When num_workers > 0, each worker process will have a different copy of the
|
|
124
|
-
# dataset object
|
|
125
|
-
# Configuring each copy independently to avoid having duplicate data returned
|
|
126
|
-
# from the workers
|
|
127
|
-
worker_info = get_worker_info()
|
|
128
|
-
worker_id = worker_info.id if worker_info is not None else 0
|
|
129
|
-
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
130
|
-
|
|
131
|
-
# iterate over the files
|
|
132
|
-
for i, filename in enumerate(self.data_files):
|
|
133
|
-
# retrieve file corresponding to the worker id
|
|
134
|
-
if i % num_workers == worker_id:
|
|
135
|
-
try:
|
|
136
|
-
# read data
|
|
137
|
-
sample = self.read_source_func(filename, self.data_config.axes)
|
|
138
|
-
|
|
139
|
-
# read target, if available
|
|
140
|
-
if self.target_files is not None:
|
|
141
|
-
if filename.name != self.target_files[i].name:
|
|
142
|
-
raise ValueError(
|
|
143
|
-
f"File {filename} does not match target file "
|
|
144
|
-
f"{self.target_files[i]}. Have you passed sorted "
|
|
145
|
-
f"arrays?"
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
# read target
|
|
149
|
-
target = self.read_source_func(
|
|
150
|
-
self.target_files[i], self.data_config.axes
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
yield sample, target
|
|
154
|
-
else:
|
|
155
|
-
yield sample, None
|
|
156
|
-
|
|
157
|
-
except Exception as e:
|
|
158
|
-
logger.error(f"Error reading file {filename}: {e}")
|
|
159
|
-
|
|
160
190
|
def __iter__(
|
|
161
191
|
self,
|
|
162
|
-
) -> Generator[Tuple[np.ndarray,
|
|
192
|
+
) -> Generator[Tuple[np.ndarray, ...], None, None]:
|
|
163
193
|
"""
|
|
164
194
|
Iterate over data source and yield single patch.
|
|
165
195
|
|
|
@@ -173,7 +203,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
173
203
|
), "Mean and std must be provided"
|
|
174
204
|
|
|
175
205
|
# iterate over files
|
|
176
|
-
for sample_input, sample_target in
|
|
206
|
+
for sample_input, sample_target in _iterate_over_files(
|
|
207
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
208
|
+
):
|
|
177
209
|
reshaped_sample = reshape_array(sample_input, self.data_config.axes)
|
|
178
210
|
reshaped_target = (
|
|
179
211
|
None
|
|
@@ -192,49 +224,10 @@ class PathIterableDataset(IterableDataset):
|
|
|
192
224
|
# or (patch, None) only if no target is available
|
|
193
225
|
# patch is of dimensions (C)ZYX
|
|
194
226
|
for patch_data in patches:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
c_patch = np.moveaxis(patch_data[0], 0, -1)
|
|
200
|
-
c_target = np.moveaxis(patch_data[1], 0, -1)
|
|
201
|
-
|
|
202
|
-
# apply the transform to the patch and the target
|
|
203
|
-
transformed = self.patch_transform(
|
|
204
|
-
image=c_patch,
|
|
205
|
-
target=c_target,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# move the axes back to the original position
|
|
209
|
-
c_patch = np.moveaxis(transformed["image"], -1, 0)
|
|
210
|
-
c_target = np.moveaxis(transformed["target"], -1, 0)
|
|
211
|
-
|
|
212
|
-
yield (c_patch, c_target)
|
|
213
|
-
elif self.data_config.has_n2v_manipulate():
|
|
214
|
-
# Albumentations expects the channel dimension to be last
|
|
215
|
-
# Taking the first element because patch_data can include target
|
|
216
|
-
patch = np.moveaxis(patch_data[0], 0, -1)
|
|
217
|
-
|
|
218
|
-
# apply transform
|
|
219
|
-
transformed = self.patch_transform(image=patch)
|
|
220
|
-
|
|
221
|
-
# retrieve the output of ManipulateN2V
|
|
222
|
-
results = transformed["image"]
|
|
223
|
-
masked_patch: np.ndarray = results[0]
|
|
224
|
-
original_patch: np.ndarray = results[1]
|
|
225
|
-
mask: np.ndarray = results[2]
|
|
226
|
-
|
|
227
|
-
# move C axes back
|
|
228
|
-
masked_patch = np.moveaxis(masked_patch, -1, 0)
|
|
229
|
-
original_patch = np.moveaxis(original_patch, -1, 0)
|
|
230
|
-
mask = np.moveaxis(mask, -1, 0)
|
|
231
|
-
|
|
232
|
-
yield (masked_patch, original_patch, mask)
|
|
233
|
-
else:
|
|
234
|
-
raise ValueError(
|
|
235
|
-
"Something went wrong! Not target file (no supervised "
|
|
236
|
-
"training) and no N2V transform (no n2v training either)."
|
|
237
|
-
)
|
|
227
|
+
yield self.patch_transform(
|
|
228
|
+
patch=patch_data[0],
|
|
229
|
+
target=patch_data[1],
|
|
230
|
+
)
|
|
238
231
|
|
|
239
232
|
def get_number_of_files(self) -> int:
|
|
240
233
|
"""
|
|
@@ -260,9 +253,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
260
253
|
Parameters
|
|
261
254
|
----------
|
|
262
255
|
percentage : float, optional
|
|
263
|
-
Percentage of files to split up, by default 0.1
|
|
256
|
+
Percentage of files to split up, by default 0.1.
|
|
264
257
|
minimum_number : int, optional
|
|
265
|
-
Minimum number of files to split up, by default 5
|
|
258
|
+
Minimum number of files to split up, by default 5.
|
|
266
259
|
|
|
267
260
|
Returns
|
|
268
261
|
-------
|
|
@@ -326,12 +319,23 @@ class PathIterableDataset(IterableDataset):
|
|
|
326
319
|
return dataset
|
|
327
320
|
|
|
328
321
|
|
|
329
|
-
class IterablePredictionDataset(
|
|
322
|
+
class IterablePredictionDataset(IterableDataset):
|
|
330
323
|
"""
|
|
331
|
-
|
|
324
|
+
Prediction dataset.
|
|
332
325
|
|
|
333
326
|
Parameters
|
|
334
327
|
----------
|
|
328
|
+
prediction_config : InferenceConfig
|
|
329
|
+
Inference configuration.
|
|
330
|
+
src_files : List[Path]
|
|
331
|
+
List of data files.
|
|
332
|
+
read_source_func : Callable, optional
|
|
333
|
+
Read source function for custom types, by default read_tiff.
|
|
334
|
+
**kwargs : Any
|
|
335
|
+
Additional keyword arguments, unused.
|
|
336
|
+
|
|
337
|
+
Attributes
|
|
338
|
+
----------
|
|
335
339
|
data_path : Union[str, Path]
|
|
336
340
|
Path to the data, must be a directory.
|
|
337
341
|
axes : str
|
|
@@ -351,13 +355,26 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
351
355
|
read_source_func: Callable = read_tiff,
|
|
352
356
|
**kwargs: Any,
|
|
353
357
|
) -> None:
|
|
354
|
-
|
|
355
|
-
data_config=prediction_config,
|
|
356
|
-
src_files=src_files,
|
|
357
|
-
read_source_func=read_source_func,
|
|
358
|
-
)
|
|
358
|
+
"""Constructor.
|
|
359
359
|
|
|
360
|
+
Parameters
|
|
361
|
+
----------
|
|
362
|
+
prediction_config : InferenceConfig
|
|
363
|
+
Inference configuration.
|
|
364
|
+
src_files : List[Path]
|
|
365
|
+
List of data files.
|
|
366
|
+
read_source_func : Callable, optional
|
|
367
|
+
Read source function for custom types, by default read_tiff.
|
|
368
|
+
**kwargs : Any
|
|
369
|
+
Additional keyword arguments, unused.
|
|
370
|
+
|
|
371
|
+
Raises
|
|
372
|
+
------
|
|
373
|
+
ValueError
|
|
374
|
+
If mean and std are not provided in the inference configuration.
|
|
375
|
+
"""
|
|
360
376
|
self.prediction_config = prediction_config
|
|
377
|
+
self.data_files = src_files
|
|
361
378
|
self.axes = prediction_config.axes
|
|
362
379
|
self.tile_size = self.prediction_config.tile_size
|
|
363
380
|
self.tile_overlap = self.prediction_config.tile_overlap
|
|
@@ -366,11 +383,21 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
366
383
|
# tile only if both tile size and overlaps are provided
|
|
367
384
|
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
368
385
|
|
|
369
|
-
#
|
|
370
|
-
self.
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
386
|
+
# check mean and std and create normalize transform
|
|
387
|
+
if self.prediction_config.mean is None or self.prediction_config.std is None:
|
|
388
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
389
|
+
else:
|
|
390
|
+
self.mean = self.prediction_config.mean
|
|
391
|
+
self.std = self.prediction_config.std
|
|
392
|
+
|
|
393
|
+
# instantiate normalize transform
|
|
394
|
+
self.patch_transform = Compose(
|
|
395
|
+
transform_list=[
|
|
396
|
+
NormalizeModel(
|
|
397
|
+
mean=prediction_config.mean, std=prediction_config.std
|
|
398
|
+
)
|
|
399
|
+
],
|
|
400
|
+
)
|
|
374
401
|
|
|
375
402
|
def __iter__(
|
|
376
403
|
self,
|
|
@@ -387,11 +414,19 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
387
414
|
self.mean is not None and self.std is not None
|
|
388
415
|
), "Mean and std must be provided"
|
|
389
416
|
|
|
390
|
-
for sample, _ in
|
|
417
|
+
for sample, _ in _iterate_over_files(
|
|
418
|
+
self.prediction_config,
|
|
419
|
+
self.data_files,
|
|
420
|
+
read_source_func=self.read_source_func,
|
|
421
|
+
):
|
|
391
422
|
# reshape array
|
|
392
423
|
reshaped_sample = reshape_array(sample, self.axes)
|
|
393
424
|
|
|
394
|
-
if
|
|
425
|
+
if (
|
|
426
|
+
self.tile
|
|
427
|
+
and self.tile_size is not None
|
|
428
|
+
and self.tile_overlap is not None
|
|
429
|
+
):
|
|
395
430
|
# generate patches, return a generator
|
|
396
431
|
patch_gen = extract_tiles(
|
|
397
432
|
arr=reshaped_sample,
|
|
@@ -408,9 +443,6 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
408
443
|
|
|
409
444
|
# apply transform to patches
|
|
410
445
|
for patch_array, tile_info in patch_gen:
|
|
411
|
-
|
|
412
|
-
patch = np.moveaxis(patch_array, 0, -1)
|
|
413
|
-
transformed_patch = self.patch_transform(image=patch)
|
|
414
|
-
transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
|
|
446
|
+
transformed_patch, _ = self.patch_transform(patch=patch_array)
|
|
415
447
|
|
|
416
448
|
yield transformed_patch, tile_info
|
|
@@ -1,8 +1,5 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tiling submodule.
|
|
1
|
+
"""Patching functions."""
|
|
3
2
|
|
|
4
|
-
These functions are used to tile images into patches or tiles.
|
|
5
|
-
"""
|
|
6
3
|
from pathlib import Path
|
|
7
4
|
from typing import Callable, List, Tuple, Union
|
|
8
5
|
|
|
@@ -20,12 +17,25 @@ def prepare_patches_supervised(
|
|
|
20
17
|
train_files: List[Path],
|
|
21
18
|
target_files: List[Path],
|
|
22
19
|
axes: str,
|
|
23
|
-
patch_size: Union[List[int], Tuple[int]],
|
|
20
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
24
21
|
read_source_func: Callable,
|
|
25
22
|
) -> Tuple[np.ndarray, np.ndarray, float, float]:
|
|
26
23
|
"""
|
|
27
24
|
Iterate over data source and create an array of patches and corresponding targets.
|
|
28
25
|
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
train_files : List[Path]
|
|
29
|
+
List of paths to training data.
|
|
30
|
+
target_files : List[Path]
|
|
31
|
+
List of paths to target data.
|
|
32
|
+
axes : str
|
|
33
|
+
Axes of the data.
|
|
34
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
35
|
+
Size of the patches.
|
|
36
|
+
read_source_func : Callable
|
|
37
|
+
Function to read the data.
|
|
38
|
+
|
|
29
39
|
Returns
|
|
30
40
|
-------
|
|
31
41
|
np.ndarray
|
|
@@ -94,13 +104,25 @@ def prepare_patches_unsupervised(
|
|
|
94
104
|
patch_size: Union[List[int], Tuple[int]],
|
|
95
105
|
read_source_func: Callable,
|
|
96
106
|
) -> Tuple[np.ndarray, None, float, float]:
|
|
97
|
-
"""
|
|
98
|
-
|
|
107
|
+
"""Iterate over data source and create an array of patches.
|
|
108
|
+
|
|
109
|
+
This method returns the mean and standard deviation of the image.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
train_files : List[Path]
|
|
114
|
+
List of paths to training data.
|
|
115
|
+
axes : str
|
|
116
|
+
Axes of the data.
|
|
117
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
118
|
+
Size of the patches.
|
|
119
|
+
read_source_func : Callable
|
|
120
|
+
Function to read the data.
|
|
99
121
|
|
|
100
122
|
Returns
|
|
101
123
|
-------
|
|
102
|
-
np.ndarray
|
|
103
|
-
|
|
124
|
+
Tuple[np.ndarray, None, float, float]
|
|
125
|
+
Source and target patches, mean and standard deviation.
|
|
104
126
|
"""
|
|
105
127
|
means, stds, num_samples = 0, 0, 0
|
|
106
128
|
all_patches = []
|
|
@@ -149,10 +171,21 @@ def prepare_patches_supervised_array(
|
|
|
149
171
|
|
|
150
172
|
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
151
173
|
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
data : np.ndarray
|
|
177
|
+
Input data array.
|
|
178
|
+
axes : str
|
|
179
|
+
Axes of the data.
|
|
180
|
+
data_target : np.ndarray
|
|
181
|
+
Target data array.
|
|
182
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
183
|
+
Size of the patches.
|
|
184
|
+
|
|
152
185
|
Returns
|
|
153
186
|
-------
|
|
154
|
-
np.ndarray
|
|
155
|
-
|
|
187
|
+
Tuple[np.ndarray, np.ndarray, float, float]
|
|
188
|
+
Source and target patches, mean and standard deviation.
|
|
156
189
|
"""
|
|
157
190
|
# compute statistics
|
|
158
191
|
mean = data.mean()
|
|
@@ -194,10 +227,19 @@ def prepare_patches_unsupervised_array(
|
|
|
194
227
|
|
|
195
228
|
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
196
229
|
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
data : np.ndarray
|
|
233
|
+
Input data array.
|
|
234
|
+
axes : str
|
|
235
|
+
Axes of the data.
|
|
236
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
237
|
+
Size of the patches.
|
|
238
|
+
|
|
197
239
|
Returns
|
|
198
240
|
-------
|
|
199
|
-
np.ndarray
|
|
200
|
-
|
|
241
|
+
Tuple[np.ndarray, None, float, float]
|
|
242
|
+
Source patches, mean and standard deviation.
|
|
201
243
|
"""
|
|
202
244
|
# calculate mean and std
|
|
203
245
|
mean = data.mean()
|
|
@@ -209,4 +251,4 @@ def prepare_patches_unsupervised_array(
|
|
|
209
251
|
# generate patches, return a generator
|
|
210
252
|
patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
|
|
211
253
|
|
|
212
|
-
return patches, _, mean, std # TODO inelegant, replace
|
|
254
|
+
return patches, _, mean, std # TODO inelegant, replace by dataclass?
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Random patching utilities."""
|
|
2
|
+
|
|
1
3
|
from typing import Generator, List, Optional, Tuple, Union
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -30,6 +32,8 @@ def extract_patches_random(
|
|
|
30
32
|
Input image array.
|
|
31
33
|
patch_size : Tuple[int]
|
|
32
34
|
Patch sizes in each dimension.
|
|
35
|
+
target : Optional[np.ndarray], optional
|
|
36
|
+
Target array, by default None.
|
|
33
37
|
|
|
34
38
|
Yields
|
|
35
39
|
------
|
|
@@ -120,10 +124,12 @@ def extract_patches_random_from_chunks(
|
|
|
120
124
|
----------
|
|
121
125
|
arr : np.ndarray
|
|
122
126
|
Input image array.
|
|
123
|
-
patch_size : Tuple[int]
|
|
127
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
124
128
|
Patch sizes in each dimension.
|
|
125
|
-
chunk_size : Tuple[int]
|
|
129
|
+
chunk_size : Union[List[int], Tuple[int, ...]]
|
|
126
130
|
Chunk sizes to load from the.
|
|
131
|
+
chunk_limit : Optional[int], optional
|
|
132
|
+
Number of chunks to load, by default None.
|
|
127
133
|
|
|
128
134
|
Yields
|
|
129
135
|
------
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Sequential patching functions."""
|
|
2
|
+
|
|
1
3
|
from typing import List, Optional, Tuple, Union
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -14,14 +16,14 @@ def _compute_number_of_patches(
|
|
|
14
16
|
|
|
15
17
|
Parameters
|
|
16
18
|
----------
|
|
17
|
-
|
|
19
|
+
arr_shape : Tuple[int, ...]
|
|
18
20
|
Shape of the input array.
|
|
19
|
-
patch_sizes : Tuple[int]
|
|
21
|
+
patch_sizes : Union[List[int], Tuple[int, ...]
|
|
20
22
|
Shape of the patches.
|
|
21
23
|
|
|
22
24
|
Returns
|
|
23
25
|
-------
|
|
24
|
-
Tuple[int]
|
|
26
|
+
Tuple[int, ...]
|
|
25
27
|
Number of patches in each dimension.
|
|
26
28
|
"""
|
|
27
29
|
if len(arr_shape) != len(patch_sizes):
|
|
@@ -55,14 +57,14 @@ def _compute_overlap(
|
|
|
55
57
|
|
|
56
58
|
Parameters
|
|
57
59
|
----------
|
|
58
|
-
|
|
60
|
+
arr_shape : Tuple[int, ...]
|
|
59
61
|
Input array shape.
|
|
60
|
-
patch_sizes : Tuple[int]
|
|
62
|
+
patch_sizes : Union[List[int], Tuple[int, ...]]
|
|
61
63
|
Size of the patches.
|
|
62
64
|
|
|
63
65
|
Returns
|
|
64
66
|
-------
|
|
65
|
-
Tuple[int]
|
|
67
|
+
Tuple[int, ...]
|
|
66
68
|
Overlap between patches in each dimension.
|
|
67
69
|
"""
|
|
68
70
|
n_patches = _compute_number_of_patches(arr_shape, patch_sizes)
|
|
@@ -123,6 +125,8 @@ def _compute_patch_views(
|
|
|
123
125
|
Steps between views.
|
|
124
126
|
output_shape : Tuple[int]
|
|
125
127
|
Shape of the output array.
|
|
128
|
+
target : Optional[np.ndarray], optional
|
|
129
|
+
Target array, by default None.
|
|
126
130
|
|
|
127
131
|
Returns
|
|
128
132
|
-------
|
|
@@ -135,15 +139,12 @@ def _compute_patch_views(
|
|
|
135
139
|
arr = np.stack([arr, target], axis=0)
|
|
136
140
|
window_shape = [arr.shape[0], *window_shape]
|
|
137
141
|
step = (arr.shape[0], *step)
|
|
138
|
-
output_shape = [arr.shape[0],
|
|
142
|
+
output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
|
|
139
143
|
|
|
140
144
|
patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
|
|
141
145
|
*output_shape
|
|
142
146
|
)
|
|
143
|
-
|
|
144
|
-
rng.shuffle(patches, axis=1)
|
|
145
|
-
else:
|
|
146
|
-
rng.shuffle(patches, axis=0)
|
|
147
|
+
rng.shuffle(patches, axis=0)
|
|
147
148
|
return patches
|
|
148
149
|
|
|
149
150
|
|
|
@@ -164,11 +165,13 @@ def extract_patches_sequential(
|
|
|
164
165
|
Input image array.
|
|
165
166
|
patch_size : Tuple[int]
|
|
166
167
|
Patch sizes in each dimension.
|
|
168
|
+
target : Optional[np.ndarray], optional
|
|
169
|
+
Target array, by default None.
|
|
167
170
|
|
|
168
171
|
Returns
|
|
169
172
|
-------
|
|
170
|
-
|
|
171
|
-
|
|
173
|
+
Tuple[np.ndarray, Optional[np.ndarray]]
|
|
174
|
+
Patches.
|
|
172
175
|
"""
|
|
173
176
|
is_3d_patch = len(patch_size) == 3
|
|
174
177
|
|
|
@@ -201,6 +204,9 @@ def extract_patches_sequential(
|
|
|
201
204
|
|
|
202
205
|
if target is not None:
|
|
203
206
|
# target was concatenated to patches in _compute_reshaped_view
|
|
204
|
-
return (
|
|
207
|
+
return (
|
|
208
|
+
patches[:, 0, ...],
|
|
209
|
+
patches[:, 1, ...],
|
|
210
|
+
) # TODO in _compute_reshaped_view?
|
|
205
211
|
else:
|
|
206
212
|
return patches, None
|