kaiko-eva 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -2
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +2 -2
- eva/core/data/datasets/classification/__init__.py +8 -0
- eva/core/data/datasets/classification/embeddings.py +34 -0
- eva/core/data/datasets/{embeddings/classification → classification}/multi_embeddings.py +13 -9
- eva/core/data/datasets/{embeddings/base.py → embeddings.py} +47 -32
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/loggers/experimental_loggers.py +2 -2
- eva/core/loggers/log/__init__.py +3 -2
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +10 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +10 -4
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/functional.py +1 -0
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +30 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +12 -1
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +16 -17
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/{total_segmentator.py → total_segmentator_2d.py} +130 -36
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +56 -13
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +33 -12
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.1.dist-info/METADATA +553 -0
- kaiko_eva-0.1.1.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/entry_points.txt +2 -0
- eva/.DS_Store +0 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/data/datasets/embeddings/__init__.py +0 -13
- eva/core/data/datasets/embeddings/classification/__init__.py +0 -10
- eva/core/data/datasets/embeddings/classification/embeddings.py +0 -66
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.2.dist-info/METADATA +0 -431
- kaiko_eva-0.0.2.dist-info/RECORD +0 -127
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.2.dist-info → kaiko_eva-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Camelyon16 dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.transforms.v2 import functional
|
|
12
|
+
from typing_extensions import override
|
|
13
|
+
|
|
14
|
+
from eva.vision.data.datasets import _validators, wsi
|
|
15
|
+
from eva.vision.data.datasets.classification import base
|
|
16
|
+
from eva.vision.data.wsi.patching import samplers
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Camelyon16(wsi.MultiWsiDataset, base.ImageClassification):
|
|
20
|
+
"""Dataset class for Camelyon16 images and corresponding targets."""
|
|
21
|
+
|
|
22
|
+
_val_slides = [
|
|
23
|
+
"normal_010",
|
|
24
|
+
"normal_013",
|
|
25
|
+
"normal_016",
|
|
26
|
+
"normal_017",
|
|
27
|
+
"normal_019",
|
|
28
|
+
"normal_020",
|
|
29
|
+
"normal_025",
|
|
30
|
+
"normal_030",
|
|
31
|
+
"normal_031",
|
|
32
|
+
"normal_032",
|
|
33
|
+
"normal_052",
|
|
34
|
+
"normal_056",
|
|
35
|
+
"normal_057",
|
|
36
|
+
"normal_067",
|
|
37
|
+
"normal_076",
|
|
38
|
+
"normal_079",
|
|
39
|
+
"normal_085",
|
|
40
|
+
"normal_095",
|
|
41
|
+
"normal_098",
|
|
42
|
+
"normal_099",
|
|
43
|
+
"normal_101",
|
|
44
|
+
"normal_102",
|
|
45
|
+
"normal_105",
|
|
46
|
+
"normal_106",
|
|
47
|
+
"normal_109",
|
|
48
|
+
"normal_129",
|
|
49
|
+
"normal_132",
|
|
50
|
+
"normal_137",
|
|
51
|
+
"normal_142",
|
|
52
|
+
"normal_143",
|
|
53
|
+
"normal_148",
|
|
54
|
+
"normal_152",
|
|
55
|
+
"tumor_001",
|
|
56
|
+
"tumor_005",
|
|
57
|
+
"tumor_011",
|
|
58
|
+
"tumor_012",
|
|
59
|
+
"tumor_013",
|
|
60
|
+
"tumor_019",
|
|
61
|
+
"tumor_031",
|
|
62
|
+
"tumor_037",
|
|
63
|
+
"tumor_043",
|
|
64
|
+
"tumor_046",
|
|
65
|
+
"tumor_057",
|
|
66
|
+
"tumor_065",
|
|
67
|
+
"tumor_069",
|
|
68
|
+
"tumor_071",
|
|
69
|
+
"tumor_073",
|
|
70
|
+
"tumor_079",
|
|
71
|
+
"tumor_080",
|
|
72
|
+
"tumor_081",
|
|
73
|
+
"tumor_082",
|
|
74
|
+
"tumor_085",
|
|
75
|
+
"tumor_097",
|
|
76
|
+
"tumor_109",
|
|
77
|
+
]
|
|
78
|
+
"""Validation slide names, same as the ones in patch camelyon."""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
root: str,
|
|
83
|
+
sampler: samplers.Sampler,
|
|
84
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
85
|
+
width: int = 224,
|
|
86
|
+
height: int = 224,
|
|
87
|
+
target_mpp: float = 0.5,
|
|
88
|
+
backend: str = "openslide",
|
|
89
|
+
image_transforms: Callable | None = None,
|
|
90
|
+
seed: int = 42,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Initializes the dataset.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
root: Root directory of the dataset.
|
|
96
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
97
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
98
|
+
width: Width of the patches to be extracted, in pixels.
|
|
99
|
+
height: Height of the patches to be extracted, in pixels.
|
|
100
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
101
|
+
backend: The backend to use for reading the whole-slide images.
|
|
102
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
103
|
+
seed: Random seed for reproducibility.
|
|
104
|
+
"""
|
|
105
|
+
self._split = split
|
|
106
|
+
self._root = root
|
|
107
|
+
self._width = width
|
|
108
|
+
self._height = height
|
|
109
|
+
self._target_mpp = target_mpp
|
|
110
|
+
self._seed = seed
|
|
111
|
+
|
|
112
|
+
wsi.MultiWsiDataset.__init__(
|
|
113
|
+
self,
|
|
114
|
+
root=root,
|
|
115
|
+
file_paths=self._load_file_paths(split),
|
|
116
|
+
width=width,
|
|
117
|
+
height=height,
|
|
118
|
+
sampler=sampler,
|
|
119
|
+
target_mpp=target_mpp,
|
|
120
|
+
backend=backend,
|
|
121
|
+
image_transforms=image_transforms,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
@override
|
|
126
|
+
def classes(self) -> List[str]:
|
|
127
|
+
return ["normal", "tumor"]
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
@override
|
|
131
|
+
def class_to_idx(self) -> Dict[str, int]:
|
|
132
|
+
return {"normal": 0, "tumor": 1}
|
|
133
|
+
|
|
134
|
+
@functools.cached_property
|
|
135
|
+
def annotations_test_set(self) -> Dict[str, str]:
|
|
136
|
+
"""Loads the dataset labels."""
|
|
137
|
+
path = os.path.join(self._root, "testing/reference.csv")
|
|
138
|
+
reference_df = pd.read_csv(path, header=None)
|
|
139
|
+
return {k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
|
|
140
|
+
|
|
141
|
+
@functools.cached_property
|
|
142
|
+
def annotations(self) -> Dict[str, str]:
|
|
143
|
+
"""Loads the dataset labels."""
|
|
144
|
+
annotations = {}
|
|
145
|
+
if self._split in ["test", None]:
|
|
146
|
+
path = os.path.join(self._root, "testing/reference.csv")
|
|
147
|
+
reference_df = pd.read_csv(path, header=None)
|
|
148
|
+
annotations.update(
|
|
149
|
+
{k: v.lower() for k, v in reference_df[[0, 1]].itertuples(index=False)}
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if self._split in ["train", "val", None]:
|
|
153
|
+
annotations.update(
|
|
154
|
+
{
|
|
155
|
+
self._get_id_from_path(file_path): self._get_class_from_path(file_path)
|
|
156
|
+
for file_path in self._file_paths
|
|
157
|
+
if "test" not in file_path
|
|
158
|
+
}
|
|
159
|
+
)
|
|
160
|
+
return annotations
|
|
161
|
+
|
|
162
|
+
@override
|
|
163
|
+
def prepare_data(self) -> None:
|
|
164
|
+
_validators.check_dataset_exists(self._root, False)
|
|
165
|
+
|
|
166
|
+
expected_directories = ["training/normal", "training/tumor", "testing/images"]
|
|
167
|
+
for resource in expected_directories:
|
|
168
|
+
if not os.path.isdir(os.path.join(self._root, resource)):
|
|
169
|
+
raise FileNotFoundError(f"'{resource}' not found in the root folder.")
|
|
170
|
+
|
|
171
|
+
if not os.path.isfile(os.path.join(self._root, "testing/reference.csv")):
|
|
172
|
+
raise FileNotFoundError("'reference.csv' file not found in the testing folder.")
|
|
173
|
+
|
|
174
|
+
@override
|
|
175
|
+
def validate(self) -> None:
|
|
176
|
+
|
|
177
|
+
expected_n_files = {
|
|
178
|
+
"train": 216,
|
|
179
|
+
"val": 54,
|
|
180
|
+
"test": 129,
|
|
181
|
+
None: 399,
|
|
182
|
+
}
|
|
183
|
+
_validators.check_number_of_files(
|
|
184
|
+
self._file_paths, expected_n_files[self._split], self._split
|
|
185
|
+
)
|
|
186
|
+
_validators.check_dataset_integrity(
|
|
187
|
+
self,
|
|
188
|
+
length=None,
|
|
189
|
+
n_classes=2,
|
|
190
|
+
first_and_last_labels=("normal", "tumor"),
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
@override
|
|
194
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
195
|
+
return base.ImageClassification.__getitem__(self, index)
|
|
196
|
+
|
|
197
|
+
@override
|
|
198
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
199
|
+
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
200
|
+
return functional.to_image(image_array)
|
|
201
|
+
|
|
202
|
+
@override
|
|
203
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
204
|
+
file_path = self._file_paths[self._get_dataset_idx(index)]
|
|
205
|
+
class_name = self.annotations[self._get_id_from_path(file_path)]
|
|
206
|
+
return torch.tensor(self.class_to_idx[class_name], dtype=torch.int64)
|
|
207
|
+
|
|
208
|
+
@override
|
|
209
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
210
|
+
return {"wsi_id": self.filename(index).split(".")[0]}
|
|
211
|
+
|
|
212
|
+
def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
|
|
213
|
+
"""Loads the file paths of the corresponding dataset split."""
|
|
214
|
+
train_paths, val_paths = [], []
|
|
215
|
+
for path in glob.glob(os.path.join(self._root, "training/**/*.tif")):
|
|
216
|
+
if self._get_id_from_path(path) in self._val_slides:
|
|
217
|
+
val_paths.append(path)
|
|
218
|
+
else:
|
|
219
|
+
train_paths.append(path)
|
|
220
|
+
test_paths = glob.glob(os.path.join(self._root, "testing/images", "*.tif"))
|
|
221
|
+
|
|
222
|
+
match split:
|
|
223
|
+
case "train":
|
|
224
|
+
paths = train_paths
|
|
225
|
+
case "val":
|
|
226
|
+
paths = val_paths
|
|
227
|
+
case "test":
|
|
228
|
+
paths = test_paths
|
|
229
|
+
case None:
|
|
230
|
+
paths = train_paths + val_paths + test_paths
|
|
231
|
+
case _:
|
|
232
|
+
raise ValueError("Invalid split. Use 'train', 'val' or `None`.")
|
|
233
|
+
return sorted([os.path.relpath(path, self._root) for path in paths])
|
|
234
|
+
|
|
235
|
+
def _get_id_from_path(self, file_path: str) -> str:
|
|
236
|
+
"""Extracts the slide ID from the file path."""
|
|
237
|
+
return os.path.basename(file_path).replace(".tif", "")
|
|
238
|
+
|
|
239
|
+
def _get_class_from_path(self, file_path: str) -> str:
|
|
240
|
+
"""Extracts the class name from the file path."""
|
|
241
|
+
class_name = self._get_id_from_path(file_path).split("_")[0]
|
|
242
|
+
if class_name not in self.classes:
|
|
243
|
+
raise ValueError(f"Invalid class name '{class_name}' in file path '{file_path}'.")
|
|
244
|
+
return class_name
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from torchvision.datasets import folder, utils
|
|
8
9
|
from typing_extensions import override
|
|
9
10
|
|
|
@@ -37,8 +38,7 @@ class CRC(base.ImageClassification):
|
|
|
37
38
|
root: str,
|
|
38
39
|
split: Literal["train", "val"],
|
|
39
40
|
download: bool = False,
|
|
40
|
-
|
|
41
|
-
target_transforms: Callable | None = None,
|
|
41
|
+
transforms: Callable | None = None,
|
|
42
42
|
) -> None:
|
|
43
43
|
"""Initializes the dataset.
|
|
44
44
|
|
|
@@ -56,15 +56,10 @@ class CRC(base.ImageClassification):
|
|
|
56
56
|
Note that the download will be executed only by additionally
|
|
57
57
|
calling the :meth:`prepare_data` method and if the data does
|
|
58
58
|
not yet exist on disk.
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
target_transforms: A function/transform that takes in the target
|
|
62
|
-
and transforms it.
|
|
59
|
+
transforms: A function/transform which returns a transformed
|
|
60
|
+
version of the raw data samples.
|
|
63
61
|
"""
|
|
64
|
-
super().__init__(
|
|
65
|
-
image_transforms=image_transforms,
|
|
66
|
-
target_transforms=target_transforms,
|
|
67
|
-
)
|
|
62
|
+
super().__init__(transforms=transforms)
|
|
68
63
|
|
|
69
64
|
self._root = root
|
|
70
65
|
self._split = split
|
|
@@ -122,14 +117,14 @@ class CRC(base.ImageClassification):
|
|
|
122
117
|
)
|
|
123
118
|
|
|
124
119
|
@override
|
|
125
|
-
def load_image(self, index: int) ->
|
|
120
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
126
121
|
image_path, _ = self._samples[index]
|
|
127
|
-
return io.
|
|
122
|
+
return io.read_image_as_tensor(image_path)
|
|
128
123
|
|
|
129
124
|
@override
|
|
130
|
-
def load_target(self, index: int) ->
|
|
125
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
131
126
|
_, target = self._samples[index]
|
|
132
|
-
return
|
|
127
|
+
return torch.tensor(target, dtype=torch.long)
|
|
133
128
|
|
|
134
129
|
@override
|
|
135
130
|
def __len__(self) -> int:
|
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Callable, Dict, List, Literal, Tuple
|
|
5
5
|
|
|
6
|
-
import
|
|
6
|
+
import torch
|
|
7
|
+
from torchvision import tv_tensors
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
10
|
from eva.vision.data.datasets import _validators
|
|
@@ -18,23 +19,17 @@ class MHIST(base.ImageClassification):
|
|
|
18
19
|
self,
|
|
19
20
|
root: str,
|
|
20
21
|
split: Literal["train", "test"],
|
|
21
|
-
|
|
22
|
-
target_transforms: Callable | None = None,
|
|
22
|
+
transforms: Callable | None = None,
|
|
23
23
|
) -> None:
|
|
24
24
|
"""Initialize the dataset.
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
27
|
root: Path to the root directory of the dataset.
|
|
28
28
|
split: Dataset split to use.
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
target_transforms: A function/transform that takes in the target
|
|
32
|
-
and transforms it.
|
|
29
|
+
transforms: A function/transform which returns a transformed
|
|
30
|
+
version of the raw data samples.
|
|
33
31
|
"""
|
|
34
|
-
super().__init__(
|
|
35
|
-
image_transforms=image_transforms,
|
|
36
|
-
target_transforms=target_transforms,
|
|
37
|
-
)
|
|
32
|
+
super().__init__(transforms=transforms)
|
|
38
33
|
|
|
39
34
|
self._root = root
|
|
40
35
|
self._split = split
|
|
@@ -74,16 +69,16 @@ class MHIST(base.ImageClassification):
|
|
|
74
69
|
)
|
|
75
70
|
|
|
76
71
|
@override
|
|
77
|
-
def load_image(self, index: int) ->
|
|
72
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
78
73
|
image_filename, _ = self._samples[index]
|
|
79
74
|
image_path = os.path.join(self._dataset_path, image_filename)
|
|
80
|
-
return io.
|
|
75
|
+
return io.read_image_as_tensor(image_path)
|
|
81
76
|
|
|
82
77
|
@override
|
|
83
|
-
def load_target(self, index: int) ->
|
|
78
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
84
79
|
_, label = self._samples[index]
|
|
85
80
|
target = self.class_to_idx[label]
|
|
86
|
-
return
|
|
81
|
+
return torch.tensor(target, dtype=torch.float32)
|
|
87
82
|
|
|
88
83
|
@override
|
|
89
84
|
def __len__(self) -> int:
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""PANDA dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.datasets import utils
|
|
12
|
+
from torchvision.transforms.v2 import functional
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from eva.core.data import splitting
|
|
16
|
+
from eva.vision.data.datasets import _validators, structs, wsi
|
|
17
|
+
from eva.vision.data.datasets.classification import base
|
|
18
|
+
from eva.vision.data.wsi.patching import samplers
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
22
|
+
"""Dataset class for PANDA images and corresponding targets."""
|
|
23
|
+
|
|
24
|
+
_train_split_ratio: float = 0.7
|
|
25
|
+
"""Train split ratio."""
|
|
26
|
+
|
|
27
|
+
_val_split_ratio: float = 0.15
|
|
28
|
+
"""Validation split ratio."""
|
|
29
|
+
|
|
30
|
+
_test_split_ratio: float = 0.15
|
|
31
|
+
"""Test split ratio."""
|
|
32
|
+
|
|
33
|
+
_resources: List[structs.DownloadResource] = [
|
|
34
|
+
structs.DownloadResource(
|
|
35
|
+
filename="train_with_noisy_labels.csv",
|
|
36
|
+
url="https://raw.githubusercontent.com/analokmaus/kaggle-panda-challenge-public/master/train.csv",
|
|
37
|
+
md5="5e4bfc78bda9603d2e2faf3ed4b21dfa",
|
|
38
|
+
)
|
|
39
|
+
]
|
|
40
|
+
"""Download resources."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
root: str,
|
|
45
|
+
sampler: samplers.Sampler,
|
|
46
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
47
|
+
width: int = 224,
|
|
48
|
+
height: int = 224,
|
|
49
|
+
target_mpp: float = 0.5,
|
|
50
|
+
backend: str = "openslide",
|
|
51
|
+
image_transforms: Callable | None = None,
|
|
52
|
+
seed: int = 42,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initializes the dataset.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
root: Root directory of the dataset.
|
|
58
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
59
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
60
|
+
width: Width of the patches to be extracted, in pixels.
|
|
61
|
+
height: Height of the patches to be extracted, in pixels.
|
|
62
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
63
|
+
backend: The backend to use for reading the whole-slide images.
|
|
64
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
65
|
+
seed: Random seed for reproducibility.
|
|
66
|
+
"""
|
|
67
|
+
self._split = split
|
|
68
|
+
self._root = root
|
|
69
|
+
self._seed = seed
|
|
70
|
+
|
|
71
|
+
self._download_resources()
|
|
72
|
+
|
|
73
|
+
wsi.MultiWsiDataset.__init__(
|
|
74
|
+
self,
|
|
75
|
+
root=root,
|
|
76
|
+
file_paths=self._load_file_paths(split),
|
|
77
|
+
width=width,
|
|
78
|
+
height=height,
|
|
79
|
+
sampler=sampler,
|
|
80
|
+
target_mpp=target_mpp,
|
|
81
|
+
backend=backend,
|
|
82
|
+
image_transforms=image_transforms,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
@override
|
|
87
|
+
def classes(self) -> List[str]:
|
|
88
|
+
return ["0", "1", "2", "3", "4", "5"]
|
|
89
|
+
|
|
90
|
+
@functools.cached_property
|
|
91
|
+
def annotations(self) -> pd.DataFrame:
|
|
92
|
+
"""Loads the dataset labels."""
|
|
93
|
+
path = os.path.join(self._root, "train_with_noisy_labels.csv")
|
|
94
|
+
return pd.read_csv(path, index_col="image_id")
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def prepare_data(self) -> None:
|
|
98
|
+
_validators.check_dataset_exists(self._root, False)
|
|
99
|
+
|
|
100
|
+
if not os.path.isdir(os.path.join(self._root, "train_images")):
|
|
101
|
+
raise FileNotFoundError("'train_images' directory not found in the root folder.")
|
|
102
|
+
if not os.path.isfile(os.path.join(self._root, "train_with_noisy_labels.csv")):
|
|
103
|
+
raise FileNotFoundError("'train.csv' file not found in the root folder.")
|
|
104
|
+
|
|
105
|
+
def _download_resources(self) -> None:
|
|
106
|
+
"""Downloads the dataset resources."""
|
|
107
|
+
for resource in self._resources:
|
|
108
|
+
utils.download_url(resource.url, self._root, resource.filename, resource.md5)
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def validate(self) -> None:
|
|
112
|
+
_validators.check_dataset_integrity(
|
|
113
|
+
self,
|
|
114
|
+
length=None,
|
|
115
|
+
n_classes=6,
|
|
116
|
+
first_and_last_labels=("0", "5"),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@override
|
|
120
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
121
|
+
return base.ImageClassification.__getitem__(self, index)
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
125
|
+
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
126
|
+
return functional.to_image(image_array)
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
130
|
+
file_path = self._file_paths[self._get_dataset_idx(index)]
|
|
131
|
+
return torch.tensor(self._get_target_from_path(file_path), dtype=torch.int64)
|
|
132
|
+
|
|
133
|
+
@override
|
|
134
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
135
|
+
return {"wsi_id": self.filename(index).split(".")[0]}
|
|
136
|
+
|
|
137
|
+
def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
|
|
138
|
+
"""Loads the file paths of the corresponding dataset split."""
|
|
139
|
+
image_dir = os.path.join(self._root, "train_images")
|
|
140
|
+
file_paths = sorted(glob.glob(os.path.join(image_dir, "*.tiff")))
|
|
141
|
+
file_paths = [os.path.relpath(path, self._root) for path in file_paths]
|
|
142
|
+
if len(file_paths) != len(self.annotations):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Expected {len(self.annotations)} images, found {len(file_paths)} in {image_dir}."
|
|
145
|
+
)
|
|
146
|
+
file_paths = self._filter_noisy_labels(file_paths)
|
|
147
|
+
targets = [self._get_target_from_path(file_path) for file_path in file_paths]
|
|
148
|
+
|
|
149
|
+
train_indices, val_indices, test_indices = splitting.stratified_split(
|
|
150
|
+
samples=file_paths,
|
|
151
|
+
targets=targets,
|
|
152
|
+
train_ratio=self._train_split_ratio,
|
|
153
|
+
val_ratio=self._val_split_ratio,
|
|
154
|
+
test_ratio=self._test_split_ratio,
|
|
155
|
+
seed=self._seed,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
match split:
|
|
159
|
+
case "train":
|
|
160
|
+
return [file_paths[i] for i in train_indices]
|
|
161
|
+
case "val":
|
|
162
|
+
return [file_paths[i] for i in val_indices]
|
|
163
|
+
case "test":
|
|
164
|
+
return [file_paths[i] for i in test_indices or []]
|
|
165
|
+
case None:
|
|
166
|
+
return file_paths
|
|
167
|
+
case _:
|
|
168
|
+
raise ValueError("Invalid split. Use 'train', 'val', 'test' or `None`.")
|
|
169
|
+
|
|
170
|
+
def _filter_noisy_labels(self, file_paths: List[str]):
|
|
171
|
+
is_noisy_filter = self.annotations["noise_ratio_10"] == 0
|
|
172
|
+
non_noisy_image_ids = set(self.annotations.loc[~is_noisy_filter].index)
|
|
173
|
+
filtered_file_paths = [
|
|
174
|
+
file_path
|
|
175
|
+
for file_path in file_paths
|
|
176
|
+
if self._get_id_from_path(file_path) in non_noisy_image_ids
|
|
177
|
+
]
|
|
178
|
+
return filtered_file_paths
|
|
179
|
+
|
|
180
|
+
def _get_target_from_path(self, file_path: str) -> int:
|
|
181
|
+
return self.annotations.loc[self._get_id_from_path(file_path), "isup_grade"]
|
|
182
|
+
|
|
183
|
+
def _get_id_from_path(self, file_path: str) -> str:
|
|
184
|
+
return os.path.basename(file_path).replace(".tiff", "")
|
|
@@ -4,8 +4,10 @@ import os
|
|
|
4
4
|
from typing import Callable, Dict, List, Literal
|
|
5
5
|
|
|
6
6
|
import h5py
|
|
7
|
-
import
|
|
7
|
+
import torch
|
|
8
|
+
from torchvision import tv_tensors
|
|
8
9
|
from torchvision.datasets import utils
|
|
10
|
+
from torchvision.transforms.v2 import functional
|
|
9
11
|
from typing_extensions import override
|
|
10
12
|
|
|
11
13
|
from eva.vision.data.datasets import _validators, structs
|
|
@@ -70,8 +72,7 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
70
72
|
root: str,
|
|
71
73
|
split: Literal["train", "val", "test"],
|
|
72
74
|
download: bool = False,
|
|
73
|
-
|
|
74
|
-
target_transforms: Callable | None = None,
|
|
75
|
+
transforms: Callable | None = None,
|
|
75
76
|
) -> None:
|
|
76
77
|
"""Initializes the dataset.
|
|
77
78
|
|
|
@@ -82,15 +83,10 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
82
83
|
download: Whether to download the data for the specified split.
|
|
83
84
|
Note that the download will be executed only by additionally
|
|
84
85
|
calling the :meth:`prepare_data` method.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
target_transforms: A function/transform that takes in the target
|
|
88
|
-
and transforms it.
|
|
86
|
+
transforms: A function/transform which returns a transformed
|
|
87
|
+
version of the raw data samples.
|
|
89
88
|
"""
|
|
90
|
-
super().__init__(
|
|
91
|
-
image_transforms=image_transforms,
|
|
92
|
-
target_transforms=target_transforms,
|
|
93
|
-
)
|
|
89
|
+
super().__init__(transforms=transforms)
|
|
94
90
|
|
|
95
91
|
self._root = root
|
|
96
92
|
self._split = split
|
|
@@ -131,13 +127,13 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
131
127
|
)
|
|
132
128
|
|
|
133
129
|
@override
|
|
134
|
-
def load_image(self, index: int) ->
|
|
130
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
135
131
|
return self._load_from_h5("x", index)
|
|
136
132
|
|
|
137
133
|
@override
|
|
138
|
-
def load_target(self, index: int) ->
|
|
134
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
139
135
|
target = self._load_from_h5("y", index).squeeze()
|
|
140
|
-
return
|
|
136
|
+
return torch.tensor(target, dtype=torch.float32)
|
|
141
137
|
|
|
142
138
|
@override
|
|
143
139
|
def __len__(self) -> int:
|
|
@@ -162,7 +158,7 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
162
158
|
self,
|
|
163
159
|
data_key: Literal["x", "y"],
|
|
164
160
|
index: int | None = None,
|
|
165
|
-
) ->
|
|
161
|
+
) -> tv_tensors.Image:
|
|
166
162
|
"""Load data or targets from an HDF5 file.
|
|
167
163
|
|
|
168
164
|
Args:
|
|
@@ -176,7 +172,8 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
176
172
|
h5_file = self._h5_file(data_key)
|
|
177
173
|
with h5py.File(h5_file, "r") as file:
|
|
178
174
|
data = file[data_key]
|
|
179
|
-
|
|
175
|
+
image_array = data[:] if index is None else data[index] # type: ignore
|
|
176
|
+
return functional.to_image(image_array) # type: ignore
|
|
180
177
|
|
|
181
178
|
def _fetch_dataset_length(self) -> int:
|
|
182
179
|
"""Fetches the dataset split length from its HDF5 file."""
|