kaiko-eva 0.1.8__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eva/core/data/datasets/base.py +7 -2
- eva/core/data/datasets/classification/embeddings.py +2 -2
- eva/core/data/datasets/classification/multi_embeddings.py +2 -2
- eva/core/data/datasets/embeddings.py +4 -4
- eva/core/data/samplers/classification/balanced.py +19 -18
- eva/core/loggers/utils/wandb.py +33 -0
- eva/core/models/modules/head.py +5 -3
- eva/core/models/modules/typings.py +2 -2
- eva/core/models/transforms/__init__.py +2 -1
- eva/core/models/transforms/as_discrete.py +57 -0
- eva/core/models/wrappers/_utils.py +121 -1
- eva/core/trainers/functional.py +8 -5
- eva/core/trainers/trainer.py +32 -17
- eva/core/utils/suppress_logs.py +28 -0
- eva/vision/data/__init__.py +2 -2
- eva/vision/data/dataloaders/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/__init__.py +5 -0
- eva/vision/data/dataloaders/collate_fn/collection.py +22 -0
- eva/vision/data/datasets/__init__.py +10 -2
- eva/vision/data/datasets/classification/__init__.py +9 -0
- eva/vision/data/datasets/classification/bach.py +3 -4
- eva/vision/data/datasets/classification/bracs.py +111 -0
- eva/vision/data/datasets/classification/breakhis.py +209 -0
- eva/vision/data/datasets/classification/camelyon16.py +4 -5
- eva/vision/data/datasets/classification/crc.py +3 -4
- eva/vision/data/datasets/classification/gleason_arvaniti.py +171 -0
- eva/vision/data/datasets/classification/mhist.py +3 -4
- eva/vision/data/datasets/classification/panda.py +4 -5
- eva/vision/data/datasets/classification/patch_camelyon.py +3 -4
- eva/vision/data/datasets/classification/unitopatho.py +158 -0
- eva/vision/data/datasets/classification/wsi.py +6 -5
- eva/vision/data/datasets/segmentation/__init__.py +2 -2
- eva/vision/data/datasets/segmentation/_utils.py +47 -0
- eva/vision/data/datasets/segmentation/bcss.py +7 -8
- eva/vision/data/datasets/segmentation/btcv.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +6 -7
- eva/vision/data/datasets/segmentation/embeddings.py +2 -2
- eva/vision/data/datasets/segmentation/lits.py +9 -8
- eva/vision/data/datasets/segmentation/lits_balanced.py +2 -1
- eva/vision/data/datasets/segmentation/monusac.py +4 -5
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +12 -10
- eva/vision/data/datasets/vision.py +95 -4
- eva/vision/data/datasets/wsi.py +5 -5
- eva/vision/data/transforms/__init__.py +22 -3
- eva/vision/data/transforms/common/__init__.py +1 -2
- eva/vision/data/transforms/croppad/__init__.py +11 -0
- eva/vision/data/transforms/croppad/crop_foreground.py +110 -0
- eva/vision/data/transforms/croppad/rand_crop_by_pos_neg_label.py +109 -0
- eva/vision/data/transforms/croppad/spatial_pad.py +67 -0
- eva/vision/data/transforms/intensity/__init__.py +11 -0
- eva/vision/data/transforms/intensity/rand_scale_intensity.py +59 -0
- eva/vision/data/transforms/intensity/rand_shift_intensity.py +55 -0
- eva/vision/data/transforms/intensity/scale_intensity_ranged.py +56 -0
- eva/vision/data/transforms/spatial/__init__.py +7 -0
- eva/vision/data/transforms/spatial/flip.py +72 -0
- eva/vision/data/transforms/spatial/rotate.py +53 -0
- eva/vision/data/transforms/spatial/spacing.py +69 -0
- eva/vision/data/transforms/utility/__init__.py +5 -0
- eva/vision/data/transforms/utility/ensure_channel_first.py +51 -0
- eva/vision/data/tv_tensors/__init__.py +5 -0
- eva/vision/data/tv_tensors/volume.py +61 -0
- eva/vision/metrics/segmentation/monai_dice.py +9 -2
- eva/vision/models/modules/semantic_segmentation.py +28 -20
- eva/vision/models/networks/backbones/__init__.py +9 -2
- eva/vision/models/networks/backbones/pathology/__init__.py +11 -2
- eva/vision/models/networks/backbones/pathology/bioptimus.py +47 -1
- eva/vision/models/networks/backbones/pathology/hkust.py +69 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +18 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +46 -19
- eva/vision/models/networks/backbones/radiology/__init__.py +11 -0
- eva/vision/models/networks/backbones/radiology/swin_unetr.py +231 -0
- eva/vision/models/networks/backbones/radiology/voco.py +75 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +6 -2
- eva/vision/models/networks/decoders/segmentation/linear.py +5 -10
- eva/vision/models/networks/decoders/segmentation/semantic/__init__.py +8 -1
- eva/vision/models/networks/decoders/segmentation/semantic/swin_unetr.py +104 -0
- eva/vision/utils/io/__init__.py +2 -0
- eva/vision/utils/io/nifti.py +91 -11
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/METADATA +3 -1
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/RECORD +83 -62
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/WHEEL +1 -1
- eva/vision/data/datasets/classification/base.py +0 -96
- eva/vision/data/datasets/segmentation/base.py +0 -96
- eva/vision/data/transforms/common/resize_and_clamp.py +0 -51
- eva/vision/data/transforms/normalization/__init__.py +0 -6
- eva/vision/data/transforms/normalization/clamp.py +0 -43
- eva/vision/data/transforms/normalization/functional/__init__.py +0 -5
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +0 -28
- eva/vision/data/transforms/normalization/rescale_intensity.py +0 -53
- eva/vision/metrics/segmentation/BUILD +0 -1
- eva/vision/models/networks/backbones/torchhub/__init__.py +0 -5
- eva/vision/models/networks/backbones/torchhub/backbones.py +0 -61
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.1.8.dist-info → kaiko_eva-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,18 +1,27 @@
|
|
|
1
1
|
"""Image classification datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.datasets.classification.bach import BACH
|
|
4
|
+
from eva.vision.data.datasets.classification.bracs import BRACS
|
|
5
|
+
from eva.vision.data.datasets.classification.breakhis import BreaKHis
|
|
4
6
|
from eva.vision.data.datasets.classification.camelyon16 import Camelyon16
|
|
5
7
|
from eva.vision.data.datasets.classification.crc import CRC
|
|
8
|
+
from eva.vision.data.datasets.classification.gleason_arvaniti import GleasonArvaniti
|
|
6
9
|
from eva.vision.data.datasets.classification.mhist import MHIST
|
|
7
10
|
from eva.vision.data.datasets.classification.panda import PANDA, PANDASmall
|
|
8
11
|
from eva.vision.data.datasets.classification.patch_camelyon import PatchCamelyon
|
|
12
|
+
from eva.vision.data.datasets.classification.unitopatho import UniToPatho
|
|
9
13
|
from eva.vision.data.datasets.classification.wsi import WsiClassificationDataset
|
|
10
14
|
|
|
11
15
|
__all__ = [
|
|
12
16
|
"BACH",
|
|
17
|
+
"BreaKHis",
|
|
18
|
+
"BRACS",
|
|
19
|
+
"Camelyon16",
|
|
13
20
|
"CRC",
|
|
21
|
+
"GleasonArvaniti",
|
|
14
22
|
"MHIST",
|
|
15
23
|
"PatchCamelyon",
|
|
24
|
+
"UniToPatho",
|
|
16
25
|
"WsiClassificationDataset",
|
|
17
26
|
"PANDA",
|
|
18
27
|
"PANDASmall",
|
|
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
|
|
|
8
8
|
from torchvision.datasets import folder, utils
|
|
9
9
|
from typing_extensions import override
|
|
10
10
|
|
|
11
|
-
from eva.vision.data.datasets import _utils, _validators, structs
|
|
12
|
-
from eva.vision.data.datasets.classification import base
|
|
11
|
+
from eva.vision.data.datasets import _utils, _validators, structs, vision
|
|
13
12
|
from eva.vision.utils import io
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class BACH(
|
|
15
|
+
class BACH(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
17
16
|
"""Dataset class for BACH images and corresponding targets."""
|
|
18
17
|
|
|
19
18
|
_train_index_ranges: List[Tuple[int, int]] = [
|
|
@@ -125,7 +124,7 @@ class BACH(base.ImageClassification):
|
|
|
125
124
|
)
|
|
126
125
|
|
|
127
126
|
@override
|
|
128
|
-
def
|
|
127
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
129
128
|
image_path, _ = self._samples[self._indices[index]]
|
|
130
129
|
return io.read_image_as_tensor(image_path)
|
|
131
130
|
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""BRACS dataset class."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
8
|
+
from torchvision.datasets import folder
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from eva.vision.data.datasets import _validators, vision
|
|
12
|
+
from eva.vision.utils import io
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BRACS(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
16
|
+
"""Dataset class for BRACS images and corresponding targets."""
|
|
17
|
+
|
|
18
|
+
_expected_dataset_lengths: Dict[str, int] = {
|
|
19
|
+
"train": 3657,
|
|
20
|
+
"val": 312,
|
|
21
|
+
"test": 570,
|
|
22
|
+
}
|
|
23
|
+
"""Expected dataset lengths for the splits and complete dataset."""
|
|
24
|
+
|
|
25
|
+
_license: str = "CC BY-NC 4.0 (https://creativecommons.org/licenses/by-nc/4.0/)"
|
|
26
|
+
"""Dataset license."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
root: str,
|
|
31
|
+
split: Literal["train", "val", "test"],
|
|
32
|
+
transforms: Callable | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initializes the dataset.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
root: Path to the root directory of the dataset.
|
|
38
|
+
split: Dataset split to use.
|
|
39
|
+
transforms: A function/transform which returns a transformed
|
|
40
|
+
version of the raw data samples.
|
|
41
|
+
"""
|
|
42
|
+
super().__init__(transforms=transforms)
|
|
43
|
+
|
|
44
|
+
self._root = root
|
|
45
|
+
self._split = split
|
|
46
|
+
|
|
47
|
+
self._samples: List[Tuple[str, int]] = []
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
@override
|
|
51
|
+
def classes(self) -> List[str]:
|
|
52
|
+
return ["0_N", "1_PB", "2_UDH", "3_FEA", "4_ADH", "5_DCIS", "6_IC"]
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
@override
|
|
56
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
57
|
+
return {name: index for index, name in enumerate(self.classes)}
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def filename(self, index: int) -> str:
|
|
61
|
+
image_path, *_ = self._samples[index]
|
|
62
|
+
return os.path.relpath(image_path, self._dataset_path)
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
def prepare_data(self) -> None:
|
|
66
|
+
_validators.check_dataset_exists(self._root, True)
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
def configure(self) -> None:
|
|
70
|
+
self._samples = self._make_dataset()
|
|
71
|
+
|
|
72
|
+
@override
|
|
73
|
+
def validate(self) -> None:
|
|
74
|
+
_validators.check_dataset_integrity(
|
|
75
|
+
self,
|
|
76
|
+
length=self._expected_dataset_lengths[self._split],
|
|
77
|
+
n_classes=7,
|
|
78
|
+
first_and_last_labels=("0_N", "6_IC"),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
83
|
+
image_path, _ = self._samples[index]
|
|
84
|
+
return io.read_image_as_tensor(image_path)
|
|
85
|
+
|
|
86
|
+
@override
|
|
87
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
88
|
+
_, target = self._samples[index]
|
|
89
|
+
return torch.tensor(target, dtype=torch.long)
|
|
90
|
+
|
|
91
|
+
@override
|
|
92
|
+
def __len__(self) -> int:
|
|
93
|
+
return len(self._samples)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def _dataset_path(self) -> str:
|
|
97
|
+
"""Returns the full path of dataset directory."""
|
|
98
|
+
return os.path.join(self._root, "BRACS_RoI/latest_version")
|
|
99
|
+
|
|
100
|
+
def _make_dataset(self) -> List[Tuple[str, int]]:
|
|
101
|
+
"""Builds the dataset for the specified split."""
|
|
102
|
+
dataset = folder.make_dataset(
|
|
103
|
+
directory=os.path.join(self._dataset_path, self._split),
|
|
104
|
+
class_to_idx=self.class_to_idx,
|
|
105
|
+
extensions=(".png"),
|
|
106
|
+
)
|
|
107
|
+
return dataset
|
|
108
|
+
|
|
109
|
+
def _print_license(self) -> None:
|
|
110
|
+
"""Prints the dataset license."""
|
|
111
|
+
print(f"Dataset license: {self._license}")
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""BreaKHis dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Set
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from torchvision.datasets import utils
|
|
11
|
+
from typing_extensions import override
|
|
12
|
+
|
|
13
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
14
|
+
from eva.vision.utils import io
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BreaKHis(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
18
|
+
"""Dataset class for BreaKHis images and corresponding targets."""
|
|
19
|
+
|
|
20
|
+
_resources: List[structs.DownloadResource] = [
|
|
21
|
+
structs.DownloadResource(
|
|
22
|
+
filename="BreaKHis_v1.tar.gz",
|
|
23
|
+
url="http://www.inf.ufpr.br/vri/databases/BreaKHis_v1.tar.gz",
|
|
24
|
+
),
|
|
25
|
+
]
|
|
26
|
+
"""Dataset resources."""
|
|
27
|
+
|
|
28
|
+
_val_patient_ids: Set[str] = {
|
|
29
|
+
"18842D",
|
|
30
|
+
"19979",
|
|
31
|
+
"15275",
|
|
32
|
+
"15792",
|
|
33
|
+
"16875",
|
|
34
|
+
"3909",
|
|
35
|
+
"5287",
|
|
36
|
+
"16716",
|
|
37
|
+
"2773",
|
|
38
|
+
"5695",
|
|
39
|
+
"16184CD",
|
|
40
|
+
"23060CD",
|
|
41
|
+
"21998CD",
|
|
42
|
+
"21998EF",
|
|
43
|
+
}
|
|
44
|
+
"""Patient IDs to use for dataset splits."""
|
|
45
|
+
|
|
46
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
47
|
+
"train": 1132,
|
|
48
|
+
"val": 339,
|
|
49
|
+
None: 1471,
|
|
50
|
+
}
|
|
51
|
+
"""Expected dataset lengths for the splits and complete dataset."""
|
|
52
|
+
|
|
53
|
+
_default_magnifications = ["40X"]
|
|
54
|
+
"""Default magnification to use for images in train/val datasets."""
|
|
55
|
+
|
|
56
|
+
_license: str = "CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/)"
|
|
57
|
+
"""Dataset license."""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
root: str,
|
|
62
|
+
split: Literal["train", "val"] | None = None,
|
|
63
|
+
magnifications: List[Literal["40X", "100X", "200X", "400X"]] | None = None,
|
|
64
|
+
download: bool = False,
|
|
65
|
+
transforms: Callable | None = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Initialize the dataset.
|
|
68
|
+
|
|
69
|
+
The dataset is split into train and validation by taking into account
|
|
70
|
+
the patient IDs to avoid any data leakage.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
root: Path to the root directory of the dataset. The dataset will
|
|
74
|
+
be downloaded and extracted here, if it does not already exist.
|
|
75
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
76
|
+
magnifications: A list of the WSI magnifications to select. By default
|
|
77
|
+
only 40X images are used.
|
|
78
|
+
download: Whether to download the data for the specified split.
|
|
79
|
+
Note that the download will be executed only by additionally
|
|
80
|
+
calling the :meth:`prepare_data` method and if the data does
|
|
81
|
+
not yet exist on disk.
|
|
82
|
+
transforms: A function/transform which returns a transformed
|
|
83
|
+
version of the raw data samples.
|
|
84
|
+
"""
|
|
85
|
+
super().__init__(transforms=transforms)
|
|
86
|
+
|
|
87
|
+
self._root = root
|
|
88
|
+
self._split = split
|
|
89
|
+
self._download = download
|
|
90
|
+
|
|
91
|
+
self._magnifications = magnifications or self._default_magnifications
|
|
92
|
+
self._indices: List[int] = []
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
@override
|
|
96
|
+
def classes(self) -> List[str]:
|
|
97
|
+
return ["TA", "MC", "F", "DC"]
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
@override
|
|
101
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
102
|
+
return {label: index for index, label in enumerate(self.classes)}
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def _dataset_path(self) -> str:
|
|
106
|
+
"""Returns the path of the image data of the dataset."""
|
|
107
|
+
return os.path.join(self._root, "BreaKHis_v1", "histology_slides")
|
|
108
|
+
|
|
109
|
+
@functools.cached_property
|
|
110
|
+
def _image_files(self) -> List[str]:
|
|
111
|
+
"""Return the list of image files in the dataset.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
List of image file paths.
|
|
115
|
+
"""
|
|
116
|
+
image_files = []
|
|
117
|
+
for magnification in self._magnifications:
|
|
118
|
+
files_pattern = os.path.join(self._dataset_path, f"**/{magnification}", "*.png")
|
|
119
|
+
image_files.extend(list(glob.glob(files_pattern, recursive=True)))
|
|
120
|
+
return sorted(image_files)
|
|
121
|
+
|
|
122
|
+
@override
|
|
123
|
+
def filename(self, index: int) -> str:
|
|
124
|
+
image_path = self._image_files[self._indices[index]]
|
|
125
|
+
return os.path.relpath(image_path, self._dataset_path)
|
|
126
|
+
|
|
127
|
+
@override
|
|
128
|
+
def prepare_data(self) -> None:
|
|
129
|
+
if self._download:
|
|
130
|
+
self._download_dataset()
|
|
131
|
+
_validators.check_dataset_exists(self._root, True)
|
|
132
|
+
|
|
133
|
+
@override
|
|
134
|
+
def configure(self) -> None:
|
|
135
|
+
self._indices = self._make_indices()
|
|
136
|
+
|
|
137
|
+
@override
|
|
138
|
+
def validate(self) -> None:
|
|
139
|
+
_validators.check_dataset_integrity(
|
|
140
|
+
self,
|
|
141
|
+
length=self._expected_dataset_lengths[self._split],
|
|
142
|
+
n_classes=4,
|
|
143
|
+
first_and_last_labels=("TA", "DC"),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
148
|
+
image_path = self._image_files[self._indices[index]]
|
|
149
|
+
return io.read_image_as_tensor(image_path)
|
|
150
|
+
|
|
151
|
+
@override
|
|
152
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
153
|
+
class_name = self._extract_class(self._image_files[self._indices[index]])
|
|
154
|
+
return torch.tensor(self.class_to_idx[class_name], dtype=torch.long)
|
|
155
|
+
|
|
156
|
+
@override
|
|
157
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
158
|
+
return {"patient_id": self._extract_patient_id(self._image_files[self._indices[index]])}
|
|
159
|
+
|
|
160
|
+
@override
|
|
161
|
+
def __len__(self) -> int:
|
|
162
|
+
return len(self._indices)
|
|
163
|
+
|
|
164
|
+
def _download_dataset(self) -> None:
|
|
165
|
+
"""Downloads the dataset."""
|
|
166
|
+
for resource in self._resources:
|
|
167
|
+
if os.path.isdir(self._dataset_path):
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
self._print_license()
|
|
171
|
+
utils.download_and_extract_archive(
|
|
172
|
+
resource.url,
|
|
173
|
+
download_root=self._root,
|
|
174
|
+
filename=resource.filename,
|
|
175
|
+
remove_finished=True,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _print_license(self) -> None:
|
|
179
|
+
"""Prints the dataset license."""
|
|
180
|
+
print(f"Dataset license: {self._license}")
|
|
181
|
+
|
|
182
|
+
def _extract_patient_id(self, image_file: str) -> str:
|
|
183
|
+
"""Extracts the patient ID from the image file name."""
|
|
184
|
+
return os.path.basename(image_file).split("-")[2]
|
|
185
|
+
|
|
186
|
+
def _extract_class(self, file: str) -> str:
|
|
187
|
+
return os.path.basename(file).split("-")[0].split("_")[-1]
|
|
188
|
+
|
|
189
|
+
def _make_indices(self) -> List[int]:
|
|
190
|
+
"""Builds the dataset indices for the specified split."""
|
|
191
|
+
train_indices = []
|
|
192
|
+
val_indices = []
|
|
193
|
+
|
|
194
|
+
for index, image_file in enumerate(self._image_files):
|
|
195
|
+
if self._extract_class(image_file) not in self.classes:
|
|
196
|
+
continue
|
|
197
|
+
patient_id = self._extract_patient_id(image_file)
|
|
198
|
+
if patient_id in self._val_patient_ids:
|
|
199
|
+
val_indices.append(index)
|
|
200
|
+
else:
|
|
201
|
+
train_indices.append(index)
|
|
202
|
+
|
|
203
|
+
split_indices = {
|
|
204
|
+
"train": train_indices,
|
|
205
|
+
"val": val_indices,
|
|
206
|
+
None: train_indices + val_indices,
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
return split_indices[self._split]
|
|
@@ -11,12 +11,11 @@ from torchvision import tv_tensors
|
|
|
11
11
|
from torchvision.transforms.v2 import functional
|
|
12
12
|
from typing_extensions import override
|
|
13
13
|
|
|
14
|
-
from eva.vision.data.datasets import _validators, wsi
|
|
15
|
-
from eva.vision.data.datasets.classification import base
|
|
14
|
+
from eva.vision.data.datasets import _validators, vision, wsi
|
|
16
15
|
from eva.vision.data.wsi.patching import samplers
|
|
17
16
|
|
|
18
17
|
|
|
19
|
-
class Camelyon16(wsi.MultiWsiDataset,
|
|
18
|
+
class Camelyon16(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
20
19
|
"""Dataset class for Camelyon16 images and corresponding targets."""
|
|
21
20
|
|
|
22
21
|
_val_slides = [
|
|
@@ -195,10 +194,10 @@ class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
195
194
|
|
|
196
195
|
@override
|
|
197
196
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
198
|
-
return
|
|
197
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
199
198
|
|
|
200
199
|
@override
|
|
201
|
-
def
|
|
200
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
202
201
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
203
202
|
return functional.to_image(image_array)
|
|
204
203
|
|
|
@@ -8,12 +8,11 @@ from torchvision import tv_tensors
|
|
|
8
8
|
from torchvision.datasets import folder, utils
|
|
9
9
|
from typing_extensions import override
|
|
10
10
|
|
|
11
|
-
from eva.vision.data.datasets import _validators, structs
|
|
12
|
-
from eva.vision.data.datasets.classification import base
|
|
11
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
13
12
|
from eva.vision.utils import io
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
class CRC(
|
|
15
|
+
class CRC(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
17
16
|
"""Dataset class for CRC images and corresponding targets."""
|
|
18
17
|
|
|
19
18
|
_train_resource: structs.DownloadResource = structs.DownloadResource(
|
|
@@ -117,7 +116,7 @@ class CRC(base.ImageClassification):
|
|
|
117
116
|
)
|
|
118
117
|
|
|
119
118
|
@override
|
|
120
|
-
def
|
|
119
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
121
120
|
image_path, _ = self._samples[index]
|
|
122
121
|
return io.read_image_as_tensor(image_path)
|
|
123
122
|
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""GleasonArvaniti dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Callable, Dict, List, Literal
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from torchvision import tv_tensors
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from eva.vision.data.datasets import _validators, vision
|
|
16
|
+
from eva.vision.utils import io
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GleasonArvaniti(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
20
|
+
"""Dataset class for GleasonArvaniti images and corresponding targets."""
|
|
21
|
+
|
|
22
|
+
_expected_dataset_lengths: Dict[str | None, int] = {
|
|
23
|
+
"train": 15303,
|
|
24
|
+
"val": 2482,
|
|
25
|
+
"test": 4967,
|
|
26
|
+
None: 22752,
|
|
27
|
+
}
|
|
28
|
+
"""Expected dataset lengths for the splits and complete dataset."""
|
|
29
|
+
|
|
30
|
+
_license: str = "CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)"
|
|
31
|
+
"""Dataset license."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
root: str,
|
|
36
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
37
|
+
transforms: Callable | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Initialize the dataset.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
root: Path to the root directory of the dataset.
|
|
43
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
44
|
+
transforms: A function/transform which returns a transformed
|
|
45
|
+
version of the raw data samples.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(transforms=transforms)
|
|
48
|
+
|
|
49
|
+
self._root = root
|
|
50
|
+
self._split = split
|
|
51
|
+
|
|
52
|
+
self._indices: List[int] = []
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
@override
|
|
56
|
+
def classes(self) -> List[str]:
|
|
57
|
+
return ["benign", "gleason_3", "gleason_4", "gleason_5"]
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
@override
|
|
61
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
62
|
+
return {name: index for index, name in enumerate(self.classes)}
|
|
63
|
+
|
|
64
|
+
@functools.cached_property
|
|
65
|
+
def _image_files(self) -> List[str]:
|
|
66
|
+
"""Return the list of image files in the dataset.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of image file paths.
|
|
70
|
+
"""
|
|
71
|
+
subdirs = ["train_validation_patches_750", "test_patches_750/patho_1"]
|
|
72
|
+
|
|
73
|
+
image_files = []
|
|
74
|
+
for subdir in subdirs:
|
|
75
|
+
files_pattern = os.path.join(self._root, subdir, "**/*.jpg")
|
|
76
|
+
image_files += list(glob.glob(files_pattern, recursive=True))
|
|
77
|
+
|
|
78
|
+
return sorted(image_files)
|
|
79
|
+
|
|
80
|
+
@functools.cached_property
|
|
81
|
+
def _manifest(self) -> pd.DataFrame:
|
|
82
|
+
"""Returns the train.csv & test.csv files as dataframe."""
|
|
83
|
+
df_train = pd.read_csv(os.path.join(self._root, "train.csv"))
|
|
84
|
+
df_val = pd.read_csv(os.path.join(self._root, "test.csv"))
|
|
85
|
+
df_train["split"], df_val["split"] = "train", "val"
|
|
86
|
+
return pd.concat([df_train, df_val], axis=0).set_index("image_id")
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
def filename(self, index: int) -> str:
|
|
90
|
+
image_path = self._image_files[self._indices[index]]
|
|
91
|
+
return os.path.relpath(image_path, self._root)
|
|
92
|
+
|
|
93
|
+
@override
|
|
94
|
+
def prepare_data(self) -> None:
|
|
95
|
+
_validators.check_dataset_exists(self._root, download_available=False)
|
|
96
|
+
if not os.path.isdir(os.path.join(self._root, "train_validation_patches_750")):
|
|
97
|
+
raise FileNotFoundError(
|
|
98
|
+
f"`train_validation_patches_750` directory not found in {self._root}"
|
|
99
|
+
)
|
|
100
|
+
if not os.path.isdir(os.path.join(self._root, "test_patches_750")):
|
|
101
|
+
raise FileNotFoundError(f"`test_patches_750` directory not found in {self._root}")
|
|
102
|
+
|
|
103
|
+
if self._split == "test":
|
|
104
|
+
logger.warning(
|
|
105
|
+
"The test split currently leads to unstable evaluation results. "
|
|
106
|
+
"We recommend using the validation split instead."
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
@override
|
|
110
|
+
def configure(self) -> None:
|
|
111
|
+
self._indices = self._make_indices()
|
|
112
|
+
|
|
113
|
+
@override
|
|
114
|
+
def validate(self) -> None:
|
|
115
|
+
_validators.check_dataset_integrity(
|
|
116
|
+
self,
|
|
117
|
+
length=self._expected_dataset_lengths[self._split],
|
|
118
|
+
n_classes=4,
|
|
119
|
+
first_and_last_labels=("benign", "gleason_5"),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
@override
|
|
123
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
124
|
+
image_path = self._image_files[self._indices[index]]
|
|
125
|
+
return io.read_image_as_tensor(image_path)
|
|
126
|
+
|
|
127
|
+
@override
|
|
128
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
129
|
+
target = self._extract_class(self._image_files[self._indices[index]])
|
|
130
|
+
return torch.tensor(target, dtype=torch.long)
|
|
131
|
+
|
|
132
|
+
@override
|
|
133
|
+
def __len__(self) -> int:
|
|
134
|
+
return len(self._indices)
|
|
135
|
+
|
|
136
|
+
def _print_license(self) -> None:
|
|
137
|
+
"""Prints the dataset license."""
|
|
138
|
+
print(f"Dataset license: {self._license}")
|
|
139
|
+
|
|
140
|
+
def _extract_micro_array_id(self, file: str) -> str:
|
|
141
|
+
"""Extracts the ID of the tissue micro array from the file name."""
|
|
142
|
+
return Path(file).stem.split("_")[0]
|
|
143
|
+
|
|
144
|
+
def _extract_class(self, file: str) -> int:
|
|
145
|
+
"""Extracts the class label from the file name."""
|
|
146
|
+
return int(Path(file).stem.split("_")[-1])
|
|
147
|
+
|
|
148
|
+
def _make_indices(self) -> List[int]:
|
|
149
|
+
"""Builds the dataset indices for the specified split."""
|
|
150
|
+
train_indices, val_indices, test_indices = [], [], []
|
|
151
|
+
|
|
152
|
+
for index, image_file in enumerate(self._image_files):
|
|
153
|
+
array_id = self._extract_micro_array_id(image_file)
|
|
154
|
+
|
|
155
|
+
if array_id == "ZT76":
|
|
156
|
+
val_indices.append(index)
|
|
157
|
+
elif array_id in {"ZT111", "ZT199", "ZT204"}:
|
|
158
|
+
train_indices.append(index)
|
|
159
|
+
elif "test_patches_750" in image_file:
|
|
160
|
+
test_indices.append(index)
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError(f"Invalid microarray value found for file {image_file}")
|
|
163
|
+
|
|
164
|
+
split_indices = {
|
|
165
|
+
"train": train_indices,
|
|
166
|
+
"val": val_indices,
|
|
167
|
+
"test": test_indices,
|
|
168
|
+
None: train_indices + val_indices + test_indices,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
return split_indices[self._split]
|
|
@@ -7,12 +7,11 @@ import torch
|
|
|
7
7
|
from torchvision import tv_tensors
|
|
8
8
|
from typing_extensions import override
|
|
9
9
|
|
|
10
|
-
from eva.vision.data.datasets import _validators
|
|
11
|
-
from eva.vision.data.datasets.classification import base
|
|
10
|
+
from eva.vision.data.datasets import _validators, vision
|
|
12
11
|
from eva.vision.utils import io
|
|
13
12
|
|
|
14
13
|
|
|
15
|
-
class MHIST(
|
|
14
|
+
class MHIST(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
16
15
|
"""MHIST dataset."""
|
|
17
16
|
|
|
18
17
|
def __init__(
|
|
@@ -69,7 +68,7 @@ class MHIST(base.ImageClassification):
|
|
|
69
68
|
)
|
|
70
69
|
|
|
71
70
|
@override
|
|
72
|
-
def
|
|
71
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
73
72
|
image_filename, _ = self._samples[index]
|
|
74
73
|
image_path = os.path.join(self._dataset_path, image_filename)
|
|
75
74
|
return io.read_image_as_tensor(image_path)
|
|
@@ -13,12 +13,11 @@ from torchvision.transforms.v2 import functional
|
|
|
13
13
|
from typing_extensions import override
|
|
14
14
|
|
|
15
15
|
from eva.core.data import splitting
|
|
16
|
-
from eva.vision.data.datasets import _validators, structs, wsi
|
|
17
|
-
from eva.vision.data.datasets.classification import base
|
|
16
|
+
from eva.vision.data.datasets import _validators, structs, vision, wsi
|
|
18
17
|
from eva.vision.data.wsi.patching import samplers
|
|
19
18
|
|
|
20
19
|
|
|
21
|
-
class PANDA(wsi.MultiWsiDataset,
|
|
20
|
+
class PANDA(wsi.MultiWsiDataset, vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
22
21
|
"""Dataset class for PANDA images and corresponding targets."""
|
|
23
22
|
|
|
24
23
|
_train_split_ratio: float = 0.7
|
|
@@ -121,10 +120,10 @@ class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
|
121
120
|
|
|
122
121
|
@override
|
|
123
122
|
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
124
|
-
return
|
|
123
|
+
return vision.VisionDataset.__getitem__(self, index)
|
|
125
124
|
|
|
126
125
|
@override
|
|
127
|
-
def
|
|
126
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
128
127
|
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
129
128
|
return functional.to_image(image_array)
|
|
130
129
|
|
|
@@ -10,14 +10,13 @@ from torchvision.datasets import utils
|
|
|
10
10
|
from torchvision.transforms.v2 import functional
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
|
-
from eva.vision.data.datasets import _validators, structs
|
|
14
|
-
from eva.vision.data.datasets.classification import base
|
|
13
|
+
from eva.vision.data.datasets import _validators, structs, vision
|
|
15
14
|
|
|
16
15
|
_URL_TEMPLATE = "https://zenodo.org/records/2546921/files/{filename}.gz?download=1"
|
|
17
16
|
"""PatchCamelyon URL files templates."""
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
class PatchCamelyon(
|
|
19
|
+
class PatchCamelyon(vision.VisionDataset[tv_tensors.Image, torch.Tensor]):
|
|
21
20
|
"""Dataset class for PatchCamelyon images and corresponding targets."""
|
|
22
21
|
|
|
23
22
|
_train_resources: List[structs.DownloadResource] = [
|
|
@@ -127,7 +126,7 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
127
126
|
)
|
|
128
127
|
|
|
129
128
|
@override
|
|
130
|
-
def
|
|
129
|
+
def load_data(self, index: int) -> tv_tensors.Image:
|
|
131
130
|
return self._load_from_h5("x", index)
|
|
132
131
|
|
|
133
132
|
@override
|