kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""PANDA dataset class."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import glob
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Tuple
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from torchvision import tv_tensors
|
|
11
|
+
from torchvision.datasets import utils
|
|
12
|
+
from torchvision.transforms.v2 import functional
|
|
13
|
+
from typing_extensions import override
|
|
14
|
+
|
|
15
|
+
from eva.core.data import splitting
|
|
16
|
+
from eva.vision.data.datasets import _validators, structs, wsi
|
|
17
|
+
from eva.vision.data.datasets.classification import base
|
|
18
|
+
from eva.vision.data.wsi.patching import samplers
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PANDA(wsi.MultiWsiDataset, base.ImageClassification):
|
|
22
|
+
"""Dataset class for PANDA images and corresponding targets."""
|
|
23
|
+
|
|
24
|
+
_train_split_ratio: float = 0.7
|
|
25
|
+
"""Train split ratio."""
|
|
26
|
+
|
|
27
|
+
_val_split_ratio: float = 0.15
|
|
28
|
+
"""Validation split ratio."""
|
|
29
|
+
|
|
30
|
+
_test_split_ratio: float = 0.15
|
|
31
|
+
"""Test split ratio."""
|
|
32
|
+
|
|
33
|
+
_resources: List[structs.DownloadResource] = [
|
|
34
|
+
structs.DownloadResource(
|
|
35
|
+
filename="train_with_noisy_labels.csv",
|
|
36
|
+
url="https://raw.githubusercontent.com/analokmaus/kaggle-panda-challenge-public/master/train.csv",
|
|
37
|
+
md5="5e4bfc78bda9603d2e2faf3ed4b21dfa",
|
|
38
|
+
)
|
|
39
|
+
]
|
|
40
|
+
"""Download resources."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
root: str,
|
|
45
|
+
sampler: samplers.Sampler,
|
|
46
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
47
|
+
width: int = 224,
|
|
48
|
+
height: int = 224,
|
|
49
|
+
target_mpp: float = 0.5,
|
|
50
|
+
backend: str = "openslide",
|
|
51
|
+
image_transforms: Callable | None = None,
|
|
52
|
+
seed: int = 42,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Initializes the dataset.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
root: Root directory of the dataset.
|
|
58
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
59
|
+
split: Dataset split to use. If `None`, the entire dataset is used.
|
|
60
|
+
width: Width of the patches to be extracted, in pixels.
|
|
61
|
+
height: Height of the patches to be extracted, in pixels.
|
|
62
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
63
|
+
backend: The backend to use for reading the whole-slide images.
|
|
64
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
65
|
+
seed: Random seed for reproducibility.
|
|
66
|
+
"""
|
|
67
|
+
self._split = split
|
|
68
|
+
self._root = root
|
|
69
|
+
self._seed = seed
|
|
70
|
+
|
|
71
|
+
self._download_resources()
|
|
72
|
+
|
|
73
|
+
wsi.MultiWsiDataset.__init__(
|
|
74
|
+
self,
|
|
75
|
+
root=root,
|
|
76
|
+
file_paths=self._load_file_paths(split),
|
|
77
|
+
width=width,
|
|
78
|
+
height=height,
|
|
79
|
+
sampler=sampler,
|
|
80
|
+
target_mpp=target_mpp,
|
|
81
|
+
backend=backend,
|
|
82
|
+
image_transforms=image_transforms,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
@override
|
|
87
|
+
def classes(self) -> List[str]:
|
|
88
|
+
return ["0", "1", "2", "3", "4", "5"]
|
|
89
|
+
|
|
90
|
+
@functools.cached_property
|
|
91
|
+
def annotations(self) -> pd.DataFrame:
|
|
92
|
+
"""Loads the dataset labels."""
|
|
93
|
+
path = os.path.join(self._root, "train_with_noisy_labels.csv")
|
|
94
|
+
return pd.read_csv(path, index_col="image_id")
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def prepare_data(self) -> None:
|
|
98
|
+
_validators.check_dataset_exists(self._root, False)
|
|
99
|
+
|
|
100
|
+
if not os.path.isdir(os.path.join(self._root, "train_images")):
|
|
101
|
+
raise FileNotFoundError("'train_images' directory not found in the root folder.")
|
|
102
|
+
if not os.path.isfile(os.path.join(self._root, "train_with_noisy_labels.csv")):
|
|
103
|
+
raise FileNotFoundError("'train.csv' file not found in the root folder.")
|
|
104
|
+
|
|
105
|
+
def _download_resources(self) -> None:
|
|
106
|
+
"""Downloads the dataset resources."""
|
|
107
|
+
for resource in self._resources:
|
|
108
|
+
utils.download_url(resource.url, self._root, resource.filename, resource.md5)
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def validate(self) -> None:
|
|
112
|
+
_validators.check_dataset_integrity(
|
|
113
|
+
self,
|
|
114
|
+
length=None,
|
|
115
|
+
n_classes=6,
|
|
116
|
+
first_and_last_labels=("0", "5"),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@override
|
|
120
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
121
|
+
return base.ImageClassification.__getitem__(self, index)
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
125
|
+
image_array = wsi.MultiWsiDataset.__getitem__(self, index)
|
|
126
|
+
return functional.to_image(image_array)
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
130
|
+
file_path = self._file_paths[self._get_dataset_idx(index)]
|
|
131
|
+
return torch.tensor(self._get_target_from_path(file_path), dtype=torch.int64)
|
|
132
|
+
|
|
133
|
+
@override
|
|
134
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
135
|
+
return {"wsi_id": self.filename(index).split(".")[0]}
|
|
136
|
+
|
|
137
|
+
def _load_file_paths(self, split: Literal["train", "val", "test"] | None = None) -> List[str]:
|
|
138
|
+
"""Loads the file paths of the corresponding dataset split."""
|
|
139
|
+
image_dir = os.path.join(self._root, "train_images")
|
|
140
|
+
file_paths = sorted(glob.glob(os.path.join(image_dir, "*.tiff")))
|
|
141
|
+
file_paths = [os.path.relpath(path, self._root) for path in file_paths]
|
|
142
|
+
if len(file_paths) != len(self.annotations):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"Expected {len(self.annotations)} images, found {len(file_paths)} in {image_dir}."
|
|
145
|
+
)
|
|
146
|
+
file_paths = self._filter_noisy_labels(file_paths)
|
|
147
|
+
targets = [self._get_target_from_path(file_path) for file_path in file_paths]
|
|
148
|
+
|
|
149
|
+
train_indices, val_indices, test_indices = splitting.stratified_split(
|
|
150
|
+
samples=file_paths,
|
|
151
|
+
targets=targets,
|
|
152
|
+
train_ratio=self._train_split_ratio,
|
|
153
|
+
val_ratio=self._val_split_ratio,
|
|
154
|
+
test_ratio=self._test_split_ratio,
|
|
155
|
+
seed=self._seed,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
match split:
|
|
159
|
+
case "train":
|
|
160
|
+
return [file_paths[i] for i in train_indices]
|
|
161
|
+
case "val":
|
|
162
|
+
return [file_paths[i] for i in val_indices]
|
|
163
|
+
case "test":
|
|
164
|
+
return [file_paths[i] for i in test_indices or []]
|
|
165
|
+
case None:
|
|
166
|
+
return file_paths
|
|
167
|
+
case _:
|
|
168
|
+
raise ValueError("Invalid split. Use 'train', 'val', 'test' or `None`.")
|
|
169
|
+
|
|
170
|
+
def _filter_noisy_labels(self, file_paths: List[str]):
|
|
171
|
+
is_noisy_filter = self.annotations["noise_ratio_10"] == 0
|
|
172
|
+
non_noisy_image_ids = set(self.annotations.loc[~is_noisy_filter].index)
|
|
173
|
+
filtered_file_paths = [
|
|
174
|
+
file_path
|
|
175
|
+
for file_path in file_paths
|
|
176
|
+
if self._get_id_from_path(file_path) in non_noisy_image_ids
|
|
177
|
+
]
|
|
178
|
+
return filtered_file_paths
|
|
179
|
+
|
|
180
|
+
def _get_target_from_path(self, file_path: str) -> int:
|
|
181
|
+
return self.annotations.loc[self._get_id_from_path(file_path), "isup_grade"]
|
|
182
|
+
|
|
183
|
+
def _get_id_from_path(self, file_path: str) -> str:
|
|
184
|
+
return os.path.basename(file_path).replace(".tiff", "")
|
|
@@ -4,8 +4,10 @@ import os
|
|
|
4
4
|
from typing import Callable, Dict, List, Literal
|
|
5
5
|
|
|
6
6
|
import h5py
|
|
7
|
-
import
|
|
7
|
+
import torch
|
|
8
|
+
from torchvision import tv_tensors
|
|
8
9
|
from torchvision.datasets import utils
|
|
10
|
+
from torchvision.transforms.v2 import functional
|
|
9
11
|
from typing_extensions import override
|
|
10
12
|
|
|
11
13
|
from eva.vision.data.datasets import _validators, structs
|
|
@@ -70,8 +72,7 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
70
72
|
root: str,
|
|
71
73
|
split: Literal["train", "val", "test"],
|
|
72
74
|
download: bool = False,
|
|
73
|
-
|
|
74
|
-
target_transforms: Callable | None = None,
|
|
75
|
+
transforms: Callable | None = None,
|
|
75
76
|
) -> None:
|
|
76
77
|
"""Initializes the dataset.
|
|
77
78
|
|
|
@@ -82,15 +83,10 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
82
83
|
download: Whether to download the data for the specified split.
|
|
83
84
|
Note that the download will be executed only by additionally
|
|
84
85
|
calling the :meth:`prepare_data` method.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
target_transforms: A function/transform that takes in the target
|
|
88
|
-
and transforms it.
|
|
86
|
+
transforms: A function/transform which returns a transformed
|
|
87
|
+
version of the raw data samples.
|
|
89
88
|
"""
|
|
90
|
-
super().__init__(
|
|
91
|
-
image_transforms=image_transforms,
|
|
92
|
-
target_transforms=target_transforms,
|
|
93
|
-
)
|
|
89
|
+
super().__init__(transforms=transforms)
|
|
94
90
|
|
|
95
91
|
self._root = root
|
|
96
92
|
self._split = split
|
|
@@ -131,13 +127,13 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
131
127
|
)
|
|
132
128
|
|
|
133
129
|
@override
|
|
134
|
-
def load_image(self, index: int) ->
|
|
130
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
135
131
|
return self._load_from_h5("x", index)
|
|
136
132
|
|
|
137
133
|
@override
|
|
138
|
-
def load_target(self, index: int) ->
|
|
134
|
+
def load_target(self, index: int) -> torch.Tensor:
|
|
139
135
|
target = self._load_from_h5("y", index).squeeze()
|
|
140
|
-
return
|
|
136
|
+
return torch.tensor(target, dtype=torch.float32)
|
|
141
137
|
|
|
142
138
|
@override
|
|
143
139
|
def __len__(self) -> int:
|
|
@@ -162,7 +158,7 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
162
158
|
self,
|
|
163
159
|
data_key: Literal["x", "y"],
|
|
164
160
|
index: int | None = None,
|
|
165
|
-
) ->
|
|
161
|
+
) -> tv_tensors.Image:
|
|
166
162
|
"""Load data or targets from an HDF5 file.
|
|
167
163
|
|
|
168
164
|
Args:
|
|
@@ -176,7 +172,8 @@ class PatchCamelyon(base.ImageClassification):
|
|
|
176
172
|
h5_file = self._h5_file(data_key)
|
|
177
173
|
with h5py.File(h5_file, "r") as file:
|
|
178
174
|
data = file[data_key]
|
|
179
|
-
|
|
175
|
+
image_array = data[:] if index is None else data[index] # type: ignore
|
|
176
|
+
return functional.to_image(image_array) # type: ignore
|
|
180
177
|
|
|
181
178
|
def _fetch_dataset_length(self) -> int:
|
|
182
179
|
"""Fetches the dataset split length from its HDF5 file."""
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""WSI classification dataset."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Callable, Dict, Literal, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
from torchvision import tv_tensors
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.vision.data.datasets import wsi
|
|
13
|
+
from eva.vision.data.datasets.classification import base
|
|
14
|
+
from eva.vision.data.wsi.patching import samplers
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WsiClassificationDataset(wsi.MultiWsiDataset, base.ImageClassification):
|
|
18
|
+
"""A general dataset class for whole-slide image classification using manifest files."""
|
|
19
|
+
|
|
20
|
+
default_column_mapping: Dict[str, str] = {
|
|
21
|
+
"path": "path",
|
|
22
|
+
"target": "target",
|
|
23
|
+
"split": "split",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
root: str,
|
|
29
|
+
manifest_file: str,
|
|
30
|
+
width: int,
|
|
31
|
+
height: int,
|
|
32
|
+
target_mpp: float,
|
|
33
|
+
sampler: samplers.Sampler,
|
|
34
|
+
backend: str = "openslide",
|
|
35
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
36
|
+
image_transforms: Callable | None = None,
|
|
37
|
+
column_mapping: Dict[str, str] = default_column_mapping,
|
|
38
|
+
):
|
|
39
|
+
"""Initializes the dataset.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
root: Root directory of the dataset.
|
|
43
|
+
manifest_file: The path to the manifest file, relative to
|
|
44
|
+
the `root` argument. The `path` column is expected to contain
|
|
45
|
+
relative paths to the whole-slide images.
|
|
46
|
+
width: Width of the patches to be extracted, in pixels.
|
|
47
|
+
height: Height of the patches to be extracted, in pixels.
|
|
48
|
+
target_mpp: Target microns per pixel (mpp) for the patches.
|
|
49
|
+
sampler: The sampler to use for sampling patch coordinates.
|
|
50
|
+
backend: The backend to use for reading the whole-slide images.
|
|
51
|
+
split: The split of the dataset to load.
|
|
52
|
+
image_transforms: Transforms to apply to the extracted image patches.
|
|
53
|
+
column_mapping: Mapping of the columns in the manifest file.
|
|
54
|
+
"""
|
|
55
|
+
self._split = split
|
|
56
|
+
self._column_mapping = self.default_column_mapping | column_mapping
|
|
57
|
+
self._manifest = self._load_manifest(os.path.join(root, manifest_file))
|
|
58
|
+
|
|
59
|
+
wsi.MultiWsiDataset.__init__(
|
|
60
|
+
self,
|
|
61
|
+
root=root,
|
|
62
|
+
file_paths=self._manifest[self._column_mapping["path"]].tolist(),
|
|
63
|
+
width=width,
|
|
64
|
+
height=height,
|
|
65
|
+
sampler=sampler,
|
|
66
|
+
target_mpp=target_mpp,
|
|
67
|
+
backend=backend,
|
|
68
|
+
image_transforms=image_transforms,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def filename(self, index: int) -> str:
|
|
73
|
+
path = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["path"]]
|
|
74
|
+
return os.path.basename(path) if os.path.isabs(path) else path
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, torch.Tensor, Dict[str, Any]]:
|
|
78
|
+
return base.ImageClassification.__getitem__(self, index)
|
|
79
|
+
|
|
80
|
+
@override
|
|
81
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
82
|
+
return wsi.MultiWsiDataset.__getitem__(self, index)
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def load_target(self, index: int) -> np.ndarray:
|
|
86
|
+
target = self._manifest.at[self._get_dataset_idx(index), self._column_mapping["target"]]
|
|
87
|
+
return np.asarray(target)
|
|
88
|
+
|
|
89
|
+
@override
|
|
90
|
+
def load_metadata(self, index: int) -> Dict[str, Any]:
|
|
91
|
+
return {"wsi_id": self.filename(index).split(".")[0]}
|
|
92
|
+
|
|
93
|
+
def _load_manifest(self, manifest_path: str) -> pd.DataFrame:
|
|
94
|
+
df = pd.read_csv(manifest_path)
|
|
95
|
+
|
|
96
|
+
missing_columns = set(self._column_mapping.values()) - set(df.columns)
|
|
97
|
+
if self._split is None:
|
|
98
|
+
missing_columns = missing_columns - {self._column_mapping["split"]}
|
|
99
|
+
if missing_columns:
|
|
100
|
+
raise ValueError(f"Missing columns in the manifest file: {missing_columns}")
|
|
101
|
+
|
|
102
|
+
if self._split is not None:
|
|
103
|
+
df = df.loc[df[self._column_mapping["split"]] == self._split]
|
|
104
|
+
|
|
105
|
+
return df.reset_index(drop=True)
|
|
@@ -1,6 +1,19 @@
|
|
|
1
1
|
"""Segmentation datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.vision.data.datasets.segmentation.base import ImageSegmentation
|
|
4
|
-
from eva.vision.data.datasets.segmentation.
|
|
4
|
+
from eva.vision.data.datasets.segmentation.bcss import BCSS
|
|
5
|
+
from eva.vision.data.datasets.segmentation.consep import CoNSeP
|
|
6
|
+
from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset
|
|
7
|
+
from eva.vision.data.datasets.segmentation.lits import LiTS
|
|
8
|
+
from eva.vision.data.datasets.segmentation.monusac import MoNuSAC
|
|
9
|
+
from eva.vision.data.datasets.segmentation.total_segmentator_2d import TotalSegmentator2D
|
|
5
10
|
|
|
6
|
-
__all__ = [
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ImageSegmentation",
|
|
13
|
+
"BCSS",
|
|
14
|
+
"CoNSeP",
|
|
15
|
+
"EmbeddingsSegmentationDataset",
|
|
16
|
+
"LiTS",
|
|
17
|
+
"MoNuSAC",
|
|
18
|
+
"TotalSegmentator2D",
|
|
19
|
+
]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy.typing as npt
|
|
4
|
+
|
|
5
|
+
from eva.vision.data.datasets import wsi
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_coords_at_index(
|
|
9
|
+
dataset: wsi.MultiWsiDataset, index: int
|
|
10
|
+
) -> Tuple[Tuple[int, int], int, int]:
|
|
11
|
+
"""Returns the coordinates ((x,y),width,height) of the patch at the given index.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
dataset: The WSI dataset instance.
|
|
15
|
+
index: The sample index.
|
|
16
|
+
"""
|
|
17
|
+
image_index = dataset._get_dataset_idx(index)
|
|
18
|
+
patch_index = index if image_index == 0 else index - dataset.cumulative_sizes[image_index - 1]
|
|
19
|
+
wsi_dataset = dataset.datasets[image_index]
|
|
20
|
+
if isinstance(wsi_dataset, wsi.WsiDataset):
|
|
21
|
+
coords = wsi_dataset._coords
|
|
22
|
+
return coords.x_y[patch_index], coords.width, coords.height
|
|
23
|
+
else:
|
|
24
|
+
raise Exception(f"Expected WsiDataset, got {type(wsi_dataset)}")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def extract_mask_patch(
|
|
28
|
+
mask: npt.NDArray[Any], dataset: wsi.MultiWsiDataset, index: int
|
|
29
|
+
) -> npt.NDArray[Any]:
|
|
30
|
+
"""Reads the mask patch at the coordinates corresponding to the dataset index.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
mask: The mask array.
|
|
34
|
+
dataset: The WSI dataset instance.
|
|
35
|
+
index: The sample index.
|
|
36
|
+
"""
|
|
37
|
+
(x, y), width, height = get_coords_at_index(dataset, index)
|
|
38
|
+
return mask[y : y + height, x : x + width]
|
|
@@ -3,38 +3,25 @@
|
|
|
3
3
|
import abc
|
|
4
4
|
from typing import Any, Callable, Dict, List, Tuple
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
from torchvision import tv_tensors
|
|
7
7
|
from typing_extensions import override
|
|
8
8
|
|
|
9
9
|
from eva.vision.data.datasets import vision
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class ImageSegmentation(vision.VisionDataset[Tuple[
|
|
12
|
+
class ImageSegmentation(vision.VisionDataset[Tuple[tv_tensors.Image, tv_tensors.Mask]], abc.ABC):
|
|
13
13
|
"""Image segmentation abstract dataset."""
|
|
14
14
|
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
image_transforms: Callable | None = None,
|
|
18
|
-
target_transforms: Callable | None = None,
|
|
19
|
-
image_target_transforms: Callable | None = None,
|
|
20
|
-
) -> None:
|
|
15
|
+
def __init__(self, transforms: Callable | None = None) -> None:
|
|
21
16
|
"""Initializes the image segmentation base class.
|
|
22
17
|
|
|
23
18
|
Args:
|
|
24
|
-
|
|
25
|
-
and returns a transformed version.
|
|
26
|
-
target_transforms: A function/transform that takes in the target
|
|
27
|
-
and transforms it.
|
|
28
|
-
image_target_transforms: A function/transforms that takes in an
|
|
19
|
+
transforms: A function/transforms that takes in an
|
|
29
20
|
image and a label and returns the transformed versions of both.
|
|
30
|
-
This transform happens after the `image_transforms` and
|
|
31
|
-
`target_transforms`.
|
|
32
21
|
"""
|
|
33
22
|
super().__init__()
|
|
34
23
|
|
|
35
|
-
self.
|
|
36
|
-
self._target_transforms = target_transforms
|
|
37
|
-
self._image_target_transforms = image_target_transforms
|
|
24
|
+
self._transforms = transforms
|
|
38
25
|
|
|
39
26
|
@property
|
|
40
27
|
def classes(self) -> List[str] | None:
|
|
@@ -44,37 +31,38 @@ class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc
|
|
|
44
31
|
def class_to_idx(self) -> Dict[str, int] | None:
|
|
45
32
|
"""Returns a mapping of the class name to its target index."""
|
|
46
33
|
|
|
47
|
-
|
|
48
|
-
|
|
34
|
+
@abc.abstractmethod
|
|
35
|
+
def load_image(self, index: int) -> tv_tensors.Image:
|
|
36
|
+
"""Loads and returns the `index`'th image sample.
|
|
49
37
|
|
|
50
38
|
Args:
|
|
51
|
-
index: The index of the data sample to
|
|
52
|
-
If `None`, it will return the metadata of the current dataset.
|
|
39
|
+
index: The index of the data sample to load.
|
|
53
40
|
|
|
54
41
|
Returns:
|
|
55
|
-
|
|
42
|
+
An image torchvision tensor (channels, height, width).
|
|
56
43
|
"""
|
|
57
44
|
|
|
58
45
|
@abc.abstractmethod
|
|
59
|
-
def
|
|
60
|
-
"""
|
|
46
|
+
def load_mask(self, index: int) -> tv_tensors.Mask:
|
|
47
|
+
"""Returns the `index`'th target masks sample.
|
|
61
48
|
|
|
62
49
|
Args:
|
|
63
|
-
index: The index of the data sample to load.
|
|
50
|
+
index: The index of the data sample target masks to load.
|
|
64
51
|
|
|
65
52
|
Returns:
|
|
66
|
-
The
|
|
53
|
+
The semantic mask as a (H x W) shaped tensor with integer
|
|
54
|
+
values which represent the pixel class id.
|
|
67
55
|
"""
|
|
68
56
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
"""Returns the `index`'th target mask sample.
|
|
57
|
+
def load_metadata(self, index: int) -> Dict[str, Any] | None:
|
|
58
|
+
"""Returns the dataset metadata.
|
|
72
59
|
|
|
73
60
|
Args:
|
|
74
|
-
index: The index of the data sample
|
|
61
|
+
index: The index of the data sample to return the metadata of.
|
|
62
|
+
If `None`, it will return the metadata of the current dataset.
|
|
75
63
|
|
|
76
64
|
Returns:
|
|
77
|
-
The sample
|
|
65
|
+
The sample metadata.
|
|
78
66
|
"""
|
|
79
67
|
|
|
80
68
|
@abc.abstractmethod
|
|
@@ -83,30 +71,26 @@ class ImageSegmentation(vision.VisionDataset[Tuple[np.ndarray, np.ndarray]], abc
|
|
|
83
71
|
raise NotImplementedError
|
|
84
72
|
|
|
85
73
|
@override
|
|
86
|
-
def __getitem__(self, index: int) -> Tuple[
|
|
74
|
+
def __getitem__(self, index: int) -> Tuple[tv_tensors.Image, tv_tensors.Mask, Dict[str, Any]]:
|
|
87
75
|
image = self.load_image(index)
|
|
88
76
|
mask = self.load_mask(index)
|
|
89
|
-
|
|
77
|
+
metadata = self.load_metadata(index) or {}
|
|
78
|
+
image_tensor, mask_tensor = self._apply_transforms(image, mask)
|
|
79
|
+
return image_tensor, mask_tensor, metadata
|
|
90
80
|
|
|
91
81
|
def _apply_transforms(
|
|
92
|
-
self, image:
|
|
93
|
-
) -> Tuple[
|
|
82
|
+
self, image: tv_tensors.Image, mask: tv_tensors.Mask
|
|
83
|
+
) -> Tuple[tv_tensors.Image, tv_tensors.Mask]:
|
|
94
84
|
"""Applies the transforms to the provided data and returns them.
|
|
95
85
|
|
|
96
86
|
Args:
|
|
97
87
|
image: The desired image.
|
|
98
|
-
|
|
88
|
+
mask: The target segmentation mask.
|
|
99
89
|
|
|
100
90
|
Returns:
|
|
101
|
-
A tuple with the image and the
|
|
91
|
+
A tuple with the image and the masks transformed.
|
|
102
92
|
"""
|
|
103
|
-
if self.
|
|
104
|
-
image = self.
|
|
105
|
-
|
|
106
|
-
if self._target_transforms is not None:
|
|
107
|
-
target = self._target_transforms(target)
|
|
108
|
-
|
|
109
|
-
if self._image_target_transforms is not None:
|
|
110
|
-
image, target = self._image_target_transforms(image, target)
|
|
93
|
+
if self._transforms is not None:
|
|
94
|
+
image, mask = self._transforms(image, mask)
|
|
111
95
|
|
|
112
|
-
return image,
|
|
96
|
+
return image, mask
|