careamics 0.1.0rc2__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +16 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +31 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +597 -0
- careamics/config/configuration_model.py +597 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +323 -134
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +743 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +396 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -14
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -221
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -12
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +112 -75
- careamics-0.1.0rc4.dist-info/METADATA +122 -0
- careamics-0.1.0rc4.dist-info/RECORD +110 -0
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -182
- careamics/bioimage/rdf.py +0 -105
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -297
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -111
- careamics/dataset/patching.py +0 -492
- careamics/dataset/prepare_dataset.py +0 -175
- careamics/dataset/tiff_dataset.py +0 -212
- careamics/engine.py +0 -1014
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -106
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -170
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc2.dist-info/METADATA +0 -81
- careamics-0.1.0rc2.dist-info/RECORD +0 -47
- {careamics-0.1.0rc2.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from torch.utils.data import IterableDataset, get_worker_info
|
|
9
|
+
|
|
10
|
+
from ..config import DataConfig, InferenceConfig
|
|
11
|
+
from ..config.tile_information import TileInformation
|
|
12
|
+
from ..utils.logging import get_logger
|
|
13
|
+
from .dataset_utils import read_tiff, reshape_array
|
|
14
|
+
from .patching import (
|
|
15
|
+
get_patch_transform,
|
|
16
|
+
)
|
|
17
|
+
from .patching.random_patching import extract_patches_random
|
|
18
|
+
from .patching.tiled_patching import extract_tiles
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PathIterableDataset(IterableDataset):
|
|
24
|
+
"""
|
|
25
|
+
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
data_path : Union[str, Path]
|
|
30
|
+
Path to the data, must be a directory.
|
|
31
|
+
axes : str
|
|
32
|
+
Description of axes in format STCZYX.
|
|
33
|
+
patch_extraction_method : Union[ExtractionStrategies, None]
|
|
34
|
+
Patch extraction strategy, as defined in extraction_strategy.
|
|
35
|
+
patch_size : Optional[Union[List[int], Tuple[int]]], optional
|
|
36
|
+
Size of the patches in each dimension, by default None.
|
|
37
|
+
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
38
|
+
Overlap of the patches in each dimension, by default None.
|
|
39
|
+
mean : Optional[float], optional
|
|
40
|
+
Expected mean of the dataset, by default None.
|
|
41
|
+
std : Optional[float], optional
|
|
42
|
+
Expected standard deviation of the dataset, by default None.
|
|
43
|
+
patch_transform : Optional[Callable], optional
|
|
44
|
+
Patch transform callable, by default None.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
data_config: Union[DataConfig, InferenceConfig],
|
|
50
|
+
src_files: List[Path],
|
|
51
|
+
target_files: Optional[List[Path]] = None,
|
|
52
|
+
read_source_func: Callable = read_tiff,
|
|
53
|
+
) -> None:
|
|
54
|
+
self.data_config = data_config
|
|
55
|
+
self.data_files = src_files
|
|
56
|
+
self.target_files = target_files
|
|
57
|
+
self.data_config = data_config
|
|
58
|
+
self.read_source_func = read_source_func
|
|
59
|
+
|
|
60
|
+
# compute mean and std over the dataset
|
|
61
|
+
if not data_config.mean or not data_config.std:
|
|
62
|
+
self.mean, self.std = self._calculate_mean_and_std()
|
|
63
|
+
|
|
64
|
+
# if the transforms are not an instance of Compose
|
|
65
|
+
# Check if the data_config is an instance of DataModel or InferenceModel
|
|
66
|
+
# isinstance isn't working properly here
|
|
67
|
+
if hasattr(data_config, "has_transform_list"):
|
|
68
|
+
if data_config.has_transform_list():
|
|
69
|
+
# update mean and std in configuration
|
|
70
|
+
# the object is mutable and should then be recorded in the CAREamist
|
|
71
|
+
data_config.set_mean_and_std(self.mean, self.std)
|
|
72
|
+
else:
|
|
73
|
+
data_config.set_mean_and_std(self.mean, self.std)
|
|
74
|
+
|
|
75
|
+
else:
|
|
76
|
+
self.mean = data_config.mean
|
|
77
|
+
self.std = data_config.std
|
|
78
|
+
|
|
79
|
+
# get transforms
|
|
80
|
+
self.patch_transform = get_patch_transform(
|
|
81
|
+
patch_transforms=data_config.transforms,
|
|
82
|
+
with_target=target_files is not None,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _calculate_mean_and_std(self) -> Tuple[float, float]:
|
|
86
|
+
"""
|
|
87
|
+
Calculate mean and std of the dataset.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
Tuple[float, float]
|
|
92
|
+
Tuple containing mean and standard deviation.
|
|
93
|
+
"""
|
|
94
|
+
means, stds = 0, 0
|
|
95
|
+
num_samples = 0
|
|
96
|
+
|
|
97
|
+
for sample, _ in self._iterate_over_files():
|
|
98
|
+
means += sample.mean()
|
|
99
|
+
stds += sample.std()
|
|
100
|
+
num_samples += 1
|
|
101
|
+
|
|
102
|
+
if num_samples == 0:
|
|
103
|
+
raise ValueError("No samples found in the dataset.")
|
|
104
|
+
|
|
105
|
+
result_mean = means / num_samples
|
|
106
|
+
result_std = stds / num_samples
|
|
107
|
+
|
|
108
|
+
logger.info(f"Calculated mean and std for {num_samples} images")
|
|
109
|
+
logger.info(f"Mean: {result_mean}, std: {result_std}")
|
|
110
|
+
return result_mean, result_std
|
|
111
|
+
|
|
112
|
+
def _iterate_over_files(
|
|
113
|
+
self,
|
|
114
|
+
) -> Generator[Tuple[np.ndarray, Optional[np.ndarray]], None, None]:
|
|
115
|
+
"""
|
|
116
|
+
Iterate over data source and yield whole image.
|
|
117
|
+
|
|
118
|
+
Yields
|
|
119
|
+
------
|
|
120
|
+
np.ndarray
|
|
121
|
+
Image.
|
|
122
|
+
"""
|
|
123
|
+
# When num_workers > 0, each worker process will have a different copy of the
|
|
124
|
+
# dataset object
|
|
125
|
+
# Configuring each copy independently to avoid having duplicate data returned
|
|
126
|
+
# from the workers
|
|
127
|
+
worker_info = get_worker_info()
|
|
128
|
+
worker_id = worker_info.id if worker_info is not None else 0
|
|
129
|
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
|
130
|
+
|
|
131
|
+
# iterate over the files
|
|
132
|
+
for i, filename in enumerate(self.data_files):
|
|
133
|
+
# retrieve file corresponding to the worker id
|
|
134
|
+
if i % num_workers == worker_id:
|
|
135
|
+
try:
|
|
136
|
+
# read data
|
|
137
|
+
sample = self.read_source_func(filename, self.data_config.axes)
|
|
138
|
+
|
|
139
|
+
# read target, if available
|
|
140
|
+
if self.target_files is not None:
|
|
141
|
+
if filename.name != self.target_files[i].name:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"File {filename} does not match target file "
|
|
144
|
+
f"{self.target_files[i]}. Have you passed sorted "
|
|
145
|
+
f"arrays?"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# read target
|
|
149
|
+
target = self.read_source_func(
|
|
150
|
+
self.target_files[i], self.data_config.axes
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
yield sample, target
|
|
154
|
+
else:
|
|
155
|
+
yield sample, None
|
|
156
|
+
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.error(f"Error reading file {filename}: {e}")
|
|
159
|
+
|
|
160
|
+
def __iter__(
|
|
161
|
+
self,
|
|
162
|
+
) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]:
|
|
163
|
+
"""
|
|
164
|
+
Iterate over data source and yield single patch.
|
|
165
|
+
|
|
166
|
+
Yields
|
|
167
|
+
------
|
|
168
|
+
np.ndarray
|
|
169
|
+
Single patch.
|
|
170
|
+
"""
|
|
171
|
+
assert (
|
|
172
|
+
self.mean is not None and self.std is not None
|
|
173
|
+
), "Mean and std must be provided"
|
|
174
|
+
|
|
175
|
+
# iterate over files
|
|
176
|
+
for sample_input, sample_target in self._iterate_over_files():
|
|
177
|
+
reshaped_sample = reshape_array(sample_input, self.data_config.axes)
|
|
178
|
+
reshaped_target = (
|
|
179
|
+
None
|
|
180
|
+
if sample_target is None
|
|
181
|
+
else reshape_array(sample_target, self.data_config.axes)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
patches = extract_patches_random(
|
|
185
|
+
arr=reshaped_sample,
|
|
186
|
+
patch_size=self.data_config.patch_size,
|
|
187
|
+
target=reshaped_target,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# iterate over patches
|
|
191
|
+
# patches are tuples of (patch, target) if target is available
|
|
192
|
+
# or (patch, None) only if no target is available
|
|
193
|
+
# patch is of dimensions (C)ZYX
|
|
194
|
+
for patch_data in patches:
|
|
195
|
+
# if there is a target
|
|
196
|
+
if self.target_files is not None:
|
|
197
|
+
# Albumentations expects the channel dimension to be last
|
|
198
|
+
# Taking the first element because patch_data can include target
|
|
199
|
+
c_patch = np.moveaxis(patch_data[0], 0, -1)
|
|
200
|
+
c_target = np.moveaxis(patch_data[1], 0, -1)
|
|
201
|
+
|
|
202
|
+
# apply the transform to the patch and the target
|
|
203
|
+
transformed = self.patch_transform(
|
|
204
|
+
image=c_patch,
|
|
205
|
+
target=c_target,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# move the axes back to the original position
|
|
209
|
+
c_patch = np.moveaxis(transformed["image"], -1, 0)
|
|
210
|
+
c_target = np.moveaxis(transformed["target"], -1, 0)
|
|
211
|
+
|
|
212
|
+
yield (c_patch, c_target)
|
|
213
|
+
elif self.data_config.has_n2v_manipulate():
|
|
214
|
+
# Albumentations expects the channel dimension to be last
|
|
215
|
+
# Taking the first element because patch_data can include target
|
|
216
|
+
patch = np.moveaxis(patch_data[0], 0, -1)
|
|
217
|
+
|
|
218
|
+
# apply transform
|
|
219
|
+
transformed = self.patch_transform(image=patch)
|
|
220
|
+
|
|
221
|
+
# retrieve the output of ManipulateN2V
|
|
222
|
+
results = transformed["image"]
|
|
223
|
+
masked_patch: np.ndarray = results[0]
|
|
224
|
+
original_patch: np.ndarray = results[1]
|
|
225
|
+
mask: np.ndarray = results[2]
|
|
226
|
+
|
|
227
|
+
# move C axes back
|
|
228
|
+
masked_patch = np.moveaxis(masked_patch, -1, 0)
|
|
229
|
+
original_patch = np.moveaxis(original_patch, -1, 0)
|
|
230
|
+
mask = np.moveaxis(mask, -1, 0)
|
|
231
|
+
|
|
232
|
+
yield (masked_patch, original_patch, mask)
|
|
233
|
+
else:
|
|
234
|
+
raise ValueError(
|
|
235
|
+
"Something went wrong! Not target file (no supervised "
|
|
236
|
+
"training) and no N2V transform (no n2v training either)."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def get_number_of_files(self) -> int:
|
|
240
|
+
"""
|
|
241
|
+
Return the number of files in the dataset.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
int
|
|
246
|
+
Number of files in the dataset.
|
|
247
|
+
"""
|
|
248
|
+
return len(self.data_files)
|
|
249
|
+
|
|
250
|
+
def split_dataset(
|
|
251
|
+
self,
|
|
252
|
+
percentage: float = 0.1,
|
|
253
|
+
minimum_number: int = 5,
|
|
254
|
+
) -> PathIterableDataset:
|
|
255
|
+
"""Split up dataset in two.
|
|
256
|
+
|
|
257
|
+
Splits the datest sing a percentage of the data (files) to extract, or the
|
|
258
|
+
minimum number of the percentage is less than the minimum number.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
percentage : float, optional
|
|
263
|
+
Percentage of files to split up, by default 0.1
|
|
264
|
+
minimum_number : int, optional
|
|
265
|
+
Minimum number of files to split up, by default 5
|
|
266
|
+
|
|
267
|
+
Returns
|
|
268
|
+
-------
|
|
269
|
+
IterableDataset
|
|
270
|
+
Dataset containing the split data.
|
|
271
|
+
|
|
272
|
+
Raises
|
|
273
|
+
------
|
|
274
|
+
ValueError
|
|
275
|
+
If the percentage is smaller than 0 or larger than 1.
|
|
276
|
+
ValueError
|
|
277
|
+
If the minimum number is smaller than 1 or larger than the number of files.
|
|
278
|
+
"""
|
|
279
|
+
if percentage < 0 or percentage > 1:
|
|
280
|
+
raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
|
|
281
|
+
|
|
282
|
+
if minimum_number < 1 or minimum_number > self.get_number_of_files():
|
|
283
|
+
raise ValueError(
|
|
284
|
+
f"Minimum number of files must be between 1 and "
|
|
285
|
+
f"{self.get_number_of_files()} (number of files), got "
|
|
286
|
+
f"{minimum_number}."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# compute number of files
|
|
290
|
+
total_files = self.get_number_of_files()
|
|
291
|
+
n_files = max(round(percentage * total_files), minimum_number)
|
|
292
|
+
|
|
293
|
+
# get random indices
|
|
294
|
+
indices = np.random.choice(total_files, n_files, replace=False)
|
|
295
|
+
|
|
296
|
+
# extract files
|
|
297
|
+
val_files = [self.data_files[i] for i in indices]
|
|
298
|
+
|
|
299
|
+
# remove patches from self.patch
|
|
300
|
+
data_files = []
|
|
301
|
+
for i, file in enumerate(self.data_files):
|
|
302
|
+
if i not in indices:
|
|
303
|
+
data_files.append(file)
|
|
304
|
+
self.data_files = data_files
|
|
305
|
+
|
|
306
|
+
# same for targets
|
|
307
|
+
if self.target_files is not None:
|
|
308
|
+
val_target_files = [self.target_files[i] for i in indices]
|
|
309
|
+
|
|
310
|
+
data_target_files = []
|
|
311
|
+
for i, file in enumerate(self.target_files):
|
|
312
|
+
if i not in indices:
|
|
313
|
+
data_target_files.append(file)
|
|
314
|
+
self.target_files = data_target_files
|
|
315
|
+
|
|
316
|
+
# clone the dataset
|
|
317
|
+
dataset = copy.deepcopy(self)
|
|
318
|
+
|
|
319
|
+
# reassign patches
|
|
320
|
+
dataset.data_files = val_files
|
|
321
|
+
|
|
322
|
+
# reassign targets
|
|
323
|
+
if self.target_files is not None:
|
|
324
|
+
dataset.target_files = val_target_files
|
|
325
|
+
|
|
326
|
+
return dataset
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class IterablePredictionDataset(PathIterableDataset):
|
|
330
|
+
"""
|
|
331
|
+
Dataset allowing extracting patches w/o loading whole data into memory.
|
|
332
|
+
|
|
333
|
+
Parameters
|
|
334
|
+
----------
|
|
335
|
+
data_path : Union[str, Path]
|
|
336
|
+
Path to the data, must be a directory.
|
|
337
|
+
axes : str
|
|
338
|
+
Description of axes in format STCZYX.
|
|
339
|
+
mean : Optional[float], optional
|
|
340
|
+
Expected mean of the dataset, by default None.
|
|
341
|
+
std : Optional[float], optional
|
|
342
|
+
Expected standard deviation of the dataset, by default None.
|
|
343
|
+
patch_transform : Optional[Callable], optional
|
|
344
|
+
Patch transform callable, by default None.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(
|
|
348
|
+
self,
|
|
349
|
+
prediction_config: InferenceConfig,
|
|
350
|
+
src_files: List[Path],
|
|
351
|
+
read_source_func: Callable = read_tiff,
|
|
352
|
+
**kwargs: Any,
|
|
353
|
+
) -> None:
|
|
354
|
+
super().__init__(
|
|
355
|
+
data_config=prediction_config,
|
|
356
|
+
src_files=src_files,
|
|
357
|
+
read_source_func=read_source_func,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
self.prediction_config = prediction_config
|
|
361
|
+
self.axes = prediction_config.axes
|
|
362
|
+
self.tile_size = self.prediction_config.tile_size
|
|
363
|
+
self.tile_overlap = self.prediction_config.tile_overlap
|
|
364
|
+
self.read_source_func = read_source_func
|
|
365
|
+
|
|
366
|
+
# tile only if both tile size and overlaps are provided
|
|
367
|
+
self.tile = self.tile_size is not None and self.tile_overlap is not None
|
|
368
|
+
|
|
369
|
+
# get tta transforms
|
|
370
|
+
self.patch_transform = get_patch_transform(
|
|
371
|
+
patch_transforms=prediction_config.transforms,
|
|
372
|
+
with_target=False,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
def __iter__(
|
|
376
|
+
self,
|
|
377
|
+
) -> Generator[Tuple[np.ndarray, TileInformation], None, None]:
|
|
378
|
+
"""
|
|
379
|
+
Iterate over data source and yield single patch.
|
|
380
|
+
|
|
381
|
+
Yields
|
|
382
|
+
------
|
|
383
|
+
np.ndarray
|
|
384
|
+
Single patch.
|
|
385
|
+
"""
|
|
386
|
+
assert (
|
|
387
|
+
self.mean is not None and self.std is not None
|
|
388
|
+
), "Mean and std must be provided"
|
|
389
|
+
|
|
390
|
+
for sample, _ in self._iterate_over_files():
|
|
391
|
+
# reshape array
|
|
392
|
+
reshaped_sample = reshape_array(sample, self.axes)
|
|
393
|
+
|
|
394
|
+
if self.tile:
|
|
395
|
+
# generate patches, return a generator
|
|
396
|
+
patch_gen = extract_tiles(
|
|
397
|
+
arr=reshaped_sample,
|
|
398
|
+
tile_size=self.tile_size,
|
|
399
|
+
overlaps=self.tile_overlap,
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
# just wrap the sample in a generator with default tiling info
|
|
403
|
+
array_shape = reshaped_sample.squeeze().shape
|
|
404
|
+
patch_gen = (
|
|
405
|
+
(reshaped_sample, TileInformation(array_shape=array_shape))
|
|
406
|
+
for _ in range(1)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
# apply transform to patches
|
|
410
|
+
for patch_array, tile_info in patch_gen:
|
|
411
|
+
# albumentations expects the channel dimension to be last
|
|
412
|
+
patch = np.moveaxis(patch_array, 0, -1)
|
|
413
|
+
transformed_patch = self.patch_transform(image=patch)
|
|
414
|
+
transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
|
|
415
|
+
|
|
416
|
+
yield transformed_patch, tile_info
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
import albumentations as Aug
|
|
4
|
+
|
|
5
|
+
from careamics.config.data_model import TRANSFORMS_UNION
|
|
6
|
+
from careamics.transforms import get_all_transforms
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# TODO add some explanations on how the additional_targets is used
|
|
10
|
+
def get_patch_transform(
|
|
11
|
+
patch_transforms: Union[List[TRANSFORMS_UNION], Aug.Compose],
|
|
12
|
+
with_target: bool,
|
|
13
|
+
normalize_mask: bool = True,
|
|
14
|
+
) -> Aug.Compose:
|
|
15
|
+
"""Return a pixel manipulation function."""
|
|
16
|
+
# if we passed a Compose, we just return it
|
|
17
|
+
if isinstance(patch_transforms, Aug.Compose):
|
|
18
|
+
return patch_transforms
|
|
19
|
+
|
|
20
|
+
# empty list of transforms is a NoOp
|
|
21
|
+
elif len(patch_transforms) == 0:
|
|
22
|
+
return Aug.Compose(
|
|
23
|
+
[Aug.NoOp()],
|
|
24
|
+
additional_targets={}, # TODO this part need be checked (wrt segmentation)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# else we have a list of transforms
|
|
28
|
+
else:
|
|
29
|
+
# retrieve all transforms
|
|
30
|
+
all_transforms = get_all_transforms()
|
|
31
|
+
|
|
32
|
+
# instantiate all transforms
|
|
33
|
+
transforms = [
|
|
34
|
+
all_transforms[transform.name](**transform.model_dump())
|
|
35
|
+
for transform in patch_transforms
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
return Aug.Compose(
|
|
39
|
+
transforms,
|
|
40
|
+
# apply image aug to "target"
|
|
41
|
+
additional_targets={"target": "image"}
|
|
42
|
+
if (with_target and normalize_mask) # TODO check this
|
|
43
|
+
else {},
|
|
44
|
+
)
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tiling submodule.
|
|
3
|
+
|
|
4
|
+
These functions are used to tile images into patches or tiles.
|
|
5
|
+
"""
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Callable, List, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from ...utils.logging import get_logger
|
|
12
|
+
from ..dataset_utils import reshape_array
|
|
13
|
+
from .sequential_patching import extract_patches_sequential
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# called by in memory dataset
|
|
19
|
+
def prepare_patches_supervised(
|
|
20
|
+
train_files: List[Path],
|
|
21
|
+
target_files: List[Path],
|
|
22
|
+
axes: str,
|
|
23
|
+
patch_size: Union[List[int], Tuple[int]],
|
|
24
|
+
read_source_func: Callable,
|
|
25
|
+
) -> Tuple[np.ndarray, np.ndarray, float, float]:
|
|
26
|
+
"""
|
|
27
|
+
Iterate over data source and create an array of patches and corresponding targets.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
np.ndarray
|
|
32
|
+
Array of patches.
|
|
33
|
+
"""
|
|
34
|
+
train_files.sort()
|
|
35
|
+
target_files.sort()
|
|
36
|
+
|
|
37
|
+
means, stds, num_samples = 0, 0, 0
|
|
38
|
+
all_patches, all_targets = [], []
|
|
39
|
+
for train_filename, target_filename in zip(train_files, target_files):
|
|
40
|
+
try:
|
|
41
|
+
sample: np.ndarray = read_source_func(train_filename, axes)
|
|
42
|
+
target: np.ndarray = read_source_func(target_filename, axes)
|
|
43
|
+
means += sample.mean()
|
|
44
|
+
stds += sample.std()
|
|
45
|
+
num_samples += 1
|
|
46
|
+
|
|
47
|
+
# reshape array
|
|
48
|
+
sample = reshape_array(sample, axes)
|
|
49
|
+
target = reshape_array(target, axes)
|
|
50
|
+
|
|
51
|
+
# generate patches, return a generator
|
|
52
|
+
patches, targets = extract_patches_sequential(
|
|
53
|
+
sample, patch_size=patch_size, target=target
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# convert generator to list and add to all_patches
|
|
57
|
+
all_patches.append(patches)
|
|
58
|
+
|
|
59
|
+
# ensure targets are not None (type checking)
|
|
60
|
+
if targets is not None:
|
|
61
|
+
all_targets.append(targets)
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(f"No target found for {target_filename}.")
|
|
64
|
+
|
|
65
|
+
except Exception as e:
|
|
66
|
+
# emit warning and continue
|
|
67
|
+
logger.error(f"Failed to read {train_filename} or {target_filename}: {e}")
|
|
68
|
+
|
|
69
|
+
# raise error if no valid samples found
|
|
70
|
+
if num_samples == 0 or len(all_patches) == 0:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"No valid samples found in the input data: {train_files} and "
|
|
73
|
+
f"{target_files}."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
result_mean, result_std = means / num_samples, stds / num_samples
|
|
77
|
+
|
|
78
|
+
patch_array: np.ndarray = np.concatenate(all_patches, axis=0)
|
|
79
|
+
target_array: np.ndarray = np.concatenate(all_targets, axis=0)
|
|
80
|
+
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
81
|
+
|
|
82
|
+
return (
|
|
83
|
+
patch_array,
|
|
84
|
+
target_array,
|
|
85
|
+
result_mean,
|
|
86
|
+
result_std,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# called by in_memory_dataset
|
|
91
|
+
def prepare_patches_unsupervised(
|
|
92
|
+
train_files: List[Path],
|
|
93
|
+
axes: str,
|
|
94
|
+
patch_size: Union[List[int], Tuple[int]],
|
|
95
|
+
read_source_func: Callable,
|
|
96
|
+
) -> Tuple[np.ndarray, None, float, float]:
|
|
97
|
+
"""
|
|
98
|
+
Iterate over data source and create an array of patches.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
np.ndarray
|
|
103
|
+
Array of patches.
|
|
104
|
+
"""
|
|
105
|
+
means, stds, num_samples = 0, 0, 0
|
|
106
|
+
all_patches = []
|
|
107
|
+
for filename in train_files:
|
|
108
|
+
try:
|
|
109
|
+
sample: np.ndarray = read_source_func(filename, axes)
|
|
110
|
+
means += sample.mean()
|
|
111
|
+
stds += sample.std()
|
|
112
|
+
num_samples += 1
|
|
113
|
+
|
|
114
|
+
# reshape array
|
|
115
|
+
sample = reshape_array(sample, axes)
|
|
116
|
+
|
|
117
|
+
# generate patches, return a generator
|
|
118
|
+
patches, _ = extract_patches_sequential(sample, patch_size=patch_size)
|
|
119
|
+
|
|
120
|
+
# convert generator to list and add to all_patches
|
|
121
|
+
all_patches.append(patches)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
# emit warning and continue
|
|
124
|
+
logger.error(f"Failed to read {filename}: {e}")
|
|
125
|
+
|
|
126
|
+
# raise error if no valid samples found
|
|
127
|
+
if num_samples == 0:
|
|
128
|
+
raise ValueError(f"No valid samples found in the input data: {train_files}.")
|
|
129
|
+
|
|
130
|
+
result_mean, result_std = means / num_samples, stds / num_samples
|
|
131
|
+
|
|
132
|
+
patch_array: np.ndarray = np.concatenate(all_patches)
|
|
133
|
+
logger.info(f"Extracted {patch_array.shape[0]} patches from input array.")
|
|
134
|
+
|
|
135
|
+
return patch_array, _, result_mean, result_std # TODO return object?
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# called on arrays by in memory dataset
|
|
139
|
+
def prepare_patches_supervised_array(
|
|
140
|
+
data: np.ndarray,
|
|
141
|
+
axes: str,
|
|
142
|
+
data_target: np.ndarray,
|
|
143
|
+
patch_size: Union[List[int], Tuple[int]],
|
|
144
|
+
) -> Tuple[np.ndarray, np.ndarray, float, float]:
|
|
145
|
+
"""Iterate over data source and create an array of patches.
|
|
146
|
+
|
|
147
|
+
This method expects an array of shape SC(Z)YX, where S and C can be singleton
|
|
148
|
+
dimensions.
|
|
149
|
+
|
|
150
|
+
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
np.ndarray
|
|
155
|
+
Array of patches.
|
|
156
|
+
"""
|
|
157
|
+
# compute statistics
|
|
158
|
+
mean = data.mean()
|
|
159
|
+
std = data.std()
|
|
160
|
+
|
|
161
|
+
# reshape array
|
|
162
|
+
reshaped_sample = reshape_array(data, axes)
|
|
163
|
+
reshaped_target = reshape_array(data_target, axes)
|
|
164
|
+
|
|
165
|
+
# generate patches, return a generator
|
|
166
|
+
patches, patch_targets = extract_patches_sequential(
|
|
167
|
+
reshaped_sample, patch_size=patch_size, target=reshaped_target
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if patch_targets is None:
|
|
171
|
+
raise ValueError("No target extracted.")
|
|
172
|
+
|
|
173
|
+
logger.info(f"Extracted {patches.shape[0]} patches from input array.")
|
|
174
|
+
|
|
175
|
+
return (
|
|
176
|
+
patches,
|
|
177
|
+
patch_targets,
|
|
178
|
+
mean,
|
|
179
|
+
std,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# called by in memory dataset
|
|
184
|
+
def prepare_patches_unsupervised_array(
|
|
185
|
+
data: np.ndarray,
|
|
186
|
+
axes: str,
|
|
187
|
+
patch_size: Union[List[int], Tuple[int]],
|
|
188
|
+
) -> Tuple[np.ndarray, None, float, float]:
|
|
189
|
+
"""
|
|
190
|
+
Iterate over data source and create an array of patches.
|
|
191
|
+
|
|
192
|
+
This method expects an array of shape SC(Z)YX, where S and C can be singleton
|
|
193
|
+
dimensions.
|
|
194
|
+
|
|
195
|
+
Patches returned are of shape SC(Z)YX, where S is now the patches dimension.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
np.ndarray
|
|
200
|
+
Array of patches.
|
|
201
|
+
"""
|
|
202
|
+
# calculate mean and std
|
|
203
|
+
mean = data.mean()
|
|
204
|
+
std = data.std()
|
|
205
|
+
|
|
206
|
+
# reshape array
|
|
207
|
+
reshaped_sample = reshape_array(data, axes)
|
|
208
|
+
|
|
209
|
+
# generate patches, return a generator
|
|
210
|
+
patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size)
|
|
211
|
+
|
|
212
|
+
return patches, _, mean, std # TODO inelegant, replace by dataclass?
|