kaiko-eva 0.0.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/.DS_Store +0 -0
- eva/__init__.py +33 -0
- eva/__main__.py +18 -0
- eva/__version__.py +25 -0
- eva/core/__init__.py +19 -0
- eva/core/callbacks/__init__.py +5 -0
- eva/core/callbacks/writers/__init__.py +5 -0
- eva/core/callbacks/writers/embeddings.py +169 -0
- eva/core/callbacks/writers/typings.py +23 -0
- eva/core/cli/__init__.py +5 -0
- eva/core/cli/cli.py +19 -0
- eva/core/cli/logo.py +38 -0
- eva/core/cli/setup.py +89 -0
- eva/core/data/__init__.py +14 -0
- eva/core/data/dataloaders/__init__.py +5 -0
- eva/core/data/dataloaders/dataloader.py +80 -0
- eva/core/data/datamodules/__init__.py +6 -0
- eva/core/data/datamodules/call.py +33 -0
- eva/core/data/datamodules/datamodule.py +108 -0
- eva/core/data/datamodules/schemas.py +62 -0
- eva/core/data/datasets/__init__.py +7 -0
- eva/core/data/datasets/base.py +53 -0
- eva/core/data/datasets/classification/__init__.py +5 -0
- eva/core/data/datasets/classification/embeddings.py +154 -0
- eva/core/data/datasets/dataset.py +6 -0
- eva/core/data/samplers/__init__.py +5 -0
- eva/core/data/samplers/sampler.py +6 -0
- eva/core/data/transforms/__init__.py +5 -0
- eva/core/data/transforms/dtype/__init__.py +5 -0
- eva/core/data/transforms/dtype/array.py +28 -0
- eva/core/interface/__init__.py +5 -0
- eva/core/interface/interface.py +79 -0
- eva/core/metrics/__init__.py +17 -0
- eva/core/metrics/average_loss.py +47 -0
- eva/core/metrics/binary_balanced_accuracy.py +22 -0
- eva/core/metrics/defaults/__init__.py +6 -0
- eva/core/metrics/defaults/classification/__init__.py +6 -0
- eva/core/metrics/defaults/classification/binary.py +76 -0
- eva/core/metrics/defaults/classification/multiclass.py +80 -0
- eva/core/metrics/structs/__init__.py +9 -0
- eva/core/metrics/structs/collection.py +6 -0
- eva/core/metrics/structs/metric.py +6 -0
- eva/core/metrics/structs/module.py +115 -0
- eva/core/metrics/structs/schemas.py +47 -0
- eva/core/metrics/structs/typings.py +15 -0
- eva/core/models/__init__.py +13 -0
- eva/core/models/modules/__init__.py +7 -0
- eva/core/models/modules/head.py +113 -0
- eva/core/models/modules/inference.py +37 -0
- eva/core/models/modules/module.py +190 -0
- eva/core/models/modules/typings.py +23 -0
- eva/core/models/modules/utils/__init__.py +6 -0
- eva/core/models/modules/utils/batch_postprocess.py +57 -0
- eva/core/models/modules/utils/grad.py +23 -0
- eva/core/models/networks/__init__.py +6 -0
- eva/core/models/networks/_utils.py +25 -0
- eva/core/models/networks/mlp.py +69 -0
- eva/core/models/networks/transforms/__init__.py +5 -0
- eva/core/models/networks/transforms/extract_cls_features.py +25 -0
- eva/core/models/networks/wrappers/__init__.py +8 -0
- eva/core/models/networks/wrappers/base.py +47 -0
- eva/core/models/networks/wrappers/from_function.py +58 -0
- eva/core/models/networks/wrappers/huggingface.py +37 -0
- eva/core/models/networks/wrappers/onnx.py +47 -0
- eva/core/trainers/__init__.py +6 -0
- eva/core/trainers/_logging.py +81 -0
- eva/core/trainers/_recorder.py +149 -0
- eva/core/trainers/_utils.py +12 -0
- eva/core/trainers/functional.py +113 -0
- eva/core/trainers/trainer.py +97 -0
- eva/core/utils/__init__.py +1 -0
- eva/core/utils/io/__init__.py +5 -0
- eva/core/utils/io/dataframe.py +21 -0
- eva/core/utils/multiprocessing.py +44 -0
- eva/core/utils/workers.py +21 -0
- eva/vision/__init__.py +14 -0
- eva/vision/data/__init__.py +5 -0
- eva/vision/data/datasets/__init__.py +22 -0
- eva/vision/data/datasets/_utils.py +50 -0
- eva/vision/data/datasets/_validators.py +44 -0
- eva/vision/data/datasets/classification/__init__.py +15 -0
- eva/vision/data/datasets/classification/bach.py +174 -0
- eva/vision/data/datasets/classification/base.py +103 -0
- eva/vision/data/datasets/classification/crc.py +176 -0
- eva/vision/data/datasets/classification/mhist.py +106 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
- eva/vision/data/datasets/classification/total_segmentator.py +212 -0
- eva/vision/data/datasets/segmentation/__init__.py +6 -0
- eva/vision/data/datasets/segmentation/base.py +112 -0
- eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
- eva/vision/data/datasets/structs.py +17 -0
- eva/vision/data/datasets/vision.py +43 -0
- eva/vision/data/transforms/__init__.py +5 -0
- eva/vision/data/transforms/common/__init__.py +5 -0
- eva/vision/data/transforms/common/resize_and_crop.py +44 -0
- eva/vision/models/__init__.py +5 -0
- eva/vision/models/networks/__init__.py +6 -0
- eva/vision/models/networks/abmil.py +176 -0
- eva/vision/models/networks/postprocesses/__init__.py +5 -0
- eva/vision/models/networks/postprocesses/cls.py +25 -0
- eva/vision/utils/__init__.py +5 -0
- eva/vision/utils/io/__init__.py +12 -0
- eva/vision/utils/io/_utils.py +29 -0
- eva/vision/utils/io/image.py +54 -0
- eva/vision/utils/io/nifti.py +50 -0
- eva/vision/utils/io/text.py +18 -0
- kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
- kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
- kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""PatchCamelyon dataset."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, List, Literal
|
|
5
|
+
|
|
6
|
+
import h5py
|
|
7
|
+
import numpy as np
|
|
8
|
+
from torchvision.datasets import utils
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data.datasets import _validators, structs
|
|
12
|
+
from eva.vision.data.datasets.classification import base
|
|
13
|
+
|
|
14
|
+
_URL_TEMPLATE = "https://zenodo.org/records/2546921/files/{filename}.gz?download=1"
|
|
15
|
+
"""PatchCamelyon URL files templates."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PatchCamelyon(base.ImageClassification):
|
|
19
|
+
"""Dataset class for PatchCamelyon images and corresponding targets."""
|
|
20
|
+
|
|
21
|
+
_train_resources: List[structs.DownloadResource] = [
|
|
22
|
+
structs.DownloadResource(
|
|
23
|
+
filename="camelyonpatch_level_2_split_train_x.h5",
|
|
24
|
+
url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_train_x.h5"),
|
|
25
|
+
md5="01844da899645b4d6f84946d417ba453",
|
|
26
|
+
),
|
|
27
|
+
structs.DownloadResource(
|
|
28
|
+
filename="camelyonpatch_level_2_split_train_y.h5",
|
|
29
|
+
url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_train_y.h5"),
|
|
30
|
+
md5="0781386bf6c2fb62d58ff18891466aca",
|
|
31
|
+
),
|
|
32
|
+
]
|
|
33
|
+
"""Train resources."""
|
|
34
|
+
|
|
35
|
+
_val_resources: List[structs.DownloadResource] = [
|
|
36
|
+
structs.DownloadResource(
|
|
37
|
+
filename="camelyonpatch_level_2_split_valid_x.h5",
|
|
38
|
+
url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_valid_x.h5"),
|
|
39
|
+
md5="81cf9680f1724c40673f10dc88e909b1",
|
|
40
|
+
),
|
|
41
|
+
structs.DownloadResource(
|
|
42
|
+
filename="camelyonpatch_level_2_split_valid_y.h5",
|
|
43
|
+
url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_valid_y.h5"),
|
|
44
|
+
md5="94d8aacc249253159ce2a2e78a86e658",
|
|
45
|
+
),
|
|
46
|
+
]
|
|
47
|
+
"""Validation resources."""
|
|
48
|
+
|
|
49
|
+
_test_resources: List[structs.DownloadResource] = [
|
|
50
|
+
structs.DownloadResource(
|
|
51
|
+
filename="camelyonpatch_level_2_split_test_x.h5",
|
|
52
|
+
url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_test_x.h5"),
|
|
53
|
+
md5="2614b2e6717d6356be141d9d6dbfcb7e",
|
|
54
|
+
),
|
|
55
|
+
structs.DownloadResource(
|
|
56
|
+
filename="camelyonpatch_level_2_split_test_y.h5",
|
|
57
|
+
url=_URL_TEMPLATE.format(filename="camelyonpatch_level_2_split_test_y.h5"),
|
|
58
|
+
md5="11ed647efe9fe457a4eb45df1dba19ba",
|
|
59
|
+
),
|
|
60
|
+
]
|
|
61
|
+
"""Test resources."""
|
|
62
|
+
|
|
63
|
+
_license: str = (
|
|
64
|
+
"Creative Commons Zero v1.0 Universal (https://choosealicense.com/licenses/cc0-1.0/)"
|
|
65
|
+
)
|
|
66
|
+
"""Dataset license."""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
root: str,
|
|
71
|
+
split: Literal["train", "val", "test"],
|
|
72
|
+
download: bool = False,
|
|
73
|
+
image_transforms: Callable | None = None,
|
|
74
|
+
target_transforms: Callable | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Initializes the dataset.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
root: The path to the dataset root. This path should contain
|
|
80
|
+
the uncompressed h5 files and the metadata.
|
|
81
|
+
split: The dataset split for training, validation, or testing.
|
|
82
|
+
download: Whether to download the data for the specified split.
|
|
83
|
+
Note that the download will be executed only by additionally
|
|
84
|
+
calling the :meth:`prepare_data` method.
|
|
85
|
+
image_transforms: A function/transform that takes in an image
|
|
86
|
+
and returns a transformed version.
|
|
87
|
+
target_transforms: A function/transform that takes in the target
|
|
88
|
+
and transforms it.
|
|
89
|
+
"""
|
|
90
|
+
super().__init__(
|
|
91
|
+
image_transforms=image_transforms,
|
|
92
|
+
target_transforms=target_transforms,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self._root = root
|
|
96
|
+
self._split = split
|
|
97
|
+
self._download = download
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
@override
|
|
101
|
+
def classes(self) -> List[str]:
|
|
102
|
+
return ["no_tumor", "tumor"]
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
@override
|
|
106
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
107
|
+
return {"no_tumor": 0, "tumor": 1}
|
|
108
|
+
|
|
109
|
+
@override
|
|
110
|
+
def filename(self, index: int) -> str:
|
|
111
|
+
return f"camelyonpatch_level_2_split_{self._split}_x_{index}"
|
|
112
|
+
|
|
113
|
+
@override
|
|
114
|
+
def prepare_data(self) -> None:
|
|
115
|
+
if self._download:
|
|
116
|
+
self._download_dataset()
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
def validate(self) -> None:
|
|
120
|
+
expected_length = {
|
|
121
|
+
"train": 262144,
|
|
122
|
+
"val": 32768,
|
|
123
|
+
"test": 32768,
|
|
124
|
+
}
|
|
125
|
+
_validators.check_dataset_integrity(
|
|
126
|
+
self,
|
|
127
|
+
length=expected_length.get(self._split, 0),
|
|
128
|
+
n_classes=2,
|
|
129
|
+
first_and_last_labels=("no_tumor", "tumor"),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@override
|
|
133
|
+
def load_image(self, index: int) -> np.ndarray:
|
|
134
|
+
return self._load_from_h5("x", index)
|
|
135
|
+
|
|
136
|
+
@override
|
|
137
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
138
|
+
target = self._load_from_h5("y", index).squeeze()
|
|
139
|
+
return np.asarray(target, dtype=np.int64)
|
|
140
|
+
|
|
141
|
+
@override
|
|
142
|
+
def __len__(self) -> int:
|
|
143
|
+
return self._fetch_dataset_length()
|
|
144
|
+
|
|
145
|
+
def _download_dataset(self) -> None:
|
|
146
|
+
"""Downloads the PatchCamelyon dataset."""
|
|
147
|
+
for resource in self._train_resources + self._val_resources + self._test_resources:
|
|
148
|
+
file_path = os.path.join(self._root, resource.filename)
|
|
149
|
+
if utils.check_integrity(file_path, resource.md5):
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
self._print_license()
|
|
153
|
+
utils.download_and_extract_archive(
|
|
154
|
+
resource.url,
|
|
155
|
+
download_root=self._root,
|
|
156
|
+
filename=resource.filename + ".gz",
|
|
157
|
+
remove_finished=True,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def _load_from_h5(
|
|
161
|
+
self,
|
|
162
|
+
data_key: Literal["x", "y"],
|
|
163
|
+
index: int | None = None,
|
|
164
|
+
) -> np.ndarray:
|
|
165
|
+
"""Load data or targets from an HDF5 file.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
data_key: Specify whether to load 'x' or 'y'.
|
|
169
|
+
index: Optional parameter to load data/targets at a specific index.
|
|
170
|
+
If `None`, the entire data/targets array is returned.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
A array containing the specified data.
|
|
174
|
+
"""
|
|
175
|
+
h5_file = self._h5_file(data_key)
|
|
176
|
+
with h5py.File(h5_file, "r") as file:
|
|
177
|
+
data = file[data_key]
|
|
178
|
+
return data[:] if index is None else data[index] # type: ignore
|
|
179
|
+
|
|
180
|
+
def _fetch_dataset_length(self) -> int:
|
|
181
|
+
"""Fetches the dataset split length from its HDF5 file."""
|
|
182
|
+
h5_file = self._h5_file("y")
|
|
183
|
+
with h5py.File(h5_file, "r") as file:
|
|
184
|
+
data = file["y"]
|
|
185
|
+
return len(data) # type: ignore
|
|
186
|
+
|
|
187
|
+
def _h5_file(self, datatype: Literal["x", "y"]) -> str:
|
|
188
|
+
"""Generates the filename for the H5 file based on the specified data type and split.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
datatype: The type of data, where "x" and "y" represent the input
|
|
192
|
+
and target datasets respectively.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
The relative file path for the H5 file based on the provided data type and split.
|
|
196
|
+
"""
|
|
197
|
+
split_suffix = "valid" if self._split == "val" else self._split
|
|
198
|
+
filename = f"camelyonpatch_level_2_split_{split_suffix}_{datatype}.h5"
|
|
199
|
+
return os.path.join(self._root, filename)
|
|
200
|
+
|
|
201
|
+
def _print_license(self) -> None:
|
|
202
|
+
"""Prints the dataset license."""
|
|
203
|
+
print(f"Dataset license: {self._license}")
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""TotalSegmentator 2D segmentation dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import os
|
|
5
|
+
from glob import glob
|
|
6
|
+
from typing import Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from torchvision.datasets import utils
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data.datasets import _utils, _validators, structs
|
|
13
|
+
from eva.vision.data.datasets.classification import base
|
|
14
|
+
from eva.vision.utils import io
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TotalSegmentatorClassification(base.ImageClassification):
|
|
18
|
+
"""TotalSegmentator multi-label classification dataset."""
|
|
19
|
+
|
|
20
|
+
_train_index_ranges: List[Tuple[int, int]] = [(0, 83)]
|
|
21
|
+
"""Train range indices."""
|
|
22
|
+
|
|
23
|
+
_val_index_ranges: List[Tuple[int, int]] = [(83, 103)]
|
|
24
|
+
"""Validation range indices."""
|
|
25
|
+
|
|
26
|
+
_n_slices_per_image: int = 20
|
|
27
|
+
"""The amount of slices to sample per 3D CT scan image."""
|
|
28
|
+
|
|
29
|
+
_resources_full: List[structs.DownloadResource] = [
|
|
30
|
+
structs.DownloadResource(
|
|
31
|
+
filename="Totalsegmentator_dataset_v201.zip",
|
|
32
|
+
url="https://zenodo.org/records/10047292/files/Totalsegmentator_dataset_v201.zip",
|
|
33
|
+
md5="fe250e5718e0a3b5df4c4ea9d58a62fe",
|
|
34
|
+
),
|
|
35
|
+
]
|
|
36
|
+
"""Resources for the full dataset version."""
|
|
37
|
+
|
|
38
|
+
_resources_small: List[structs.DownloadResource] = [
|
|
39
|
+
structs.DownloadResource(
|
|
40
|
+
filename="Totalsegmentator_dataset_small_v201.zip",
|
|
41
|
+
url="https://zenodo.org/records/10047263/files/Totalsegmentator_dataset_small_v201.zip",
|
|
42
|
+
md5="6b5524af4b15e6ba06ef2d700c0c73e0",
|
|
43
|
+
),
|
|
44
|
+
]
|
|
45
|
+
"""Resources for the small dataset version."""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
root: str,
|
|
50
|
+
split: Literal["train", "val"] | None,
|
|
51
|
+
version: Literal["small", "full"] = "small",
|
|
52
|
+
download: bool = False,
|
|
53
|
+
image_transforms: Callable | None = None,
|
|
54
|
+
target_transforms: Callable | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Initialize dataset.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
root: Path to the root directory of the dataset. The dataset will
|
|
60
|
+
be downloaded and extracted here, if it does not already exist.
|
|
61
|
+
split: Dataset split to use. If None, the entire dataset is used.
|
|
62
|
+
version: The version of the dataset to initialize.
|
|
63
|
+
download: Whether to download the data for the specified split.
|
|
64
|
+
Note that the download will be executed only by additionally
|
|
65
|
+
calling the :meth:`prepare_data` method and if the data does not
|
|
66
|
+
exist yet on disk.
|
|
67
|
+
image_transforms: A function/transform that takes in an image
|
|
68
|
+
and returns a transformed version.
|
|
69
|
+
target_transforms: A function/transform that takes in the target
|
|
70
|
+
and transforms it.
|
|
71
|
+
"""
|
|
72
|
+
super().__init__(
|
|
73
|
+
image_transforms=image_transforms,
|
|
74
|
+
target_transforms=target_transforms,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self._root = root
|
|
78
|
+
self._split = split
|
|
79
|
+
self._version = version
|
|
80
|
+
self._download = download
|
|
81
|
+
|
|
82
|
+
self._samples_dirs: List[str] = []
|
|
83
|
+
self._indices: List[int] = []
|
|
84
|
+
|
|
85
|
+
@functools.cached_property
|
|
86
|
+
@override
|
|
87
|
+
def classes(self) -> List[str]:
|
|
88
|
+
def get_filename(path: str) -> str:
|
|
89
|
+
"""Returns the filename from the full path."""
|
|
90
|
+
return os.path.basename(path).split(".")[0]
|
|
91
|
+
|
|
92
|
+
first_sample_labels = os.path.join(
|
|
93
|
+
self._root, self._samples_dirs[0], "segmentations", "*.nii.gz"
|
|
94
|
+
)
|
|
95
|
+
return sorted(map(get_filename, glob(first_sample_labels)))
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
@override
|
|
99
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
100
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
101
|
+
|
|
102
|
+
@override
|
|
103
|
+
def filename(self, index: int) -> str:
|
|
104
|
+
sample_dir = self._samples_dirs[self._indices[index]]
|
|
105
|
+
return os.path.join(sample_dir, "ct.nii.gz")
|
|
106
|
+
|
|
107
|
+
@override
|
|
108
|
+
def prepare_data(self) -> None:
|
|
109
|
+
if self._download:
|
|
110
|
+
self._download_dataset()
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def configure(self) -> None:
|
|
114
|
+
self._samples_dirs = self._fetch_samples_dirs()
|
|
115
|
+
self._indices = self._create_indices()
|
|
116
|
+
|
|
117
|
+
@override
|
|
118
|
+
def validate(self) -> None:
|
|
119
|
+
_validators.check_dataset_integrity(
|
|
120
|
+
self,
|
|
121
|
+
length=1660 if self._split == "train" else 400,
|
|
122
|
+
n_classes=117,
|
|
123
|
+
first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
@override
|
|
127
|
+
def __len__(self) -> int:
|
|
128
|
+
return len(self._indices) * self._n_slices_per_image
|
|
129
|
+
|
|
130
|
+
@override
|
|
131
|
+
def load_image(self, index: int) -> np.ndarray:
|
|
132
|
+
image_path = self._get_image_path(index)
|
|
133
|
+
slice_index = self._get_sample_slice_index(index)
|
|
134
|
+
image_array = io.read_nifti_slice(image_path, slice_index)
|
|
135
|
+
return image_array.repeat(3, axis=2)
|
|
136
|
+
|
|
137
|
+
@override
|
|
138
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
139
|
+
masks = self._load_masks(index)
|
|
140
|
+
targets = [1 in masks[..., mask_index] for mask_index in range(masks.shape[-1])]
|
|
141
|
+
return np.asarray(targets, dtype=np.int64)
|
|
142
|
+
|
|
143
|
+
def _load_masks(self, index: int) -> np.ndarray:
|
|
144
|
+
"""Returns the `index`'th target mask sample."""
|
|
145
|
+
masks_dir = self._get_masks_dir(index)
|
|
146
|
+
slice_index = self._get_sample_slice_index(index)
|
|
147
|
+
mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
|
|
148
|
+
masks = [io.read_nifti_slice(path, slice_index) for path in mask_paths]
|
|
149
|
+
return np.concatenate(masks, axis=-1)
|
|
150
|
+
|
|
151
|
+
def _get_masks_dir(self, index: int) -> str:
|
|
152
|
+
"""Returns the directory of the corresponding masks."""
|
|
153
|
+
sample_dir = self._get_sample_dir(index)
|
|
154
|
+
return os.path.join(self._root, sample_dir, "segmentations")
|
|
155
|
+
|
|
156
|
+
def _get_image_path(self, index: int) -> str:
|
|
157
|
+
"""Returns the corresponding image path."""
|
|
158
|
+
sample_dir = self._get_sample_dir(index)
|
|
159
|
+
return os.path.join(self._root, sample_dir, "ct.nii.gz")
|
|
160
|
+
|
|
161
|
+
def _get_sample_dir(self, index: int) -> str:
|
|
162
|
+
"""Returns the corresponding sample directory."""
|
|
163
|
+
sample_index = self._indices[index // self._n_slices_per_image]
|
|
164
|
+
return self._samples_dirs[sample_index]
|
|
165
|
+
|
|
166
|
+
def _get_sample_slice_index(self, index: int) -> int:
|
|
167
|
+
"""Returns the corresponding slice index."""
|
|
168
|
+
image_path = self._get_image_path(index)
|
|
169
|
+
total_slices = io.fetch_total_nifti_slices(image_path)
|
|
170
|
+
slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int)
|
|
171
|
+
return slice_indices[index % self._n_slices_per_image]
|
|
172
|
+
|
|
173
|
+
def _fetch_samples_dirs(self) -> List[str]:
|
|
174
|
+
"""Returns the name of all the samples of all the splits of the dataset."""
|
|
175
|
+
sample_filenames = [
|
|
176
|
+
filename
|
|
177
|
+
for filename in os.listdir(self._root)
|
|
178
|
+
if os.path.isdir(os.path.join(self._root, filename))
|
|
179
|
+
]
|
|
180
|
+
return sorted(sample_filenames)
|
|
181
|
+
|
|
182
|
+
def _create_indices(self) -> List[int]:
|
|
183
|
+
"""Builds the dataset indices for the specified split."""
|
|
184
|
+
split_index_ranges = {
|
|
185
|
+
"train": self._train_index_ranges,
|
|
186
|
+
"val": self._val_index_ranges,
|
|
187
|
+
None: [(0, 103)],
|
|
188
|
+
}
|
|
189
|
+
index_ranges = split_index_ranges.get(self._split)
|
|
190
|
+
if index_ranges is None:
|
|
191
|
+
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
|
|
192
|
+
|
|
193
|
+
return _utils.ranges_to_indices(index_ranges)
|
|
194
|
+
|
|
195
|
+
def _download_dataset(self) -> None:
|
|
196
|
+
"""Downloads the dataset."""
|
|
197
|
+
dataset_resources = {
|
|
198
|
+
"small": self._resources_small,
|
|
199
|
+
"full": self._resources_full,
|
|
200
|
+
None: (0, 103),
|
|
201
|
+
}
|
|
202
|
+
resources = dataset_resources.get(self._version)
|
|
203
|
+
if resources is None:
|
|
204
|
+
raise ValueError("Invalid data version. Use 'small' or 'full'.")
|
|
205
|
+
|
|
206
|
+
for resource in resources:
|
|
207
|
+
utils.download_and_extract_archive(
|
|
208
|
+
resource.url,
|
|
209
|
+
download_root=self._root,
|
|
210
|
+
filename=resource.filename,
|
|
211
|
+
remove_finished=True,
|
|
212
|
+
)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Base for image segmentation datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.data.datasets import vision
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
|
|
13
|
+
"""Image segmentation abstract dataset."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
image_transforms: Callable | None = None,
|
|
18
|
+
target_transforms: Callable | None = None,
|
|
19
|
+
image_target_transforms: Callable | None = None,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Initializes the image segmentation base class.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
image_transforms: A function/transform that takes in an image
|
|
25
|
+
and returns a transformed version.
|
|
26
|
+
target_transforms: A function/transform that takes in the target
|
|
27
|
+
and transforms it.
|
|
28
|
+
image_target_transforms: A function/transforms that takes in an
|
|
29
|
+
image and a label and returns the transformed versions of both.
|
|
30
|
+
This transform happens after the `image_transforms` and
|
|
31
|
+
`target_transforms`.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
self._image_transforms = image_transforms
|
|
36
|
+
self._target_transforms = target_transforms
|
|
37
|
+
self._image_target_transforms = image_target_transforms
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def classes(self) -> List[str] | None:
|
|
41
|
+
"""Returns the list with names of the dataset names."""
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def class_to_idx(self) -> Dict[str, int] | None:
|
|
45
|
+
"""Returns a mapping of the class name to its target index."""
|
|
46
|
+
|
|
47
|
+
def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
|
|
48
|
+
"""Returns the dataset metadata.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
index: The index of the data sample to return the metadata of.
|
|
52
|
+
If `None`, it will return the metadata of the current dataset.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
The sample metadata.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
@abc.abstractmethod
|
|
59
|
+
def load_image(self, index: int) -> np.ndarray:
|
|
60
|
+
"""Loads and returns the `index`'th image sample.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
index: The index of the data sample to load.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The image as a numpy array.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def load_mask(self, index: int) -> np.ndarray:
|
|
71
|
+
"""Returns the `index`'th target mask sample.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
index: The index of the data sample target mask to load.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
The sample mask as a stack of binary mask arrays (label, height, width).
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
@abc.abstractmethod
|
|
81
|
+
@override
|
|
82
|
+
def __len__(self) -> int:
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
@override
|
|
86
|
+
def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
87
|
+
image = self.load_image(index)
|
|
88
|
+
mask = self.load_mask(index)
|
|
89
|
+
return self._apply_transforms(image, mask)
|
|
90
|
+
|
|
91
|
+
def _apply_transforms(
|
|
92
|
+
self, image: np.ndarray, target: np.ndarray
|
|
93
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
94
|
+
"""Applies the transforms to the provided data and returns them.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
image: The desired image.
|
|
98
|
+
target: The target of the image.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
A tuple with the image and the target transformed.
|
|
102
|
+
"""
|
|
103
|
+
if self._image_transforms is not None:
|
|
104
|
+
image = self._image_transforms(image)
|
|
105
|
+
|
|
106
|
+
if self._target_transforms is not None:
|
|
107
|
+
target = self._target_transforms(target)
|
|
108
|
+
|
|
109
|
+
if self._image_target_transforms is not None:
|
|
110
|
+
image, target = self._image_target_transforms(image, target)
|
|
111
|
+
|
|
112
|
+
return image, target
|