evalsense 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalsense/__init__.py +0 -0
- evalsense/cli/__init__.py +0 -0
- evalsense/cli/__main__.py +4 -0
- evalsense/cli/datasets.py +18 -0
- evalsense/cli/main.py +25 -0
- evalsense/constants.py +33 -0
- evalsense/dataset_config/ACI-BENCH.yml +51 -0
- evalsense/datasets/__init__.py +23 -0
- evalsense/datasets/dataset_config.py +302 -0
- evalsense/datasets/dataset_manager.py +292 -0
- evalsense/datasets/managers/__init__.py +3 -0
- evalsense/datasets/managers/aci_bench.py +83 -0
- evalsense/evaluation/__init__.py +25 -0
- evalsense/evaluation/evaluator.py +107 -0
- evalsense/evaluation/evaluators/__init__.py +41 -0
- evalsense/evaluation/evaluators/bertscore.py +273 -0
- evalsense/evaluation/evaluators/bleu.py +159 -0
- evalsense/evaluation/evaluators/g_eval.py +272 -0
- evalsense/evaluation/evaluators/qags.py +910 -0
- evalsense/evaluation/evaluators/rouge.py +134 -0
- evalsense/evaluation/experiment.py +228 -0
- evalsense/generation/__init__.py +4 -0
- evalsense/generation/generation_steps.py +11 -0
- evalsense/generation/model_config.py +70 -0
- evalsense/logging.py +61 -0
- evalsense/py.typed +0 -0
- evalsense/tasks/__init__.py +7 -0
- evalsense/tasks/task_preprocessor.py +106 -0
- evalsense/utils/__init__.py +0 -0
- evalsense/utils/dict.py +14 -0
- evalsense/utils/files.py +249 -0
- evalsense/utils/huggingface.py +20 -0
- evalsense/utils/text.py +274 -0
- evalsense/workflow/__init__.py +9 -0
- evalsense/workflow/analysers/__init__.py +11 -0
- evalsense/workflow/analysers/metric_correlation_analyser.py +201 -0
- evalsense/workflow/analysers/tabular_analyser.py +93 -0
- evalsense/workflow/pipeline.py +529 -0
- evalsense/workflow/project.py +426 -0
- evalsense/workflow/result_analyser.py +31 -0
- evalsense-0.1.0.dist-info/METADATA +139 -0
- evalsense-0.1.0.dist-info/RECORD +45 -0
- evalsense-0.1.0.dist-info/WHEEL +4 -0
- evalsense-0.1.0.dist-info/entry_points.txt +2 -0
- evalsense-0.1.0.dist-info/licenses/LICENCE +21 -0
evalsense/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import typer
|
|
2
|
+
|
|
3
|
+
datasets_app = typer.Typer(
|
|
4
|
+
no_args_is_help=True,
|
|
5
|
+
help="Manage datasets for EvalSense.",
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@datasets_app.command(no_args_is_help=True)
|
|
10
|
+
def get(name: str):
|
|
11
|
+
"""
|
|
12
|
+
Download and prepare a dataset.
|
|
13
|
+
"""
|
|
14
|
+
print(f"Downloading and preparing dataset {name}.")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if __name__ == "__main__":
|
|
18
|
+
datasets_app()
|
evalsense/cli/main.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import typer
|
|
2
|
+
from typing_extensions import Annotated
|
|
3
|
+
|
|
4
|
+
from evalsense.cli.datasets import datasets_app
|
|
5
|
+
|
|
6
|
+
app = typer.Typer(
|
|
7
|
+
no_args_is_help=True,
|
|
8
|
+
help="EvalSense: A tool for evaluating LLM performance on healthcare tasks.",
|
|
9
|
+
)
|
|
10
|
+
app.add_typer(datasets_app, name="datasets")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@app.command(no_args_is_help=True)
|
|
14
|
+
def run(
|
|
15
|
+
model: Annotated[str, typer.Option("--model", "-m")],
|
|
16
|
+
dataset: Annotated[str, typer.Option("--dataset", "-d")],
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Run a model on a dataset.
|
|
20
|
+
"""
|
|
21
|
+
print(f"Running model {model} on dataset {dataset}.")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if __name__ == "__main__":
|
|
25
|
+
app()
|
evalsense/constants.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from platformdirs import user_cache_dir
|
|
5
|
+
|
|
6
|
+
# Application metadata
|
|
7
|
+
APP_NAME = "evalsense"
|
|
8
|
+
APP_AUTHOR = "NHS"
|
|
9
|
+
USER_AGENT = "EvalSense/0.1.0"
|
|
10
|
+
|
|
11
|
+
# Datasets
|
|
12
|
+
DEFAULT_VERSION_NAME = "default"
|
|
13
|
+
DEFAULT_HASH_TYPE = "sha256"
|
|
14
|
+
|
|
15
|
+
if "OPENAI_API_KEY" in os.environ:
|
|
16
|
+
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
|
|
17
|
+
else:
|
|
18
|
+
OPENAI_API_KEY = None
|
|
19
|
+
|
|
20
|
+
if "EVALSENSE_STORAGE_DIR" in os.environ:
|
|
21
|
+
STORAGE_PATH = Path(os.environ["EVALSENSE_STORAGE_DIR"])
|
|
22
|
+
else:
|
|
23
|
+
STORAGE_PATH = Path(user_cache_dir(APP_NAME, APP_AUTHOR))
|
|
24
|
+
DATA_PATH = STORAGE_PATH / "datasets"
|
|
25
|
+
MODELS_PATH = STORAGE_PATH / "models"
|
|
26
|
+
PROJECTS_PATH = STORAGE_PATH / "projects"
|
|
27
|
+
if "HF_HUB_CACHE" not in os.environ:
|
|
28
|
+
os.environ["HF_HUB_CACHE"] = str(STORAGE_PATH / "huggingface")
|
|
29
|
+
|
|
30
|
+
DATASET_CONFIG_PATHS = [Path(__file__).parent / "dataset_config"]
|
|
31
|
+
if "DATASET_CONFIG_PATH" in os.environ:
|
|
32
|
+
for directory in os.environ["DATASET_CONFIG_PATH"].split(os.pathsep):
|
|
33
|
+
DATASET_CONFIG_PATHS.append(Path(directory))
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
name: ACI-BENCH
|
|
2
|
+
description: "Dataset for benchmarking clinical note generation from doctor-patient dialogue."
|
|
3
|
+
config_version: "v1"
|
|
4
|
+
default_version: "5d3cd4d8a25b4ebb5b2b87c3923a7b2b7150e33d"
|
|
5
|
+
source:
|
|
6
|
+
online: true
|
|
7
|
+
url_template: "https://raw.githubusercontent.com/wyim/aci-bench/{version}/data/challenge_data/{filename}"
|
|
8
|
+
requires_auth: false
|
|
9
|
+
versions:
|
|
10
|
+
- name: 5d3cd4d8a25b4ebb5b2b87c3923a7b2b7150e33d
|
|
11
|
+
splits:
|
|
12
|
+
- name: train
|
|
13
|
+
files:
|
|
14
|
+
- name: train.csv
|
|
15
|
+
hash: "6c778d4ac5e6cc6f1964786f9286e8d765c210f22ed6b57f83aff8497409cea4"
|
|
16
|
+
hash_type: sha256
|
|
17
|
+
- name: train_metadata.csv
|
|
18
|
+
hash: "7da650e223f04ff6bf1666cb62d52ebd83ce71ecfc4b2311cbe64d9f3ab19d83"
|
|
19
|
+
hash_type: sha256
|
|
20
|
+
- name: valid
|
|
21
|
+
files:
|
|
22
|
+
- name: valid.csv
|
|
23
|
+
hash: "6629e89e3fb409d2b3eceab60dc7b32fe1d3fb8d4e07795039284965522aa4d0"
|
|
24
|
+
hash_type: sha256
|
|
25
|
+
- name: valid_metadata.csv
|
|
26
|
+
hash: "ae4c7eef6fc97e22f447c33e4546691715821e2a671ab53d80eb1ee5598e2914"
|
|
27
|
+
hash_type: sha256
|
|
28
|
+
- name: test1
|
|
29
|
+
files:
|
|
30
|
+
- name: clinicalnlp_taskB_test1.csv
|
|
31
|
+
hash: "5cc4008e68545f84913744a8e493a58bdf17ba7e1b7a0be46d6943d6bfca9471"
|
|
32
|
+
hash_type: sha256
|
|
33
|
+
- name: clinicalnlp_taskB_test1_metadata.csv
|
|
34
|
+
hash: "6960581701816c6dbe1aea8a53df6cff2f1ca92b24b036f6303242ca681cbafd"
|
|
35
|
+
hash_type: sha256
|
|
36
|
+
- name: test2
|
|
37
|
+
files:
|
|
38
|
+
- name: clinicalnlp_taskC_test2.csv
|
|
39
|
+
hash: "599e3330a14e25a0e056aee1365ffac7ebe50058f15821eae42a5513c2bb5a4f"
|
|
40
|
+
hash_type: sha256
|
|
41
|
+
- name: clinicalnlp_taskC_test2_metadata.csv
|
|
42
|
+
hash: "60e799ee5033767e9f5c2e9c3d84f64366628e2aae8637cd5d96ca29ba01b83c"
|
|
43
|
+
hash_type: sha256
|
|
44
|
+
- name: test3
|
|
45
|
+
files:
|
|
46
|
+
- name: clef_taskC_test3.csv
|
|
47
|
+
hash: "d3c18362a42124ea2bd1b2b4b66ba76a11bb123dfdb416471ae3b5924d1428ec"
|
|
48
|
+
hash_type: sha256
|
|
49
|
+
- name: clef_taskC_test3_metadata.csv
|
|
50
|
+
hash: "e1bec9323b2bed8e544ead77fae6251c4709bc9ffe2d0de346d843334760736b"
|
|
51
|
+
hash_type: sha256
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from evalsense.datasets.dataset_config import (
|
|
2
|
+
OnlineSource,
|
|
3
|
+
LocalSource,
|
|
4
|
+
FileMetadata,
|
|
5
|
+
SplitMetadata,
|
|
6
|
+
VersionMetadata,
|
|
7
|
+
DatasetMetadata,
|
|
8
|
+
DatasetConfig,
|
|
9
|
+
)
|
|
10
|
+
from evalsense.datasets.dataset_manager import DatasetManager, DatasetRecord
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"DatasetManager",
|
|
15
|
+
"DatasetRecord",
|
|
16
|
+
"DatasetConfig",
|
|
17
|
+
"OnlineSource",
|
|
18
|
+
"LocalSource",
|
|
19
|
+
"FileMetadata",
|
|
20
|
+
"SplitMetadata",
|
|
21
|
+
"VersionMetadata",
|
|
22
|
+
"DatasetMetadata",
|
|
23
|
+
]
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
from typing_extensions import override
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, field_validator
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from evalsense.constants import (
|
|
10
|
+
DEFAULT_HASH_TYPE,
|
|
11
|
+
DATASET_CONFIG_PATHS,
|
|
12
|
+
)
|
|
13
|
+
from evalsense.utils.dict import deep_update
|
|
14
|
+
from evalsense.utils.files import to_safe_filename
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# TODO: Handle folders
|
|
18
|
+
class OnlineSource(BaseModel):
|
|
19
|
+
"""The online source of the dataset file(s).
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
url_template (str): The URL template for the dataset file(s),
|
|
23
|
+
optionally taking a version and filename
|
|
24
|
+
requires_auth (bool, optional): Whether accessing the dataset file(s)
|
|
25
|
+
requires authentication
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
online: Literal[True]
|
|
29
|
+
url_template: str
|
|
30
|
+
requires_auth: bool = False
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LocalSource(BaseModel):
|
|
34
|
+
"""The local source of the dataset file(s).
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
path (str): The path to the dataset file(s)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
online: Literal[False]
|
|
41
|
+
path: Path
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class FileMetadata(BaseModel):
|
|
45
|
+
"""The metadata for a dataset file.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
name (str): The name of the dataset file
|
|
49
|
+
hash (str, optional): The hash of the dataset file
|
|
50
|
+
hash_type (str): The type of hash used for the dataset file
|
|
51
|
+
source (OnlineSource | LocalSource, optional): The immediate source of
|
|
52
|
+
the dataset file (use `effective_source` to access the effective source,
|
|
53
|
+
which may be inherited)
|
|
54
|
+
parent (SplitMetadata): The parent split metadata
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
name: str
|
|
58
|
+
hash: str | None = None
|
|
59
|
+
hash_type: str = DEFAULT_HASH_TYPE
|
|
60
|
+
source: OnlineSource | LocalSource | None = None
|
|
61
|
+
parent: Optional["SplitMetadata"] = None
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def effective_source(self) -> OnlineSource | LocalSource:
|
|
65
|
+
"""The effective source of the dataset file.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
(OnlineSource | LocalSource): The effective source.
|
|
69
|
+
"""
|
|
70
|
+
if self.source is not None:
|
|
71
|
+
return self.source
|
|
72
|
+
if self.parent is None:
|
|
73
|
+
raise RuntimeError("Parent metadata not filled. Please report this issue.")
|
|
74
|
+
return self.parent.effective_source
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SplitMetadata(BaseModel):
|
|
78
|
+
"""The metadata for a dataset split.
|
|
79
|
+
|
|
80
|
+
Attributes:
|
|
81
|
+
name (str): The name of the dataset split
|
|
82
|
+
files (dict[str, FileMetadata]): The dataset files in the split
|
|
83
|
+
source (OnlineSource | LocalSource, optional): The immediate source of
|
|
84
|
+
the dataset split (use `effective_source` to access the effective source,
|
|
85
|
+
which may be inherited)
|
|
86
|
+
parent (VersionMetadata): The parent version metadata
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
name: str
|
|
90
|
+
files: dict[str, FileMetadata]
|
|
91
|
+
source: OnlineSource | LocalSource | None = None
|
|
92
|
+
parent: Optional["VersionMetadata"] = None
|
|
93
|
+
|
|
94
|
+
@field_validator("files", mode="before")
|
|
95
|
+
@classmethod
|
|
96
|
+
def convert_list_to_dict(cls, files):
|
|
97
|
+
if isinstance(files, list):
|
|
98
|
+
return {file["name"]: file for file in files}
|
|
99
|
+
return files
|
|
100
|
+
|
|
101
|
+
@override
|
|
102
|
+
def model_post_init(self, _):
|
|
103
|
+
for file in self.files.values():
|
|
104
|
+
file.parent = self
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def effective_source(self) -> OnlineSource | LocalSource:
|
|
108
|
+
"""The effective source of the dataset split.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
(OnlineSource | LocalSource): The effective source.
|
|
112
|
+
"""
|
|
113
|
+
if self.source is not None:
|
|
114
|
+
return self.source
|
|
115
|
+
if self.parent is None:
|
|
116
|
+
raise RuntimeError("Parent metadata not filled. Please report this issue.")
|
|
117
|
+
return self.parent.effective_source
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class VersionMetadata(BaseModel):
|
|
121
|
+
"""The metadata for a dataset version.
|
|
122
|
+
|
|
123
|
+
Attributes:
|
|
124
|
+
name (str): The name of the dataset version
|
|
125
|
+
splits (dict[str, SplitMetadata], optional): The dataset splits in the version
|
|
126
|
+
files (dict[str, FileMetadata], optional): The dataset files in the version
|
|
127
|
+
source (OnlineSource | LocalSource, optional): The immediate source of
|
|
128
|
+
the dataset version (use `effective_source` to access the effective source,
|
|
129
|
+
which may be inherited)
|
|
130
|
+
parent (DatasetMetadata): The parent dataset metadata
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
name: str
|
|
134
|
+
splits: dict[str, SplitMetadata]
|
|
135
|
+
files: dict[str, FileMetadata] | None = None
|
|
136
|
+
source: OnlineSource | LocalSource | None = None
|
|
137
|
+
parent: Optional["DatasetMetadata"] = None
|
|
138
|
+
|
|
139
|
+
@field_validator("splits", "files", mode="before")
|
|
140
|
+
@classmethod
|
|
141
|
+
def convert_list_to_dict(cls, vs):
|
|
142
|
+
if isinstance(vs, list):
|
|
143
|
+
return {v["name"]: v for v in vs}
|
|
144
|
+
return vs
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def model_post_init(self, _):
|
|
148
|
+
for split in self.splits.values():
|
|
149
|
+
split.parent = self
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def effective_source(self) -> OnlineSource | LocalSource:
|
|
153
|
+
"""The effective source of the dataset version.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
(OnlineSource | LocalSource): The effective source.
|
|
157
|
+
"""
|
|
158
|
+
if self.source is not None:
|
|
159
|
+
return self.source
|
|
160
|
+
if self.parent is None:
|
|
161
|
+
raise RuntimeError("Parent metadata not filled. Please report this issue.")
|
|
162
|
+
return self.parent.effective_source
|
|
163
|
+
|
|
164
|
+
def get_files(self, splits: list[str]) -> dict[str, FileMetadata]:
|
|
165
|
+
"""Gets the files for the specified splits.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
splits (list[str]): The names of the splits.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
(dict[str, FileMetadata]): The files for the splits.
|
|
172
|
+
"""
|
|
173
|
+
files = {}
|
|
174
|
+
if self.files is not None:
|
|
175
|
+
files.update(self.files)
|
|
176
|
+
for split_name in splits:
|
|
177
|
+
if split_name not in self.splits:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"Split '{split_name}' not found for version {self.name}."
|
|
180
|
+
)
|
|
181
|
+
files.update(self.splits[split_name].files)
|
|
182
|
+
return files
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class DatasetMetadata(BaseModel):
|
|
186
|
+
"""The metadata for a dataset.
|
|
187
|
+
|
|
188
|
+
Attributes:
|
|
189
|
+
name (str): The name of the dataset
|
|
190
|
+
versions (dict[str, VersionMetadata]): The dataset versions
|
|
191
|
+
source (OnlineSource | LocalSource, optional): The immediate source of
|
|
192
|
+
the dataset (use `effective_source` to access the effective source,
|
|
193
|
+
which may be inherited)
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
name: str
|
|
197
|
+
versions: dict[str, VersionMetadata]
|
|
198
|
+
source: OnlineSource | LocalSource | None = None
|
|
199
|
+
|
|
200
|
+
@field_validator("versions", mode="before")
|
|
201
|
+
@classmethod
|
|
202
|
+
def convert_list_to_dict(cls, versions):
|
|
203
|
+
if isinstance(versions, list):
|
|
204
|
+
return {version["name"]: version for version in versions}
|
|
205
|
+
return versions
|
|
206
|
+
|
|
207
|
+
@override
|
|
208
|
+
def model_post_init(self, _):
|
|
209
|
+
for version in self.versions.values():
|
|
210
|
+
version.parent = self
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def effective_source(self) -> OnlineSource | LocalSource:
|
|
214
|
+
"""The effective source of the dataset.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
(OnlineSource | LocalSource): The effective source.
|
|
218
|
+
"""
|
|
219
|
+
if self.source is not None:
|
|
220
|
+
return self.source
|
|
221
|
+
raise ValueError("No effective source exists.")
|
|
222
|
+
|
|
223
|
+
def get_files(self, version: str, splits: list[str]) -> dict[str, FileMetadata]:
|
|
224
|
+
"""Gets the files for the specified version and splits.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
version (str): The name of the version.
|
|
228
|
+
splits (list[str]): The names of the splits.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
(dict[str, FileMetadata]): The files for the version and splits.
|
|
232
|
+
"""
|
|
233
|
+
if version not in self.versions:
|
|
234
|
+
raise ValueError(f"Version '{version}' not found for dataset {self.name}.")
|
|
235
|
+
return self.versions[version].get_files(splits)
|
|
236
|
+
|
|
237
|
+
def get_splits(self, version: str) -> dict[str, SplitMetadata]:
|
|
238
|
+
"""Gets the dataset splits for the specified version.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
version (str): The name of the version.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
(dict[str, SplitMetadata]): The splits for the version.
|
|
245
|
+
"""
|
|
246
|
+
if version not in self.versions:
|
|
247
|
+
raise ValueError(f"Version '{version}' not found for dataset {self.name}.")
|
|
248
|
+
return self.versions[version].splits
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class DatasetConfig:
|
|
252
|
+
"""Configuration for a dataset.
|
|
253
|
+
|
|
254
|
+
Attributes:
|
|
255
|
+
dataset_name (str): The name of the dataset.
|
|
256
|
+
dataset_metadata (DatasetMetadata): The metadata for the dataset.
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
def __init__(self, dataset_name: str):
|
|
260
|
+
"""Initializes a new DatasetConfig.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
dataset_name (str): The name of the dataset.
|
|
264
|
+
"""
|
|
265
|
+
self.dataset_name = dataset_name
|
|
266
|
+
config = {}
|
|
267
|
+
for config_path in DATASET_CONFIG_PATHS:
|
|
268
|
+
config_file = config_path / (to_safe_filename(dataset_name) + ".yml")
|
|
269
|
+
if config_file.exists():
|
|
270
|
+
try:
|
|
271
|
+
with open(config_file, "r") as f:
|
|
272
|
+
new_config = yaml.safe_load(f)
|
|
273
|
+
config = deep_update(config, new_config)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
warnings.warn(
|
|
276
|
+
f"Failed to load dataset config from {config_file}: {e}"
|
|
277
|
+
)
|
|
278
|
+
continue
|
|
279
|
+
self.dataset_metadata = DatasetMetadata(**config)
|
|
280
|
+
|
|
281
|
+
def get_files(self, version: str, splits: list[str]) -> dict[str, FileMetadata]:
|
|
282
|
+
"""Gets the files for the specified version and splits.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
version (str): The name of the version.
|
|
286
|
+
splits (list[str]): The names of the splits.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
(dict[str, FileMetadata]): The files for the version and splits.
|
|
290
|
+
"""
|
|
291
|
+
return self.dataset_metadata.get_files(version, splits)
|
|
292
|
+
|
|
293
|
+
def get_splits(self, version: str) -> dict[str, SplitMetadata]:
|
|
294
|
+
"""Gets the dataset splits for the specified version.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
version (str): The name of the version.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
(dict[str, SplitMetadata]): The splits for the version.
|
|
301
|
+
"""
|
|
302
|
+
return self.dataset_metadata.get_splits(version)
|