kaiko-eva 0.0.0.dev8__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kaiko-eva might be problematic. Click here for more details.
- eva/core/callbacks/__init__.py +2 -1
- eva/core/callbacks/config.py +143 -0
- eva/core/data/datasets/__init__.py +10 -2
- eva/core/data/datasets/embeddings/__init__.py +13 -0
- eva/core/data/datasets/{classification/embeddings.py → embeddings/base.py} +41 -43
- eva/core/data/datasets/embeddings/classification/__init__.py +10 -0
- eva/core/data/datasets/embeddings/classification/embeddings.py +66 -0
- eva/core/data/datasets/embeddings/classification/multi_embeddings.py +106 -0
- eva/core/data/transforms/__init__.py +3 -1
- eva/core/data/transforms/padding/__init__.py +5 -0
- eva/core/data/transforms/padding/pad_2d_tensor.py +38 -0
- eva/core/data/transforms/sampling/__init__.py +5 -0
- eva/core/data/transforms/sampling/sample_from_axis.py +40 -0
- eva/core/loggers/__init__.py +7 -0
- eva/core/loggers/dummy.py +38 -0
- eva/core/loggers/experimental_loggers.py +8 -0
- eva/core/loggers/log/__init__.py +5 -0
- eva/core/loggers/log/parameters.py +64 -0
- eva/core/loggers/log/utils.py +13 -0
- eva/core/models/modules/head.py +6 -11
- eva/core/models/modules/module.py +25 -1
- eva/core/trainers/_recorder.py +69 -7
- eva/core/trainers/functional.py +22 -5
- eva/core/trainers/trainer.py +20 -6
- eva/vision/__init__.py +1 -1
- eva/vision/data/datasets/__init__.py +1 -8
- eva/vision/data/datasets/_utils.py +3 -3
- eva/vision/data/datasets/classification/__init__.py +1 -8
- eva/vision/data/datasets/segmentation/base.py +20 -35
- eva/vision/data/datasets/segmentation/total_segmentator.py +88 -69
- eva/vision/models/.DS_Store +0 -0
- eva/vision/models/networks/.DS_Store +0 -0
- eva/vision/utils/convert.py +24 -0
- eva/vision/utils/io/nifti.py +10 -6
- {kaiko_eva-0.0.0.dev8.dist-info → kaiko_eva-0.0.2.dist-info}/METADATA +71 -27
- {kaiko_eva-0.0.0.dev8.dist-info → kaiko_eva-0.0.2.dist-info}/RECORD +39 -23
- {kaiko_eva-0.0.0.dev8.dist-info → kaiko_eva-0.0.2.dist-info}/WHEEL +1 -1
- eva/core/data/datasets/classification/__init__.py +0 -5
- eva/vision/data/datasets/classification/total_segmentator.py +0 -213
- {kaiko_eva-0.0.0.dev8.dist-info → kaiko_eva-0.0.2.dist-info}/entry_points.txt +0 -0
- {kaiko_eva-0.0.0.dev8.dist-info → kaiko_eva-0.0.2.dist-info}/licenses/LICENSE +0 -0
eva/core/callbacks/__init__.py
CHANGED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Configuration logger callback."""
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from types import BuiltinFunctionType
|
|
7
|
+
from typing import Any, Dict, List
|
|
8
|
+
|
|
9
|
+
import lightning.pytorch as pl
|
|
10
|
+
import yaml
|
|
11
|
+
from lightning_fabric.utilities import cloud_io
|
|
12
|
+
from loguru import logger as cli_logger
|
|
13
|
+
from omegaconf import OmegaConf
|
|
14
|
+
from typing_extensions import TypeGuard, override
|
|
15
|
+
|
|
16
|
+
from eva.core import loggers
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConfigurationLogger(pl.Callback):
|
|
20
|
+
"""Logs the submitted configuration to the experimental logger."""
|
|
21
|
+
|
|
22
|
+
_save_as: str = "config.yaml"
|
|
23
|
+
|
|
24
|
+
def __init__(self, verbose: bool = True) -> None:
|
|
25
|
+
"""Initializes the callback.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
verbose: Whether to print the configurations to print the
|
|
29
|
+
configuration to the terminal.
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
self._verbose = verbose
|
|
34
|
+
|
|
35
|
+
@override
|
|
36
|
+
def setup(
|
|
37
|
+
self,
|
|
38
|
+
trainer: pl.Trainer,
|
|
39
|
+
pl_module: pl.LightningModule,
|
|
40
|
+
stage: str | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
log_dir = trainer.log_dir
|
|
43
|
+
if not _logdir_exists(log_dir):
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
configuration = _load_submitted_config()
|
|
47
|
+
|
|
48
|
+
if self._verbose:
|
|
49
|
+
config_as_text = yaml.dump(configuration, sort_keys=False)
|
|
50
|
+
print(f"Configuration:\033[94m\n---\n{config_as_text}\033[0m")
|
|
51
|
+
|
|
52
|
+
save_as = os.path.join(log_dir, self._save_as)
|
|
53
|
+
fs = cloud_io.get_filesystem(log_dir)
|
|
54
|
+
with fs.open(save_as, "w") as output_file:
|
|
55
|
+
yaml.dump(configuration, output_file, sort_keys=False)
|
|
56
|
+
|
|
57
|
+
loggers.log_parameters(trainer.loggers, tag="configuration", parameters=configuration)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _logdir_exists(logdir: str | None, verbose: bool = True) -> TypeGuard[str]:
|
|
61
|
+
"""Checks if the trainer has a log directory.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
logdir: Trainer's logdir.
|
|
65
|
+
name: The name to log with.
|
|
66
|
+
verbose: Whether to log if it does not exist.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A bool indicating if the log directory exists or not.
|
|
70
|
+
"""
|
|
71
|
+
exists = isinstance(logdir, str)
|
|
72
|
+
if not exists and verbose:
|
|
73
|
+
print("\n")
|
|
74
|
+
cli_logger.warning("Log directory is `None`. Configuration file will not be logged.\n")
|
|
75
|
+
return exists
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _load_submitted_config() -> Dict[str, Any]:
|
|
79
|
+
"""Retrieves and loads the submitted configuration.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
The path to the configuration file.
|
|
83
|
+
"""
|
|
84
|
+
config_paths = _fetch_submitted_config_path()
|
|
85
|
+
return _load_yaml_files(config_paths)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _fetch_submitted_config_path() -> List[str]:
|
|
89
|
+
"""Fetches the config path from command line arguments.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
The path to the configuration file.
|
|
93
|
+
"""
|
|
94
|
+
return list(filter(lambda f: f.endswith(".yaml"), sys.argv))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _load_yaml_files(paths: List[str]) -> Dict[str, Any]:
|
|
98
|
+
"""Loads yaml files and merge them from multiple paths.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
paths: The paths to the yaml files.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
The merged configurations as a dictionary.
|
|
105
|
+
"""
|
|
106
|
+
merged_config = {}
|
|
107
|
+
for config_path in paths:
|
|
108
|
+
fs = cloud_io.get_filesystem(config_path)
|
|
109
|
+
with fs.open(config_path, "r") as file:
|
|
110
|
+
omegaconf_file = OmegaConf.load(file) # type: ignore
|
|
111
|
+
config_dict = OmegaConf.to_object(omegaconf_file) # type: ignore
|
|
112
|
+
parsed_config = _type_resolver(config_dict) # type: ignore
|
|
113
|
+
merged_config.update(parsed_config)
|
|
114
|
+
return merged_config
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _type_resolver(mapping: Dict[str, Any]) -> Dict[str, Any]:
|
|
118
|
+
"""Parses the string values of a dictionary in-place.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
mapping: A dictionary object.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
The mapping with the formatted values.
|
|
125
|
+
"""
|
|
126
|
+
for key, value in mapping.items():
|
|
127
|
+
if isinstance(value, dict):
|
|
128
|
+
formatted_value = _type_resolver(value)
|
|
129
|
+
elif isinstance(value, list) and isinstance(value[0], dict):
|
|
130
|
+
formatted_value = [_type_resolver(subvalue) for subvalue in value]
|
|
131
|
+
else:
|
|
132
|
+
try:
|
|
133
|
+
parsed_value = ast.literal_eval(value) # type: ignore
|
|
134
|
+
formatted_value = (
|
|
135
|
+
value if isinstance(parsed_value, BuiltinFunctionType) else parsed_value
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
except Exception:
|
|
139
|
+
formatted_value = value
|
|
140
|
+
|
|
141
|
+
mapping[key] = formatted_value
|
|
142
|
+
|
|
143
|
+
return mapping
|
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
"""Datasets API."""
|
|
2
2
|
|
|
3
3
|
from eva.core.data.datasets.base import Dataset
|
|
4
|
-
from eva.core.data.datasets.classification import EmbeddingsClassificationDataset
|
|
5
4
|
from eva.core.data.datasets.dataset import TorchDataset
|
|
5
|
+
from eva.core.data.datasets.embeddings import (
|
|
6
|
+
EmbeddingsClassificationDataset,
|
|
7
|
+
MultiEmbeddingsClassificationDataset,
|
|
8
|
+
)
|
|
6
9
|
|
|
7
|
-
__all__ = [
|
|
10
|
+
__all__ = [
|
|
11
|
+
"Dataset",
|
|
12
|
+
"EmbeddingsClassificationDataset",
|
|
13
|
+
"MultiEmbeddingsClassificationDataset",
|
|
14
|
+
"TorchDataset",
|
|
15
|
+
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Datasets API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.data.datasets.embeddings.base import EmbeddingsDataset
|
|
4
|
+
from eva.core.data.datasets.embeddings.classification import (
|
|
5
|
+
EmbeddingsClassificationDataset,
|
|
6
|
+
MultiEmbeddingsClassificationDataset,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"EmbeddingsDataset",
|
|
11
|
+
"EmbeddingsClassificationDataset",
|
|
12
|
+
"MultiEmbeddingsClassificationDataset",
|
|
13
|
+
]
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Base dataset class for Embeddings."""
|
|
2
2
|
|
|
3
|
+
import abc
|
|
3
4
|
import os
|
|
4
|
-
from typing import Callable, Dict, Tuple
|
|
5
|
+
from typing import Callable, Dict, Literal, Tuple
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pandas as pd
|
|
@@ -11,22 +12,23 @@ from typing_extensions import override
|
|
|
11
12
|
from eva.core.data.datasets import base
|
|
12
13
|
from eva.core.utils import io
|
|
13
14
|
|
|
15
|
+
default_column_mapping: Dict[str, str] = {
|
|
16
|
+
"path": "embeddings",
|
|
17
|
+
"target": "target",
|
|
18
|
+
"split": "split",
|
|
19
|
+
"multi_id": "slide_id",
|
|
20
|
+
}
|
|
21
|
+
"""The default column mapping of the variables to the manifest columns."""
|
|
14
22
|
|
|
15
|
-
class EmbeddingsClassificationDataset(base.Dataset):
|
|
16
|
-
"""Embeddings classification dataset."""
|
|
17
23
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
"target": "target",
|
|
21
|
-
"split": "split",
|
|
22
|
-
}
|
|
23
|
-
"""The default column mapping of the variables to the manifest columns."""
|
|
24
|
+
class EmbeddingsDataset(base.Dataset):
|
|
25
|
+
"""Abstract base class for embedding datasets."""
|
|
24
26
|
|
|
25
27
|
def __init__(
|
|
26
28
|
self,
|
|
27
29
|
root: str,
|
|
28
30
|
manifest_file: str,
|
|
29
|
-
split:
|
|
31
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
30
32
|
column_mapping: Dict[str, str] = default_column_mapping,
|
|
31
33
|
embeddings_transforms: Callable | None = None,
|
|
32
34
|
target_transforms: Callable | None = None,
|
|
@@ -54,12 +56,38 @@ class EmbeddingsClassificationDataset(base.Dataset):
|
|
|
54
56
|
self._root = root
|
|
55
57
|
self._manifest_file = manifest_file
|
|
56
58
|
self._split = split
|
|
57
|
-
self._column_mapping =
|
|
59
|
+
self._column_mapping = default_column_mapping | column_mapping
|
|
58
60
|
self._embeddings_transforms = embeddings_transforms
|
|
59
61
|
self._target_transforms = target_transforms
|
|
60
62
|
|
|
61
63
|
self._data: pd.DataFrame
|
|
62
64
|
|
|
65
|
+
@abc.abstractmethod
|
|
66
|
+
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
67
|
+
"""Returns the `index`'th embedding sample.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
index: The index of the data sample to load.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
The embedding sample as a tensor.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
def _load_target(self, index: int) -> np.ndarray:
|
|
78
|
+
"""Returns the `index`'th target sample.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
index: The index of the data sample to load.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
The sample target as an array.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
@abc.abstractmethod
|
|
88
|
+
def __len__(self) -> int:
|
|
89
|
+
"""Returns the total length of the data."""
|
|
90
|
+
|
|
63
91
|
def filename(self, index: int) -> str:
|
|
64
92
|
"""Returns the filename of the `index`'th data sample.
|
|
65
93
|
|
|
@@ -71,7 +99,7 @@ class EmbeddingsClassificationDataset(base.Dataset):
|
|
|
71
99
|
Returns:
|
|
72
100
|
The filename of the `index`'th data sample.
|
|
73
101
|
"""
|
|
74
|
-
return self._data.at[index, self._column_mapping["
|
|
102
|
+
return self._data.at[index, self._column_mapping["path"]]
|
|
75
103
|
|
|
76
104
|
@override
|
|
77
105
|
def setup(self):
|
|
@@ -90,36 +118,6 @@ class EmbeddingsClassificationDataset(base.Dataset):
|
|
|
90
118
|
target = self._load_target(index)
|
|
91
119
|
return self._apply_transforms(embeddings, target)
|
|
92
120
|
|
|
93
|
-
def __len__(self) -> int:
|
|
94
|
-
"""Returns the total length of the data."""
|
|
95
|
-
return len(self._data)
|
|
96
|
-
|
|
97
|
-
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
98
|
-
"""Returns the `index`'th embedding sample.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
index: The index of the data sample to load.
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
The sample embedding as an array.
|
|
105
|
-
"""
|
|
106
|
-
filename = self.filename(index)
|
|
107
|
-
embeddings_path = os.path.join(self._root, filename)
|
|
108
|
-
tensor = torch.load(embeddings_path, map_location="cpu")
|
|
109
|
-
return tensor.squeeze(0)
|
|
110
|
-
|
|
111
|
-
def _load_target(self, index: int) -> np.ndarray:
|
|
112
|
-
"""Returns the `index`'th target sample.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
index: The index of the data sample to load.
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
The sample target as an array.
|
|
119
|
-
"""
|
|
120
|
-
target = self._data.at[index, self._column_mapping["target"]]
|
|
121
|
-
return np.asarray(target, dtype=np.int64)
|
|
122
|
-
|
|
123
121
|
def _load_manifest(self) -> pd.DataFrame:
|
|
124
122
|
"""Loads manifest file and filters the data based on the split column.
|
|
125
123
|
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Embedding cllassification datasets API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.data.datasets.embeddings.classification.embeddings import (
|
|
4
|
+
EmbeddingsClassificationDataset,
|
|
5
|
+
)
|
|
6
|
+
from eva.core.data.datasets.embeddings.classification.multi_embeddings import (
|
|
7
|
+
MultiEmbeddingsClassificationDataset,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = ["EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset"]
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Embeddings classification dataset."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.core.data.datasets.embeddings import base
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbeddingsClassificationDataset(base.EmbeddingsDataset):
|
|
14
|
+
"""Embeddings dataset class for classification tasks."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
root: str,
|
|
19
|
+
manifest_file: str,
|
|
20
|
+
split: Literal["train", "val", "test"] | None = None,
|
|
21
|
+
column_mapping: Dict[str, str] = base.default_column_mapping,
|
|
22
|
+
embeddings_transforms: Callable | None = None,
|
|
23
|
+
target_transforms: Callable | None = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Initialize dataset.
|
|
26
|
+
|
|
27
|
+
Expects a manifest file listing the paths of .pt files that contain
|
|
28
|
+
tensor embeddings of shape [embedding_dim] or [1, embedding_dim].
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
root: Root directory of the dataset.
|
|
32
|
+
manifest_file: The path to the manifest file, which is relative to
|
|
33
|
+
the `root` argument.
|
|
34
|
+
split: The dataset split to use. The `split` column of the manifest
|
|
35
|
+
file will be splitted based on this value.
|
|
36
|
+
column_mapping: Defines the map between the variables and the manifest
|
|
37
|
+
columns. It will overwrite the `default_column_mapping` with
|
|
38
|
+
the provided values, so that `column_mapping` can contain only the
|
|
39
|
+
values which are altered or missing.
|
|
40
|
+
embeddings_transforms: A function/transform that transforms the embedding.
|
|
41
|
+
target_transforms: A function/transform that transforms the target.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(
|
|
44
|
+
root=root,
|
|
45
|
+
manifest_file=manifest_file,
|
|
46
|
+
split=split,
|
|
47
|
+
column_mapping=column_mapping,
|
|
48
|
+
embeddings_transforms=embeddings_transforms,
|
|
49
|
+
target_transforms=target_transforms,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
54
|
+
filename = self.filename(index)
|
|
55
|
+
embeddings_path = os.path.join(self._root, filename)
|
|
56
|
+
tensor = torch.load(embeddings_path, map_location="cpu")
|
|
57
|
+
return tensor.squeeze(0)
|
|
58
|
+
|
|
59
|
+
@override
|
|
60
|
+
def _load_target(self, index: int) -> np.ndarray:
|
|
61
|
+
target = self._data.at[index, self._column_mapping["target"]]
|
|
62
|
+
return np.asarray(target, dtype=np.int64)
|
|
63
|
+
|
|
64
|
+
@override
|
|
65
|
+
def __len__(self) -> int:
|
|
66
|
+
return len(self._data)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Dataset class for where a sample corresponds to multiple embeddings."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Callable, Dict, List, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from eva.core.data.datasets.embeddings import base
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MultiEmbeddingsClassificationDataset(base.EmbeddingsDataset):
|
|
14
|
+
"""Dataset class for where a sample corresponds to multiple embeddings.
|
|
15
|
+
|
|
16
|
+
Example use case: Slide level dataset where each slide has multiple patch embeddings.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
root: str,
|
|
22
|
+
manifest_file: str,
|
|
23
|
+
split: Literal["train", "val", "test"],
|
|
24
|
+
column_mapping: Dict[str, str] = base.default_column_mapping,
|
|
25
|
+
embeddings_transforms: Callable | None = None,
|
|
26
|
+
target_transforms: Callable | None = None,
|
|
27
|
+
):
|
|
28
|
+
"""Initialize dataset.
|
|
29
|
+
|
|
30
|
+
Expects a manifest file listing the paths of `.pt` files containing tensor embeddings.
|
|
31
|
+
|
|
32
|
+
The manifest must have a `column_mapping["multi_id"]` column that contains the
|
|
33
|
+
unique identifier group of embeddings. For oncology datasets, this would be usually
|
|
34
|
+
the slide id. Each row in the manifest file points to a .pt file that can contain
|
|
35
|
+
one or multiple embeddings. There can also be multiple rows for the same `multi_id`,
|
|
36
|
+
in which case the embeddings from the different .pt files corresponding to that same
|
|
37
|
+
`multi_id` will be stacked along the first dimension.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
root: Root directory of the dataset.
|
|
41
|
+
manifest_file: The path to the manifest file, which is relative to
|
|
42
|
+
the `root` argument.
|
|
43
|
+
split: The dataset split to use. The `split` column of the manifest
|
|
44
|
+
file will be splitted based on this value.
|
|
45
|
+
column_mapping: Defines the map between the variables and the manifest
|
|
46
|
+
columns. It will overwrite the `default_column_mapping` with
|
|
47
|
+
the provided values, so that `column_mapping` can contain only the
|
|
48
|
+
values which are altered or missing.
|
|
49
|
+
embeddings_transforms: A function/transform that transforms the embedding.
|
|
50
|
+
target_transforms: A function/transform that transforms the target.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__(
|
|
53
|
+
manifest_file=manifest_file,
|
|
54
|
+
root=root,
|
|
55
|
+
split=split,
|
|
56
|
+
column_mapping=column_mapping,
|
|
57
|
+
embeddings_transforms=embeddings_transforms,
|
|
58
|
+
target_transforms=target_transforms,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self._multi_ids: List[int]
|
|
62
|
+
|
|
63
|
+
@override
|
|
64
|
+
def setup(self):
|
|
65
|
+
super().setup()
|
|
66
|
+
self._multi_ids = list(self._data[self._column_mapping["multi_id"]].unique())
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
def _load_embeddings(self, index: int) -> torch.Tensor:
|
|
70
|
+
"""Loads and stacks all embedding corresponding to the `index`'th multi_id."""
|
|
71
|
+
# Get all embeddings for the given index (multi_id)
|
|
72
|
+
multi_id = self._multi_ids[index]
|
|
73
|
+
embedding_paths = self._data.loc[
|
|
74
|
+
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["path"]
|
|
75
|
+
].to_list()
|
|
76
|
+
embedding_paths = [os.path.join(self._root, path) for path in embedding_paths]
|
|
77
|
+
|
|
78
|
+
# Load embeddings and stack them accross the first dimension
|
|
79
|
+
embeddings = [torch.load(path, map_location="cpu") for path in embedding_paths]
|
|
80
|
+
embeddings = torch.cat(embeddings, dim=0)
|
|
81
|
+
|
|
82
|
+
if not embeddings.ndim == 2:
|
|
83
|
+
raise ValueError(f"Expected 2D tensor, got {embeddings.ndim} for {multi_id}.")
|
|
84
|
+
|
|
85
|
+
return embeddings
|
|
86
|
+
|
|
87
|
+
@override
|
|
88
|
+
def _load_target(self, index: int) -> np.ndarray:
|
|
89
|
+
"""Returns the target corresponding to the `index`'th multi_id.
|
|
90
|
+
|
|
91
|
+
This method assumes that all the embeddings corresponding to the same `multi_id`
|
|
92
|
+
have the same target. If this is not the case, it will raise an error.
|
|
93
|
+
"""
|
|
94
|
+
multi_id = self._multi_ids[index]
|
|
95
|
+
targets = self._data.loc[
|
|
96
|
+
self._data[self._column_mapping["multi_id"]] == multi_id, self._column_mapping["target"]
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
if not targets.nunique() == 1:
|
|
100
|
+
raise ValueError(f"Multiple targets found for {multi_id}.")
|
|
101
|
+
|
|
102
|
+
return np.asarray(targets.iloc[0], dtype=np.int64)
|
|
103
|
+
|
|
104
|
+
@override
|
|
105
|
+
def __len__(self) -> int:
|
|
106
|
+
return len(self._data)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Core data transforms."""
|
|
2
2
|
|
|
3
3
|
from eva.core.data.transforms.dtype import ArrayToFloatTensor, ArrayToTensor
|
|
4
|
+
from eva.core.data.transforms.padding import Pad2DTensor
|
|
5
|
+
from eva.core.data.transforms.sampling import SampleFromAxis
|
|
4
6
|
|
|
5
|
-
__all__ = ["ArrayToFloatTensor", "ArrayToTensor"]
|
|
7
|
+
__all__ = ["ArrayToFloatTensor", "ArrayToTensor", "Pad2DTensor", "SampleFromAxis"]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Padding transformation for 2D tensors."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Pad2DTensor:
|
|
8
|
+
"""Pads a 2D tensor to a fixed dimension accross the first dimension."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, pad_size: int, pad_value: int | float = float("-inf")):
|
|
11
|
+
"""Initialize the transformation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
pad_size: The size to pad the tensor to. If the tensor is larger than this size,
|
|
15
|
+
no padding will be applied.
|
|
16
|
+
pad_value: The value to use for padding.
|
|
17
|
+
"""
|
|
18
|
+
self._pad_size = pad_size
|
|
19
|
+
self._pad_value = pad_value
|
|
20
|
+
|
|
21
|
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
"""Call method for the transformation.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
tensor: The input tensor of shape [n, embedding_dim].
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A tensor of shape [max(n, pad_dim), embedding_dim].
|
|
29
|
+
"""
|
|
30
|
+
n_pad_values = self._pad_size - tensor.size(0)
|
|
31
|
+
if n_pad_values > 0:
|
|
32
|
+
tensor = torch.nn.functional.pad(
|
|
33
|
+
tensor,
|
|
34
|
+
pad=(0, 0, 0, n_pad_values),
|
|
35
|
+
mode="constant",
|
|
36
|
+
value=self._pad_value,
|
|
37
|
+
)
|
|
38
|
+
return tensor
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Sampling transformations."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SampleFromAxis:
|
|
7
|
+
"""Samples n_samples entries from a tensor along a given axis."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, n_samples: int, seed: int = 42, axis: int = 0):
|
|
10
|
+
"""Initialize the transformation.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
n_samples: The number of samples to draw.
|
|
14
|
+
seed: The seed to use for sampling.
|
|
15
|
+
axis: The axis along which to sample.
|
|
16
|
+
"""
|
|
17
|
+
self._seed = seed
|
|
18
|
+
self._n_samples = n_samples
|
|
19
|
+
self._axis = axis
|
|
20
|
+
self._generator = self._get_generator()
|
|
21
|
+
|
|
22
|
+
def _get_generator(self):
|
|
23
|
+
"""Return a torch random generator with fixed seed."""
|
|
24
|
+
generator = torch.Generator()
|
|
25
|
+
generator.manual_seed(self._seed)
|
|
26
|
+
return generator
|
|
27
|
+
|
|
28
|
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
"""Call method for the transformation.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
tensor: The input tensor of shape [n, embedding_dim].
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A tensor of shape [n_samples, embedding_dim].
|
|
36
|
+
"""
|
|
37
|
+
indices = torch.randperm(tensor.size(self._axis), generator=self._generator)[
|
|
38
|
+
: self._n_samples
|
|
39
|
+
]
|
|
40
|
+
return tensor.index_select(self._axis, indices)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Experimental loggers API."""
|
|
2
|
+
|
|
3
|
+
from eva.core.loggers.dummy import DummyLogger
|
|
4
|
+
from eva.core.loggers.experimental_loggers import ExperimentalLoggers
|
|
5
|
+
from eva.core.loggers.log import log_parameters
|
|
6
|
+
|
|
7
|
+
__all__ = ["DummyLogger", "ExperimentalLoggers", "log_parameters"]
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Dummy logger class."""
|
|
2
|
+
|
|
3
|
+
import lightning.pytorch.loggers.logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DummyLogger(lightning.pytorch.loggers.logger.DummyLogger):
|
|
7
|
+
"""Dummy logger class.
|
|
8
|
+
|
|
9
|
+
This logger is currently used as a placeholder when saving results
|
|
10
|
+
to remote storage, as common lightning loggers do not work
|
|
11
|
+
with azure blob storage:
|
|
12
|
+
|
|
13
|
+
<https://github.com/Lightning-AI/pytorch-lightning/issues/18861>
|
|
14
|
+
<https://github.com/Lightning-AI/pytorch-lightning/issues/19736>
|
|
15
|
+
|
|
16
|
+
Simply disabling the loggers when pointing to remote storage doesn't work
|
|
17
|
+
because callbacks such as LearningRateMonitor or ModelCheckpoint require a
|
|
18
|
+
logger to be present.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, save_dir: str) -> None:
|
|
22
|
+
"""Initializes the logger.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
save_dir: The save directory (this logger does not save anything,
|
|
26
|
+
but callbacks might use this path to save their outputs).
|
|
27
|
+
"""
|
|
28
|
+
super().__init__()
|
|
29
|
+
self._save_dir = save_dir
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def save_dir(self) -> str:
|
|
33
|
+
"""Returns the save directory."""
|
|
34
|
+
return self._save_dir
|
|
35
|
+
|
|
36
|
+
def __deepcopy__(self, memo=None):
|
|
37
|
+
"""Override of the deepcopy method."""
|
|
38
|
+
return self
|