kaiko-eva 0.0.0.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/.DS_Store +0 -0
- eva/__init__.py +33 -0
- eva/__main__.py +18 -0
- eva/__version__.py +25 -0
- eva/core/__init__.py +19 -0
- eva/core/callbacks/__init__.py +5 -0
- eva/core/callbacks/writers/__init__.py +5 -0
- eva/core/callbacks/writers/embeddings.py +169 -0
- eva/core/callbacks/writers/typings.py +23 -0
- eva/core/cli/__init__.py +5 -0
- eva/core/cli/cli.py +19 -0
- eva/core/cli/logo.py +38 -0
- eva/core/cli/setup.py +89 -0
- eva/core/data/__init__.py +14 -0
- eva/core/data/dataloaders/__init__.py +5 -0
- eva/core/data/dataloaders/dataloader.py +80 -0
- eva/core/data/datamodules/__init__.py +6 -0
- eva/core/data/datamodules/call.py +33 -0
- eva/core/data/datamodules/datamodule.py +108 -0
- eva/core/data/datamodules/schemas.py +62 -0
- eva/core/data/datasets/__init__.py +7 -0
- eva/core/data/datasets/base.py +53 -0
- eva/core/data/datasets/classification/__init__.py +5 -0
- eva/core/data/datasets/classification/embeddings.py +154 -0
- eva/core/data/datasets/dataset.py +6 -0
- eva/core/data/samplers/__init__.py +5 -0
- eva/core/data/samplers/sampler.py +6 -0
- eva/core/data/transforms/__init__.py +5 -0
- eva/core/data/transforms/dtype/__init__.py +5 -0
- eva/core/data/transforms/dtype/array.py +28 -0
- eva/core/interface/__init__.py +5 -0
- eva/core/interface/interface.py +79 -0
- eva/core/metrics/__init__.py +17 -0
- eva/core/metrics/average_loss.py +47 -0
- eva/core/metrics/binary_balanced_accuracy.py +22 -0
- eva/core/metrics/defaults/__init__.py +6 -0
- eva/core/metrics/defaults/classification/__init__.py +6 -0
- eva/core/metrics/defaults/classification/binary.py +76 -0
- eva/core/metrics/defaults/classification/multiclass.py +80 -0
- eva/core/metrics/structs/__init__.py +9 -0
- eva/core/metrics/structs/collection.py +6 -0
- eva/core/metrics/structs/metric.py +6 -0
- eva/core/metrics/structs/module.py +115 -0
- eva/core/metrics/structs/schemas.py +47 -0
- eva/core/metrics/structs/typings.py +15 -0
- eva/core/models/__init__.py +13 -0
- eva/core/models/modules/__init__.py +7 -0
- eva/core/models/modules/head.py +113 -0
- eva/core/models/modules/inference.py +37 -0
- eva/core/models/modules/module.py +190 -0
- eva/core/models/modules/typings.py +23 -0
- eva/core/models/modules/utils/__init__.py +6 -0
- eva/core/models/modules/utils/batch_postprocess.py +57 -0
- eva/core/models/modules/utils/grad.py +23 -0
- eva/core/models/networks/__init__.py +6 -0
- eva/core/models/networks/_utils.py +25 -0
- eva/core/models/networks/mlp.py +69 -0
- eva/core/models/networks/transforms/__init__.py +5 -0
- eva/core/models/networks/transforms/extract_cls_features.py +25 -0
- eva/core/models/networks/wrappers/__init__.py +8 -0
- eva/core/models/networks/wrappers/base.py +47 -0
- eva/core/models/networks/wrappers/from_function.py +58 -0
- eva/core/models/networks/wrappers/huggingface.py +37 -0
- eva/core/models/networks/wrappers/onnx.py +47 -0
- eva/core/trainers/__init__.py +6 -0
- eva/core/trainers/_logging.py +81 -0
- eva/core/trainers/_recorder.py +149 -0
- eva/core/trainers/_utils.py +12 -0
- eva/core/trainers/functional.py +113 -0
- eva/core/trainers/trainer.py +97 -0
- eva/core/utils/__init__.py +1 -0
- eva/core/utils/io/__init__.py +5 -0
- eva/core/utils/io/dataframe.py +21 -0
- eva/core/utils/multiprocessing.py +44 -0
- eva/core/utils/workers.py +21 -0
- eva/vision/__init__.py +14 -0
- eva/vision/data/__init__.py +5 -0
- eva/vision/data/datasets/__init__.py +22 -0
- eva/vision/data/datasets/_utils.py +50 -0
- eva/vision/data/datasets/_validators.py +44 -0
- eva/vision/data/datasets/classification/__init__.py +15 -0
- eva/vision/data/datasets/classification/bach.py +174 -0
- eva/vision/data/datasets/classification/base.py +103 -0
- eva/vision/data/datasets/classification/crc.py +176 -0
- eva/vision/data/datasets/classification/mhist.py +106 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +203 -0
- eva/vision/data/datasets/classification/total_segmentator.py +212 -0
- eva/vision/data/datasets/segmentation/__init__.py +6 -0
- eva/vision/data/datasets/segmentation/base.py +112 -0
- eva/vision/data/datasets/segmentation/total_segmentator.py +212 -0
- eva/vision/data/datasets/structs.py +17 -0
- eva/vision/data/datasets/vision.py +43 -0
- eva/vision/data/transforms/__init__.py +5 -0
- eva/vision/data/transforms/common/__init__.py +5 -0
- eva/vision/data/transforms/common/resize_and_crop.py +44 -0
- eva/vision/models/__init__.py +5 -0
- eva/vision/models/networks/__init__.py +6 -0
- eva/vision/models/networks/abmil.py +176 -0
- eva/vision/models/networks/postprocesses/__init__.py +5 -0
- eva/vision/models/networks/postprocesses/cls.py +25 -0
- eva/vision/utils/__init__.py +5 -0
- eva/vision/utils/io/__init__.py +12 -0
- eva/vision/utils/io/_utils.py +29 -0
- eva/vision/utils/io/image.py +54 -0
- eva/vision/utils/io/nifti.py +50 -0
- eva/vision/utils/io/text.py +18 -0
- kaiko_eva-0.0.0.dev6.dist-info/METADATA +393 -0
- kaiko_eva-0.0.0.dev6.dist-info/RECORD +111 -0
- kaiko_eva-0.0.0.dev6.dist-info/WHEEL +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/entry_points.txt +4 -0
- kaiko_eva-0.0.0.dev6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Dataset validation related functions."""
|
|
2
|
+
|
|
3
|
+
from typing_extensions import List, Tuple
|
|
4
|
+
|
|
5
|
+
from eva.vision.data.datasets import vision
|
|
6
|
+
|
|
7
|
+
_SUFFIX_ERROR_MESSAGE = "Please verify that the data are properly downloaded and stored."
|
|
8
|
+
"""Common suffix dataset verification error message."""
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def check_dataset_integrity(
|
|
12
|
+
dataset: vision.VisionDataset,
|
|
13
|
+
*,
|
|
14
|
+
length: int,
|
|
15
|
+
n_classes: int,
|
|
16
|
+
first_and_last_labels: Tuple[str, str],
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Verifies the datasets integrity.
|
|
19
|
+
|
|
20
|
+
Raise:
|
|
21
|
+
ValuesError: If the input dataset's values do not
|
|
22
|
+
match the expected ones.
|
|
23
|
+
"""
|
|
24
|
+
if len(dataset) != length:
|
|
25
|
+
raise ValueError(
|
|
26
|
+
f"Dataset's '{dataset.__class__.__qualname__}' length "
|
|
27
|
+
f"({len(dataset)}) does not match the expected one ({length}). "
|
|
28
|
+
f"{_SUFFIX_ERROR_MESSAGE}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
dataset_classes: List[str] = getattr(dataset, "classes", [])
|
|
32
|
+
if dataset_classes and len(dataset_classes) != n_classes:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"Dataset's '{dataset.__class__.__qualname__}' number of classes "
|
|
35
|
+
f"({len(dataset_classes)}) does not match the expected one ({n_classes})."
|
|
36
|
+
f"{_SUFFIX_ERROR_MESSAGE}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if dataset_classes and (dataset_classes[0], dataset_classes[-1]) != first_and_last_labels:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"Dataset's '{dataset.__class__.__qualname__}' first and last labels "
|
|
42
|
+
f"({(dataset_classes[0], dataset_classes[-1])}) does not match the expected "
|
|
43
|
+
f"ones ({first_and_last_labels}). {_SUFFIX_ERROR_MESSAGE}"
|
|
44
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Image classification datasets API."""
|
|
2
|
+
|
|
3
|
+
from eva.vision.data.datasets.classification.bach import BACH
|
|
4
|
+
from eva.vision.data.datasets.classification.crc import CRC
|
|
5
|
+
from eva.vision.data.datasets.classification.mhist import MHIST
|
|
6
|
+
from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
|
|
7
|
+
from eva.vision.data.datasets.classification.total_segmentator import TotalSegmentatorClassification
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BACH",
|
|
11
|
+
"CRC",
|
|
12
|
+
"MHIST",
|
|
13
|
+
"PatchCamelyon",
|
|
14
|
+
"TotalSegmentatorClassification",
|
|
15
|
+
]
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""BACH dataset class."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from torchvision.datasets import folder, utils
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.vision.data.datasets import _utils, _validators, structs
|
|
11
|
+
from eva.vision.data.datasets.classification import base
|
|
12
|
+
from eva.vision.utils import io
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BACH(base.ImageClassification):
|
|
16
|
+
"""Dataset class for BACH images and corresponding targets."""
|
|
17
|
+
|
|
18
|
+
_train_index_ranges: List[Tuple[int, int]] = [
|
|
19
|
+
(0, 41),
|
|
20
|
+
(59, 60),
|
|
21
|
+
(90, 139),
|
|
22
|
+
(169, 240),
|
|
23
|
+
(258, 260),
|
|
24
|
+
(273, 345),
|
|
25
|
+
(368, 400),
|
|
26
|
+
]
|
|
27
|
+
"""Train range indices."""
|
|
28
|
+
|
|
29
|
+
_val_index_ranges: List[Tuple[int, int]] = [
|
|
30
|
+
(41, 59),
|
|
31
|
+
(60, 90),
|
|
32
|
+
(139, 169),
|
|
33
|
+
(240, 258),
|
|
34
|
+
(260, 273),
|
|
35
|
+
(345, 368),
|
|
36
|
+
]
|
|
37
|
+
"""Validation range indices."""
|
|
38
|
+
|
|
39
|
+
_resources: List[structs.DownloadResource] = [
|
|
40
|
+
structs.DownloadResource(
|
|
41
|
+
filename="ICIAR2018_BACH_Challenge.zip",
|
|
42
|
+
url="https://zenodo.org/records/3632035/files/ICIAR2018_BACH_Challenge.zip",
|
|
43
|
+
),
|
|
44
|
+
]
|
|
45
|
+
"""Dataset resources."""
|
|
46
|
+
|
|
47
|
+
_license: str = "CC BY-NC-ND 4.0 (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode)"
|
|
48
|
+
"""Dataset license."""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
root: str,
|
|
53
|
+
split: Literal["train", "val"] | None = None,
|
|
54
|
+
download: bool = False,
|
|
55
|
+
image_transforms: Callable | None = None,
|
|
56
|
+
target_transforms: Callable | None = None,
|
|
57
|
+
) -> None:
|
|
58
|
+
"""Initialize the dataset.
|
|
59
|
+
|
|
60
|
+
The dataset is split into train and validation by taking into account
|
|
61
|
+
the patient IDs to avoid any data leakage.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
root: Path to the root directory of the dataset. The dataset will
|
|
65
|
+
be downloaded and extracted here, if it does not already exist.
|
|
66
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
67
|
+
download: Whether to download the data for the specified split.
|
|
68
|
+
Note that the download will be executed only by additionally
|
|
69
|
+
calling the :meth:`prepare_data` method and if the data does
|
|
70
|
+
not yet exist on disk.
|
|
71
|
+
image_transforms: A function/transform that takes in an image
|
|
72
|
+
and returns a transformed version.
|
|
73
|
+
target_transforms: A function/transform that takes in the target
|
|
74
|
+
and transforms it.
|
|
75
|
+
"""
|
|
76
|
+
super().__init__(
|
|
77
|
+
image_transforms=image_transforms,
|
|
78
|
+
target_transforms=target_transforms,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self._root = root
|
|
82
|
+
self._split = split
|
|
83
|
+
self._download = download
|
|
84
|
+
|
|
85
|
+
self._samples: List[Tuple[str, int]] = []
|
|
86
|
+
self._indices: List[int] = []
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
@override
|
|
90
|
+
def classes(self) -> List[str]:
|
|
91
|
+
return ["Benign", "InSitu", "Invasive", "Normal"]
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
@override
|
|
95
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
96
|
+
return {"Benign": 0, "InSitu": 1, "Invasive": 2, "Normal": 3}
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def dataset_path(self) -> str:
|
|
100
|
+
"""Returns the path of the image data of the dataset."""
|
|
101
|
+
return os.path.join(self._root, "ICIAR2018_BACH_Challenge", "Photos")
|
|
102
|
+
|
|
103
|
+
@override
|
|
104
|
+
def filename(self, index: int) -> str:
|
|
105
|
+
image_path, _ = self._samples[self._indices[index]]
|
|
106
|
+
return os.path.relpath(image_path, self.dataset_path)
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def prepare_data(self) -> None:
|
|
110
|
+
if self._download:
|
|
111
|
+
self._download_dataset()
|
|
112
|
+
|
|
113
|
+
@override
|
|
114
|
+
def configure(self) -> None:
|
|
115
|
+
self._samples = folder.make_dataset(
|
|
116
|
+
directory=self.dataset_path,
|
|
117
|
+
class_to_idx=self.class_to_idx,
|
|
118
|
+
extensions=(".tif"),
|
|
119
|
+
)
|
|
120
|
+
self._indices = self._make_indices()
|
|
121
|
+
|
|
122
|
+
@override
|
|
123
|
+
def validate(self) -> None:
|
|
124
|
+
_validators.check_dataset_integrity(
|
|
125
|
+
self,
|
|
126
|
+
length=268 if self._split == "train" else 132,
|
|
127
|
+
n_classes=4,
|
|
128
|
+
first_and_last_labels=("Benign", "Normal"),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@override
|
|
132
|
+
def load_image(self, index: int) -> np.ndarray:
|
|
133
|
+
image_path, _ = self._samples[self._indices[index]]
|
|
134
|
+
return io.read_image(image_path)
|
|
135
|
+
|
|
136
|
+
@override
|
|
137
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
138
|
+
_, target = self._samples[self._indices[index]]
|
|
139
|
+
return np.asarray(target, dtype=np.int64)
|
|
140
|
+
|
|
141
|
+
@override
|
|
142
|
+
def __len__(self) -> int:
|
|
143
|
+
return len(self._indices)
|
|
144
|
+
|
|
145
|
+
def _download_dataset(self) -> None:
|
|
146
|
+
"""Downloads the dataset."""
|
|
147
|
+
for resource in self._resources:
|
|
148
|
+
if os.path.isdir(self.dataset_path):
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
self._print_license()
|
|
152
|
+
utils.download_and_extract_archive(
|
|
153
|
+
resource.url,
|
|
154
|
+
download_root=self._root,
|
|
155
|
+
filename=resource.filename,
|
|
156
|
+
remove_finished=True,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def _print_license(self) -> None:
|
|
160
|
+
"""Prints the dataset license."""
|
|
161
|
+
print(f"Dataset license: {self._license}")
|
|
162
|
+
|
|
163
|
+
def _make_indices(self) -> List[int]:
|
|
164
|
+
"""Builds the dataset indices for the specified split."""
|
|
165
|
+
split_index_ranges = {
|
|
166
|
+
"train": self._train_index_ranges,
|
|
167
|
+
"val": self._val_index_ranges,
|
|
168
|
+
None: [(0, 400)],
|
|
169
|
+
}
|
|
170
|
+
index_ranges = split_index_ranges.get(self._split)
|
|
171
|
+
if index_ranges is None:
|
|
172
|
+
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
|
|
173
|
+
|
|
174
|
+
return _utils.ranges_to_indices(index_ranges)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Base for image classification datasets."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.data.datasets import vision
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ImageClassification(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc.ABC):
|
|
13
|
+
"""Image classification abstract dataset."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
image_transforms: Callable | None = None,
|
|
18
|
+
target_transforms: Callable | None = None,
|
|
19
|
+
) -> None:
|
|
20
|
+
"""Initializes the image classification dataset.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
image_transforms: A function/transform that takes in an image
|
|
24
|
+
and returns a transformed version.
|
|
25
|
+
target_transforms: A function/transform that takes in the target
|
|
26
|
+
and transforms it.
|
|
27
|
+
"""
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
self._image_transforms = image_transforms
|
|
31
|
+
self._target_transforms = target_transforms
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def classes(self) -> List[str] | None:
|
|
35
|
+
"""Returns the list with names of the dataset names."""
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def class_to_idx(self) -> Dict[str, int] | None:
|
|
39
|
+
"""Returns a mapping of the class name to its target index."""
|
|
40
|
+
|
|
41
|
+
def load_metadata(self, index: int | None) -> Dict[str, Any] | List[Dict[str, Any]] | None:
|
|
42
|
+
"""Returns the dataset metadata.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
index: The index of the data sample to return the metadata of.
|
|
46
|
+
If `None`, it will return the metadata of the current dataset.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The sample metadata.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
@abc.abstractmethod
|
|
53
|
+
def load_image(self, index: int) -> np.ndarray:
|
|
54
|
+
"""Returns the `index`'th image sample.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
index: The index of the data sample to load.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The image as a numpy array.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@abc.abstractmethod
|
|
64
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
65
|
+
"""Returns the `index`'th target sample.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
index: The index of the data sample to load.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The sample target as an array.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
@abc.abstractmethod
|
|
75
|
+
@override
|
|
76
|
+
def __len__(self) -> int:
|
|
77
|
+
raise NotImplementedError
|
|
78
|
+
|
|
79
|
+
@override
|
|
80
|
+
def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
81
|
+
image = self.load_image(index)
|
|
82
|
+
target = self.load_target(index)
|
|
83
|
+
return self._apply_transforms(image, target)
|
|
84
|
+
|
|
85
|
+
def _apply_transforms(
|
|
86
|
+
self, image: np.ndarray, target: np.ndarray
|
|
87
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
88
|
+
"""Applies the transforms to the provided data and returns them.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
image: The desired image.
|
|
92
|
+
target: The target of the image.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A tuple with the image and the target transformed.
|
|
96
|
+
"""
|
|
97
|
+
if self._image_transforms is not None:
|
|
98
|
+
image = self._image_transforms(image)
|
|
99
|
+
|
|
100
|
+
if self._target_transforms is not None:
|
|
101
|
+
target = self._target_transforms(target)
|
|
102
|
+
|
|
103
|
+
return image, target
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""CRC dataset class."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from torchvision.datasets import folder, utils
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.vision.data.datasets import _validators, structs
|
|
11
|
+
from eva.vision.data.datasets.classification import base
|
|
12
|
+
from eva.vision.utils import io
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CRC(base.ImageClassification):
|
|
16
|
+
"""Dataset class for CRC images and corresponding targets."""
|
|
17
|
+
|
|
18
|
+
_train_resource: structs.DownloadResource = structs.DownloadResource(
|
|
19
|
+
filename="NCT-CRC-HE-100K.zip",
|
|
20
|
+
url="https://zenodo.org/records/1214456/files/NCT-CRC-HE-100K.zip?download=1",
|
|
21
|
+
md5="md5:035777cf327776a71a05c95da6d6325f",
|
|
22
|
+
)
|
|
23
|
+
"""Train resource."""
|
|
24
|
+
|
|
25
|
+
_val_resource: structs.DownloadResource = structs.DownloadResource(
|
|
26
|
+
filename="CRC-VAL-HE-7K.zip",
|
|
27
|
+
url="https://zenodo.org/records/1214456/files/CRC-VAL-HE-7K.zip?download=1",
|
|
28
|
+
md5="md5:2fd1651b4f94ebd818ebf90ad2b6ce06",
|
|
29
|
+
)
|
|
30
|
+
"""Validation resource."""
|
|
31
|
+
|
|
32
|
+
_license: str = "CC BY 4.0 LEGAL CODE (https://creativecommons.org/licenses/by/4.0/legalcode)"
|
|
33
|
+
"""Dataset license."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
root: str,
|
|
38
|
+
split: Literal["train", "val"],
|
|
39
|
+
download: bool = False,
|
|
40
|
+
image_transforms: Callable | None = None,
|
|
41
|
+
target_transforms: Callable | None = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Initializes the dataset.
|
|
44
|
+
|
|
45
|
+
The dataset is split into a train (train) and validation (val) set:
|
|
46
|
+
- train: A set of 100,000 non-overlapping image patches from
|
|
47
|
+
hematoxylin & eosin (H&E) stained histological images of human
|
|
48
|
+
colorectal cancer (CRC) and normal tissue.
|
|
49
|
+
- val: A set of 7180 image patches from N=50 patients with colorectal
|
|
50
|
+
adenocarcinoma (no overlap with patients in NCT-CRC-HE-100K).
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
root: Path to the root directory of the dataset.
|
|
54
|
+
split: Dataset split to use.
|
|
55
|
+
download: Whether to download the data for the specified split.
|
|
56
|
+
Note that the download will be executed only by additionally
|
|
57
|
+
calling the :meth:`prepare_data` method and if the data does
|
|
58
|
+
not yet exist on disk.
|
|
59
|
+
image_transforms: A function/transform that takes in an image
|
|
60
|
+
and returns a transformed version.
|
|
61
|
+
target_transforms: A function/transform that takes in the target
|
|
62
|
+
and transforms it.
|
|
63
|
+
"""
|
|
64
|
+
super().__init__(
|
|
65
|
+
image_transforms=image_transforms,
|
|
66
|
+
target_transforms=target_transforms,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self._root = root
|
|
70
|
+
self._split = split
|
|
71
|
+
self._download = download
|
|
72
|
+
|
|
73
|
+
self._samples: List[Tuple[str, int]] = []
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
@override
|
|
77
|
+
def classes(self) -> List[str]:
|
|
78
|
+
return ["ADI", "BACK", "DEB", "LYM", "MUC", "MUS", "NORM", "STR", "TUM"]
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
@override
|
|
82
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
83
|
+
return {
|
|
84
|
+
"ADI": 0,
|
|
85
|
+
"BACK": 1,
|
|
86
|
+
"DEB": 2,
|
|
87
|
+
"LYM": 3,
|
|
88
|
+
"MUC": 4,
|
|
89
|
+
"MUS": 5,
|
|
90
|
+
"NORM": 6,
|
|
91
|
+
"STR": 7,
|
|
92
|
+
"TUM": 8,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
def filename(self, index: int) -> str:
|
|
97
|
+
image_path, *_ = self._samples[index]
|
|
98
|
+
return os.path.relpath(image_path, self._dataset_dir)
|
|
99
|
+
|
|
100
|
+
@override
|
|
101
|
+
def prepare_data(self) -> None:
|
|
102
|
+
if self._download:
|
|
103
|
+
self._download_dataset()
|
|
104
|
+
|
|
105
|
+
@override
|
|
106
|
+
def configure(self) -> None:
|
|
107
|
+
self._samples = self._make_dataset()
|
|
108
|
+
|
|
109
|
+
@override
|
|
110
|
+
def validate(self) -> None:
|
|
111
|
+
expected_length = {
|
|
112
|
+
"train": 100000,
|
|
113
|
+
"val": 7180,
|
|
114
|
+
None: 107180,
|
|
115
|
+
}
|
|
116
|
+
_validators.check_dataset_integrity(
|
|
117
|
+
self,
|
|
118
|
+
length=expected_length.get(self._split, 0),
|
|
119
|
+
n_classes=9,
|
|
120
|
+
first_and_last_labels=("ADI", "TUM"),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def load_image(self, index: int) -> np.ndarray:
|
|
125
|
+
image_path, _ = self._samples[index]
|
|
126
|
+
return io.read_image(image_path)
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
130
|
+
_, target = self._samples[index]
|
|
131
|
+
return np.asarray(target, dtype=np.int64)
|
|
132
|
+
|
|
133
|
+
@override
|
|
134
|
+
def __len__(self) -> int:
|
|
135
|
+
return len(self._samples)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def _dataset_dir(self) -> str:
|
|
139
|
+
"""Returns the full path of dataset directory."""
|
|
140
|
+
dataset_dirs = {
|
|
141
|
+
"train": os.path.join(self._root, "NCT-CRC-HE-100K"),
|
|
142
|
+
"val": os.path.join(self._root, "CRC-VAL-HE-7K"),
|
|
143
|
+
}
|
|
144
|
+
dataset_dir = dataset_dirs.get(self._split)
|
|
145
|
+
if dataset_dir is None:
|
|
146
|
+
raise ValueError("Invalid data split. Use 'train' or 'val'.")
|
|
147
|
+
|
|
148
|
+
return dataset_dir
|
|
149
|
+
|
|
150
|
+
def _make_dataset(self) -> List[Tuple[str, int]]:
|
|
151
|
+
"""Builds the dataset for the specified split."""
|
|
152
|
+
dataset = folder.make_dataset(
|
|
153
|
+
directory=self._dataset_dir,
|
|
154
|
+
class_to_idx=self.class_to_idx,
|
|
155
|
+
extensions=(".tif"),
|
|
156
|
+
)
|
|
157
|
+
return dataset
|
|
158
|
+
|
|
159
|
+
def _download_dataset(self) -> None:
|
|
160
|
+
"""Downloads the dataset resources."""
|
|
161
|
+
for resource in [self._train_resource, self._val_resource]:
|
|
162
|
+
resource_dir = resource.filename.rsplit(".", maxsplit=1)[0]
|
|
163
|
+
if os.path.isdir(os.path.join(self._root, resource_dir)):
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
self._print_license()
|
|
167
|
+
utils.download_and_extract_archive(
|
|
168
|
+
resource.url,
|
|
169
|
+
download_root=self._root,
|
|
170
|
+
filename=resource.filename,
|
|
171
|
+
remove_finished=True,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def _print_license(self) -> None:
|
|
175
|
+
"""Prints the dataset license."""
|
|
176
|
+
print(f"Dataset license: {self._license}")
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""MHIST dataset class."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from eva.vision.data.datasets import _validators
|
|
10
|
+
from eva.vision.data.datasets.classification import base
|
|
11
|
+
from eva.vision.utils import io
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MHIST(base.ImageClassification):
|
|
15
|
+
"""MHIST dataset."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
root: str,
|
|
20
|
+
split: Literal["train", "test"],
|
|
21
|
+
image_transforms: Callable | None = None,
|
|
22
|
+
target_transforms: Callable | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
"""Initialize the dataset.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
root: Path to the root directory of the dataset.
|
|
28
|
+
split: Dataset split to use.
|
|
29
|
+
image_transforms: A function/transform that takes in an image
|
|
30
|
+
and returns a transformed version.
|
|
31
|
+
target_transforms: A function/transform that takes in the target
|
|
32
|
+
and transforms it.
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(
|
|
35
|
+
image_transforms=image_transforms,
|
|
36
|
+
target_transforms=target_transforms,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
self._root = root
|
|
40
|
+
self._split = split
|
|
41
|
+
|
|
42
|
+
self._samples: List[Tuple[str, str]] = []
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
@override
|
|
46
|
+
def classes(self) -> List[str]:
|
|
47
|
+
return ["SSA", "HP"]
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
@override
|
|
51
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
52
|
+
return {"SSA": 0, "HP": 1}
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def filename(self, index: int) -> str:
|
|
56
|
+
image_filename, _ = self._samples[index]
|
|
57
|
+
return image_filename
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def configure(self) -> None:
|
|
61
|
+
self._samples = self._make_dataset()
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def validate(self) -> None:
|
|
65
|
+
_validators.check_dataset_integrity(
|
|
66
|
+
self,
|
|
67
|
+
length=2175 if self._split == "train" else 977,
|
|
68
|
+
n_classes=2,
|
|
69
|
+
first_and_last_labels=("SSA", "HP"),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
@override
|
|
73
|
+
def load_image(self, index: int) -> np.ndarray:
|
|
74
|
+
image_filename, _ = self._samples[index]
|
|
75
|
+
image_path = os.path.join(self._dataset_path, image_filename)
|
|
76
|
+
return io.read_image(image_path)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
80
|
+
_, label = self._samples[index]
|
|
81
|
+
target = self.class_to_idx[label]
|
|
82
|
+
return np.asarray(target, dtype=np.int64)
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def __len__(self) -> int:
|
|
86
|
+
return len(self._samples)
|
|
87
|
+
|
|
88
|
+
def _make_dataset(self) -> List[Tuple[str, str]]:
|
|
89
|
+
"""Generates and returns a list of samples of a form (image_filename, label)."""
|
|
90
|
+
data = io.read_csv(self._annotations_path)
|
|
91
|
+
samples = [
|
|
92
|
+
(sample["Image Name"], sample["Majority Vote Label"])
|
|
93
|
+
for sample in data
|
|
94
|
+
if sample["Partition"] == self._split
|
|
95
|
+
]
|
|
96
|
+
return samples
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def _dataset_path(self) -> str:
|
|
100
|
+
"""Returns the path of the image data of the dataset."""
|
|
101
|
+
return os.path.join(self._root, "images")
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def _annotations_path(self) -> str:
|
|
105
|
+
"""Returns the path of the annotations file of the dataset."""
|
|
106
|
+
return os.path.join(self._root, "annotations.csv")
|