ScandEval 16.12.0__py3-none-any.whl → 16.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scandeval/async_utils.py +46 -0
- scandeval/benchmark_config_factory.py +26 -2
- scandeval/benchmark_modules/fresh.py +2 -1
- scandeval/benchmark_modules/hf.py +50 -12
- scandeval/benchmark_modules/litellm.py +25 -15
- scandeval/benchmark_modules/vllm.py +3 -3
- scandeval/benchmarker.py +15 -33
- scandeval/cli.py +2 -4
- scandeval/constants.py +5 -0
- scandeval/custom_dataset_configs.py +152 -0
- scandeval/data_loading.py +87 -31
- scandeval/data_models.py +396 -225
- scandeval/dataset_configs/__init__.py +51 -25
- scandeval/dataset_configs/albanian.py +1 -1
- scandeval/dataset_configs/belarusian.py +47 -0
- scandeval/dataset_configs/bulgarian.py +1 -1
- scandeval/dataset_configs/catalan.py +1 -1
- scandeval/dataset_configs/croatian.py +1 -1
- scandeval/dataset_configs/danish.py +3 -2
- scandeval/dataset_configs/dutch.py +7 -6
- scandeval/dataset_configs/english.py +4 -3
- scandeval/dataset_configs/estonian.py +8 -7
- scandeval/dataset_configs/faroese.py +1 -1
- scandeval/dataset_configs/finnish.py +5 -4
- scandeval/dataset_configs/french.py +6 -5
- scandeval/dataset_configs/german.py +4 -3
- scandeval/dataset_configs/greek.py +1 -1
- scandeval/dataset_configs/hungarian.py +1 -1
- scandeval/dataset_configs/icelandic.py +4 -3
- scandeval/dataset_configs/italian.py +4 -3
- scandeval/dataset_configs/latvian.py +2 -2
- scandeval/dataset_configs/lithuanian.py +1 -1
- scandeval/dataset_configs/norwegian.py +6 -5
- scandeval/dataset_configs/polish.py +4 -3
- scandeval/dataset_configs/portuguese.py +5 -4
- scandeval/dataset_configs/romanian.py +2 -2
- scandeval/dataset_configs/serbian.py +1 -1
- scandeval/dataset_configs/slovene.py +1 -1
- scandeval/dataset_configs/spanish.py +4 -3
- scandeval/dataset_configs/swedish.py +4 -3
- scandeval/dataset_configs/ukrainian.py +1 -1
- scandeval/generation_utils.py +6 -6
- scandeval/metrics/llm_as_a_judge.py +1 -1
- scandeval/metrics/pipeline.py +1 -1
- scandeval/model_cache.py +34 -4
- scandeval/prompt_templates/linguistic_acceptability.py +9 -0
- scandeval/prompt_templates/multiple_choice.py +9 -0
- scandeval/prompt_templates/named_entity_recognition.py +21 -0
- scandeval/prompt_templates/reading_comprehension.py +10 -0
- scandeval/prompt_templates/sentiment_classification.py +11 -0
- scandeval/string_utils.py +157 -0
- scandeval/task_group_utils/sequence_classification.py +2 -5
- scandeval/task_group_utils/token_classification.py +2 -4
- scandeval/utils.py +6 -323
- scandeval-16.13.0.dist-info/METADATA +334 -0
- scandeval-16.13.0.dist-info/RECORD +94 -0
- scandeval-16.12.0.dist-info/METADATA +0 -667
- scandeval-16.12.0.dist-info/RECORD +0 -90
- {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/WHEEL +0 -0
- {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/entry_points.txt +0 -0
- {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/licenses/LICENSE +0 -0
scandeval/async_utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Utility functions for asyncronous tasks."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
from .constants import T
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
|
|
10
|
+
"""Run a coroutine, ensuring that the event loop is always closed when we're done.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
coroutine:
|
|
14
|
+
The coroutine to run.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The result of the coroutine.
|
|
18
|
+
"""
|
|
19
|
+
try:
|
|
20
|
+
loop = asyncio.get_event_loop()
|
|
21
|
+
except RuntimeError: # If the current event loop is closed
|
|
22
|
+
loop = asyncio.new_event_loop()
|
|
23
|
+
asyncio.set_event_loop(loop)
|
|
24
|
+
response = loop.run_until_complete(coroutine)
|
|
25
|
+
return response
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def add_semaphore_and_catch_exception(
|
|
29
|
+
coroutine: t.Coroutine[t.Any, t.Any, T], semaphore: asyncio.Semaphore
|
|
30
|
+
) -> T | Exception:
|
|
31
|
+
"""Run a coroutine with a semaphore.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
coroutine:
|
|
35
|
+
The coroutine to run.
|
|
36
|
+
semaphore:
|
|
37
|
+
The semaphore to use.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The result of the coroutine.
|
|
41
|
+
"""
|
|
42
|
+
async with semaphore:
|
|
43
|
+
try:
|
|
44
|
+
return await coroutine
|
|
45
|
+
except Exception as exc:
|
|
46
|
+
return exc
|
|
@@ -46,6 +46,8 @@ def build_benchmark_config(
|
|
|
46
46
|
dataset=benchmark_config_params.dataset,
|
|
47
47
|
languages=languages,
|
|
48
48
|
custom_datasets_file=benchmark_config_params.custom_datasets_file,
|
|
49
|
+
api_key=benchmark_config_params.api_key,
|
|
50
|
+
cache_dir=Path(benchmark_config_params.cache_dir),
|
|
49
51
|
)
|
|
50
52
|
|
|
51
53
|
return BenchmarkConfig(
|
|
@@ -159,7 +161,9 @@ def prepare_dataset_configs(
|
|
|
159
161
|
languages: c.Sequence["Language"],
|
|
160
162
|
dataset: "str | DatasetConfig | c.Sequence[str | DatasetConfig] | None",
|
|
161
163
|
custom_datasets_file: Path,
|
|
162
|
-
|
|
164
|
+
api_key: str | None,
|
|
165
|
+
cache_dir: Path,
|
|
166
|
+
) -> list["DatasetConfig"]:
|
|
163
167
|
"""Prepare dataset config(s) for benchmarking.
|
|
164
168
|
|
|
165
169
|
Args:
|
|
@@ -173,6 +177,10 @@ def prepare_dataset_configs(
|
|
|
173
177
|
included, limited by the `task` and `languages` parameters.
|
|
174
178
|
custom_datasets_file:
|
|
175
179
|
A path to a Python file containing custom dataset configurations.
|
|
180
|
+
api_key:
|
|
181
|
+
The API key to use for accessing the Hugging Face Hub.
|
|
182
|
+
cache_dir:
|
|
183
|
+
The directory to store the cache in.
|
|
176
184
|
|
|
177
185
|
Returns:
|
|
178
186
|
The prepared dataset configs.
|
|
@@ -181,9 +189,25 @@ def prepare_dataset_configs(
|
|
|
181
189
|
InvalidBenchmark:
|
|
182
190
|
If the task or dataset is not found in the benchmark tasks or datasets.
|
|
183
191
|
"""
|
|
192
|
+
# Extract the dataset IDs from the `dataset` argument
|
|
193
|
+
dataset_ids: list[str] = list()
|
|
194
|
+
if isinstance(dataset, str):
|
|
195
|
+
dataset_ids.append(dataset)
|
|
196
|
+
elif isinstance(dataset, DatasetConfig):
|
|
197
|
+
dataset_ids.append(dataset.name)
|
|
198
|
+
elif isinstance(dataset, list):
|
|
199
|
+
for d in dataset:
|
|
200
|
+
if isinstance(d, str):
|
|
201
|
+
dataset_ids.append(d)
|
|
202
|
+
elif isinstance(d, DatasetConfig):
|
|
203
|
+
dataset_ids.append(d.name)
|
|
204
|
+
|
|
184
205
|
# Create the list of dataset configs
|
|
185
206
|
all_dataset_configs = get_all_dataset_configs(
|
|
186
|
-
custom_datasets_file=custom_datasets_file
|
|
207
|
+
custom_datasets_file=custom_datasets_file,
|
|
208
|
+
dataset_ids=dataset_ids,
|
|
209
|
+
api_key=api_key,
|
|
210
|
+
cache_dir=cache_dir,
|
|
187
211
|
)
|
|
188
212
|
all_official_dataset_configs: c.Sequence[DatasetConfig] = [
|
|
189
213
|
dataset_config
|
|
@@ -28,8 +28,9 @@ from ..exceptions import (
|
|
|
28
28
|
)
|
|
29
29
|
from ..generation_utils import raise_if_wrong_params
|
|
30
30
|
from ..logging_utils import block_terminal_output
|
|
31
|
+
from ..model_cache import create_model_cache_dir
|
|
31
32
|
from ..types import Tokeniser
|
|
32
|
-
from ..utils import
|
|
33
|
+
from ..utils import get_hf_token
|
|
33
34
|
from .hf import (
|
|
34
35
|
HuggingFaceEncoderModel,
|
|
35
36
|
align_model_and_tokeniser,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Encoder models from the Hugging Face Hub."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
+
import importlib
|
|
4
5
|
import logging
|
|
5
6
|
import re
|
|
6
7
|
import typing as t
|
|
@@ -63,6 +64,8 @@ from ..exceptions import (
|
|
|
63
64
|
from ..generation_utils import raise_if_wrong_params
|
|
64
65
|
from ..languages import get_all_languages
|
|
65
66
|
from ..logging_utils import block_terminal_output, log, log_once
|
|
67
|
+
from ..model_cache import create_model_cache_dir
|
|
68
|
+
from ..string_utils import split_model_id
|
|
66
69
|
from ..task_group_utils import (
|
|
67
70
|
multiple_choice_classification,
|
|
68
71
|
question_answering,
|
|
@@ -70,13 +73,7 @@ from ..task_group_utils import (
|
|
|
70
73
|
)
|
|
71
74
|
from ..tokenisation_utils import get_bos_token, get_eos_token
|
|
72
75
|
from ..types import Tokeniser
|
|
73
|
-
from ..utils import
|
|
74
|
-
create_model_cache_dir,
|
|
75
|
-
get_class_by_name,
|
|
76
|
-
get_hf_token,
|
|
77
|
-
internet_connection_available,
|
|
78
|
-
split_model_id,
|
|
79
|
-
)
|
|
76
|
+
from ..utils import get_hf_token, internet_connection_available
|
|
80
77
|
from .base import BenchmarkModule
|
|
81
78
|
|
|
82
79
|
try:
|
|
@@ -381,7 +378,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
381
378
|
if "label" in examples:
|
|
382
379
|
try:
|
|
383
380
|
examples["label"] = [
|
|
384
|
-
self._model.config.label2id[lbl.lower()]
|
|
381
|
+
self._model.config.label2id[str(lbl).lower()]
|
|
385
382
|
if self._model.config.label2id is not None
|
|
386
383
|
else lbl
|
|
387
384
|
for lbl in examples["label"]
|
|
@@ -817,8 +814,8 @@ def get_model_repo_info(
|
|
|
817
814
|
log(
|
|
818
815
|
f"Could not access the model {model_id} with the revision "
|
|
819
816
|
f"{revision}. The error was {str(e)!r}. Please set the "
|
|
820
|
-
"`HUGGINGFACE_API_KEY` environment variable or
|
|
821
|
-
"`--api-key` argument.",
|
|
817
|
+
"`HUGGINGFACE_API_KEY` or `HF_TOKEN` environment variable or "
|
|
818
|
+
"use the `--api-key` argument.",
|
|
822
819
|
level=logging.DEBUG,
|
|
823
820
|
)
|
|
824
821
|
return None
|
|
@@ -1095,8 +1092,8 @@ def load_hf_model_config(
|
|
|
1095
1092
|
f"The model {model_id!r} is a gated repository. Please ensure "
|
|
1096
1093
|
"that you are logged in with `hf auth login` or have provided a "
|
|
1097
1094
|
"valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
|
|
1098
|
-
"environment variable or the `--api-key` argument.
|
|
1099
|
-
"your account has access to this model."
|
|
1095
|
+
"or `HF_TOKEN` environment variable or the `--api-key` argument. "
|
|
1096
|
+
"Also check that your account has access to this model."
|
|
1100
1097
|
) from e
|
|
1101
1098
|
raise InvalidModel(
|
|
1102
1099
|
f"Couldn't load model config for {model_id!r}. The error was "
|
|
@@ -1334,3 +1331,44 @@ def task_group_to_class_name(task_group: TaskGroup) -> str:
|
|
|
1334
1331
|
)
|
|
1335
1332
|
pascal_case = special_case_mapping.get(pascal_case, pascal_case)
|
|
1336
1333
|
return f"AutoModelFor{pascal_case}"
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
def get_class_by_name(
|
|
1337
|
+
class_name: str | c.Sequence[str], module_name: str
|
|
1338
|
+
) -> t.Type | None:
|
|
1339
|
+
"""Get a class by its name.
|
|
1340
|
+
|
|
1341
|
+
Args:
|
|
1342
|
+
class_name:
|
|
1343
|
+
The name of the class, written in kebab-case. The corresponding class name
|
|
1344
|
+
must be the same, but written in PascalCase, and lying in a module with the
|
|
1345
|
+
same name, but written in snake_case. If a list of strings is passed, the
|
|
1346
|
+
first class that is found is returned.
|
|
1347
|
+
module_name:
|
|
1348
|
+
The name of the module where the class is located.
|
|
1349
|
+
|
|
1350
|
+
Returns:
|
|
1351
|
+
The class. If the class is not found, None is returned.
|
|
1352
|
+
"""
|
|
1353
|
+
if isinstance(class_name, str):
|
|
1354
|
+
class_name = [class_name]
|
|
1355
|
+
|
|
1356
|
+
error_messages = list()
|
|
1357
|
+
for name in class_name:
|
|
1358
|
+
try:
|
|
1359
|
+
module = importlib.import_module(name=module_name)
|
|
1360
|
+
class_: t.Type = getattr(module, name)
|
|
1361
|
+
return class_
|
|
1362
|
+
except (ModuleNotFoundError, AttributeError) as e:
|
|
1363
|
+
error_messages.append(str(e))
|
|
1364
|
+
|
|
1365
|
+
if error_messages:
|
|
1366
|
+
errors = "\n- " + "\n- ".join(error_messages)
|
|
1367
|
+
log(
|
|
1368
|
+
f"Could not find the class with the name(s) {', '.join(class_name)}. The "
|
|
1369
|
+
f"following error messages were raised: {errors}",
|
|
1370
|
+
level=logging.DEBUG,
|
|
1371
|
+
)
|
|
1372
|
+
|
|
1373
|
+
# If the class could not be found, return None
|
|
1374
|
+
return None
|
|
@@ -40,7 +40,7 @@ from pydantic import ValidationError, conlist, create_model
|
|
|
40
40
|
from requests.exceptions import RequestException
|
|
41
41
|
from tqdm.asyncio import tqdm as tqdm_async
|
|
42
42
|
|
|
43
|
-
from ..
|
|
43
|
+
from ..async_utils import add_semaphore_and_catch_exception, safe_run
|
|
44
44
|
from ..constants import (
|
|
45
45
|
JSON_STRIP_CHARACTERS,
|
|
46
46
|
LITELLM_CLASSIFICATION_OUTPUT_KEY,
|
|
@@ -74,6 +74,8 @@ from ..generation_utils import (
|
|
|
74
74
|
raise_if_wrong_params,
|
|
75
75
|
)
|
|
76
76
|
from ..logging_utils import get_pbar, log, log_once
|
|
77
|
+
from ..model_cache import create_model_cache_dir
|
|
78
|
+
from ..string_utils import split_model_id
|
|
77
79
|
from ..task_group_utils import (
|
|
78
80
|
question_answering,
|
|
79
81
|
sequence_classification,
|
|
@@ -83,13 +85,7 @@ from ..task_group_utils import (
|
|
|
83
85
|
from ..tasks import NER
|
|
84
86
|
from ..tokenisation_utils import get_first_label_token_mapping
|
|
85
87
|
from ..types import ExtractLabelsFunction
|
|
86
|
-
from ..utils import
|
|
87
|
-
add_semaphore_and_catch_exception,
|
|
88
|
-
create_model_cache_dir,
|
|
89
|
-
get_hf_token,
|
|
90
|
-
safe_run,
|
|
91
|
-
split_model_id,
|
|
92
|
-
)
|
|
88
|
+
from ..utils import get_hf_token
|
|
93
89
|
from .base import BenchmarkModule
|
|
94
90
|
from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokeniser
|
|
95
91
|
|
|
@@ -700,10 +696,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
700
696
|
elif isinstance(
|
|
701
697
|
error, (Timeout, ServiceUnavailableError, InternalServerError, SystemError)
|
|
702
698
|
):
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
"Retrying in 10 seconds...",
|
|
706
|
-
level=logging.
|
|
699
|
+
log(
|
|
700
|
+
"Service temporarily unavailable during generation. The error "
|
|
701
|
+
f"message was: {error}. Retrying in 10 seconds...",
|
|
702
|
+
level=logging.INFO,
|
|
707
703
|
)
|
|
708
704
|
return generation_kwargs, 10
|
|
709
705
|
elif isinstance(error, UnsupportedParamsError):
|
|
@@ -764,6 +760,20 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
764
760
|
run_with_cli=self.benchmark_config.run_with_cli,
|
|
765
761
|
) from error
|
|
766
762
|
|
|
763
|
+
if (
|
|
764
|
+
isinstance(error, (BadRequestError, NotFoundError))
|
|
765
|
+
and self.benchmark_config.api_base is not None
|
|
766
|
+
and not self.benchmark_config.api_base.endswith("/v1")
|
|
767
|
+
):
|
|
768
|
+
log_once(
|
|
769
|
+
f"The API base {self.benchmark_config.api_base!r} is not valid. We "
|
|
770
|
+
"will try appending '/v1' to it and try again.",
|
|
771
|
+
level=logging.DEBUG,
|
|
772
|
+
)
|
|
773
|
+
self.benchmark_config.api_base += "/v1"
|
|
774
|
+
generation_kwargs["api_base"] = self.benchmark_config.api_base
|
|
775
|
+
return generation_kwargs, 0
|
|
776
|
+
|
|
767
777
|
raise InvalidBenchmark(
|
|
768
778
|
f"Failed to generate text. The error message was: {error}"
|
|
769
779
|
) from error
|
|
@@ -1390,9 +1400,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1390
1400
|
InternalServerError,
|
|
1391
1401
|
) as e:
|
|
1392
1402
|
log(
|
|
1393
|
-
|
|
1403
|
+
"Service temporarily unavailable while checking for model "
|
|
1404
|
+
f"existence of the model {model_id!r}. The error message was: {e}. "
|
|
1394
1405
|
"Retrying in 10 seconds...",
|
|
1395
|
-
level=logging.
|
|
1406
|
+
level=logging.INFO,
|
|
1396
1407
|
)
|
|
1397
1408
|
sleep(10)
|
|
1398
1409
|
except APIError as e:
|
|
@@ -1567,7 +1578,6 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1567
1578
|
|
|
1568
1579
|
return dataset
|
|
1569
1580
|
|
|
1570
|
-
@cache_arguments()
|
|
1571
1581
|
def get_generation_kwargs(self, dataset_config: DatasetConfig) -> dict[str, t.Any]:
|
|
1572
1582
|
"""Get the generation arguments for the model.
|
|
1573
1583
|
|
|
@@ -54,6 +54,8 @@ from ..generation_utils import (
|
|
|
54
54
|
)
|
|
55
55
|
from ..languages import get_all_languages
|
|
56
56
|
from ..logging_utils import get_pbar, log, log_once, no_terminal_output
|
|
57
|
+
from ..model_cache import create_model_cache_dir
|
|
58
|
+
from ..string_utils import split_model_id
|
|
57
59
|
from ..task_group_utils import (
|
|
58
60
|
question_answering,
|
|
59
61
|
sequence_classification,
|
|
@@ -73,12 +75,10 @@ from ..tokenisation_utils import (
|
|
|
73
75
|
from ..types import ExtractLabelsFunction, Tokeniser
|
|
74
76
|
from ..utils import (
|
|
75
77
|
clear_memory,
|
|
76
|
-
create_model_cache_dir,
|
|
77
78
|
get_hf_token,
|
|
78
79
|
get_min_cuda_compute_capability,
|
|
79
80
|
internet_connection_available,
|
|
80
81
|
resolve_model_path,
|
|
81
|
-
split_model_id,
|
|
82
82
|
)
|
|
83
83
|
from .hf import HuggingFaceEncoderModel, get_model_repo_info, load_hf_model_config
|
|
84
84
|
|
|
@@ -1144,7 +1144,7 @@ def load_model_and_tokeniser(
|
|
|
1144
1144
|
pipeline_parallel_size=pipeline_parallel_size,
|
|
1145
1145
|
disable_custom_all_reduce=True,
|
|
1146
1146
|
quantization=quantization,
|
|
1147
|
-
dtype=dtype,
|
|
1147
|
+
dtype=dtype, # pyrefly: ignore[bad-argument-type]
|
|
1148
1148
|
enforce_eager=True,
|
|
1149
1149
|
# TEMP: Prefix caching isn't supported with sliding window in vLLM yet,
|
|
1150
1150
|
# so we disable it for now
|
scandeval/benchmarker.py
CHANGED
|
@@ -18,7 +18,6 @@ from .benchmark_config_factory import build_benchmark_config
|
|
|
18
18
|
from .constants import ATTENTION_BACKENDS, GENERATIVE_PIPELINE_TAGS
|
|
19
19
|
from .data_loading import load_data, load_raw_data
|
|
20
20
|
from .data_models import BenchmarkConfigParams, BenchmarkResult
|
|
21
|
-
from .dataset_configs import get_all_dataset_configs
|
|
22
21
|
from .enums import Device, GenerativeType, ModelType
|
|
23
22
|
from .exceptions import HuggingFaceHubDown, InvalidBenchmark, InvalidModel
|
|
24
23
|
from .finetuning import finetune
|
|
@@ -28,12 +27,9 @@ from .model_config import get_model_config
|
|
|
28
27
|
from .model_loading import load_model
|
|
29
28
|
from .scores import log_scores
|
|
30
29
|
from .speed_benchmark import benchmark_speed
|
|
30
|
+
from .string_utils import split_model_id
|
|
31
31
|
from .tasks import SPEED
|
|
32
|
-
from .utils import
|
|
33
|
-
enforce_reproducibility,
|
|
34
|
-
internet_connection_available,
|
|
35
|
-
split_model_id,
|
|
36
|
-
)
|
|
32
|
+
from .utils import enforce_reproducibility, internet_connection_available
|
|
37
33
|
|
|
38
34
|
if t.TYPE_CHECKING:
|
|
39
35
|
from .benchmark_modules import BenchmarkModule
|
|
@@ -79,7 +75,9 @@ class Benchmarker:
|
|
|
79
75
|
api_base: str | None = None,
|
|
80
76
|
api_version: str | None = None,
|
|
81
77
|
gpu_memory_utilization: float = 0.8,
|
|
82
|
-
attention_backend:
|
|
78
|
+
attention_backend: t.Literal[
|
|
79
|
+
*ATTENTION_BACKENDS # pyrefly: ignore[invalid-literal]
|
|
80
|
+
] = "FLASHINFER",
|
|
83
81
|
generative_type: GenerativeType | None = None,
|
|
84
82
|
custom_datasets_file: Path | str = Path("custom_datasets.py"),
|
|
85
83
|
debug: bool = False,
|
|
@@ -346,7 +344,9 @@ class Benchmarker:
|
|
|
346
344
|
f"Loading data for {dataset_config.logging_string}", level=logging.INFO
|
|
347
345
|
)
|
|
348
346
|
dataset = load_raw_data(
|
|
349
|
-
dataset_config=dataset_config,
|
|
347
|
+
dataset_config=dataset_config,
|
|
348
|
+
cache_dir=benchmark_config.cache_dir,
|
|
349
|
+
api_key=benchmark_config.api_key,
|
|
350
350
|
)
|
|
351
351
|
del dataset
|
|
352
352
|
|
|
@@ -513,6 +513,11 @@ class Benchmarker:
|
|
|
513
513
|
ValueError:
|
|
514
514
|
If both `task` and `dataset` are specified.
|
|
515
515
|
"""
|
|
516
|
+
log(
|
|
517
|
+
"Started EuroEval run. Run with `--verbose` for more information.",
|
|
518
|
+
level=logging.INFO,
|
|
519
|
+
)
|
|
520
|
+
|
|
516
521
|
if task is not None and dataset is not None:
|
|
517
522
|
raise ValueError("Only one of `task` and `dataset` can be specified.")
|
|
518
523
|
|
|
@@ -790,7 +795,7 @@ class Benchmarker:
|
|
|
790
795
|
|
|
791
796
|
# Update the benchmark config if the dataset requires it
|
|
792
797
|
if (
|
|
793
|
-
|
|
798
|
+
dataset_config.val_split is None
|
|
794
799
|
and not benchmark_config.evaluate_test_split
|
|
795
800
|
):
|
|
796
801
|
log(
|
|
@@ -1066,7 +1071,7 @@ class Benchmarker:
|
|
|
1066
1071
|
),
|
|
1067
1072
|
validation_split=(
|
|
1068
1073
|
None
|
|
1069
|
-
if
|
|
1074
|
+
if dataset_config.val_split is None
|
|
1070
1075
|
else not benchmark_config.evaluate_test_split
|
|
1071
1076
|
),
|
|
1072
1077
|
)
|
|
@@ -1181,29 +1186,6 @@ def clear_model_cache_fn(cache_dir: str) -> None:
|
|
|
1181
1186
|
rmtree(sub_model_dir)
|
|
1182
1187
|
|
|
1183
1188
|
|
|
1184
|
-
def prepare_dataset_configs(
|
|
1185
|
-
dataset_names: c.Sequence[str], custom_datasets_file: Path
|
|
1186
|
-
) -> c.Sequence["DatasetConfig"]:
|
|
1187
|
-
"""Prepare the dataset configuration(s) to be benchmarked.
|
|
1188
|
-
|
|
1189
|
-
Args:
|
|
1190
|
-
dataset_names:
|
|
1191
|
-
The dataset names to benchmark.
|
|
1192
|
-
custom_datasets_file:
|
|
1193
|
-
A path to a Python file containing custom dataset configurations.
|
|
1194
|
-
|
|
1195
|
-
Returns:
|
|
1196
|
-
The prepared list of model IDs.
|
|
1197
|
-
"""
|
|
1198
|
-
return [
|
|
1199
|
-
cfg
|
|
1200
|
-
for cfg in get_all_dataset_configs(
|
|
1201
|
-
custom_datasets_file=custom_datasets_file
|
|
1202
|
-
).values()
|
|
1203
|
-
if cfg.name in dataset_names
|
|
1204
|
-
]
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
1189
|
def initial_logging(
|
|
1208
1190
|
model_config: "ModelConfig",
|
|
1209
1191
|
dataset_config: "DatasetConfig",
|
scandeval/cli.py
CHANGED
|
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
import click
|
|
6
6
|
|
|
7
7
|
from .benchmarker import Benchmarker
|
|
8
|
+
from .constants import ATTENTION_BACKENDS
|
|
8
9
|
from .data_models import DatasetConfig
|
|
9
10
|
from .enums import Device, GenerativeType
|
|
10
11
|
from .languages import get_all_languages
|
|
@@ -174,10 +175,7 @@ from .languages import get_all_languages
|
|
|
174
175
|
"--attention-backend",
|
|
175
176
|
default="FLASHINFER",
|
|
176
177
|
show_default=True,
|
|
177
|
-
type=click.Choice(
|
|
178
|
-
["FLASHINFER", "FLASH_ATTN", "TRITON_ATTN", "FLEX_ATTENTION"],
|
|
179
|
-
case_sensitive=True,
|
|
180
|
-
),
|
|
178
|
+
type=click.Choice(ATTENTION_BACKENDS, case_sensitive=True),
|
|
181
179
|
help="The attention backend to use for vLLM. Only relevant if the model is "
|
|
182
180
|
"generative.",
|
|
183
181
|
)
|
scandeval/constants.py
CHANGED
|
@@ -134,3 +134,8 @@ ATTENTION_BACKENDS: list[str] = [
|
|
|
134
134
|
"CPU_ATTN",
|
|
135
135
|
"CUSTOM",
|
|
136
136
|
]
|
|
137
|
+
|
|
138
|
+
# If a dataset configuration has more than this number of languages, we won't log any of
|
|
139
|
+
# the languages. This is for instance the case for the speed benchmark, which has all
|
|
140
|
+
# the languages. The threshold of 5 is somewhat arbitrary.
|
|
141
|
+
MAX_NUMBER_OF_LOGGING_LANGUAGES = 5
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""Load custom dataset configs."""
|
|
2
|
+
|
|
3
|
+
import importlib.util
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from types import ModuleType
|
|
7
|
+
|
|
8
|
+
from huggingface_hub import HfApi
|
|
9
|
+
|
|
10
|
+
from .data_models import DatasetConfig
|
|
11
|
+
from .logging_utils import log_once
|
|
12
|
+
from .utils import get_hf_token
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_custom_datasets_module(custom_datasets_file: Path) -> ModuleType | None:
|
|
16
|
+
"""Load the custom datasets module if it exists.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
custom_datasets_file:
|
|
20
|
+
The path to the custom datasets module.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
RuntimeError:
|
|
24
|
+
If the custom datasets module cannot be loaded.
|
|
25
|
+
"""
|
|
26
|
+
if custom_datasets_file.exists():
|
|
27
|
+
spec = importlib.util.spec_from_file_location(
|
|
28
|
+
name="custom_datasets_module", location=str(custom_datasets_file.resolve())
|
|
29
|
+
)
|
|
30
|
+
if spec is None:
|
|
31
|
+
log_once(
|
|
32
|
+
"Could not load the spec for the custom datasets file from "
|
|
33
|
+
f"{custom_datasets_file.resolve()}.",
|
|
34
|
+
level=logging.ERROR,
|
|
35
|
+
)
|
|
36
|
+
return None
|
|
37
|
+
module = importlib.util.module_from_spec(spec=spec)
|
|
38
|
+
if spec.loader is None:
|
|
39
|
+
log_once(
|
|
40
|
+
"Could not load the module for the custom datasets file from "
|
|
41
|
+
f"{custom_datasets_file.resolve()}.",
|
|
42
|
+
level=logging.ERROR,
|
|
43
|
+
)
|
|
44
|
+
return None
|
|
45
|
+
spec.loader.exec_module(module)
|
|
46
|
+
return module
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def try_get_dataset_config_from_repo(
|
|
51
|
+
dataset_id: str, api_key: str | None, cache_dir: Path
|
|
52
|
+
) -> DatasetConfig | None:
|
|
53
|
+
"""Try to get a dataset config from a Hugging Face dataset repository.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
dataset_id:
|
|
57
|
+
The ID of the dataset to get the config for.
|
|
58
|
+
api_key:
|
|
59
|
+
The Hugging Face API key to use to check if the repositories have custom
|
|
60
|
+
dataset configs.
|
|
61
|
+
cache_dir:
|
|
62
|
+
The directory to store the cache in.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The dataset config if it exists, otherwise None.
|
|
66
|
+
"""
|
|
67
|
+
# Check if the dataset ID is a Hugging Face dataset ID, abort if it isn't
|
|
68
|
+
token = get_hf_token(api_key=api_key)
|
|
69
|
+
hf_api = HfApi(token=token)
|
|
70
|
+
if not hf_api.repo_exists(repo_id=dataset_id, repo_type="dataset"):
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
# Check if the repository has a euroeval_config.py file, abort if it doesn't
|
|
74
|
+
repo_files = hf_api.list_repo_files(
|
|
75
|
+
repo_id=dataset_id, repo_type="dataset", revision="main"
|
|
76
|
+
)
|
|
77
|
+
if "euroeval_config.py" not in repo_files:
|
|
78
|
+
log_once(
|
|
79
|
+
f"Dataset {dataset_id} does not have a euroeval_config.py file, so we "
|
|
80
|
+
"cannot load it. Skipping.",
|
|
81
|
+
level=logging.WARNING,
|
|
82
|
+
)
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
# Fetch the euroeval_config.py file, abort if loading failed
|
|
86
|
+
external_config_path = cache_dir / "external_dataset_configs" / dataset_id
|
|
87
|
+
external_config_path.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
hf_api.hf_hub_download(
|
|
89
|
+
repo_id=dataset_id,
|
|
90
|
+
repo_type="dataset",
|
|
91
|
+
filename="euroeval_config.py",
|
|
92
|
+
local_dir=external_config_path,
|
|
93
|
+
local_dir_use_symlinks=False,
|
|
94
|
+
)
|
|
95
|
+
module = load_custom_datasets_module(
|
|
96
|
+
custom_datasets_file=external_config_path / "euroeval_config.py"
|
|
97
|
+
)
|
|
98
|
+
if module is None:
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
# Check that there is exactly one dataset config, abort if there isn't
|
|
102
|
+
repo_dataset_configs = [
|
|
103
|
+
cfg for cfg in vars(module).values() if isinstance(cfg, DatasetConfig)
|
|
104
|
+
]
|
|
105
|
+
if not repo_dataset_configs:
|
|
106
|
+
return None # Already warned the user in this case, so we just skip
|
|
107
|
+
elif len(repo_dataset_configs) > 1:
|
|
108
|
+
log_once(
|
|
109
|
+
f"Dataset {dataset_id} has multiple dataset configurations. Please ensure "
|
|
110
|
+
"that only a single DatasetConfig is defined in the `euroeval_config.py` "
|
|
111
|
+
"file.",
|
|
112
|
+
level=logging.WARNING,
|
|
113
|
+
)
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
# Get the dataset split names
|
|
117
|
+
splits = [
|
|
118
|
+
split["name"]
|
|
119
|
+
for split in hf_api.dataset_info(repo_id=dataset_id).card_data.dataset_info[
|
|
120
|
+
"splits"
|
|
121
|
+
]
|
|
122
|
+
]
|
|
123
|
+
train_split_candidates = sorted(
|
|
124
|
+
[split for split in splits if "train" in split.lower()], key=len
|
|
125
|
+
)
|
|
126
|
+
val_split_candidates = sorted(
|
|
127
|
+
[split for split in splits if "val" in split.lower()], key=len
|
|
128
|
+
)
|
|
129
|
+
test_split_candidates = sorted(
|
|
130
|
+
[split for split in splits if "test" in split.lower()], key=len
|
|
131
|
+
)
|
|
132
|
+
train_split = train_split_candidates[0] if train_split_candidates else None
|
|
133
|
+
val_split = val_split_candidates[0] if val_split_candidates else None
|
|
134
|
+
test_split = test_split_candidates[0] if test_split_candidates else None
|
|
135
|
+
if test_split is None:
|
|
136
|
+
log_once(
|
|
137
|
+
f"Dataset {dataset_id} does not have a test split, so we cannot load it. "
|
|
138
|
+
"Please ensure that the dataset has a test split.",
|
|
139
|
+
level=logging.ERROR,
|
|
140
|
+
)
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
# Set up the config with the repo information
|
|
144
|
+
repo_dataset_config = repo_dataset_configs[0]
|
|
145
|
+
repo_dataset_config.name = dataset_id
|
|
146
|
+
repo_dataset_config.pretty_name = dataset_id
|
|
147
|
+
repo_dataset_config.source = dataset_id
|
|
148
|
+
repo_dataset_config.train_split = train_split
|
|
149
|
+
repo_dataset_config.val_split = val_split
|
|
150
|
+
repo_dataset_config.test_split = test_split
|
|
151
|
+
|
|
152
|
+
return repo_dataset_config
|