kaiko-eva 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +3 -2
- eva/core/callbacks/config.py +143 -0
- eva/core/callbacks/writers/__init__.py +6 -3
- eva/core/callbacks/writers/embeddings/__init__.py +6 -0
- eva/core/callbacks/writers/embeddings/_manifest.py +71 -0
- eva/core/callbacks/writers/embeddings/base.py +192 -0
- eva/core/callbacks/writers/embeddings/classification.py +117 -0
- eva/core/callbacks/writers/embeddings/segmentation.py +78 -0
- eva/core/callbacks/writers/embeddings/typings.py +38 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/classification/__init__.py +5 -2
- eva/core/data/datasets/classification/embeddings.py +15 -135
- eva/core/data/datasets/classification/multi_embeddings.py +110 -0
- eva/core/data/datasets/embeddings.py +167 -0
- eva/core/data/splitting/__init__.py +6 -0
- eva/core/data/splitting/random.py +41 -0
- eva/core/data/splitting/stratified.py +56 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +6 -0
- eva/core/loggers/log/image.py +71 -0
- eva/core/loggers/log/parameters.py +74 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/loggers/loggers.py +6 -0
- eva/core/metrics/__init__.py +6 -2
- eva/core/metrics/defaults/__init__.py +10 -3
- eva/core/metrics/defaults/classification/__init__.py +1 -1
- eva/core/metrics/defaults/classification/binary.py +0 -9
- eva/core/metrics/defaults/classification/multiclass.py +0 -8
- eva/core/metrics/defaults/segmentation/__init__.py +5 -0
- eva/core/metrics/defaults/segmentation/multiclass.py +43 -0
- eva/core/metrics/generalized_dice.py +59 -0
- eva/core/metrics/mean_iou.py +120 -0
- eva/core/metrics/structs/schemas.py +3 -1
- eva/core/models/__init__.py +3 -1
- eva/core/models/modules/head.py +16 -15
- eva/core/models/modules/module.py +25 -1
- eva/core/models/modules/typings.py +14 -1
- eva/core/models/modules/utils/batch_postprocess.py +37 -5
- eva/core/models/networks/__init__.py +1 -2
- eva/core/models/networks/mlp.py +2 -2
- eva/core/models/transforms/__init__.py +6 -0
- eva/core/models/{networks/transforms → transforms}/extract_cls_features.py +10 -2
- eva/core/models/transforms/extract_patch_features.py +47 -0
- eva/core/models/wrappers/__init__.py +13 -0
- eva/core/models/{networks/wrappers → wrappers}/base.py +3 -2
- eva/core/models/{networks/wrappers → wrappers}/from_function.py +5 -12
- eva/core/models/{networks/wrappers → wrappers}/huggingface.py +15 -11
- eva/core/models/{networks/wrappers → wrappers}/onnx.py +6 -3
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +23 -5
- eva/core/trainers/trainer.py +20 -6
- eva/core/utils/__init__.py +6 -0
- eva/core/utils/clone.py +27 -0
- eva/core/utils/memory.py +28 -0
- eva/core/utils/operations.py +26 -0
- eva/core/utils/parser.py +20 -0
- eva/vision/__init__.py +2 -2
- eva/vision/callbacks/__init__.py +5 -0
- eva/vision/callbacks/loggers/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/__init__.py +5 -0
- eva/vision/callbacks/loggers/batch/base.py +130 -0
- eva/vision/callbacks/loggers/batch/segmentation.py +188 -0
- eva/vision/data/datasets/__init__.py +24 -4
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/_validators.py +15 -2
- eva/vision/data/datasets/classification/__init__.py +6 -2
- eva/vision/data/datasets/classification/bach.py +10 -15
- eva/vision/data/datasets/classification/base.py +17 -24
- eva/vision/data/datasets/classification/camelyon16.py +244 -0
- eva/vision/data/datasets/classification/crc.py +10 -15
- eva/vision/data/datasets/classification/mhist.py +10 -15
- eva/vision/data/datasets/classification/panda.py +184 -0
- eva/vision/data/datasets/classification/patch_camelyon.py +13 -16
- eva/vision/data/datasets/classification/wsi.py +105 -0
- eva/vision/data/datasets/segmentation/__init__.py +15 -2
- eva/vision/data/datasets/segmentation/_utils.py +38 -0
- eva/vision/data/datasets/segmentation/base.py +31 -47
- eva/vision/data/datasets/segmentation/bcss.py +236 -0
- eva/vision/data/datasets/segmentation/consep.py +156 -0
- eva/vision/data/datasets/segmentation/embeddings.py +34 -0
- eva/vision/data/datasets/segmentation/lits.py +178 -0
- eva/vision/data/datasets/segmentation/monusac.py +236 -0
- eva/vision/data/datasets/segmentation/total_segmentator_2d.py +325 -0
- eva/vision/data/datasets/wsi.py +187 -0
- eva/vision/data/transforms/__init__.py +3 -2
- eva/vision/data/transforms/common/__init__.py +2 -1
- eva/vision/data/transforms/common/resize_and_clamp.py +51 -0
- eva/vision/data/transforms/common/resize_and_crop.py +6 -7
- eva/vision/data/transforms/normalization/__init__.py +6 -0
- eva/vision/data/transforms/normalization/clamp.py +43 -0
- eva/vision/data/transforms/normalization/functional/__init__.py +5 -0
- eva/vision/data/transforms/normalization/functional/rescale_intensity.py +28 -0
- eva/vision/data/transforms/normalization/rescale_intensity.py +53 -0
- eva/vision/data/wsi/__init__.py +16 -0
- eva/vision/data/wsi/backends/__init__.py +69 -0
- eva/vision/data/wsi/backends/base.py +115 -0
- eva/vision/data/wsi/backends/openslide.py +73 -0
- eva/vision/data/wsi/backends/pil.py +52 -0
- eva/vision/data/wsi/backends/tiffslide.py +42 -0
- eva/vision/data/wsi/patching/__init__.py +6 -0
- eva/vision/data/wsi/patching/coordinates.py +98 -0
- eva/vision/data/wsi/patching/mask.py +123 -0
- eva/vision/data/wsi/patching/samplers/__init__.py +14 -0
- eva/vision/data/wsi/patching/samplers/_utils.py +50 -0
- eva/vision/data/wsi/patching/samplers/base.py +48 -0
- eva/vision/data/wsi/patching/samplers/foreground_grid.py +99 -0
- eva/vision/data/wsi/patching/samplers/grid.py +47 -0
- eva/vision/data/wsi/patching/samplers/random.py +41 -0
- eva/vision/losses/__init__.py +5 -0
- eva/vision/losses/dice.py +40 -0
- eva/vision/models/__init__.py +4 -2
- eva/vision/models/modules/__init__.py +5 -0
- eva/vision/models/modules/semantic_segmentation.py +161 -0
- eva/vision/models/networks/__init__.py +1 -2
- eva/vision/models/networks/backbones/__init__.py +6 -0
- eva/vision/models/networks/backbones/_utils.py +39 -0
- eva/vision/models/networks/backbones/pathology/__init__.py +31 -0
- eva/vision/models/networks/backbones/pathology/bioptimus.py +34 -0
- eva/vision/models/networks/backbones/pathology/gigapath.py +33 -0
- eva/vision/models/networks/backbones/pathology/histai.py +46 -0
- eva/vision/models/networks/backbones/pathology/kaiko.py +123 -0
- eva/vision/models/networks/backbones/pathology/lunit.py +68 -0
- eva/vision/models/networks/backbones/pathology/mahmood.py +62 -0
- eva/vision/models/networks/backbones/pathology/owkin.py +22 -0
- eva/vision/models/networks/backbones/registry.py +47 -0
- eva/vision/models/networks/backbones/timm/__init__.py +5 -0
- eva/vision/models/networks/backbones/timm/backbones.py +54 -0
- eva/vision/models/networks/backbones/universal/__init__.py +8 -0
- eva/vision/models/networks/backbones/universal/vit.py +54 -0
- eva/vision/models/networks/decoders/__init__.py +6 -0
- eva/vision/models/networks/decoders/decoder.py +7 -0
- eva/vision/models/networks/decoders/segmentation/__init__.py +11 -0
- eva/vision/models/networks/decoders/segmentation/common.py +74 -0
- eva/vision/models/networks/decoders/segmentation/conv2d.py +114 -0
- eva/vision/models/networks/decoders/segmentation/linear.py +125 -0
- eva/vision/models/wrappers/__init__.py +6 -0
- eva/vision/models/wrappers/from_registry.py +48 -0
- eva/vision/models/wrappers/from_timm.py +68 -0
- eva/vision/utils/colormap.py +77 -0
- eva/vision/utils/convert.py +67 -0
- eva/vision/utils/io/__init__.py +10 -4
- eva/vision/utils/io/image.py +21 -2
- eva/vision/utils/io/mat.py +36 -0
- eva/vision/utils/io/nifti.py +40 -15
- eva/vision/utils/io/text.py +10 -3
- kaiko_eva-0.1.0.dist-info/METADATA +553 -0
- kaiko_eva-0.1.0.dist-info/RECORD +205 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/WHEEL +1 -1
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/entry_points.txt +2 -0
- eva/core/callbacks/writers/embeddings.py +0 -169
- eva/core/callbacks/writers/typings.py +0 -23
- eva/core/models/networks/transforms/__init__.py +0 -5
- eva/core/models/networks/wrappers/__init__.py +0 -8
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- eva/vision/data/datasets/segmentation/total_segmentator.py +0 -212
- eva/vision/models/networks/postprocesses/__init__.py +0 -5
- eva/vision/models/networks/postprocesses/cls.py +0 -25
- kaiko_eva-0.0.1.dist-info/METADATA +0 -405
- kaiko_eva-0.0.1.dist-info/RECORD +0 -110
- /eva/core/models/{networks → wrappers}/_utils.py +0 -0
- {kaiko_eva-0.0.1.dist-info → kaiko_eva-0.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,8 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Embedding cllassification datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.core.data.datasets.classification.embeddings import EmbeddingsClassificationDataset
|
|
4
|
+
from eva.core.data.datasets.classification.multi_embeddings import (
|
|
5
|
+
MultiEmbeddingsClassificationDataset,
|
|
6
|
+
)
|
|
4
7
|
|
|
5
|
-
__all__ = ["EmbeddingsClassificationDataset"]
|
|
8
|
+
__all__ = ["EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset"]
|
|
@@ -1,154 +1,34 @@
|
|
|
1
1
|
"""Embeddings classification dataset."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import Callable, Dict, Tuple
|
|
5
4
|
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pandas as pd
|
|
8
5
|
import torch
|
|
9
6
|
from typing_extensions import override
|
|
10
7
|
|
|
11
|
-
from eva.core.data.datasets import
|
|
12
|
-
from eva.core.utils import io
|
|
8
|
+
from eva.core.data.datasets import embeddings as embeddings_base
|
|
13
9
|
|
|
14
10
|
|
|
15
|
-
class EmbeddingsClassificationDataset(
|
|
16
|
-
"""Embeddings classification
|
|
17
|
-
|
|
18
|
-
default_column_mapping: Dict[str, str] = {
|
|
19
|
-
"data": "embeddings",
|
|
20
|
-
"target": "target",
|
|
21
|
-
"split": "split",
|
|
22
|
-
}
|
|
23
|
-
"""The default column mapping of the variables to the manifest columns."""
|
|
24
|
-
|
|
25
|
-
def __init__(
|
|
26
|
-
self,
|
|
27
|
-
root: str,
|
|
28
|
-
manifest_file: str,
|
|
29
|
-
split: str | None = None,
|
|
30
|
-
column_mapping: Dict[str, str] = default_column_mapping,
|
|
31
|
-
embeddings_transforms: Callable | None = None,
|
|
32
|
-
target_transforms: Callable | None = None,
|
|
33
|
-
) -> None:
|
|
34
|
-
"""Initialize dataset.
|
|
35
|
-
|
|
36
|
-
Expects a manifest file listing the paths of .pt files that contain
|
|
37
|
-
tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
root: Root directory of the dataset.
|
|
41
|
-
manifest_file: The path to the manifest file, which is relative to
|
|
42
|
-
the `root` argument.
|
|
43
|
-
split: The dataset split to use. The `split` column of the manifest
|
|
44
|
-
file will be splitted based on this value.
|
|
45
|
-
column_mapping: Defines the map between the variables and the manifest
|
|
46
|
-
columns. It will overwrite the `default_column_mapping` with
|
|
47
|
-
the provided values, so that `column_mapping` can contain only the
|
|
48
|
-
values which are altered or missing.
|
|
49
|
-
embeddings_transforms: A function/transform that transforms the embedding.
|
|
50
|
-
target_transforms: A function/transform that transforms the target.
|
|
51
|
-
"""
|
|
52
|
-
super().__init__()
|
|
53
|
-
|
|
54
|
-
self._root = root
|
|
55
|
-
self._manifest_file = manifest_file
|
|
56
|
-
self._split = split
|
|
57
|
-
self._column_mapping = self.default_column_mapping | column_mapping
|
|
58
|
-
self._embeddings_transforms = embeddings_transforms
|
|
59
|
-
self._target_transforms = target_transforms
|
|
60
|
-
|
|
61
|
-
self._data: pd.DataFrame
|
|
62
|
-
|
|
63
|
-
def filename(self, index: int) -> str:
|
|
64
|
-
"""Returns the filename of the `index`'th data sample.
|
|
65
|
-
|
|
66
|
-
Note that this is the relative file path to the root.
|
|
67
|
-
|
|
68
|
-
Args:
|
|
69
|
-
index: The index of the data-sample to select.
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
The filename of the `index`'th data sample.
|
|
73
|
-
"""
|
|
74
|
-
return self._data.at[index, self._column_mapping["data"]]
|
|
11
|
+
class EmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
|
|
12
|
+
"""Embeddings dataset class for classification tasks."""
|
|
75
13
|
|
|
76
14
|
@override
|
|
77
|
-
def setup(self):
|
|
78
|
-
self._data = self._load_manifest()
|
|
79
|
-
|
|
80
|
-
def __getitem__(self, index) -> Tuple[torch.Tensor, np.ndarray]:
|
|
81
|
-
"""Returns the `index`'th data sample.
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
index: The index of the data-sample to select.
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
A data sample and its target.
|
|
88
|
-
"""
|
|
89
|
-
embeddings = self._load_embeddings(index)
|
|
90
|
-
target = self._load_target(index)
|
|
91
|
-
return self._apply_transforms(embeddings, target)
|
|
92
|
-
|
|
93
|
-
def __len__(self) -> int:
|
|
94
|
-
"""Returns the total length of the data."""
|
|
95
|
-
return len(self._data)
|
|
96
|
-
|
|
97
15
|
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
98
|
-
"""Returns the `index`'th embedding sample.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
index: The index of the data sample to load.
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
The sample embedding as an array.
|
|
105
|
-
"""
|
|
106
16
|
filename = self.filename(index)
|
|
107
17
|
embeddings_path = os.path.join(self._root, filename)
|
|
108
18
|
tensor = torch.load(embeddings_path, map_location="cpu")
|
|
19
|
+
if isinstance(tensor, list):
|
|
20
|
+
if len(tensor) > 1:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"Expected a single tensor in the .pt file, but found {len(tensor)}."
|
|
23
|
+
)
|
|
24
|
+
tensor = tensor[0]
|
|
109
25
|
return tensor.squeeze(0)
|
|
110
26
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
index: The index of the data sample to load.
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
The sample target as an array.
|
|
119
|
-
"""
|
|
27
|
+
@override
|
|
28
|
+
def _load_target(self, index: int) -> torch.Tensor:
|
|
120
29
|
target = self._data.at[index, self._column_mapping["target"]]
|
|
121
|
-
return
|
|
122
|
-
|
|
123
|
-
def _load_manifest(self) -> pd.DataFrame:
|
|
124
|
-
"""Loads manifest file and filters the data based on the split column.
|
|
125
|
-
|
|
126
|
-
Returns:
|
|
127
|
-
The data as a pandas DataFrame.
|
|
128
|
-
"""
|
|
129
|
-
manifest_path = os.path.join(self._root, self._manifest_file)
|
|
130
|
-
data = io.read_dataframe(manifest_path)
|
|
131
|
-
if self._split is not None:
|
|
132
|
-
filtered_data = data.loc[data[self._column_mapping["split"]] == self._split]
|
|
133
|
-
data = filtered_data.reset_index(drop=True)
|
|
134
|
-
return data
|
|
30
|
+
return torch.tensor(target, dtype=torch.int64)
|
|
135
31
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
"""Applies the transforms to the provided data and returns them.
|
|
140
|
-
|
|
141
|
-
Args:
|
|
142
|
-
embeddings: The embeddings to be transformed.
|
|
143
|
-
target: The training target.
|
|
144
|
-
|
|
145
|
-
Returns:
|
|
146
|
-
A tuple with the embeddings and the target transformed.
|
|
147
|
-
"""
|
|
148
|
-
if self._embeddings_transforms is not None:
|
|
149
|
-
embeddings = self._embeddings_transforms(embeddings)
|
|
150
|
-
|
|
151
|
-
if self._target_transforms is not None:
|
|
152
|
-
target = self._target_transforms(target)
|
|
153
|
-
|
|
154
|
-
return embeddings, target
|
|
32
|
+
@override
|
|
33
|
+
def __len__(self) -> int:
|
|
34
|
+
return len(self._data)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Dataset class for where a sample corresponds to multiple embeddings."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, List, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.core.data.datasets import embeddings as embeddings_base
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MultiEmbeddingsClassificationDataset(embeddings_base.EmbeddingsDataset[torch.Tensor]):
|
|
14
|
+
"""Dataset class for where a sample corresponds to multiple embeddings.
|
|
15
|
+
|
|
16
|
+
Example use case: Slide level dataset where each slide has multiple patch embeddings.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
root: str,
|
|
22
|
+
manifest_file: str,
|
|
23
|
+
split: Literal["train", "val", "test"],
|
|
24
|
+
column_mapping: Dict[str, str] = embeddings_base.default_column_mapping,
|
|
25
|
+
embeddings_transforms: Callable | None = None,
|
|
26
|
+
target_transforms: Callable | None = None,
|
|
27
|
+
):
|
|
28
|
+
"""Initialize dataset.
|
|
29
|
+
|
|
30
|
+
Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.
|
|
31
|
+
|
|
32
|
+
The manifest must have a `column_mapping["multi_id"]` column that contains the
|
|
33
|
+
unique identifier group of embeddings. For oncology datasets, this would be usually
|
|
34
|
+
the slide id. Each row in the manifest file points to a .pt file that can contain
|
|
35
|
+
one or multiple embeddings (either as a list or stacked tensors). There can also be
|
|
36
|
+
multiple rows for the same `multi_id`, in which case the embeddings from the different
|
|
37
|
+
.pt files corresponding to that same `multi_id` will be stacked along the first dimension.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
root: Root directory of the dataset.
|
|
41
|
+
manifest_file: The path to the manifest file, which is relative to
|
|
42
|
+
the `root` argument.
|
|
43
|
+
split: The dataset split to use. The `split` column of the manifest
|
|
44
|
+
file will be splitted based on this value.
|
|
45
|
+
column_mapping: Defines the map between the variables and the manifest
|
|
46
|
+
columns. It will overwrite the `default_column_mapping` with
|
|
47
|
+
the provided values, so that `column_mapping` can contain only the
|
|
48
|
+
values which are altered or missing.
|
|
49
|
+
embeddings_transforms: A function/transform that transforms the embedding.
|
|
50
|
+
target_transforms: A function/transform that transforms the target.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__(
|
|
53
|
+
manifest_file=manifest_file,
|
|
54
|
+
root=root,
|
|
55
|
+
split=split,
|
|
56
|
+
column_mapping=column_mapping,
|
|
57
|
+
embeddings_transforms=embeddings_transforms,
|
|
58
|
+
target_transforms=target_transforms,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self._multi_ids: List[int]
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def setup(self):
|
|
65
|
+
super().setup()
|
|
66
|
+
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
70
|
+
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
|
|
71
|
+
# Get all embeddings for the given index (multi_id)
|
|
72
|
+
multi_id = self._multi_ids[index]
|
|
73
|
+
embedding_paths = self._data.loc[
|
|
74
|
+
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
|
|
75
|
+
].to_list()
|
|
76
|
+
|
|
77
|
+
# Load embeddings and stack them accross the first dimension
|
|
78
|
+
embeddings = []
|
|
79
|
+
for path in embedding_paths:
|
|
80
|
+
embedding = torch.load(os.path.join(self._root, path), map_location="cpu")
|
|
81
|
+
if isinstance(embedding, list):
|
|
82
|
+
embedding = torch.stack(embedding, dim=0)
|
|
83
|
+
embeddings.append(embedding.unsqueeze(0) if embedding.ndim == 1 else embedding)
|
|
84
|
+
embeddings = torch.cat(embeddings, dim=0)
|
|
85
|
+
|
|
86
|
+
if not embeddings.ndim == 2:
|
|
87
|
+
raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")
|
|
88
|
+
|
|
89
|
+
return embeddings
|
|
90
|
+
|
|
91
|
+
@override
|
|
92
|
+
def _load_target(self, index: int) -> np.ndarray:
|
|
93
|
+
"""Returns the target corresponding to the `index`'th multi_id.
|
|
94
|
+
|
|
95
|
+
This method assumes that all the embeddings corresponding to the same `multi_id`
|
|
96
|
+
have the same target. If this is not the case, it will raise an error.
|
|
97
|
+
"""
|
|
98
|
+
multi_id = self._multi_ids[index]
|
|
99
|
+
targets = self._data.loc[
|
|
100
|
+
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
if not targets.nunique() == 1:
|
|
104
|
+
raise ValueError(f"Multiple targets found for {multi_id}.")
|
|
105
|
+
|
|
106
|
+
return np.asarray(targets.iloc[0], dtype=np.int64)
|
|
107
|
+
|
|
108
|
+
@override
|
|
109
|
+
def __len__(self) -> int:
|
|
110
|
+
return len(self._multi_ids)
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""Base dataset class for Embeddings."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import multiprocessing
|
|
5
|
+
import os
|
|
6
|
+
from typing import Callable, Dict, Generic, Literal, Tuple, TypeVar
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import torch
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from eva.core.data.datasets import base
|
|
13
|
+
from eva.core.utils import io
|
|
14
|
+
|
|
15
|
+
TargetType = TypeVar("TargetType")
|
|
16
|
+
"""The target data type."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
default_column_mapping: Dict[str, str] = {
|
|
20
|
+
"path": "embeddings",
|
|
21
|
+
"target": "target",
|
|
22
|
+
"split": "split",
|
|
23
|
+
"multi_id": "wsi_id",
|
|
24
|
+
}
|
|
25
|
+
"""The default column mapping of the variables to the manifest columns."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class EmbeddingsDataset(base.Dataset, Generic[TargetType]):
|
|
29
|
+
"""Abstract base class for embedding datasets."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
root: str,
|
|
34
|
+
manifest_file: str,
|
|
35
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
36
|
+
column_mapping: Dict[str, str] = default_column_mapping,
|
|
37
|
+
embeddings_transforms: Callable | None = None,
|
|
38
|
+
target_transforms: Callable | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize dataset.
|
|
41
|
+
|
|
42
|
+
Expects a manifest file listing the paths of .pt files that contain
|
|
43
|
+
tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
root: Root directory of the dataset.
|
|
47
|
+
manifest_file: The path to the manifest file, which is relative to
|
|
48
|
+
the `root` argument.
|
|
49
|
+
split: The dataset split to use. The `split` column of the manifest
|
|
50
|
+
file will be splitted based on this value.
|
|
51
|
+
column_mapping: Defines the map between the variables and the manifest
|
|
52
|
+
columns. It will overwrite the `default_column_mapping` with
|
|
53
|
+
the provided values, so that `column_mapping` can contain only the
|
|
54
|
+
values which are altered or missing.
|
|
55
|
+
embeddings_transforms: A function/transform that transforms the embedding.
|
|
56
|
+
target_transforms: A function/transform that transforms the target.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__()
|
|
59
|
+
|
|
60
|
+
self._root = root
|
|
61
|
+
self._manifest_file = manifest_file
|
|
62
|
+
self._split = split
|
|
63
|
+
self._column_mapping = default_column_mapping | column_mapping
|
|
64
|
+
self._embeddings_transforms = embeddings_transforms
|
|
65
|
+
self._target_transforms = target_transforms
|
|
66
|
+
|
|
67
|
+
self._data: pd.DataFrame
|
|
68
|
+
|
|
69
|
+
self._set_multiprocessing_start_method()
|
|
70
|
+
|
|
71
|
+
def filename(self, index: int) -> str:
|
|
72
|
+
"""Returns the filename of the `index`'th data sample.
|
|
73
|
+
|
|
74
|
+
Note that this is the relative file path to the root.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
index: The index of the data-sample to select.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The filename of the `index`'th data sample.
|
|
81
|
+
"""
|
|
82
|
+
return self._data.at[index, self._column_mapping["path"]]
|
|
83
|
+
|
|
84
|
+
@override
|
|
85
|
+
def setup(self):
|
|
86
|
+
self._data = self._load_manifest()
|
|
87
|
+
|
|
88
|
+
@abc.abstractmethod
|
|
89
|
+
def __len__(self) -> int:
|
|
90
|
+
"""Returns the total length of the data."""
|
|
91
|
+
|
|
92
|
+
def __getitem__(self, index) -> Tuple[torch.Tensor, TargetType]:
|
|
93
|
+
"""Returns the `index`'th data sample.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
index: The index of the data-sample to select.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
A data sample and its target.
|
|
100
|
+
"""
|
|
101
|
+
embeddings = self._load_embeddings(index)
|
|
102
|
+
target = self._load_target(index)
|
|
103
|
+
return self._apply_transforms(embeddings, target)
|
|
104
|
+
|
|
105
|
+
@abc.abstractmethod
|
|
106
|
+
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
107
|
+
"""Returns the `index`'th embedding sample.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
index: The index of the data sample to load.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
The embedding sample as a tensor.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
@abc.abstractmethod
|
|
117
|
+
def _load_target(self, index: int) -> TargetType:
|
|
118
|
+
"""Returns the `index`'th target sample.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
index: The index of the data sample to load.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The sample target as an array.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def _load_manifest(self) -> pd.DataFrame:
|
|
128
|
+
"""Loads manifest file and filters the data based on the split column.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
The data as a pandas DataFrame.
|
|
132
|
+
"""
|
|
133
|
+
manifest_path = os.path.join(self._root, self._manifest_file)
|
|
134
|
+
data = io.read_dataframe(manifest_path)
|
|
135
|
+
if self._split is not None:
|
|
136
|
+
filtered_data = data.loc[data[self._column_mapping["split"]] == self._split]
|
|
137
|
+
data = filtered_data.reset_index(drop=True)
|
|
138
|
+
return data
|
|
139
|
+
|
|
140
|
+
def _apply_transforms(
|
|
141
|
+
self, embeddings: torch.Tensor, target: TargetType
|
|
142
|
+
) -> Tuple[torch.Tensor, TargetType]:
|
|
143
|
+
"""Applies the transforms to the provided data and returns them.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
embeddings: The embeddings to be transformed.
|
|
147
|
+
target: The training target.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
A tuple with the embeddings and the target transformed.
|
|
151
|
+
"""
|
|
152
|
+
if self._embeddings_transforms is not None:
|
|
153
|
+
embeddings = self._embeddings_transforms(embeddings)
|
|
154
|
+
|
|
155
|
+
if self._target_transforms is not None:
|
|
156
|
+
target = self._target_transforms(target)
|
|
157
|
+
|
|
158
|
+
return embeddings, target
|
|
159
|
+
|
|
160
|
+
def _set_multiprocessing_start_method(self):
|
|
161
|
+
"""Sets the multiprocessing start method to spawn.
|
|
162
|
+
|
|
163
|
+
If the start method is not set explicitly, the torch data loaders will
|
|
164
|
+
use the OS default method, which for some unix systems is `fork` and
|
|
165
|
+
can lead to runtime issues such as deadlocks in this context.
|
|
166
|
+
"""
|
|
167
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Functions for random splitting."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def random_split(
|
|
9
|
+
samples: Sequence[Any],
|
|
10
|
+
train_ratio: float,
|
|
11
|
+
val_ratio: float,
|
|
12
|
+
test_ratio: float = 0.0,
|
|
13
|
+
seed: int = 42,
|
|
14
|
+
) -> Tuple[List[int], List[int], List[int] | None]:
|
|
15
|
+
"""Splits the samples into random train, validation, and test (optional) sets.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
samples: The samples to split.
|
|
19
|
+
train_ratio: The ratio of the training set.
|
|
20
|
+
val_ratio: The ratio of the validation set.
|
|
21
|
+
test_ratio: The ratio of the test set (optional).
|
|
22
|
+
seed: The seed for reproducibility.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The indices of the train, validation, and test sets as lists.
|
|
26
|
+
"""
|
|
27
|
+
if train_ratio + val_ratio + (test_ratio or 0) != 1:
|
|
28
|
+
raise ValueError("The sum of the ratios must be equal to 1.")
|
|
29
|
+
|
|
30
|
+
np.random.seed(seed)
|
|
31
|
+
n_samples = len(samples)
|
|
32
|
+
indices = np.random.permutation(n_samples)
|
|
33
|
+
|
|
34
|
+
n_train = int(np.floor(train_ratio * n_samples))
|
|
35
|
+
n_val = n_samples - n_train if test_ratio == 0.0 else int(np.floor(val_ratio * n_samples)) or 1
|
|
36
|
+
|
|
37
|
+
train_indices = list(indices[:n_train])
|
|
38
|
+
val_indices = list(indices[n_train : n_train + n_val])
|
|
39
|
+
test_indices = list(indices[n_train + n_val :]) if test_ratio > 0.0 else None
|
|
40
|
+
|
|
41
|
+
return train_indices, val_indices, test_indices
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Functions for stratified splitting."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def stratified_split(
|
|
9
|
+
samples: Sequence[Any],
|
|
10
|
+
targets: Sequence[Any],
|
|
11
|
+
train_ratio: float,
|
|
12
|
+
val_ratio: float,
|
|
13
|
+
test_ratio: float = 0.0,
|
|
14
|
+
seed: int = 42,
|
|
15
|
+
) -> Tuple[List[int], List[int], List[int] | None]:
|
|
16
|
+
"""Splits the samples into stratified train, validation, and test (optional) sets.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
samples: The samples to split.
|
|
20
|
+
targets: The corresponding targets used for stratification.
|
|
21
|
+
train_ratio: The ratio of the training set.
|
|
22
|
+
val_ratio: The ratio of the validation set.
|
|
23
|
+
test_ratio: The ratio of the test set (optional).
|
|
24
|
+
seed: The seed for reproducibility.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The indices of the train, validation, and test sets.
|
|
28
|
+
"""
|
|
29
|
+
if len(samples) != len(targets):
|
|
30
|
+
raise ValueError("The number of samples and targets must be equal.")
|
|
31
|
+
if train_ratio + val_ratio + (test_ratio or 0) != 1:
|
|
32
|
+
raise ValueError("The sum of the ratios must be equal to 1.")
|
|
33
|
+
|
|
34
|
+
np.random.seed(seed)
|
|
35
|
+
unique_classes, y_indices = np.unique(targets, return_inverse=True)
|
|
36
|
+
n_classes = unique_classes.shape[0]
|
|
37
|
+
|
|
38
|
+
train_indices, val_indices, test_indices = [], [], []
|
|
39
|
+
|
|
40
|
+
for c in range(n_classes):
|
|
41
|
+
class_indices = np.where(y_indices == c)[0]
|
|
42
|
+
np.random.shuffle(class_indices)
|
|
43
|
+
|
|
44
|
+
n_train = int(np.floor(train_ratio * len(class_indices))) or 1
|
|
45
|
+
n_val = (
|
|
46
|
+
len(class_indices) - n_train
|
|
47
|
+
if test_ratio == 0.0
|
|
48
|
+
else int(np.floor(val_ratio * len(class_indices))) or 1
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
train_indices.extend(class_indices[:n_train])
|
|
52
|
+
val_indices.extend(class_indices[n_train : n_train + n_val])
|
|
53
|
+
if test_ratio > 0.0:
|
|
54
|
+
test_indices.extend(class_indices[n_train + n_val :])
|
|
55
|
+
|
|
56
|
+
return train_indices, val_indices, test_indices or None
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Core data transforms."""
|
|
2
2
|
|
|
3
3
|
from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor
|
|
4
|
+
from eva.core.data.transforms.padding import Pad2DTensor
|
|
5
|
+
from eva.core.data.transforms.sampling import SampleFromAxis
|
|
4
6
|
|
|
5
|
-
__all__ = ["ArrayToFloatTensor", "ArrayToTensor"]
|
|
7
|
+
__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis"]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Padding transformation for 2D tensors."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Pad2DTensor:
|
|
8
|
+
"""Pads a 2D tensor to a fixed dimension accross the first dimension."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, pad_size: int, pad_value: int | float = float("-inf")):
|
|
11
|
+
"""Initialize the transformation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
pad_size: The size to pad the tensor to. If the tensor is larger than this size,
|
|
15
|
+
no padding will be applied.
|
|
16
|
+
pad_value: The value to use for padding.
|
|
17
|
+
"""
|
|
18
|
+
self._pad_size = pad_size
|
|
19
|
+
self._pad_value = pad_value
|
|
20
|
+
|
|
21
|
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
"""Call method for the transformation.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
tensor: The input tensor of shape [n, embedding_dim].
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A tensor of shape [max(n, pad_dim), embedding_dim].
|
|
29
|
+
"""
|
|
30
|
+
n_pad_values = self._pad_size - tensor.size(0)
|
|
31
|
+
if n_pad_values > 0:
|
|
32
|
+
tensor = torch.nn.functional.pad(
|
|
33
|
+
tensor,
|
|
34
|
+
pad=(0, 0, 0, n_pad_values),
|
|
35
|
+
mode="constant",
|
|
36
|
+
value=self._pad_value,
|
|
37
|
+
)
|
|
38
|
+
return tensor
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Sampling transformations."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SampleFromAxis:
|
|
7
|
+
"""Samples n_samples entries from a tensor along a given axis."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, n_samples: int, seed: int = 42, axis: int = 0):
|
|
10
|
+
"""Initialize the transformation.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
n_samples: The number of samples to draw.
|
|
14
|
+
seed: The seed to use for sampling.
|
|
15
|
+
axis: The axis along which to sample.
|
|
16
|
+
"""
|
|
17
|
+
self._seed = seed
|
|
18
|
+
self._n_samples = n_samples
|
|
19
|
+
self._axis = axis
|
|
20
|
+
self._generator = self._get_generator()
|
|
21
|
+
|
|
22
|
+
def _get_generator(self):
|
|
23
|
+
"""Return a torch random generator with fixed seed."""
|
|
24
|
+
generator = torch.Generator()
|
|
25
|
+
generator.manual_seed(self._seed)
|
|
26
|
+
return generator
|
|
27
|
+
|
|
28
|
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
"""Call method for the transformation.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
tensor: The input tensor of shape [n, embedding_dim].
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A tensor of shape [n_samples, embedding_dim].
|
|
36
|
+
"""
|
|
37
|
+
indices = torch.randperm(tensor.size(self._axis), generator=self._generator)[
|
|
38
|
+
: self._n_samples
|
|
39
|
+
]
|
|
40
|
+
return tensor.index_select(self._axis, indices)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Experimental loggers API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.loggers.dummy import DummyLogger
|
|
4
|
+
from eva.core.loggers.experimental_loggers import ExperimentalLoggers
|
|
5
|
+
from eva.core.loggers.log import log_parameters
|
|
6
|
+
|
|
7
|
+
__all__ = ["DummyLogger", "ExperimentalLoggers", "log_parameters"]
|