careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +14 -4
- careamics/callbacks/__init__.py +6 -0
- careamics/callbacks/hyperparameters_callback.py +42 -0
- careamics/callbacks/progress_bar_callback.py +57 -0
- careamics/careamist.py +761 -0
- careamics/config/__init__.py +27 -3
- careamics/config/algorithm_model.py +167 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +29 -0
- careamics/config/architectures/custom_model.py +150 -0
- careamics/config/architectures/register_model.py +101 -0
- careamics/config/architectures/unet_model.py +96 -0
- careamics/config/architectures/vae_model.py +39 -0
- careamics/config/callback_model.py +92 -0
- careamics/config/configuration_factory.py +460 -0
- careamics/config/configuration_model.py +596 -0
- careamics/config/data_model.py +555 -0
- careamics/config/inference_model.py +283 -0
- careamics/config/noise_models.py +162 -0
- careamics/config/optimizer_models.py +181 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +131 -0
- careamics/config/references/references.py +38 -0
- careamics/config/support/__init__.py +33 -0
- careamics/config/support/supported_activations.py +24 -0
- careamics/config/support/supported_algorithms.py +18 -0
- careamics/config/support/supported_architectures.py +18 -0
- careamics/config/support/supported_data.py +82 -0
- careamics/{dataset/extraction_strategy.py → config/support/supported_extraction_strategies.py} +5 -2
- careamics/config/support/supported_loggers.py +8 -0
- careamics/config/support/supported_losses.py +25 -0
- careamics/config/support/supported_optimizers.py +55 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +19 -0
- careamics/config/support/supported_transforms.py +23 -0
- careamics/config/tile_information.py +104 -0
- careamics/config/training_model.py +65 -0
- careamics/config/transformations/__init__.py +14 -0
- careamics/config/transformations/n2v_manipulate_model.py +63 -0
- careamics/config/transformations/nd_flip_model.py +32 -0
- careamics/config/transformations/normalize_model.py +31 -0
- careamics/config/transformations/transform_model.py +44 -0
- careamics/config/transformations/xy_random_rotate90_model.py +29 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +100 -0
- careamics/conftest.py +26 -0
- careamics/dataset/__init__.py +5 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +100 -0
- careamics/dataset/dataset_utils/file_utils.py +140 -0
- careamics/dataset/dataset_utils/read_tiff.py +61 -0
- careamics/dataset/dataset_utils/read_utils.py +25 -0
- careamics/dataset/dataset_utils/read_zarr.py +56 -0
- careamics/dataset/in_memory_dataset.py +321 -131
- careamics/dataset/iterable_dataset.py +416 -0
- careamics/dataset/patching/__init__.py +8 -0
- careamics/dataset/patching/patch_transform.py +44 -0
- careamics/dataset/patching/patching.py +212 -0
- careamics/dataset/patching/random_patching.py +190 -0
- careamics/dataset/patching/sequential_patching.py +206 -0
- careamics/dataset/patching/tiled_patching.py +158 -0
- careamics/dataset/patching/validate_patch_dimension.py +60 -0
- careamics/dataset/zarr_dataset.py +149 -0
- careamics/lightning_datamodule.py +665 -0
- careamics/lightning_module.py +292 -0
- careamics/lightning_prediction_datamodule.py +390 -0
- careamics/lightning_prediction_loop.py +116 -0
- careamics/losses/__init__.py +4 -1
- careamics/losses/loss_factory.py +24 -13
- careamics/losses/losses.py +65 -5
- careamics/losses/noise_model_factory.py +40 -0
- careamics/losses/noise_models.py +524 -0
- careamics/model_io/__init__.py +8 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +120 -0
- careamics/model_io/bioimage/bioimage_utils.py +48 -0
- careamics/model_io/bioimage/model_description.py +318 -0
- careamics/model_io/bmz_io.py +231 -0
- careamics/model_io/model_io_utils.py +80 -0
- careamics/models/__init__.py +4 -1
- careamics/models/activation.py +35 -0
- careamics/models/layers.py +244 -0
- careamics/models/model_factory.py +21 -202
- careamics/models/unet.py +46 -20
- careamics/prediction/__init__.py +1 -3
- careamics/prediction/stitch_prediction.py +73 -0
- careamics/transforms/__init__.py +41 -0
- careamics/transforms/n2v_manipulate.py +113 -0
- careamics/transforms/nd_flip.py +93 -0
- careamics/transforms/normalize.py +109 -0
- careamics/transforms/pixel_manipulation.py +383 -0
- careamics/transforms/struct_mask_parameters.py +18 -0
- careamics/transforms/tta.py +74 -0
- careamics/transforms/xy_random_rotate90.py +95 -0
- careamics/utils/__init__.py +10 -13
- careamics/utils/base_enum.py +32 -0
- careamics/utils/context.py +22 -2
- careamics/utils/metrics.py +0 -46
- careamics/utils/path_utils.py +24 -0
- careamics/utils/ram.py +13 -0
- careamics/utils/receptive_field.py +102 -0
- careamics/utils/running_stats.py +43 -0
- careamics/utils/torch_utils.py +89 -56
- careamics-0.1.0rc3.dist-info/METADATA +122 -0
- careamics-0.1.0rc3.dist-info/RECORD +109 -0
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/WHEEL +1 -1
- careamics/bioimage/__init__.py +0 -15
- careamics/bioimage/docs/Noise2Void.md +0 -5
- careamics/bioimage/docs/__init__.py +0 -1
- careamics/bioimage/io.py +0 -271
- careamics/config/algorithm.py +0 -231
- careamics/config/config.py +0 -296
- careamics/config/config_filter.py +0 -44
- careamics/config/data.py +0 -194
- careamics/config/torch_optim.py +0 -118
- careamics/config/training.py +0 -534
- careamics/dataset/dataset_utils.py +0 -115
- careamics/dataset/patching.py +0 -493
- careamics/dataset/prepare_dataset.py +0 -174
- careamics/dataset/tiff_dataset.py +0 -211
- careamics/engine.py +0 -954
- careamics/manipulation/__init__.py +0 -4
- careamics/manipulation/pixel_manipulation.py +0 -158
- careamics/prediction/prediction_utils.py +0 -102
- careamics/utils/ascii_logo.txt +0 -9
- careamics/utils/augment.py +0 -65
- careamics/utils/normalization.py +0 -55
- careamics/utils/validators.py +0 -156
- careamics/utils/wandb.py +0 -121
- careamics-0.1.0rc1.dist-info/METADATA +0 -80
- careamics-0.1.0rc1.dist-info/RECORD +0 -46
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from fnmatch import fnmatch
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import tifffile
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
from careamics.utils.logging import get_logger
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
15
|
+
"""
|
|
16
|
+
Read a tiff file and return a numpy array.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
file_path : Path
|
|
21
|
+
Path to a file.
|
|
22
|
+
axes : str
|
|
23
|
+
Description of axes in format STCZYX.
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
np.ndarray
|
|
28
|
+
Resulting array.
|
|
29
|
+
|
|
30
|
+
Raises
|
|
31
|
+
------
|
|
32
|
+
ValueError
|
|
33
|
+
If the file failed to open.
|
|
34
|
+
OSError
|
|
35
|
+
If the file failed to open.
|
|
36
|
+
ValueError
|
|
37
|
+
If the file is not a valid tiff.
|
|
38
|
+
ValueError
|
|
39
|
+
If the data dimensions are incorrect.
|
|
40
|
+
ValueError
|
|
41
|
+
If the axes length is incorrect.
|
|
42
|
+
"""
|
|
43
|
+
if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)):
|
|
44
|
+
try:
|
|
45
|
+
array = tifffile.imread(file_path)
|
|
46
|
+
except (ValueError, OSError) as e:
|
|
47
|
+
logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
|
|
48
|
+
raise e
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(f"File {file_path} is not a valid tiff.")
|
|
51
|
+
|
|
52
|
+
# check dimensions
|
|
53
|
+
# TODO or should this really be done here? probably in the LightningDataModule
|
|
54
|
+
# TODO this should also be centralized somewhere else (validate_dimensions)
|
|
55
|
+
if len(array.shape) < 2 or len(array.shape) > 6:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape} for"
|
|
58
|
+
f"file {file_path})."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return array
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Callable, Union
|
|
2
|
+
|
|
3
|
+
from careamics.config.support import SupportedData
|
|
4
|
+
|
|
5
|
+
from .read_tiff import read_tiff
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_read_func(data_type: Union[SupportedData, str]) -> Callable:
|
|
9
|
+
"""
|
|
10
|
+
Get the read function for the data type.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
data_type : SupportedData
|
|
15
|
+
Data type.
|
|
16
|
+
|
|
17
|
+
Returns
|
|
18
|
+
-------
|
|
19
|
+
Callable
|
|
20
|
+
Read function.
|
|
21
|
+
"""
|
|
22
|
+
if data_type == SupportedData.TIFF:
|
|
23
|
+
return read_tiff
|
|
24
|
+
else:
|
|
25
|
+
raise NotImplementedError(f"Data type {data_type} is not supported.")
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
from zarr import Group, core, hierarchy, storage
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def read_zarr(
|
|
7
|
+
zarr_source: Group, axes: str
|
|
8
|
+
) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]:
|
|
9
|
+
"""Reads a file and returns a pointer.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
file_path : Path
|
|
14
|
+
pathlib.Path object containing a path to a file
|
|
15
|
+
|
|
16
|
+
Returns
|
|
17
|
+
-------
|
|
18
|
+
np.ndarray
|
|
19
|
+
Pointer to zarr storage
|
|
20
|
+
|
|
21
|
+
Raises
|
|
22
|
+
------
|
|
23
|
+
ValueError, OSError
|
|
24
|
+
if a file is not a valid tiff or damaged
|
|
25
|
+
ValueError
|
|
26
|
+
if data dimensions are not 2, 3 or 4
|
|
27
|
+
ValueError
|
|
28
|
+
if axes parameter from config is not consistent with data dimensions
|
|
29
|
+
"""
|
|
30
|
+
if isinstance(zarr_source, hierarchy.Group):
|
|
31
|
+
array = zarr_source[0]
|
|
32
|
+
|
|
33
|
+
elif isinstance(zarr_source, storage.DirectoryStore):
|
|
34
|
+
raise NotImplementedError("DirectoryStore not supported yet")
|
|
35
|
+
|
|
36
|
+
elif isinstance(zarr_source, core.Array):
|
|
37
|
+
# array should be of shape (S, (C), (Z), Y, X), iterating over S ?
|
|
38
|
+
if zarr_source.dtype == "O":
|
|
39
|
+
raise NotImplementedError("Object type not supported yet")
|
|
40
|
+
else:
|
|
41
|
+
array = zarr_source
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError(f"Unsupported zarr object type {type(zarr_source)}")
|
|
44
|
+
|
|
45
|
+
# sanity check on dimensions
|
|
46
|
+
if len(array.shape) < 2 or len(array.shape) > 4:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Incorrect data dimensions. Must be 2, 3 or 4 (got {array.shape})."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# sanity check on axes length
|
|
52
|
+
if len(axes) != len(array.shape):
|
|
53
|
+
raise ValueError(f"Incorrect axes length (got {axes}).")
|
|
54
|
+
|
|
55
|
+
# arr = fix_axes(arr, axes)
|
|
56
|
+
return array
|
|
@@ -1,154 +1,356 @@
|
|
|
1
1
|
"""In-memory dataset module."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import copy
|
|
2
5
|
from pathlib import Path
|
|
3
|
-
from typing import
|
|
6
|
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
4
7
|
|
|
5
8
|
import numpy as np
|
|
6
|
-
import
|
|
9
|
+
from torch.utils.data import Dataset
|
|
7
10
|
|
|
8
|
-
from ..
|
|
11
|
+
from ..config import DataModel, InferenceModel
|
|
12
|
+
from ..config.tile_information import TileInformation
|
|
9
13
|
from ..utils.logging import get_logger
|
|
10
|
-
from .dataset_utils import
|
|
11
|
-
|
|
12
|
-
|
|
14
|
+
from .dataset_utils import read_tiff, reshape_array
|
|
15
|
+
from .patching.patch_transform import get_patch_transform
|
|
16
|
+
from .patching.patching import (
|
|
17
|
+
prepare_patches_supervised,
|
|
18
|
+
prepare_patches_supervised_array,
|
|
19
|
+
prepare_patches_unsupervised,
|
|
20
|
+
prepare_patches_unsupervised_array,
|
|
13
21
|
)
|
|
14
|
-
from .
|
|
15
|
-
from .patching import generate_patches
|
|
22
|
+
from .patching.tiled_patching import extract_tiles
|
|
16
23
|
|
|
17
24
|
logger = get_logger(__name__)
|
|
18
25
|
|
|
19
26
|
|
|
20
|
-
class InMemoryDataset(
|
|
27
|
+
class InMemoryDataset(Dataset):
|
|
28
|
+
"""Dataset storing data in memory and allowing generating patches from it."""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
data_config: DataModel,
|
|
33
|
+
inputs: Union[np.ndarray, List[Path]],
|
|
34
|
+
data_target: Optional[Union[np.ndarray, List[Path]]] = None,
|
|
35
|
+
read_source_func: Callable = read_tiff,
|
|
36
|
+
**kwargs: Any,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Constructor.
|
|
40
|
+
|
|
41
|
+
# TODO
|
|
42
|
+
"""
|
|
43
|
+
self.data_config = data_config
|
|
44
|
+
self.inputs = inputs
|
|
45
|
+
self.data_target = data_target
|
|
46
|
+
self.axes = self.data_config.axes
|
|
47
|
+
self.patch_size = self.data_config.patch_size
|
|
48
|
+
|
|
49
|
+
# read function
|
|
50
|
+
self.read_source_func = read_source_func
|
|
51
|
+
|
|
52
|
+
# Generate patches
|
|
53
|
+
supervised = self.data_target is not None
|
|
54
|
+
patches = self._prepare_patches(supervised)
|
|
55
|
+
|
|
56
|
+
# Add results to members
|
|
57
|
+
self.data, self.data_targets, computed_mean, computed_std = patches
|
|
58
|
+
|
|
59
|
+
if not self.data_config.mean or not self.data_config.std:
|
|
60
|
+
self.mean, self.std = computed_mean, computed_std
|
|
61
|
+
logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
|
|
62
|
+
|
|
63
|
+
# if the transforms are not an instance of Compose
|
|
64
|
+
if self.data_config.has_transform_list():
|
|
65
|
+
# update mean and std in configuration
|
|
66
|
+
# the object is mutable and should then be recorded in the CAREamist obj
|
|
67
|
+
self.data_config.set_mean_and_std(self.mean, self.std)
|
|
68
|
+
else:
|
|
69
|
+
self.mean, self.std = self.data_config.mean, self.data_config.std
|
|
70
|
+
|
|
71
|
+
# get transforms
|
|
72
|
+
self.patch_transform = get_patch_transform(
|
|
73
|
+
patch_transforms=self.data_config.transforms,
|
|
74
|
+
with_target=self.data_target is not None,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _prepare_patches(
|
|
78
|
+
self, supervised: bool
|
|
79
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray], float, float]:
|
|
80
|
+
"""
|
|
81
|
+
Iterate over data source and create an array of patches.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
supervised : bool
|
|
86
|
+
Whether the dataset is supervised or not.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
np.ndarray
|
|
91
|
+
Array of patches.
|
|
92
|
+
"""
|
|
93
|
+
if supervised:
|
|
94
|
+
if isinstance(self.inputs, np.ndarray) and isinstance(
|
|
95
|
+
self.data_target, np.ndarray
|
|
96
|
+
):
|
|
97
|
+
return prepare_patches_supervised_array(
|
|
98
|
+
self.inputs,
|
|
99
|
+
self.axes,
|
|
100
|
+
self.data_target,
|
|
101
|
+
self.patch_size,
|
|
102
|
+
)
|
|
103
|
+
elif isinstance(self.inputs, list) and isinstance(self.data_target, list):
|
|
104
|
+
return prepare_patches_supervised(
|
|
105
|
+
self.inputs,
|
|
106
|
+
self.data_target,
|
|
107
|
+
self.axes,
|
|
108
|
+
self.patch_size,
|
|
109
|
+
self.read_source_func,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Data and target must be of the same type, either both numpy "
|
|
114
|
+
f"arrays or both lists of paths, got {type(self.inputs)} (data) "
|
|
115
|
+
f"and {type(self.data_target)} (target)."
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
if isinstance(self.inputs, np.ndarray):
|
|
119
|
+
return prepare_patches_unsupervised_array(
|
|
120
|
+
self.inputs,
|
|
121
|
+
self.axes,
|
|
122
|
+
self.patch_size,
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
return prepare_patches_unsupervised(
|
|
126
|
+
self.inputs,
|
|
127
|
+
self.axes,
|
|
128
|
+
self.patch_size,
|
|
129
|
+
self.read_source_func,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def __len__(self) -> int:
|
|
133
|
+
"""
|
|
134
|
+
Return the length of the dataset.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
int
|
|
139
|
+
Length of the dataset.
|
|
140
|
+
"""
|
|
141
|
+
return len(self.data)
|
|
142
|
+
|
|
143
|
+
def __getitem__(self, index: int) -> Tuple[np.ndarray]:
|
|
144
|
+
"""
|
|
145
|
+
Return the patch corresponding to the provided index.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
index : int
|
|
150
|
+
Index of the patch to return.
|
|
151
|
+
|
|
152
|
+
Returns
|
|
153
|
+
-------
|
|
154
|
+
Tuple[np.ndarray]
|
|
155
|
+
Patch.
|
|
156
|
+
|
|
157
|
+
Raises
|
|
158
|
+
------
|
|
159
|
+
ValueError
|
|
160
|
+
If dataset mean and std are not set.
|
|
161
|
+
"""
|
|
162
|
+
patch = self.data[index]
|
|
163
|
+
|
|
164
|
+
# if there is a target
|
|
165
|
+
if self.data_target is not None:
|
|
166
|
+
# get target
|
|
167
|
+
target = self.data_targets[index]
|
|
168
|
+
|
|
169
|
+
# Albumentations requires Channel last
|
|
170
|
+
c_patch = np.moveaxis(patch, 0, -1)
|
|
171
|
+
c_target = np.moveaxis(target, 0, -1)
|
|
172
|
+
|
|
173
|
+
# Apply transforms
|
|
174
|
+
transformed = self.patch_transform(image=c_patch, target=c_target)
|
|
175
|
+
|
|
176
|
+
# move axes back
|
|
177
|
+
patch = np.moveaxis(transformed["image"], -1, 0)
|
|
178
|
+
target = np.moveaxis(transformed["target"], -1, 0)
|
|
179
|
+
|
|
180
|
+
return patch, target
|
|
181
|
+
|
|
182
|
+
elif self.data_config.has_n2v_manipulate():
|
|
183
|
+
# Albumentations requires Channel last
|
|
184
|
+
patch = np.moveaxis(patch, 0, -1)
|
|
185
|
+
|
|
186
|
+
# Apply transforms
|
|
187
|
+
transformed_patch = self.patch_transform(image=patch)["image"]
|
|
188
|
+
manip_patch, patch, mask = transformed_patch
|
|
189
|
+
|
|
190
|
+
# move C axes back
|
|
191
|
+
manip_patch = np.moveaxis(manip_patch, -1, 0)
|
|
192
|
+
patch = np.moveaxis(patch, -1, 0)
|
|
193
|
+
mask = np.moveaxis(mask, -1, 0)
|
|
194
|
+
|
|
195
|
+
return (manip_patch, patch, mask)
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
"Something went wrong! No target provided (not supervised training) "
|
|
199
|
+
"and no N2V manipulation (no N2V training)."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def split_dataset(
|
|
203
|
+
self,
|
|
204
|
+
percentage: float = 0.1,
|
|
205
|
+
minimum_patches: int = 1,
|
|
206
|
+
) -> InMemoryDataset:
|
|
207
|
+
"""Split a new dataset away from the current one.
|
|
208
|
+
|
|
209
|
+
This method is used to extract random validation patches from the dataset.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
percentage : float, optional
|
|
214
|
+
Percentage of patches to extract, by default 0.1.
|
|
215
|
+
minimum_patches : int, optional
|
|
216
|
+
Minimum number of patches to extract, by default 5.
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
InMemoryDataset
|
|
221
|
+
New dataset with the extracted patches.
|
|
222
|
+
|
|
223
|
+
Raises
|
|
224
|
+
------
|
|
225
|
+
ValueError
|
|
226
|
+
If `percentage` is not between 0 and 1.
|
|
227
|
+
ValueError
|
|
228
|
+
If `minimum_number` is not between 1 and the number of patches.
|
|
229
|
+
"""
|
|
230
|
+
if percentage < 0 or percentage > 1:
|
|
231
|
+
raise ValueError(f"Percentage must be between 0 and 1, got {percentage}.")
|
|
232
|
+
|
|
233
|
+
if minimum_patches < 1 or minimum_patches > len(self):
|
|
234
|
+
raise ValueError(
|
|
235
|
+
f"Minimum number of patches must be between 1 and "
|
|
236
|
+
f"{len(self)} (number of patches), got "
|
|
237
|
+
f"{minimum_patches}. Adjust the patch size or the minimum number of "
|
|
238
|
+
f"patches."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
total_patches = len(self)
|
|
242
|
+
|
|
243
|
+
# number of patches to extract (either percentage rounded or minimum number)
|
|
244
|
+
n_patches = max(round(total_patches * percentage), minimum_patches)
|
|
245
|
+
|
|
246
|
+
# get random indices
|
|
247
|
+
indices = np.random.choice(total_patches, n_patches, replace=False)
|
|
248
|
+
|
|
249
|
+
# extract patches
|
|
250
|
+
val_patches = self.data[indices]
|
|
251
|
+
|
|
252
|
+
# remove patches from self.patch
|
|
253
|
+
self.data = np.delete(self.data, indices, axis=0)
|
|
254
|
+
|
|
255
|
+
# same for targets
|
|
256
|
+
if self.data_targets is not None:
|
|
257
|
+
val_targets = self.data_targets[indices]
|
|
258
|
+
self.data_targets = np.delete(self.data_targets, indices, axis=0)
|
|
259
|
+
|
|
260
|
+
# clone the dataset
|
|
261
|
+
dataset = copy.deepcopy(self)
|
|
262
|
+
|
|
263
|
+
# reassign patches
|
|
264
|
+
dataset.data = val_patches
|
|
265
|
+
|
|
266
|
+
# reassign targets
|
|
267
|
+
if self.data_targets is not None:
|
|
268
|
+
dataset.data_targets = val_targets
|
|
269
|
+
|
|
270
|
+
return dataset
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class InMemoryPredictionDataset(Dataset):
|
|
21
274
|
"""
|
|
22
275
|
Dataset storing data in memory and allowing generating patches from it.
|
|
23
276
|
|
|
24
|
-
|
|
25
|
-
----------
|
|
26
|
-
data_path : Union[str, Path]
|
|
27
|
-
Path to the data, must be a directory.
|
|
28
|
-
data_format : str
|
|
29
|
-
Extension of the data files, without period.
|
|
30
|
-
axes : str
|
|
31
|
-
Description of axes in format STCZYX.
|
|
32
|
-
patch_extraction_method : ExtractionStrategies
|
|
33
|
-
Patch extraction strategy, as defined in extraction_strategy.
|
|
34
|
-
patch_size : Union[List[int], Tuple[int]]
|
|
35
|
-
Size of the patches along each axis, must be of dimension 2 or 3.
|
|
36
|
-
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
37
|
-
Overlap of the patches, must be of dimension 2 or 3, by default None.
|
|
38
|
-
mean : Optional[float], optional
|
|
39
|
-
Expected mean of the dataset, by default None.
|
|
40
|
-
std : Optional[float], optional
|
|
41
|
-
Expected standard deviation of the dataset, by default None.
|
|
42
|
-
patch_transform : Optional[Callable], optional
|
|
43
|
-
Patch transform to apply, by default None.
|
|
44
|
-
patch_transform_params : Optional[Dict], optional
|
|
45
|
-
Patch transform parameters, by default None.
|
|
277
|
+
# TODO
|
|
46
278
|
"""
|
|
47
279
|
|
|
48
280
|
def __init__(
|
|
49
281
|
self,
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
patch_size: Union[List[int], Tuple[int]],
|
|
55
|
-
patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
|
|
56
|
-
mean: Optional[float] = None,
|
|
57
|
-
std: Optional[float] = None,
|
|
58
|
-
patch_transform: Optional[Callable] = None,
|
|
59
|
-
patch_transform_params: Optional[Dict] = None,
|
|
282
|
+
prediction_config: InferenceModel,
|
|
283
|
+
inputs: np.ndarray,
|
|
284
|
+
data_target: Optional[np.ndarray] = None,
|
|
285
|
+
read_source_func: Optional[Callable] = read_tiff,
|
|
60
286
|
) -> None:
|
|
61
|
-
"""
|
|
62
|
-
Constructor.
|
|
287
|
+
"""Constructor.
|
|
63
288
|
|
|
64
289
|
Parameters
|
|
65
290
|
----------
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
data_format : str
|
|
69
|
-
Extension of the data files, without period.
|
|
291
|
+
array : np.ndarray
|
|
292
|
+
Array containing the data.
|
|
70
293
|
axes : str
|
|
71
294
|
Description of axes in format STCZYX.
|
|
72
|
-
patch_extraction_method : ExtractionStrategies
|
|
73
|
-
Patch extraction strategy, as defined in extraction_strategy.
|
|
74
|
-
patch_size : Union[List[int], Tuple[int]]
|
|
75
|
-
Size of the patches along each axis, must be of dimension 2 or 3.
|
|
76
|
-
patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
|
|
77
|
-
Overlap of the patches, must be of dimension 2 or 3, by default None.
|
|
78
|
-
mean : Optional[float], optional
|
|
79
|
-
Expected mean of the dataset, by default None.
|
|
80
|
-
std : Optional[float], optional
|
|
81
|
-
Expected standard deviation of the dataset, by default None.
|
|
82
|
-
patch_transform : Optional[Callable], optional
|
|
83
|
-
Patch transform to apply, by default None.
|
|
84
|
-
patch_transform_params : Optional[Dict], optional
|
|
85
|
-
Patch transform parameters, by default None.
|
|
86
295
|
|
|
87
296
|
Raises
|
|
88
297
|
------
|
|
89
298
|
ValueError
|
|
90
299
|
If data_path is not a directory.
|
|
91
300
|
"""
|
|
92
|
-
self.
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
self.
|
|
97
|
-
self.
|
|
98
|
-
|
|
99
|
-
self.
|
|
301
|
+
self.pred_config = prediction_config
|
|
302
|
+
self.input_array = inputs
|
|
303
|
+
self.axes = self.pred_config.axes
|
|
304
|
+
self.tile_size = self.pred_config.tile_size
|
|
305
|
+
self.tile_overlap = self.pred_config.tile_overlap
|
|
306
|
+
self.mean = self.pred_config.mean
|
|
307
|
+
self.std = self.pred_config.std
|
|
308
|
+
self.data_target = data_target
|
|
100
309
|
|
|
101
|
-
|
|
310
|
+
# tiling only if both tile size and overlap are provided
|
|
311
|
+
self.tiling = self.tile_size is not None and self.tile_overlap is not None
|
|
102
312
|
|
|
103
|
-
|
|
104
|
-
self.
|
|
105
|
-
self.patch_extraction_method = patch_extraction_method
|
|
106
|
-
self.patch_transform = patch_transform
|
|
107
|
-
self.patch_transform_params = patch_transform_params
|
|
108
|
-
|
|
109
|
-
self.mean = mean
|
|
110
|
-
self.std = std
|
|
313
|
+
# read function
|
|
314
|
+
self.read_source_func = read_source_func
|
|
111
315
|
|
|
112
316
|
# Generate patches
|
|
113
|
-
self.data
|
|
114
|
-
|
|
115
|
-
if not mean or not std:
|
|
116
|
-
self.mean, self.std = computed_mean, computed_std
|
|
117
|
-
logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
|
|
317
|
+
self.data = self._prepare_tiles()
|
|
318
|
+
self.mean, self.std = self.pred_config.mean, self.pred_config.std
|
|
118
319
|
|
|
119
|
-
|
|
120
|
-
|
|
320
|
+
# get transforms
|
|
321
|
+
self.patch_transform = get_patch_transform(
|
|
322
|
+
patch_transforms=self.pred_config.transforms,
|
|
323
|
+
with_target=self.data_target is not None,
|
|
324
|
+
)
|
|
121
325
|
|
|
122
|
-
def
|
|
326
|
+
def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
|
|
123
327
|
"""
|
|
124
328
|
Iterate over data source and create an array of patches.
|
|
125
329
|
|
|
126
330
|
Returns
|
|
127
331
|
-------
|
|
128
|
-
|
|
129
|
-
|
|
332
|
+
List[XArrayTile]
|
|
333
|
+
List of tiles.
|
|
130
334
|
"""
|
|
131
|
-
|
|
132
|
-
self.
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
patches = generate_patches(
|
|
141
|
-
sample,
|
|
142
|
-
self.patch_extraction_method,
|
|
143
|
-
self.patch_size,
|
|
144
|
-
self.patch_overlap,
|
|
335
|
+
# reshape array
|
|
336
|
+
reshaped_sample = reshape_array(self.input_array, self.axes)
|
|
337
|
+
|
|
338
|
+
if self.tiling:
|
|
339
|
+
# generate patches, which returns a generator
|
|
340
|
+
patch_generator = extract_tiles(
|
|
341
|
+
arr=reshaped_sample,
|
|
342
|
+
tile_size=self.tile_size,
|
|
343
|
+
overlaps=self.tile_overlap,
|
|
145
344
|
)
|
|
345
|
+
patches_list = list(patch_generator)
|
|
146
346
|
|
|
147
|
-
|
|
148
|
-
|
|
347
|
+
if len(patches_list) == 0:
|
|
348
|
+
raise ValueError("No tiles generated, ")
|
|
149
349
|
|
|
150
|
-
|
|
151
|
-
|
|
350
|
+
return patches_list
|
|
351
|
+
else:
|
|
352
|
+
array_shape = reshaped_sample.squeeze().shape
|
|
353
|
+
return [(reshaped_sample, TileInformation(array_shape=array_shape))]
|
|
152
354
|
|
|
153
355
|
def __len__(self) -> int:
|
|
154
356
|
"""
|
|
@@ -159,10 +361,9 @@ class InMemoryDataset(torch.utils.data.Dataset):
|
|
|
159
361
|
int
|
|
160
362
|
Length of the dataset.
|
|
161
363
|
"""
|
|
162
|
-
|
|
163
|
-
return sum(np.array(s).shape[0] for s in self.all_patches)
|
|
364
|
+
return len(self.data)
|
|
164
365
|
|
|
165
|
-
def __getitem__(self, index: int) -> Tuple[np.ndarray]:
|
|
366
|
+
def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]:
|
|
166
367
|
"""
|
|
167
368
|
Return the patch corresponding to the provided index.
|
|
168
369
|
|
|
@@ -173,29 +374,18 @@ class InMemoryDataset(torch.utils.data.Dataset):
|
|
|
173
374
|
|
|
174
375
|
Returns
|
|
175
376
|
-------
|
|
176
|
-
Tuple[np.ndarray]
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
Raises
|
|
180
|
-
------
|
|
181
|
-
ValueError
|
|
182
|
-
If dataset mean and std are not set.
|
|
377
|
+
Tuple[np.ndarray, TileInformation]
|
|
378
|
+
Transformed patch.
|
|
183
379
|
"""
|
|
184
|
-
|
|
380
|
+
tile_array, tile_info = self.data[index]
|
|
185
381
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
patch = normalize(img=patch[0], mean=self.mean, std=self.std)
|
|
189
|
-
patch = (patch, *patch[1:])
|
|
190
|
-
else:
|
|
191
|
-
patch = normalize(img=patch, mean=self.mean, std=self.std)
|
|
382
|
+
# Albumentations requires channel last, use the XArrayTile array
|
|
383
|
+
patch = np.moveaxis(tile_array, 0, -1)
|
|
192
384
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
if self.patch_transform_params is None:
|
|
196
|
-
self.patch_transform_params = {}
|
|
385
|
+
# Apply transforms
|
|
386
|
+
transformed_patch = self.patch_transform(image=patch)["image"]
|
|
197
387
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
388
|
+
# move C axes back
|
|
389
|
+
transformed_patch = np.moveaxis(transformed_patch, -1, 0)
|
|
390
|
+
|
|
391
|
+
return transformed_patch, tile_info
|