EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
euroeval/utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Utility functions to be used in other scripts."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import collections.abc as c
|
|
4
5
|
import gc
|
|
5
6
|
import importlib
|
|
6
7
|
import importlib.metadata
|
|
@@ -8,34 +9,28 @@ import importlib.util
|
|
|
8
9
|
import logging
|
|
9
10
|
import os
|
|
10
11
|
import random
|
|
12
|
+
import re
|
|
13
|
+
import socket
|
|
11
14
|
import sys
|
|
12
15
|
import typing as t
|
|
13
|
-
import warnings
|
|
14
|
-
from functools import cache
|
|
15
16
|
from pathlib import Path
|
|
17
|
+
from types import ModuleType
|
|
16
18
|
|
|
17
|
-
import
|
|
19
|
+
import demjson3
|
|
20
|
+
import huggingface_hub as hf_hub
|
|
18
21
|
import numpy as np
|
|
19
|
-
import requests
|
|
20
22
|
import torch
|
|
21
|
-
from datasets.utils import disable_progress_bar
|
|
22
|
-
from requests.exceptions import RequestException
|
|
23
|
-
from transformers import logging as tf_logging
|
|
24
23
|
|
|
25
|
-
from .
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
24
|
+
from .caching_utils import cache_arguments
|
|
25
|
+
from .constants import T
|
|
26
|
+
from .exceptions import InvalidBenchmark, InvalidModel, NaNValueInModelOutput
|
|
27
|
+
from .logging_utils import log, log_once
|
|
29
28
|
|
|
30
29
|
if t.TYPE_CHECKING:
|
|
31
|
-
from
|
|
32
|
-
|
|
30
|
+
from .data_models import ModelIdComponents
|
|
33
31
|
from .types import Predictions
|
|
34
32
|
|
|
35
33
|
|
|
36
|
-
logger = logging.getLogger("euroeval")
|
|
37
|
-
|
|
38
|
-
|
|
39
34
|
def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
|
|
40
35
|
"""Create cache directory for a model.
|
|
41
36
|
|
|
@@ -54,6 +49,72 @@ def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
|
|
|
54
49
|
return str(cache_dir_path)
|
|
55
50
|
|
|
56
51
|
|
|
52
|
+
def resolve_model_path(download_dir: str) -> str:
|
|
53
|
+
"""Resolve the path to the directory containing the model config files and weights.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
download_dir:
|
|
57
|
+
The download directory
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The path to the model.
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
InvalidModel:
|
|
64
|
+
If the model path is not valid, or if required files are missing.
|
|
65
|
+
"""
|
|
66
|
+
model_path = Path(download_dir)
|
|
67
|
+
# Get the 'path safe' version of the model id, which is the last dir in the path
|
|
68
|
+
model_id_path = model_path.name
|
|
69
|
+
# Hf hub `cache_dir` puts the files in models--`model_id_path`/snapshots
|
|
70
|
+
model_path = model_path / f"models--{model_id_path}" / "snapshots"
|
|
71
|
+
if not model_path.exists():
|
|
72
|
+
raise InvalidModel(
|
|
73
|
+
f"Attempted to load models from the {model_path} directory, "
|
|
74
|
+
"but it does not exist."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Get all files in the model path
|
|
78
|
+
found_files = [
|
|
79
|
+
found_file for found_file in model_path.rglob("*") if found_file.is_file()
|
|
80
|
+
]
|
|
81
|
+
if not found_files:
|
|
82
|
+
raise InvalidModel(f"No model files found at {model_path}")
|
|
83
|
+
|
|
84
|
+
# Make sure that there arent multiples of the files found
|
|
85
|
+
if len(found_files) == len(set(found_files)):
|
|
86
|
+
raise InvalidModel(
|
|
87
|
+
f"Found multiple model config files for {model_id_path.strip('models--')}"
|
|
88
|
+
f"at {model_path}"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Check that found_files contains at least a 'config.json'
|
|
92
|
+
config_file = next(
|
|
93
|
+
(file for file in found_files if file.name == "config.json"), None
|
|
94
|
+
)
|
|
95
|
+
if config_file is None:
|
|
96
|
+
raise InvalidModel(
|
|
97
|
+
f"Missing required file 'config.json' for {model_id_path.strip('models--')}"
|
|
98
|
+
f"at {model_path}"
|
|
99
|
+
)
|
|
100
|
+
model_path = config_file.parent
|
|
101
|
+
|
|
102
|
+
# As a precaution we also check that all of the files are in the same directory
|
|
103
|
+
# if not we create a new dir with symlinks to all of the files from all snapshots
|
|
104
|
+
# this is especially useful for vllm where we can only specify one folder and e.g.,
|
|
105
|
+
# the safetensors version of the weights was added in an unmerged PR
|
|
106
|
+
if not all(
|
|
107
|
+
[found_file.parent == found_files[0].parent for found_file in found_files]
|
|
108
|
+
):
|
|
109
|
+
new_model_path = model_path.parent / "model_files"
|
|
110
|
+
new_model_path.mkdir(exist_ok=True)
|
|
111
|
+
for found_file in found_files:
|
|
112
|
+
Path(new_model_path / found_file.name).symlink_to(found_file)
|
|
113
|
+
model_path = new_model_path
|
|
114
|
+
|
|
115
|
+
return str(model_path)
|
|
116
|
+
|
|
117
|
+
|
|
57
118
|
def clear_memory() -> None:
|
|
58
119
|
"""Clears the memory of unused items."""
|
|
59
120
|
for gc_generation in range(3):
|
|
@@ -84,67 +145,9 @@ def enforce_reproducibility(seed: int = 4242) -> np.random.Generator:
|
|
|
84
145
|
return rng
|
|
85
146
|
|
|
86
147
|
|
|
87
|
-
def
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
This filters warnings from some libraries, sets the logging level to ERROR for some
|
|
91
|
-
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
|
|
92
|
-
disables most of the logging from the `transformers` library.
|
|
93
|
-
"""
|
|
94
|
-
# Ignore miscellaneous warnings
|
|
95
|
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
96
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
97
|
-
warnings.filterwarnings(
|
|
98
|
-
"ignore",
|
|
99
|
-
module="torch.nn.parallel*",
|
|
100
|
-
message="Was asked to gather along dimension 0, but all input tensors were "
|
|
101
|
-
"scalars; will instead unsqueeze and return a vector.",
|
|
102
|
-
)
|
|
103
|
-
warnings.filterwarnings("ignore", module="seqeval*")
|
|
104
|
-
|
|
105
|
-
# Up the logging level, to disable outputs
|
|
106
|
-
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
|
107
|
-
logging.getLogger("absl").setLevel(logging.CRITICAL)
|
|
108
|
-
logging.getLogger("datasets").setLevel(logging.CRITICAL)
|
|
109
|
-
logging.getLogger("openai").setLevel(logging.CRITICAL)
|
|
110
|
-
logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.CRITICAL)
|
|
111
|
-
logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.CRITICAL)
|
|
112
|
-
logging.getLogger("vllm").setLevel(logging.CRITICAL)
|
|
113
|
-
logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
|
|
114
|
-
logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
|
|
115
|
-
logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
|
|
116
|
-
logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
|
|
117
|
-
logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
|
|
118
|
-
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
119
|
-
logging.getLogger("ray._private.worker").setLevel(logging.CRITICAL)
|
|
120
|
-
logging.getLogger("ray._private.services").setLevel(logging.CRITICAL)
|
|
121
|
-
logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
|
|
122
|
-
logging.getLogger("accelerate").setLevel(logging.CRITICAL)
|
|
123
|
-
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
|
124
|
-
logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
|
|
125
|
-
logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
|
|
126
|
-
logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
|
|
127
|
-
|
|
128
|
-
# This suppresses vLLM logging
|
|
129
|
-
os.environ["LOG_LEVEL"] = "CRITICAL"
|
|
130
|
-
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
131
|
-
|
|
132
|
-
if importlib.util.find_spec("ray") is not None:
|
|
133
|
-
ray._private.worker._worker_logs_enabled = False
|
|
134
|
-
|
|
135
|
-
# Disable the tokeniser progress bars
|
|
136
|
-
disable_progress_bar()
|
|
137
|
-
|
|
138
|
-
# Disable most of the `transformers` logging
|
|
139
|
-
tf_logging._default_log_level = logging.CRITICAL
|
|
140
|
-
tf_logging.set_verbosity(logging.CRITICAL)
|
|
141
|
-
logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
|
|
142
|
-
|
|
143
|
-
# Disable logging from `litellm`
|
|
144
|
-
litellm.suppress_debug_info = True
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type | None:
|
|
148
|
+
def get_class_by_name(
|
|
149
|
+
class_name: str | c.Sequence[str], module_name: str
|
|
150
|
+
) -> t.Type | None:
|
|
148
151
|
"""Get a class by its name.
|
|
149
152
|
|
|
150
153
|
Args:
|
|
@@ -173,9 +176,10 @@ def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type |
|
|
|
173
176
|
|
|
174
177
|
if error_messages:
|
|
175
178
|
errors = "\n- " + "\n- ".join(error_messages)
|
|
176
|
-
|
|
179
|
+
log(
|
|
177
180
|
f"Could not find the class with the name(s) {', '.join(class_name)}. The "
|
|
178
|
-
f"following error messages were raised: {errors}"
|
|
181
|
+
f"following error messages were raised: {errors}",
|
|
182
|
+
level=logging.DEBUG,
|
|
179
183
|
)
|
|
180
184
|
|
|
181
185
|
# If the class could not be found, return None
|
|
@@ -197,40 +201,27 @@ def get_min_cuda_compute_capability() -> float | None:
|
|
|
197
201
|
return float(f"{major}.{minor}")
|
|
198
202
|
|
|
199
203
|
|
|
204
|
+
@cache_arguments(disable_condition=lambda: hasattr(sys, "_called_from_test"))
|
|
200
205
|
def internet_connection_available() -> bool:
|
|
201
206
|
"""Checks if internet connection is available by pinging google.com.
|
|
202
207
|
|
|
203
208
|
Returns:
|
|
204
209
|
Whether or not internet connection is available.
|
|
205
210
|
"""
|
|
206
|
-
|
|
207
|
-
requests.get("https://www.google.com")
|
|
208
|
-
return True
|
|
209
|
-
except RequestException:
|
|
210
|
-
return False
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
class HiddenPrints:
|
|
214
|
-
"""Context manager which removes all terminal output."""
|
|
211
|
+
internet_available: bool = False
|
|
215
212
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
213
|
+
try:
|
|
214
|
+
s = socket.create_connection(("1.1.1.1", 80))
|
|
215
|
+
s.close()
|
|
216
|
+
internet_available = True
|
|
217
|
+
except OSError:
|
|
218
|
+
pass
|
|
219
|
+
except Exception as e:
|
|
220
|
+
pytest_socket_errors = ["SocketConnectBlockedError", "SocketBlockedError"]
|
|
221
|
+
if type(e).__name__ not in pytest_socket_errors:
|
|
222
|
+
raise e
|
|
222
223
|
|
|
223
|
-
|
|
224
|
-
self,
|
|
225
|
-
exc_type: t.Type[BaseException],
|
|
226
|
-
exc_val: BaseException,
|
|
227
|
-
exc_tb: "TracebackType",
|
|
228
|
-
) -> None:
|
|
229
|
-
"""Exit the context manager."""
|
|
230
|
-
sys.stdout.close()
|
|
231
|
-
sys.stderr.close()
|
|
232
|
-
sys.stdout = self._original_stdout
|
|
233
|
-
sys.stderr = self._original_stderr
|
|
224
|
+
return internet_available
|
|
234
225
|
|
|
235
226
|
|
|
236
227
|
def raise_if_model_output_contains_nan_values(model_output: "Predictions") -> None:
|
|
@@ -288,34 +279,6 @@ def unscramble(scrambled_text: str) -> str:
|
|
|
288
279
|
return unscrambled
|
|
289
280
|
|
|
290
281
|
|
|
291
|
-
@cache
|
|
292
|
-
def log_once(message: str, level: int = logging.INFO) -> None:
|
|
293
|
-
"""Log a message once.
|
|
294
|
-
|
|
295
|
-
This is ensured by caching the input/output pairs of this function, using the
|
|
296
|
-
`functools.cache` decorator.
|
|
297
|
-
|
|
298
|
-
Args:
|
|
299
|
-
message:
|
|
300
|
-
The message to log.
|
|
301
|
-
level:
|
|
302
|
-
The logging level. Defaults to logging.INFO.
|
|
303
|
-
"""
|
|
304
|
-
match level:
|
|
305
|
-
case logging.DEBUG:
|
|
306
|
-
logger.debug(message)
|
|
307
|
-
case logging.INFO:
|
|
308
|
-
logger.info(message)
|
|
309
|
-
case logging.WARNING:
|
|
310
|
-
logger.warning(message)
|
|
311
|
-
case logging.ERROR:
|
|
312
|
-
logger.error(message)
|
|
313
|
-
case logging.CRITICAL:
|
|
314
|
-
logger.critical(message)
|
|
315
|
-
case _:
|
|
316
|
-
raise ValueError(f"Invalid logging level: {level}")
|
|
317
|
-
|
|
318
|
-
|
|
319
282
|
def get_package_version(package_name: str) -> str | None:
|
|
320
283
|
"""Get the version of a package.
|
|
321
284
|
|
|
@@ -332,9 +295,6 @@ def get_package_version(package_name: str) -> str | None:
|
|
|
332
295
|
return None
|
|
333
296
|
|
|
334
297
|
|
|
335
|
-
T = t.TypeVar("T", bound=object)
|
|
336
|
-
|
|
337
|
-
|
|
338
298
|
def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
|
|
339
299
|
"""Run a coroutine, ensuring that the event loop is always closed when we're done.
|
|
340
300
|
|
|
@@ -348,7 +308,8 @@ def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
|
|
|
348
308
|
loop = asyncio.new_event_loop()
|
|
349
309
|
try:
|
|
350
310
|
asyncio.set_event_loop(loop)
|
|
351
|
-
|
|
311
|
+
response = loop.run_until_complete(coroutine)
|
|
312
|
+
return response
|
|
352
313
|
finally:
|
|
353
314
|
loop.close()
|
|
354
315
|
asyncio.set_event_loop(None)
|
|
@@ -373,3 +334,237 @@ async def add_semaphore_and_catch_exception(
|
|
|
373
334
|
return await coroutine
|
|
374
335
|
except Exception as exc:
|
|
375
336
|
return exc
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def extract_json_dict_from_string(s: str) -> dict | None:
|
|
340
|
+
"""Extract a JSON dictionary from a string.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
s:
|
|
344
|
+
The string to extract the JSON dictionary from.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
The extracted JSON dictionary, or None if no JSON dictionary could be found.
|
|
348
|
+
"""
|
|
349
|
+
json_regex = r"\{[^{}]*?\}"
|
|
350
|
+
if (json_match := re.search(pattern=json_regex, string=s, flags=re.DOTALL)) is None:
|
|
351
|
+
log(
|
|
352
|
+
"The model output does not contain any JSON dictionary, so cannot parse "
|
|
353
|
+
f"it. Skipping. Here is the output: {s!r}",
|
|
354
|
+
level=logging.DEBUG,
|
|
355
|
+
)
|
|
356
|
+
return None
|
|
357
|
+
json_string = json_match.group()
|
|
358
|
+
try:
|
|
359
|
+
json_output = demjson3.decode(txt=json_string)
|
|
360
|
+
except demjson3.JSONDecodeError:
|
|
361
|
+
log(
|
|
362
|
+
"The model output is not valid JSON, so cannot parse it. Skipping. "
|
|
363
|
+
f"Here is the output: {json_string!r}",
|
|
364
|
+
level=logging.DEBUG,
|
|
365
|
+
)
|
|
366
|
+
return None
|
|
367
|
+
if not isinstance(json_output, dict):
|
|
368
|
+
log(
|
|
369
|
+
"The model output is not a JSON dictionary, so cannot parse "
|
|
370
|
+
f"it. Skipping. Here is the output: {json_string!r}",
|
|
371
|
+
level=logging.DEBUG,
|
|
372
|
+
)
|
|
373
|
+
return None
|
|
374
|
+
elif not all(isinstance(key, str) for key in json_output.keys()):
|
|
375
|
+
log(
|
|
376
|
+
"The model output is not a JSON dictionary with string keys, "
|
|
377
|
+
"so cannot parse it. Skipping. Here is the output: "
|
|
378
|
+
f"{json_string!r}",
|
|
379
|
+
level=logging.DEBUG,
|
|
380
|
+
)
|
|
381
|
+
return None
|
|
382
|
+
return json_output
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@cache_arguments()
|
|
386
|
+
def get_hf_token(api_key: str | None) -> str | bool:
|
|
387
|
+
"""Get the Hugging Face token.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
api_key:
|
|
391
|
+
The API key to use as the Hugging Face token. If None, we will try to
|
|
392
|
+
extract it in other ways.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
The Hugging Face token, or True if no token is set but the user is logged in, or
|
|
396
|
+
False if no token is set and the user is not logged in.
|
|
397
|
+
"""
|
|
398
|
+
if api_key is not None:
|
|
399
|
+
log_once(
|
|
400
|
+
"Using the Hugging Face API key passed to the function.",
|
|
401
|
+
level=logging.DEBUG,
|
|
402
|
+
)
|
|
403
|
+
return api_key
|
|
404
|
+
elif (token := os.getenv("HUGGINGFACE_API_KEY")) is not None:
|
|
405
|
+
log_once(
|
|
406
|
+
"Using the Hugging Face API key from the environment variable "
|
|
407
|
+
"`HUGGINGFACE_API_KEY`.",
|
|
408
|
+
level=logging.DEBUG,
|
|
409
|
+
)
|
|
410
|
+
return token
|
|
411
|
+
try:
|
|
412
|
+
hf_hub.whoami()
|
|
413
|
+
log_once(
|
|
414
|
+
"No Hugging Face API key was set, but the user is logged in to Hugging "
|
|
415
|
+
"Face, so using the local token.",
|
|
416
|
+
level=logging.DEBUG,
|
|
417
|
+
)
|
|
418
|
+
return True
|
|
419
|
+
except hf_hub.errors.LocalTokenNotFoundError:
|
|
420
|
+
log_once(
|
|
421
|
+
"No Hugging Face API key was set and the user is not logged in to Hugging "
|
|
422
|
+
"Face, so no token will be used.",
|
|
423
|
+
level=logging.DEBUG,
|
|
424
|
+
)
|
|
425
|
+
return False
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def extract_multiple_choice_labels(
|
|
429
|
+
prompt: str, candidate_labels: c.Sequence[str]
|
|
430
|
+
) -> c.Sequence[str]:
|
|
431
|
+
"""Extract multiple choice labels from a prompt.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
prompt:
|
|
435
|
+
The prompt to extract the labels from.
|
|
436
|
+
candidate_labels:
|
|
437
|
+
The candidate labels to look for in the prompt.
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
The extracted labels.
|
|
441
|
+
"""
|
|
442
|
+
sample_candidate_labels: list[str] = list()
|
|
443
|
+
for candidate_label in candidate_labels:
|
|
444
|
+
candidate_label_match = re.search(
|
|
445
|
+
pattern=rf"\b{candidate_label}\. ", string=prompt, flags=re.IGNORECASE
|
|
446
|
+
)
|
|
447
|
+
if candidate_label_match is not None:
|
|
448
|
+
sample_candidate_labels.append(candidate_label)
|
|
449
|
+
if not sample_candidate_labels:
|
|
450
|
+
raise InvalidBenchmark(
|
|
451
|
+
"Could not extract any candidate labels from the prompt. Please ensure "
|
|
452
|
+
"that the candidate labels are present in the prompt, each followed by a "
|
|
453
|
+
"dot and a space (e.g., 'a. '). The candidate labels are: "
|
|
454
|
+
f"{', '.join(candidate_labels)}. Here is the prompt: {prompt!r}"
|
|
455
|
+
)
|
|
456
|
+
return sample_candidate_labels
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def split_model_id(model_id: str) -> "ModelIdComponents":
|
|
460
|
+
"""Split a model ID into its components.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
model_id:
|
|
464
|
+
The model ID to split.
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
The split model ID.
|
|
468
|
+
|
|
469
|
+
Raises:
|
|
470
|
+
If the model ID is not valid.
|
|
471
|
+
"""
|
|
472
|
+
# Importing here to avoid circular imports
|
|
473
|
+
from .data_models import ModelIdComponents
|
|
474
|
+
|
|
475
|
+
# Attempt to extract the model ID, revision, and param using regex
|
|
476
|
+
model_id_match = re.match(pattern=r"^[^@#]+", string=model_id)
|
|
477
|
+
revision_match = re.search(pattern=r"@([^@#]+)", string=model_id)
|
|
478
|
+
param_match = re.search(pattern=r"#([^@#]+)", string=model_id)
|
|
479
|
+
|
|
480
|
+
# If we cannot extract the model ID, raise an error
|
|
481
|
+
if model_id_match is None:
|
|
482
|
+
raise InvalidModel(f"The model ID {model_id!r} is not valid.")
|
|
483
|
+
model_id = model_id_match.group()
|
|
484
|
+
|
|
485
|
+
# Extract the revision and param and return the result
|
|
486
|
+
revision = revision_match.group(1) if revision_match is not None else "main"
|
|
487
|
+
param = param_match.group(1) if param_match is not None else None
|
|
488
|
+
return ModelIdComponents(model_id=model_id, revision=revision, param=param)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def load_custom_datasets_module() -> ModuleType | None:
|
|
492
|
+
"""Load the custom datasets module if it exists.
|
|
493
|
+
|
|
494
|
+
Raises:
|
|
495
|
+
RuntimeError:
|
|
496
|
+
If the custom datasets module cannot be loaded.
|
|
497
|
+
"""
|
|
498
|
+
custom_datasets_file = Path("custom_datasets.py")
|
|
499
|
+
if custom_datasets_file.exists():
|
|
500
|
+
spec = importlib.util.spec_from_file_location(
|
|
501
|
+
name="custom_datasets_module", location=str(custom_datasets_file.resolve())
|
|
502
|
+
)
|
|
503
|
+
if spec is None:
|
|
504
|
+
log_once(
|
|
505
|
+
"Could not load the spec for the custom datasets file from "
|
|
506
|
+
f"{custom_datasets_file.resolve()}.",
|
|
507
|
+
level=logging.ERROR,
|
|
508
|
+
)
|
|
509
|
+
return None
|
|
510
|
+
module = importlib.util.module_from_spec(spec=spec)
|
|
511
|
+
if spec.loader is None:
|
|
512
|
+
log_once(
|
|
513
|
+
"Could not load the module for the custom datasets file from "
|
|
514
|
+
f"{custom_datasets_file.resolve()}.",
|
|
515
|
+
level=logging.ERROR,
|
|
516
|
+
)
|
|
517
|
+
return None
|
|
518
|
+
spec.loader.exec_module(module)
|
|
519
|
+
return module
|
|
520
|
+
return None
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
class flash_attention_backend:
|
|
524
|
+
"""Context manager to temporarily set the flash attention backend.
|
|
525
|
+
|
|
526
|
+
This sets the `VLLM_ATTENTION_BACKEND` environment variable to `FLASH_ATTN`
|
|
527
|
+
for the duration of the context manager, and restores the previous value afterwards.
|
|
528
|
+
"""
|
|
529
|
+
|
|
530
|
+
def __init__(self, disabled: bool = False) -> None:
|
|
531
|
+
"""Initialise the context manager.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
disabled:
|
|
535
|
+
If True, this context manager does nothing.
|
|
536
|
+
"""
|
|
537
|
+
self.disabled = (
|
|
538
|
+
True if disabled else os.environ["VLLM_ATTENTION_BACKEND"] != "FLASHINFER"
|
|
539
|
+
)
|
|
540
|
+
self.previous_value: str | None = None
|
|
541
|
+
|
|
542
|
+
def __enter__(self) -> None:
|
|
543
|
+
"""Enter the context manager."""
|
|
544
|
+
if self.disabled:
|
|
545
|
+
return
|
|
546
|
+
self.previous_value = os.getenv("VLLM_ATTENTION_BACKEND")
|
|
547
|
+
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
|
|
548
|
+
|
|
549
|
+
def __exit__(
|
|
550
|
+
self,
|
|
551
|
+
exc_type: t.Type[BaseException] | None,
|
|
552
|
+
exc_value: BaseException | None,
|
|
553
|
+
traceback: type[BaseException] | None,
|
|
554
|
+
) -> None:
|
|
555
|
+
"""Exit the context manager.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
exc_type:
|
|
559
|
+
The type of the exception.
|
|
560
|
+
exc_value:
|
|
561
|
+
The value of the exception.
|
|
562
|
+
exc_tb:
|
|
563
|
+
The traceback of the exception.
|
|
564
|
+
"""
|
|
565
|
+
if self.disabled:
|
|
566
|
+
return
|
|
567
|
+
if self.previous_value is None:
|
|
568
|
+
os.environ.pop("VLLM_ATTENTION_BACKEND", None)
|
|
569
|
+
else:
|
|
570
|
+
os.environ["VLLM_ATTENTION_BACKEND"] = self.previous_value
|