careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/callbacks/hyperparameters_callback.py +10 -3
- careamics/callbacks/progress_bar_callback.py +37 -4
- careamics/careamist.py +80 -44
- careamics/config/algorithm_model.py +5 -3
- careamics/config/architectures/architecture_model.py +7 -0
- careamics/config/architectures/custom_model.py +8 -1
- careamics/config/architectures/register_model.py +3 -1
- careamics/config/architectures/unet_model.py +2 -0
- careamics/config/architectures/vae_model.py +2 -0
- careamics/config/callback_model.py +3 -15
- careamics/config/configuration_example.py +4 -2
- careamics/config/configuration_factory.py +4 -16
- careamics/config/data_model.py +10 -14
- careamics/config/inference_model.py +0 -65
- careamics/config/optimizer_models.py +4 -4
- careamics/config/support/__init__.py +0 -2
- careamics/config/support/supported_activations.py +2 -0
- careamics/config/support/supported_algorithms.py +3 -1
- careamics/config/support/supported_architectures.py +2 -0
- careamics/config/support/supported_data.py +2 -0
- careamics/config/support/supported_loggers.py +2 -0
- careamics/config/support/supported_losses.py +2 -0
- careamics/config/support/supported_optimizers.py +2 -0
- careamics/config/support/supported_pixel_manipulations.py +3 -3
- careamics/config/support/supported_struct_axis.py +2 -0
- careamics/config/support/supported_transforms.py +4 -15
- careamics/config/tile_information.py +2 -0
- careamics/config/transformations/__init__.py +3 -2
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +11 -3
- careamics/conftest.py +12 -0
- careamics/dataset/dataset_utils/dataset_utils.py +4 -4
- careamics/dataset/dataset_utils/file_utils.py +4 -3
- careamics/dataset/dataset_utils/read_tiff.py +6 -2
- careamics/dataset/dataset_utils/read_utils.py +2 -0
- careamics/dataset/dataset_utils/read_zarr.py +11 -7
- careamics/dataset/in_memory_dataset.py +71 -32
- careamics/dataset/iterable_dataset.py +155 -68
- careamics/dataset/patching/patching.py +56 -15
- careamics/dataset/patching/random_patching.py +8 -2
- careamics/dataset/patching/sequential_patching.py +14 -8
- careamics/dataset/patching/tiled_patching.py +3 -1
- careamics/dataset/patching/validate_patch_dimension.py +2 -0
- careamics/dataset/zarr_dataset.py +2 -0
- careamics/lightning_datamodule.py +45 -19
- careamics/lightning_module.py +8 -2
- careamics/lightning_prediction_datamodule.py +3 -13
- careamics/lightning_prediction_loop.py +8 -6
- careamics/losses/__init__.py +2 -3
- careamics/losses/loss_factory.py +1 -1
- careamics/losses/losses.py +11 -7
- careamics/model_io/bmz_io.py +3 -3
- careamics/models/activation.py +2 -0
- careamics/models/layers.py +121 -25
- careamics/models/model_factory.py +1 -1
- careamics/models/unet.py +35 -14
- careamics/prediction/stitch_prediction.py +2 -6
- careamics/transforms/__init__.py +2 -2
- careamics/transforms/compose.py +33 -7
- careamics/transforms/n2v_manipulate.py +49 -13
- careamics/transforms/normalize.py +55 -3
- careamics/transforms/pixel_manipulation.py +5 -5
- careamics/transforms/struct_mask_parameters.py +3 -1
- careamics/transforms/transform.py +10 -19
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +38 -5
- careamics/utils/base_enum.py +28 -0
- careamics/utils/path_utils.py +2 -0
- careamics/utils/ram.py +2 -0
- careamics/utils/receptive_field.py +93 -87
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/METADATA +2 -1
- careamics-0.1.0rc6.dist-info/RECORD +107 -0
- careamics/config/noise_models.py +0 -162
- careamics/config/support/supported_extraction_strategies.py +0 -25
- careamics/config/transformations/nd_flip_model.py +0 -27
- careamics/losses/noise_model_factory.py +0 -40
- careamics/losses/noise_models.py +0 -524
- careamics/transforms/nd_flip.py +0 -67
- careamics-0.1.0rc5.dist-info/RECORD +0 -111
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc6.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,6 +13,7 @@ from careamics.transforms import Compose
|
|
|
13
13
|
|
|
14
14
|
from ..config import DataConfig, InferenceConfig
|
|
15
15
|
from ..config.tile_information import TileInformation
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
16
17
|
from ..utils.logging import get_logger
|
|
17
18
|
from .dataset_utils import read_tiff, reshape_array
|
|
18
19
|
from .patching.patching import (
|
|
@@ -27,24 +28,49 @@ logger = get_logger(__name__)
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class InMemoryDataset(Dataset):
|
|
30
|
-
"""Dataset storing data in memory and allowing generating patches from it.
|
|
31
|
+
"""Dataset storing data in memory and allowing generating patches from it.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
data_config : DataConfig
|
|
36
|
+
Data configuration.
|
|
37
|
+
inputs : Union[np.ndarray, List[Path]]
|
|
38
|
+
Input data.
|
|
39
|
+
input_target : Optional[Union[np.ndarray, List[Path]]], optional
|
|
40
|
+
Target data, by default None.
|
|
41
|
+
read_source_func : Callable, optional
|
|
42
|
+
Read source function for custom types, by default read_tiff.
|
|
43
|
+
**kwargs : Any
|
|
44
|
+
Additional keyword arguments, unused.
|
|
45
|
+
"""
|
|
31
46
|
|
|
32
47
|
def __init__(
|
|
33
48
|
self,
|
|
34
49
|
data_config: DataConfig,
|
|
35
50
|
inputs: Union[np.ndarray, List[Path]],
|
|
36
|
-
|
|
51
|
+
input_target: Optional[Union[np.ndarray, List[Path]]] = None,
|
|
37
52
|
read_source_func: Callable = read_tiff,
|
|
38
53
|
**kwargs: Any,
|
|
39
54
|
) -> None:
|
|
40
55
|
"""
|
|
41
56
|
Constructor.
|
|
42
57
|
|
|
43
|
-
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
data_config : DataConfig
|
|
61
|
+
Data configuration.
|
|
62
|
+
inputs : Union[np.ndarray, List[Path]]
|
|
63
|
+
Input data.
|
|
64
|
+
input_target : Optional[Union[np.ndarray, List[Path]]], optional
|
|
65
|
+
Target data, by default None.
|
|
66
|
+
read_source_func : Callable, optional
|
|
67
|
+
Read source function for custom types, by default read_tiff.
|
|
68
|
+
**kwargs : Any
|
|
69
|
+
Additional keyword arguments, unused.
|
|
44
70
|
"""
|
|
45
71
|
self.data_config = data_config
|
|
46
72
|
self.inputs = inputs
|
|
47
|
-
self.
|
|
73
|
+
self.input_targets = input_target
|
|
48
74
|
self.axes = self.data_config.axes
|
|
49
75
|
self.patch_size = self.data_config.patch_size
|
|
50
76
|
|
|
@@ -52,11 +78,11 @@ class InMemoryDataset(Dataset):
|
|
|
52
78
|
self.read_source_func = read_source_func
|
|
53
79
|
|
|
54
80
|
# Generate patches
|
|
55
|
-
supervised = self.
|
|
56
|
-
|
|
81
|
+
supervised = self.input_targets is not None
|
|
82
|
+
patch_data = self._prepare_patches(supervised)
|
|
57
83
|
|
|
58
84
|
# Add results to members
|
|
59
|
-
self.
|
|
85
|
+
self.patches, self.patch_targets, computed_mean, computed_std = patch_data
|
|
60
86
|
|
|
61
87
|
if not self.data_config.mean or not self.data_config.std:
|
|
62
88
|
self.mean, self.std = computed_mean, computed_std
|
|
@@ -91,18 +117,18 @@ class InMemoryDataset(Dataset):
|
|
|
91
117
|
"""
|
|
92
118
|
if supervised:
|
|
93
119
|
if isinstance(self.inputs, np.ndarray) and isinstance(
|
|
94
|
-
self.
|
|
120
|
+
self.input_targets, np.ndarray
|
|
95
121
|
):
|
|
96
122
|
return prepare_patches_supervised_array(
|
|
97
123
|
self.inputs,
|
|
98
124
|
self.axes,
|
|
99
|
-
self.
|
|
125
|
+
self.input_targets,
|
|
100
126
|
self.patch_size,
|
|
101
127
|
)
|
|
102
|
-
elif isinstance(self.inputs, list) and isinstance(self.
|
|
128
|
+
elif isinstance(self.inputs, list) and isinstance(self.input_targets, list):
|
|
103
129
|
return prepare_patches_supervised(
|
|
104
130
|
self.inputs,
|
|
105
|
-
self.
|
|
131
|
+
self.input_targets,
|
|
106
132
|
self.axes,
|
|
107
133
|
self.patch_size,
|
|
108
134
|
self.read_source_func,
|
|
@@ -111,7 +137,7 @@ class InMemoryDataset(Dataset):
|
|
|
111
137
|
raise ValueError(
|
|
112
138
|
f"Data and target must be of the same type, either both numpy "
|
|
113
139
|
f"arrays or both lists of paths, got {type(self.inputs)} (data) "
|
|
114
|
-
f"and {type(self.
|
|
140
|
+
f"and {type(self.input_targets)} (target)."
|
|
115
141
|
)
|
|
116
142
|
else:
|
|
117
143
|
if isinstance(self.inputs, np.ndarray):
|
|
@@ -137,9 +163,9 @@ class InMemoryDataset(Dataset):
|
|
|
137
163
|
int
|
|
138
164
|
Length of the dataset.
|
|
139
165
|
"""
|
|
140
|
-
return len(self.
|
|
166
|
+
return len(self.patches)
|
|
141
167
|
|
|
142
|
-
def __getitem__(self, index: int) -> Tuple[np.ndarray]:
|
|
168
|
+
def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]:
|
|
143
169
|
"""
|
|
144
170
|
Return the patch corresponding to the provided index.
|
|
145
171
|
|
|
@@ -158,12 +184,12 @@ class InMemoryDataset(Dataset):
|
|
|
158
184
|
ValueError
|
|
159
185
|
If dataset mean and std are not set.
|
|
160
186
|
"""
|
|
161
|
-
patch = self.
|
|
187
|
+
patch = self.patches[index]
|
|
162
188
|
|
|
163
189
|
# if there is a target
|
|
164
|
-
if self.
|
|
190
|
+
if self.patch_targets is not None:
|
|
165
191
|
# get target
|
|
166
|
-
target = self.
|
|
192
|
+
target = self.patch_targets[index]
|
|
167
193
|
|
|
168
194
|
return self.patch_transform(patch=patch, target=target)
|
|
169
195
|
|
|
@@ -223,25 +249,25 @@ class InMemoryDataset(Dataset):
|
|
|
223
249
|
indices = np.random.choice(total_patches, n_patches, replace=False)
|
|
224
250
|
|
|
225
251
|
# extract patches
|
|
226
|
-
val_patches = self.
|
|
252
|
+
val_patches = self.patches[indices]
|
|
227
253
|
|
|
228
254
|
# remove patches from self.patch
|
|
229
|
-
self.
|
|
255
|
+
self.patches = np.delete(self.patches, indices, axis=0)
|
|
230
256
|
|
|
231
257
|
# same for targets
|
|
232
|
-
if self.
|
|
233
|
-
val_targets = self.
|
|
234
|
-
self.
|
|
258
|
+
if self.patch_targets is not None:
|
|
259
|
+
val_targets = self.patch_targets[indices]
|
|
260
|
+
self.patch_targets = np.delete(self.patch_targets, indices, axis=0)
|
|
235
261
|
|
|
236
262
|
# clone the dataset
|
|
237
263
|
dataset = copy.deepcopy(self)
|
|
238
264
|
|
|
239
265
|
# reassign patches
|
|
240
|
-
dataset.
|
|
266
|
+
dataset.patches = val_patches
|
|
241
267
|
|
|
242
268
|
# reassign targets
|
|
243
|
-
if self.
|
|
244
|
-
dataset.
|
|
269
|
+
if self.patch_targets is not None:
|
|
270
|
+
dataset.patch_targets = val_targets
|
|
245
271
|
|
|
246
272
|
return dataset
|
|
247
273
|
|
|
@@ -250,7 +276,16 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
250
276
|
"""
|
|
251
277
|
Dataset storing data in memory and allowing generating patches from it.
|
|
252
278
|
|
|
253
|
-
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
prediction_config : InferenceConfig
|
|
282
|
+
Prediction configuration.
|
|
283
|
+
inputs : np.ndarray
|
|
284
|
+
Input data.
|
|
285
|
+
data_target : Optional[np.ndarray], optional
|
|
286
|
+
Target data, by default None.
|
|
287
|
+
read_source_func : Optional[Callable], optional
|
|
288
|
+
Read source function for custom types, by default read_tiff.
|
|
254
289
|
"""
|
|
255
290
|
|
|
256
291
|
def __init__(
|
|
@@ -264,10 +299,14 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
264
299
|
|
|
265
300
|
Parameters
|
|
266
301
|
----------
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
302
|
+
prediction_config : InferenceConfig
|
|
303
|
+
Prediction configuration.
|
|
304
|
+
inputs : np.ndarray
|
|
305
|
+
Input data.
|
|
306
|
+
data_target : Optional[np.ndarray], optional
|
|
307
|
+
Target data, by default None.
|
|
308
|
+
read_source_func : Optional[Callable], optional
|
|
309
|
+
Read source function for custom types, by default read_tiff.
|
|
271
310
|
|
|
272
311
|
Raises
|
|
273
312
|
------
|
|
@@ -295,7 +334,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
295
334
|
|
|
296
335
|
# get transforms
|
|
297
336
|
self.patch_transform = Compose(
|
|
298
|
-
transform_list=self.
|
|
337
|
+
transform_list=[NormalizeModel(mean=self.mean, std=self.std)],
|
|
299
338
|
)
|
|
300
339
|
|
|
301
340
|
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
@@ -310,7 +349,7 @@ class InMemoryPredictionDataset(Dataset):
|
|
|
310
349
|
# reshape array
|
|
311
350
|
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
312
351
|
|
|
313
|
-
if self.tiling:
|
|
352
|
+
if self.tiling and self.tile_size is not None and self.tile_overlap is not None:
|
|
314
353
|
# generate patches, which returns a generator
|
|
315
354
|
patch_generator = extract_tiles(
|
|
316
355
|
arr=reshaped_sample,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Iterable dataset used to load data file by file."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import copy
|
|
@@ -11,6 +13,7 @@ from careamics.transforms import Compose
|
|
|
11
13
|
|
|
12
14
|
from ..config import DataConfig, InferenceConfig
|
|
13
15
|
from ..config.tile_information import TileInformation
|
|
16
|
+
from ..config.transformations import NormalizeModel
|
|
14
17
|
from ..utils.logging import get_logger
|
|
15
18
|
from .dataset_utils import read_tiff, reshape_array
|
|
16
19
|
from .patching.random_patching import extract_patches_random
|
|
@@ -19,13 +22,85 @@ from .patching.tiled_patching import extract_tiles
|
|
|
19
22
|
logger = get_logger(__name__)
|
|
20
23
|
|
|
21
24
|
|
|
25
|
+
def _iterate_over_files(
|
|
26
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
27
|
+
data_files: List[Path],
|
|
28
|
+
target_files: Optional[List[Path]] = None,
|
|
29
|
+
read_source_func: Callable = read_tiff,
|
|
30
|
+
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
31
|
+
"""
|
|
32
|
+
Iterate over data source and yield whole image.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
data_config : Union[DataConfig, InferenceConfig]
|
|
37
|
+
Data configuration.
|
|
38
|
+
data_files : List[Path]
|
|
39
|
+
List of data files.
|
|
40
|
+
target_files : Optional[List[Path]]
|
|
41
|
+
List of target files, by default None.
|
|
42
|
+
read_source_func : Optional[Callable]
|
|
43
|
+
Function to read the source, by default read_tiff.
|
|
44
|
+
|
|
45
|
+
Yields
|
|
46
|
+
------
|
|
47
|
+
np.ndarray
|
|
48
|
+
Image.
|
|
49
|
+
"""
|
|
50
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
51
|
+
# dataset object
|
|
52
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
53
|
+
# from the workers
|
|
54
|
+
worker_info = get_worker_info()
|
|
55
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
56
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
57
|
+
|
|
58
|
+
# iterate over the files
|
|
59
|
+
for i, filename in enumerate(data_files):
|
|
60
|
+
# retrieve file corresponding to the worker id
|
|
61
|
+
if i % num_workers == worker_id:
|
|
62
|
+
try:
|
|
63
|
+
# read data
|
|
64
|
+
sample = read_source_func(filename, data_config.axes)
|
|
65
|
+
|
|
66
|
+
# read target, if available
|
|
67
|
+
if target_files is not None:
|
|
68
|
+
if filename.name != target_files[i].name:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"File {filename} does not match target file "
|
|
71
|
+
f"{target_files[i]}. Have you passed sorted "
|
|
72
|
+
f"arrays?"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# read target
|
|
76
|
+
target = read_source_func(target_files[i], data_config.axes)
|
|
77
|
+
|
|
78
|
+
yield sample, target
|
|
79
|
+
else:
|
|
80
|
+
yield sample, None
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"Error reading file {filename}: {e}")
|
|
84
|
+
|
|
85
|
+
|
|
22
86
|
class PathIterableDataset(IterableDataset):
|
|
23
87
|
"""
|
|
24
88
|
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
25
89
|
|
|
26
90
|
Parameters
|
|
27
91
|
----------
|
|
28
|
-
|
|
92
|
+
data_config : DataConfig
|
|
93
|
+
Data configuration.
|
|
94
|
+
src_files : List[Path]
|
|
95
|
+
List of data files.
|
|
96
|
+
target_files : Optional[List[Path]], optional
|
|
97
|
+
Optional list of target files, by default None.
|
|
98
|
+
read_source_func : Callable, optional
|
|
99
|
+
Read source function for custom types, by default read_tiff.
|
|
100
|
+
|
|
101
|
+
Attributes
|
|
102
|
+
----------
|
|
103
|
+
data_path : List[Path]
|
|
29
104
|
Path to the data, must be a directory.
|
|
30
105
|
axes : str
|
|
31
106
|
Description of axes in format STCZYX.
|
|
@@ -45,11 +120,24 @@ class PathIterableDataset(IterableDataset):
|
|
|
45
120
|
|
|
46
121
|
def __init__(
|
|
47
122
|
self,
|
|
48
|
-
data_config:
|
|
123
|
+
data_config: DataConfig,
|
|
49
124
|
src_files: List[Path],
|
|
50
125
|
target_files: Optional[List[Path]] = None,
|
|
51
126
|
read_source_func: Callable = read_tiff,
|
|
52
127
|
) -> None:
|
|
128
|
+
"""Constructors.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
data_config : DataConfig
|
|
133
|
+
Data configuration.
|
|
134
|
+
src_files : List[Path]
|
|
135
|
+
List of data files.
|
|
136
|
+
target_files : Optional[List[Path]], optional
|
|
137
|
+
Optional list of target files, by default None.
|
|
138
|
+
read_source_func : Callable, optional
|
|
139
|
+
Read source function for custom types, by default read_tiff.
|
|
140
|
+
"""
|
|
53
141
|
self.data_config = data_config
|
|
54
142
|
self.data_files = src_files
|
|
55
143
|
self.target_files = target_files
|
|
@@ -82,7 +170,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
82
170
|
means, stds = 0, 0
|
|
83
171
|
num_samples = 0
|
|
84
172
|
|
|
85
|
-
for sample, _ in
|
|
173
|
+
for sample, _ in _iterate_over_files(
|
|
174
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
175
|
+
):
|
|
86
176
|
means += sample.mean()
|
|
87
177
|
stds += sample.std()
|
|
88
178
|
num_samples += 1
|
|
@@ -97,57 +187,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
97
187
|
logger.info(f"Mean: {result_mean}, std: {result_std}")
|
|
98
188
|
return result_mean, result_std
|
|
99
189
|
|
|
100
|
-
def _iterate_over_files(
|
|
101
|
-
self,
|
|
102
|
-
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
103
|
-
"""
|
|
104
|
-
Iterate over data source and yield whole image.
|
|
105
|
-
|
|
106
|
-
Yields
|
|
107
|
-
------
|
|
108
|
-
np.ndarray
|
|
109
|
-
Image.
|
|
110
|
-
"""
|
|
111
|
-
# When num_workers > 0, each worker process will have a different copy of the
|
|
112
|
-
# dataset object
|
|
113
|
-
# Configuring each copy independently to avoid having duplicate data returned
|
|
114
|
-
# from the workers
|
|
115
|
-
worker_info = get_worker_info()
|
|
116
|
-
worker_id = worker_info.id if worker_info is not None else 0
|
|
117
|
-
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
118
|
-
|
|
119
|
-
# iterate over the files
|
|
120
|
-
for i, filename in enumerate(self.data_files):
|
|
121
|
-
# retrieve file corresponding to the worker id
|
|
122
|
-
if i % num_workers == worker_id:
|
|
123
|
-
try:
|
|
124
|
-
# read data
|
|
125
|
-
sample = self.read_source_func(filename, self.data_config.axes)
|
|
126
|
-
|
|
127
|
-
# read target, if available
|
|
128
|
-
if self.target_files is not None:
|
|
129
|
-
if filename.name != self.target_files[i].name:
|
|
130
|
-
raise ValueError(
|
|
131
|
-
f"File {filename} does not match target file "
|
|
132
|
-
f"{self.target_files[i]}. Have you passed sorted "
|
|
133
|
-
f"arrays?"
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
# read target
|
|
137
|
-
target = self.read_source_func(
|
|
138
|
-
self.target_files[i], self.data_config.axes
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
yield sample, target
|
|
142
|
-
else:
|
|
143
|
-
yield sample, None
|
|
144
|
-
|
|
145
|
-
except Exception as e:
|
|
146
|
-
logger.error(f"Error reading file {filename}: {e}")
|
|
147
|
-
|
|
148
190
|
def __iter__(
|
|
149
191
|
self,
|
|
150
|
-
) -> Generator[Tuple[np.ndarray,
|
|
192
|
+
) -> Generator[Tuple[np.ndarray, ...], None, None]:
|
|
151
193
|
"""
|
|
152
194
|
Iterate over data source and yield single patch.
|
|
153
195
|
|
|
@@ -161,7 +203,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
161
203
|
), "Mean and std must be provided"
|
|
162
204
|
|
|
163
205
|
# iterate over files
|
|
164
|
-
for sample_input, sample_target in
|
|
206
|
+
for sample_input, sample_target in _iterate_over_files(
|
|
207
|
+
self.data_config, self.data_files, self.target_files, self.read_source_func
|
|
208
|
+
):
|
|
165
209
|
reshaped_sample = reshape_array(sample_input, self.data_config.axes)
|
|
166
210
|
reshaped_target = (
|
|
167
211
|
None
|
|
@@ -209,9 +253,9 @@ class PathIterableDataset(IterableDataset):
|
|
|
209
253
|
Parameters
|
|
210
254
|
----------
|
|
211
255
|
percentage : float, optional
|
|
212
|
-
Percentage of files to split up, by default 0.1
|
|
256
|
+
Percentage of files to split up, by default 0.1.
|
|
213
257
|
minimum_number : int, optional
|
|
214
|
-
Minimum number of files to split up, by default 5
|
|
258
|
+
Minimum number of files to split up, by default 5.
|
|
215
259
|
|
|
216
260
|
Returns
|
|
217
261
|
-------
|
|
@@ -275,12 +319,23 @@ class PathIterableDataset(IterableDataset):
|
|
|
275
319
|
return dataset
|
|
276
320
|
|
|
277
321
|
|
|
278
|
-
class IterablePredictionDataset(
|
|
322
|
+
class IterablePredictionDataset(IterableDataset):
|
|
279
323
|
"""
|
|
280
|
-
|
|
324
|
+
Prediction dataset.
|
|
281
325
|
|
|
282
326
|
Parameters
|
|
283
327
|
----------
|
|
328
|
+
prediction_config : InferenceConfig
|
|
329
|
+
Inference configuration.
|
|
330
|
+
src_files : List[Path]
|
|
331
|
+
List of data files.
|
|
332
|
+
read_source_func : Callable, optional
|
|
333
|
+
Read source function for custom types, by default read_tiff.
|
|
334
|
+
**kwargs : Any
|
|
335
|
+
Additional keyword arguments, unused.
|
|
336
|
+
|
|
337
|
+
Attributes
|
|
338
|
+
----------
|
|
284
339
|
data_path : Union[str, Path]
|
|
285
340
|
Path to the data, must be a directory.
|
|
286
341
|
axes : str
|
|
@@ -300,13 +355,26 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
300
355
|
read_source_func: Callable = read_tiff,
|
|
301
356
|
**kwargs: Any,
|
|
302
357
|
) -> None:
|
|
303
|
-
|
|
304
|
-
data_config=prediction_config,
|
|
305
|
-
src_files=src_files,
|
|
306
|
-
read_source_func=read_source_func,
|
|
307
|
-
)
|
|
358
|
+
"""Constructor.
|
|
308
359
|
|
|
360
|
+
Parameters
|
|
361
|
+
----------
|
|
362
|
+
prediction_config : InferenceConfig
|
|
363
|
+
Inference configuration.
|
|
364
|
+
src_files : List[Path]
|
|
365
|
+
List of data files.
|
|
366
|
+
read_source_func : Callable, optional
|
|
367
|
+
Read source function for custom types, by default read_tiff.
|
|
368
|
+
**kwargs : Any
|
|
369
|
+
Additional keyword arguments, unused.
|
|
370
|
+
|
|
371
|
+
Raises
|
|
372
|
+
------
|
|
373
|
+
ValueError
|
|
374
|
+
If mean and std are not provided in the inference configuration.
|
|
375
|
+
"""
|
|
309
376
|
self.prediction_config = prediction_config
|
|
377
|
+
self.data_files = src_files
|
|
310
378
|
self.axes = prediction_config.axes
|
|
311
379
|
self.tile_size = self.prediction_config.tile_size
|
|
312
380
|
self.tile_overlap = self.prediction_config.tile_overlap
|
|
@@ -315,10 +383,21 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
315
383
|
# tile only if both tile size and overlaps are provided
|
|
316
384
|
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
317
385
|
|
|
318
|
-
#
|
|
319
|
-
self.
|
|
320
|
-
|
|
321
|
-
|
|
386
|
+
# check mean and std and create normalize transform
|
|
387
|
+
if self.prediction_config.mean is None or self.prediction_config.std is None:
|
|
388
|
+
raise ValueError("Mean and std must be provided for prediction.")
|
|
389
|
+
else:
|
|
390
|
+
self.mean = self.prediction_config.mean
|
|
391
|
+
self.std = self.prediction_config.std
|
|
392
|
+
|
|
393
|
+
# instantiate normalize transform
|
|
394
|
+
self.patch_transform = Compose(
|
|
395
|
+
transform_list=[
|
|
396
|
+
NormalizeModel(
|
|
397
|
+
mean=prediction_config.mean, std=prediction_config.std
|
|
398
|
+
)
|
|
399
|
+
],
|
|
400
|
+
)
|
|
322
401
|
|
|
323
402
|
def __iter__(
|
|
324
403
|
self,
|
|
@@ -335,11 +414,19 @@ class IterablePredictionDataset(PathIterableDataset):
|
|
|
335
414
|
self.mean is not None and self.std is not None
|
|
336
415
|
), "Mean and std must be provided"
|
|
337
416
|
|
|
338
|
-
for sample, _ in
|
|
417
|
+
for sample, _ in _iterate_over_files(
|
|
418
|
+
self.prediction_config,
|
|
419
|
+
self.data_files,
|
|
420
|
+
read_source_func=self.read_source_func,
|
|
421
|
+
):
|
|
339
422
|
# reshape array
|
|
340
423
|
reshaped_sample = reshape_array(sample, self.axes)
|
|
341
424
|
|
|
342
|
-
if
|
|
425
|
+
if (
|
|
426
|
+
self.tile
|
|
427
|
+
and self.tile_size is not None
|
|
428
|
+
and self.tile_overlap is not None
|
|
429
|
+
):
|
|
343
430
|
# generate patches, return a generator
|
|
344
431
|
patch_gen = extract_tiles(
|
|
345
432
|
arr=reshaped_sample,
|
|
@@ -1,8 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tiling submodule.
|
|
3
|
-
|
|
4
|
-
These functions are used to tile images into patches or tiles.
|
|
5
|
-
"""
|
|
1
|
+
"""Patching functions."""
|
|
6
2
|
|
|
7
3
|
from pathlib import Path
|
|
8
4
|
from typing import Callable, List, Tuple, Union
|
|
@@ -21,12 +17,25 @@ def prepare_patches_supervised(
|
|
|
21
17
|
train_files: List[Path],
|
|
22
18
|
target_files: List[Path],
|
|
23
19
|
axes: str,
|
|
24
|
-
patch_size: Union[List[int], Tuple[int]],
|
|
20
|
+
patch_size: Union[List[int], Tuple[int, ...]],
|
|
25
21
|
read_source_func: Callable,
|
|
26
22
|
) -> Tuple[np.ndarray, np.ndarray, float, float]:
|
|
27
23
|
"""
|
|
28
24
|
Iterate over data source and create an array of patches and corresponding targets.
|
|
29
25
|
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
train_files : List[Path]
|
|
29
|
+
List of paths to training data.
|
|
30
|
+
target_files : List[Path]
|
|
31
|
+
List of paths to target data.
|
|
32
|
+
axes : str
|
|
33
|
+
Axes of the data.
|
|
34
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
35
|
+
Size of the patches.
|
|
36
|
+
read_source_func : Callable
|
|
37
|
+
Function to read the data.
|
|
38
|
+
|
|
30
39
|
Returns
|
|
31
40
|
-------
|
|
32
41
|
np.ndarray
|
|
@@ -95,13 +104,25 @@ def prepare_patches_unsupervised(
|
|
|
95
104
|
patch_size: Union[List[int], Tuple[int]],
|
|
96
105
|
read_source_func: Callable,
|
|
97
106
|
) -> Tuple[np.ndarray, None, float, float]:
|
|
98
|
-
"""
|
|
99
|
-
|
|
107
|
+
"""Iterate over data source and create an array of patches.
|
|
108
|
+
|
|
109
|
+
This method returns the mean and standard deviation of the image.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
train_files : List[Path]
|
|
114
|
+
List of paths to training data.
|
|
115
|
+
axes : str
|
|
116
|
+
Axes of the data.
|
|
117
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
118
|
+
Size of the patches.
|
|
119
|
+
read_source_func : Callable
|
|
120
|
+
Function to read the data.
|
|
100
121
|
|
|
101
122
|
Returns
|
|
102
123
|
-------
|
|
103
|
-
np.ndarray
|
|
104
|
-
|
|
124
|
+
Tuple[np.ndarray, None, float, float]
|
|
125
|
+
Source and target patches, mean and standard deviation.
|
|
105
126
|
"""
|
|
106
127
|
means, stds, num_samples = 0, 0, 0
|
|
107
128
|
all_patches = []
|
|
@@ -150,10 +171,21 @@ def prepare_patches_supervised_array(
|
|
|
150
171
|
|
|
151
172
|
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
152
173
|
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
data : np.ndarray
|
|
177
|
+
Input data array.
|
|
178
|
+
axes : str
|
|
179
|
+
Axes of the data.
|
|
180
|
+
data_target : np.ndarray
|
|
181
|
+
Target data array.
|
|
182
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
183
|
+
Size of the patches.
|
|
184
|
+
|
|
153
185
|
Returns
|
|
154
186
|
-------
|
|
155
|
-
np.ndarray
|
|
156
|
-
|
|
187
|
+
Tuple[np.ndarray, np.ndarray, float, float]
|
|
188
|
+
Source and target patches, mean and standard deviation.
|
|
157
189
|
"""
|
|
158
190
|
# compute statistics
|
|
159
191
|
mean = data.mean()
|
|
@@ -195,10 +227,19 @@ def prepare_patches_unsupervised_array(
|
|
|
195
227
|
|
|
196
228
|
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
197
229
|
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
data : np.ndarray
|
|
233
|
+
Input data array.
|
|
234
|
+
axes : str
|
|
235
|
+
Axes of the data.
|
|
236
|
+
patch_size : Union[List[int], Tuple[int]]
|
|
237
|
+
Size of the patches.
|
|
238
|
+
|
|
198
239
|
Returns
|
|
199
240
|
-------
|
|
200
|
-
np.ndarray
|
|
201
|
-
|
|
241
|
+
Tuple[np.ndarray, None, float, float]
|
|
242
|
+
Source patches, mean and standard deviation.
|
|
202
243
|
"""
|
|
203
244
|
# calculate mean and std
|
|
204
245
|
mean = data.mean()
|
|
@@ -210,4 +251,4 @@ def prepare_patches_unsupervised_array(
|
|
|
210
251
|
# generate patches, return a generator
|
|
211
252
|
patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
|
|
212
253
|
|
|
213
|
-
return patches, _, mean, std # TODO inelegant, replace
|
|
254
|
+
return patches, _, mean, std # TODO inelegant, replace by dataclass?
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Random patching utilities."""
|
|
2
|
+
|
|
1
3
|
from typing import Generator, List, Optional, Tuple, Union
|
|
2
4
|
|
|
3
5
|
import numpy as np
|
|
@@ -30,6 +32,8 @@ def extract_patches_random(
|
|
|
30
32
|
Input image array.
|
|
31
33
|
patch_size : Tuple[int]
|
|
32
34
|
Patch sizes in each dimension.
|
|
35
|
+
target : Optional[np.ndarray], optional
|
|
36
|
+
Target array, by default None.
|
|
33
37
|
|
|
34
38
|
Yields
|
|
35
39
|
------
|
|
@@ -120,10 +124,12 @@ def extract_patches_random_from_chunks(
|
|
|
120
124
|
----------
|
|
121
125
|
arr : np.ndarray
|
|
122
126
|
Input image array.
|
|
123
|
-
patch_size : Tuple[int]
|
|
127
|
+
patch_size : Union[List[int], Tuple[int, ...]]
|
|
124
128
|
Patch sizes in each dimension.
|
|
125
|
-
chunk_size : Tuple[int]
|
|
129
|
+
chunk_size : Union[List[int], Tuple[int, ...]]
|
|
126
130
|
Chunk sizes to load from the.
|
|
131
|
+
chunk_limit : Optional[int], optional
|
|
132
|
+
Number of chunks to load, by default None.
|
|
127
133
|
|
|
128
134
|
Yields
|
|
129
135
|
------
|