ScandEval 16.11.0__py3-none-any.whl → 16.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scandeval/__init__.py +0 -9
- scandeval/async_utils.py +46 -0
- scandeval/benchmark_config_factory.py +31 -2
- scandeval/benchmark_modules/fresh.py +2 -1
- scandeval/benchmark_modules/hf.py +76 -23
- scandeval/benchmark_modules/litellm.py +33 -15
- scandeval/benchmark_modules/vllm.py +97 -44
- scandeval/benchmarker.py +29 -33
- scandeval/cli.py +11 -0
- scandeval/constants.py +36 -2
- scandeval/custom_dataset_configs.py +152 -0
- scandeval/data_loading.py +87 -31
- scandeval/data_models.py +405 -224
- scandeval/dataset_configs/__init__.py +51 -25
- scandeval/dataset_configs/albanian.py +1 -1
- scandeval/dataset_configs/belarusian.py +47 -0
- scandeval/dataset_configs/bulgarian.py +1 -1
- scandeval/dataset_configs/catalan.py +1 -1
- scandeval/dataset_configs/croatian.py +1 -1
- scandeval/dataset_configs/danish.py +3 -2
- scandeval/dataset_configs/dutch.py +16 -5
- scandeval/dataset_configs/english.py +4 -3
- scandeval/dataset_configs/estonian.py +8 -7
- scandeval/dataset_configs/faroese.py +1 -1
- scandeval/dataset_configs/finnish.py +5 -4
- scandeval/dataset_configs/french.py +6 -5
- scandeval/dataset_configs/german.py +4 -3
- scandeval/dataset_configs/greek.py +1 -1
- scandeval/dataset_configs/hungarian.py +1 -1
- scandeval/dataset_configs/icelandic.py +4 -3
- scandeval/dataset_configs/italian.py +4 -3
- scandeval/dataset_configs/latvian.py +2 -2
- scandeval/dataset_configs/lithuanian.py +1 -1
- scandeval/dataset_configs/norwegian.py +6 -5
- scandeval/dataset_configs/polish.py +4 -3
- scandeval/dataset_configs/portuguese.py +5 -4
- scandeval/dataset_configs/romanian.py +2 -2
- scandeval/dataset_configs/serbian.py +1 -1
- scandeval/dataset_configs/slovene.py +1 -1
- scandeval/dataset_configs/spanish.py +4 -3
- scandeval/dataset_configs/swedish.py +4 -3
- scandeval/dataset_configs/ukrainian.py +1 -1
- scandeval/generation_utils.py +6 -6
- scandeval/metrics/__init__.py +1 -0
- scandeval/metrics/bias.py +237 -0
- scandeval/metrics/huggingface.py +2 -1
- scandeval/metrics/llm_as_a_judge.py +1 -1
- scandeval/metrics/pipeline.py +1 -1
- scandeval/model_cache.py +34 -4
- scandeval/prompt_templates/linguistic_acceptability.py +9 -0
- scandeval/prompt_templates/multiple_choice.py +9 -0
- scandeval/prompt_templates/named_entity_recognition.py +21 -0
- scandeval/prompt_templates/reading_comprehension.py +10 -0
- scandeval/prompt_templates/sentiment_classification.py +11 -0
- scandeval/string_utils.py +157 -0
- scandeval/task_group_utils/sequence_classification.py +2 -5
- scandeval/task_group_utils/token_classification.py +2 -4
- scandeval/tasks.py +22 -0
- scandeval/tokenisation_utils.py +12 -1
- scandeval/utils.py +13 -383
- scandeval-16.13.0.dist-info/METADATA +334 -0
- scandeval-16.13.0.dist-info/RECORD +94 -0
- scandeval-16.11.0.dist-info/METADATA +0 -649
- scandeval-16.11.0.dist-info/RECORD +0 -89
- {scandeval-16.11.0.dist-info → scandeval-16.13.0.dist-info}/WHEEL +0 -0
- {scandeval-16.11.0.dist-info → scandeval-16.13.0.dist-info}/entry_points.txt +0 -0
- {scandeval-16.11.0.dist-info → scandeval-16.13.0.dist-info}/licenses/LICENSE +0 -0
scandeval/__init__.py
CHANGED
|
@@ -110,15 +110,6 @@ os.environ["DISABLE_AIOHTTP_TRANSPORT"] = "True"
|
|
|
110
110
|
os.environ["VLLM_USE_V1"] = "1"
|
|
111
111
|
|
|
112
112
|
|
|
113
|
-
# Use the FlashInfer flash-attention backend for vLLM, unless the user has already
|
|
114
|
-
# specified a different backend.
|
|
115
|
-
if os.getenv("VLLM_ATTENTION_BACKEND") is None:
|
|
116
|
-
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
|
|
117
|
-
os.environ["USER_HAS_SET_VLLM_ATTENTION_BACKEND"] = "0"
|
|
118
|
-
else:
|
|
119
|
-
os.environ["USER_HAS_SET_VLLM_ATTENTION_BACKEND"] = "1"
|
|
120
|
-
|
|
121
|
-
|
|
122
113
|
# Set the HF_TOKEN env var to copy the HUGGINGFACE_API_KEY env var, as vLLM uses the
|
|
123
114
|
# former and LiteLLM uses the latter
|
|
124
115
|
if os.getenv("HUGGINGFACE_API_KEY"):
|
scandeval/async_utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Utility functions for asyncronous tasks."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
from .constants import T
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
|
|
10
|
+
"""Run a coroutine, ensuring that the event loop is always closed when we're done.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
coroutine:
|
|
14
|
+
The coroutine to run.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The result of the coroutine.
|
|
18
|
+
"""
|
|
19
|
+
try:
|
|
20
|
+
loop = asyncio.get_event_loop()
|
|
21
|
+
except RuntimeError: # If the current event loop is closed
|
|
22
|
+
loop = asyncio.new_event_loop()
|
|
23
|
+
asyncio.set_event_loop(loop)
|
|
24
|
+
response = loop.run_until_complete(coroutine)
|
|
25
|
+
return response
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def add_semaphore_and_catch_exception(
|
|
29
|
+
coroutine: t.Coroutine[t.Any, t.Any, T], semaphore: asyncio.Semaphore
|
|
30
|
+
) -> T | Exception:
|
|
31
|
+
"""Run a coroutine with a semaphore.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
coroutine:
|
|
35
|
+
The coroutine to run.
|
|
36
|
+
semaphore:
|
|
37
|
+
The semaphore to use.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The result of the coroutine.
|
|
41
|
+
"""
|
|
42
|
+
async with semaphore:
|
|
43
|
+
try:
|
|
44
|
+
return await coroutine
|
|
45
|
+
except Exception as exc:
|
|
46
|
+
return exc
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Factory class for creating dataset configurations."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
+
import importlib.util
|
|
4
5
|
import sys
|
|
5
6
|
import typing as t
|
|
6
7
|
from pathlib import Path
|
|
@@ -13,6 +14,9 @@ from .enums import Device
|
|
|
13
14
|
from .exceptions import InvalidBenchmark
|
|
14
15
|
from .languages import get_all_languages
|
|
15
16
|
|
|
17
|
+
if importlib.util.find_spec("vllm") is not None:
|
|
18
|
+
pass
|
|
19
|
+
|
|
16
20
|
if t.TYPE_CHECKING:
|
|
17
21
|
from .data_models import Language
|
|
18
22
|
|
|
@@ -42,6 +46,8 @@ def build_benchmark_config(
|
|
|
42
46
|
dataset=benchmark_config_params.dataset,
|
|
43
47
|
languages=languages,
|
|
44
48
|
custom_datasets_file=benchmark_config_params.custom_datasets_file,
|
|
49
|
+
api_key=benchmark_config_params.api_key,
|
|
50
|
+
cache_dir=Path(benchmark_config_params.cache_dir),
|
|
45
51
|
)
|
|
46
52
|
|
|
47
53
|
return BenchmarkConfig(
|
|
@@ -68,6 +74,7 @@ def build_benchmark_config(
|
|
|
68
74
|
api_base=benchmark_config_params.api_base,
|
|
69
75
|
api_version=benchmark_config_params.api_version,
|
|
70
76
|
gpu_memory_utilization=benchmark_config_params.gpu_memory_utilization,
|
|
77
|
+
attention_backend=benchmark_config_params.attention_backend,
|
|
71
78
|
generative_type=benchmark_config_params.generative_type,
|
|
72
79
|
debug=benchmark_config_params.debug,
|
|
73
80
|
run_with_cli=benchmark_config_params.run_with_cli,
|
|
@@ -154,7 +161,9 @@ def prepare_dataset_configs(
|
|
|
154
161
|
languages: c.Sequence["Language"],
|
|
155
162
|
dataset: "str | DatasetConfig | c.Sequence[str | DatasetConfig] | None",
|
|
156
163
|
custom_datasets_file: Path,
|
|
157
|
-
|
|
164
|
+
api_key: str | None,
|
|
165
|
+
cache_dir: Path,
|
|
166
|
+
) -> list["DatasetConfig"]:
|
|
158
167
|
"""Prepare dataset config(s) for benchmarking.
|
|
159
168
|
|
|
160
169
|
Args:
|
|
@@ -168,6 +177,10 @@ def prepare_dataset_configs(
|
|
|
168
177
|
included, limited by the `task` and `languages` parameters.
|
|
169
178
|
custom_datasets_file:
|
|
170
179
|
A path to a Python file containing custom dataset configurations.
|
|
180
|
+
api_key:
|
|
181
|
+
The API key to use for accessing the Hugging Face Hub.
|
|
182
|
+
cache_dir:
|
|
183
|
+
The directory to store the cache in.
|
|
171
184
|
|
|
172
185
|
Returns:
|
|
173
186
|
The prepared dataset configs.
|
|
@@ -176,9 +189,25 @@ def prepare_dataset_configs(
|
|
|
176
189
|
InvalidBenchmark:
|
|
177
190
|
If the task or dataset is not found in the benchmark tasks or datasets.
|
|
178
191
|
"""
|
|
192
|
+
# Extract the dataset IDs from the `dataset` argument
|
|
193
|
+
dataset_ids: list[str] = list()
|
|
194
|
+
if isinstance(dataset, str):
|
|
195
|
+
dataset_ids.append(dataset)
|
|
196
|
+
elif isinstance(dataset, DatasetConfig):
|
|
197
|
+
dataset_ids.append(dataset.name)
|
|
198
|
+
elif isinstance(dataset, list):
|
|
199
|
+
for d in dataset:
|
|
200
|
+
if isinstance(d, str):
|
|
201
|
+
dataset_ids.append(d)
|
|
202
|
+
elif isinstance(d, DatasetConfig):
|
|
203
|
+
dataset_ids.append(d.name)
|
|
204
|
+
|
|
179
205
|
# Create the list of dataset configs
|
|
180
206
|
all_dataset_configs = get_all_dataset_configs(
|
|
181
|
-
custom_datasets_file=custom_datasets_file
|
|
207
|
+
custom_datasets_file=custom_datasets_file,
|
|
208
|
+
dataset_ids=dataset_ids,
|
|
209
|
+
api_key=api_key,
|
|
210
|
+
cache_dir=cache_dir,
|
|
182
211
|
)
|
|
183
212
|
all_official_dataset_configs: c.Sequence[DatasetConfig] = [
|
|
184
213
|
dataset_config
|
|
@@ -28,8 +28,9 @@ from ..exceptions import (
|
|
|
28
28
|
)
|
|
29
29
|
from ..generation_utils import raise_if_wrong_params
|
|
30
30
|
from ..logging_utils import block_terminal_output
|
|
31
|
+
from ..model_cache import create_model_cache_dir
|
|
31
32
|
from ..types import Tokeniser
|
|
32
|
-
from ..utils import
|
|
33
|
+
from ..utils import get_hf_token
|
|
33
34
|
from .hf import (
|
|
34
35
|
HuggingFaceEncoderModel,
|
|
35
36
|
align_model_and_tokeniser,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Encoder models from the Hugging Face Hub."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
+
import importlib
|
|
4
5
|
import logging
|
|
5
6
|
import re
|
|
6
7
|
import typing as t
|
|
@@ -63,6 +64,8 @@ from ..exceptions import (
|
|
|
63
64
|
from ..generation_utils import raise_if_wrong_params
|
|
64
65
|
from ..languages import get_all_languages
|
|
65
66
|
from ..logging_utils import block_terminal_output, log, log_once
|
|
67
|
+
from ..model_cache import create_model_cache_dir
|
|
68
|
+
from ..string_utils import split_model_id
|
|
66
69
|
from ..task_group_utils import (
|
|
67
70
|
multiple_choice_classification,
|
|
68
71
|
question_answering,
|
|
@@ -70,13 +73,7 @@ from ..task_group_utils import (
|
|
|
70
73
|
)
|
|
71
74
|
from ..tokenisation_utils import get_bos_token, get_eos_token
|
|
72
75
|
from ..types import Tokeniser
|
|
73
|
-
from ..utils import
|
|
74
|
-
create_model_cache_dir,
|
|
75
|
-
get_class_by_name,
|
|
76
|
-
get_hf_token,
|
|
77
|
-
internet_connection_available,
|
|
78
|
-
split_model_id,
|
|
79
|
-
)
|
|
76
|
+
from ..utils import get_hf_token, internet_connection_available
|
|
80
77
|
from .base import BenchmarkModule
|
|
81
78
|
|
|
82
79
|
try:
|
|
@@ -381,7 +378,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
381
378
|
if "label" in examples:
|
|
382
379
|
try:
|
|
383
380
|
examples["label"] = [
|
|
384
|
-
self._model.config.label2id[lbl.lower()]
|
|
381
|
+
self._model.config.label2id[str(lbl).lower()]
|
|
385
382
|
if self._model.config.label2id is not None
|
|
386
383
|
else lbl
|
|
387
384
|
for lbl in examples["label"]
|
|
@@ -758,20 +755,30 @@ def get_model_repo_info(
|
|
|
758
755
|
# model info object.
|
|
759
756
|
model_info: HfApiModelInfo | None = None
|
|
760
757
|
if Path(model_id).is_dir():
|
|
761
|
-
if
|
|
762
|
-
(Path(model_id) / required_file).exists()
|
|
763
|
-
for required_file in LOCAL_MODELS_REQUIRED_FILES
|
|
764
|
-
):
|
|
758
|
+
if Path(model_id, "config.json").exists():
|
|
765
759
|
log_once(
|
|
766
|
-
f"The local model directory {model_id!r} has
|
|
767
|
-
|
|
768
|
-
"
|
|
760
|
+
f"The local model directory {model_id!r} has a 'config.json' file, so "
|
|
761
|
+
"we're skipping looking up model information from the Hugging Face "
|
|
762
|
+
"Hub.",
|
|
769
763
|
level=logging.DEBUG,
|
|
770
764
|
)
|
|
771
765
|
model_info = HfApiModelInfo(id=model_id, tags=None, pipeline_tag=None)
|
|
766
|
+
elif Path(model_id, "adapter_config.json").exists():
|
|
767
|
+
log_once(
|
|
768
|
+
f"The local model directory {model_id!r} has an 'adapter_config.json' "
|
|
769
|
+
"file, so we're skipping looking up model information from the Hugging "
|
|
770
|
+
"Face Hub.",
|
|
771
|
+
level=logging.DEBUG,
|
|
772
|
+
)
|
|
773
|
+
model_info = HfApiModelInfo(
|
|
774
|
+
id=model_id,
|
|
775
|
+
tags=None,
|
|
776
|
+
pipeline_tag=None,
|
|
777
|
+
siblings=[dict(rfilename="adapter_config.json")],
|
|
778
|
+
)
|
|
772
779
|
else:
|
|
773
780
|
log_once(
|
|
774
|
-
f"The local model directory {model_id} does not contain
|
|
781
|
+
f"The local model directory {model_id} does not contain any of the "
|
|
775
782
|
f"required files: {LOCAL_MODELS_REQUIRED_FILES}. Skipping this "
|
|
776
783
|
f"model.",
|
|
777
784
|
level=logging.WARNING,
|
|
@@ -807,8 +814,8 @@ def get_model_repo_info(
|
|
|
807
814
|
log(
|
|
808
815
|
f"Could not access the model {model_id} with the revision "
|
|
809
816
|
f"{revision}. The error was {str(e)!r}. Please set the "
|
|
810
|
-
"`HUGGINGFACE_API_KEY` environment variable or
|
|
811
|
-
"`--api-key` argument.",
|
|
817
|
+
"`HUGGINGFACE_API_KEY` or `HF_TOKEN` environment variable or "
|
|
818
|
+
"use the `--api-key` argument.",
|
|
812
819
|
level=logging.DEBUG,
|
|
813
820
|
)
|
|
814
821
|
return None
|
|
@@ -876,8 +883,9 @@ def get_model_repo_info(
|
|
|
876
883
|
for tag in GENERATIVE_PIPELINE_TAGS
|
|
877
884
|
for class_name in TASK_MAPPING.get(tag, dict()).values() # type: ignore[attr-defined]
|
|
878
885
|
]
|
|
879
|
-
if class_names is not None and
|
|
880
|
-
class_name in generative_class_names for class_name in class_names
|
|
886
|
+
if class_names is not None and (
|
|
887
|
+
any(class_name in generative_class_names for class_name in class_names)
|
|
888
|
+
or any("ForCausalLM" in class_name for class_name in class_names)
|
|
881
889
|
):
|
|
882
890
|
pipeline_tag = "text-generation"
|
|
883
891
|
else:
|
|
@@ -1084,8 +1092,8 @@ def load_hf_model_config(
|
|
|
1084
1092
|
f"The model {model_id!r} is a gated repository. Please ensure "
|
|
1085
1093
|
"that you are logged in with `hf auth login` or have provided a "
|
|
1086
1094
|
"valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
|
|
1087
|
-
"environment variable or the `--api-key` argument.
|
|
1088
|
-
"your account has access to this model."
|
|
1095
|
+
"or `HF_TOKEN` environment variable or the `--api-key` argument. "
|
|
1096
|
+
"Also check that your account has access to this model."
|
|
1089
1097
|
) from e
|
|
1090
1098
|
raise InvalidModel(
|
|
1091
1099
|
f"Couldn't load model config for {model_id!r}. The error was "
|
|
@@ -1121,7 +1129,11 @@ def load_hf_model_config(
|
|
|
1121
1129
|
)
|
|
1122
1130
|
|
|
1123
1131
|
# Ensure that the PAD token ID is set
|
|
1124
|
-
if
|
|
1132
|
+
if (
|
|
1133
|
+
hasattr(config, "eos_token_id")
|
|
1134
|
+
and config.eos_token_id is not None
|
|
1135
|
+
and (not hasattr(config, "pad_token_id") or config.pad_token_id is None)
|
|
1136
|
+
):
|
|
1125
1137
|
if isinstance(config.eos_token_id, list):
|
|
1126
1138
|
config.pad_token_id = config.eos_token_id[0]
|
|
1127
1139
|
else:
|
|
@@ -1319,3 +1331,44 @@ def task_group_to_class_name(task_group: TaskGroup) -> str:
|
|
|
1319
1331
|
)
|
|
1320
1332
|
pascal_case = special_case_mapping.get(pascal_case, pascal_case)
|
|
1321
1333
|
return f"AutoModelFor{pascal_case}"
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
def get_class_by_name(
|
|
1337
|
+
class_name: str | c.Sequence[str], module_name: str
|
|
1338
|
+
) -> t.Type | None:
|
|
1339
|
+
"""Get a class by its name.
|
|
1340
|
+
|
|
1341
|
+
Args:
|
|
1342
|
+
class_name:
|
|
1343
|
+
The name of the class, written in kebab-case. The corresponding class name
|
|
1344
|
+
must be the same, but written in PascalCase, and lying in a module with the
|
|
1345
|
+
same name, but written in snake_case. If a list of strings is passed, the
|
|
1346
|
+
first class that is found is returned.
|
|
1347
|
+
module_name:
|
|
1348
|
+
The name of the module where the class is located.
|
|
1349
|
+
|
|
1350
|
+
Returns:
|
|
1351
|
+
The class. If the class is not found, None is returned.
|
|
1352
|
+
"""
|
|
1353
|
+
if isinstance(class_name, str):
|
|
1354
|
+
class_name = [class_name]
|
|
1355
|
+
|
|
1356
|
+
error_messages = list()
|
|
1357
|
+
for name in class_name:
|
|
1358
|
+
try:
|
|
1359
|
+
module = importlib.import_module(name=module_name)
|
|
1360
|
+
class_: t.Type = getattr(module, name)
|
|
1361
|
+
return class_
|
|
1362
|
+
except (ModuleNotFoundError, AttributeError) as e:
|
|
1363
|
+
error_messages.append(str(e))
|
|
1364
|
+
|
|
1365
|
+
if error_messages:
|
|
1366
|
+
errors = "\n- " + "\n- ".join(error_messages)
|
|
1367
|
+
log(
|
|
1368
|
+
f"Could not find the class with the name(s) {', '.join(class_name)}. The "
|
|
1369
|
+
f"following error messages were raised: {errors}",
|
|
1370
|
+
level=logging.DEBUG,
|
|
1371
|
+
)
|
|
1372
|
+
|
|
1373
|
+
# If the class could not be found, return None
|
|
1374
|
+
return None
|
|
@@ -40,7 +40,7 @@ from pydantic import ValidationError, conlist, create_model
|
|
|
40
40
|
from requests.exceptions import RequestException
|
|
41
41
|
from tqdm.asyncio import tqdm as tqdm_async
|
|
42
42
|
|
|
43
|
-
from ..
|
|
43
|
+
from ..async_utils import add_semaphore_and_catch_exception, safe_run
|
|
44
44
|
from ..constants import (
|
|
45
45
|
JSON_STRIP_CHARACTERS,
|
|
46
46
|
LITELLM_CLASSIFICATION_OUTPUT_KEY,
|
|
@@ -74,6 +74,8 @@ from ..generation_utils import (
|
|
|
74
74
|
raise_if_wrong_params,
|
|
75
75
|
)
|
|
76
76
|
from ..logging_utils import get_pbar, log, log_once
|
|
77
|
+
from ..model_cache import create_model_cache_dir
|
|
78
|
+
from ..string_utils import split_model_id
|
|
77
79
|
from ..task_group_utils import (
|
|
78
80
|
question_answering,
|
|
79
81
|
sequence_classification,
|
|
@@ -83,13 +85,7 @@ from ..task_group_utils import (
|
|
|
83
85
|
from ..tasks import NER
|
|
84
86
|
from ..tokenisation_utils import get_first_label_token_mapping
|
|
85
87
|
from ..types import ExtractLabelsFunction
|
|
86
|
-
from ..utils import
|
|
87
|
-
add_semaphore_and_catch_exception,
|
|
88
|
-
create_model_cache_dir,
|
|
89
|
-
get_hf_token,
|
|
90
|
-
safe_run,
|
|
91
|
-
split_model_id,
|
|
92
|
-
)
|
|
88
|
+
from ..utils import get_hf_token
|
|
93
89
|
from .base import BenchmarkModule
|
|
94
90
|
from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokeniser
|
|
95
91
|
|
|
@@ -700,10 +696,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
700
696
|
elif isinstance(
|
|
701
697
|
error, (Timeout, ServiceUnavailableError, InternalServerError, SystemError)
|
|
702
698
|
):
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
"Retrying in 10 seconds...",
|
|
706
|
-
level=logging.
|
|
699
|
+
log(
|
|
700
|
+
"Service temporarily unavailable during generation. The error "
|
|
701
|
+
f"message was: {error}. Retrying in 10 seconds...",
|
|
702
|
+
level=logging.INFO,
|
|
707
703
|
)
|
|
708
704
|
return generation_kwargs, 10
|
|
709
705
|
elif isinstance(error, UnsupportedParamsError):
|
|
@@ -764,6 +760,20 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
764
760
|
run_with_cli=self.benchmark_config.run_with_cli,
|
|
765
761
|
) from error
|
|
766
762
|
|
|
763
|
+
if (
|
|
764
|
+
isinstance(error, (BadRequestError, NotFoundError))
|
|
765
|
+
and self.benchmark_config.api_base is not None
|
|
766
|
+
and not self.benchmark_config.api_base.endswith("/v1")
|
|
767
|
+
):
|
|
768
|
+
log_once(
|
|
769
|
+
f"The API base {self.benchmark_config.api_base!r} is not valid. We "
|
|
770
|
+
"will try appending '/v1' to it and try again.",
|
|
771
|
+
level=logging.DEBUG,
|
|
772
|
+
)
|
|
773
|
+
self.benchmark_config.api_base += "/v1"
|
|
774
|
+
generation_kwargs["api_base"] = self.benchmark_config.api_base
|
|
775
|
+
return generation_kwargs, 0
|
|
776
|
+
|
|
767
777
|
raise InvalidBenchmark(
|
|
768
778
|
f"Failed to generate text. The error message was: {error}"
|
|
769
779
|
) from error
|
|
@@ -1390,9 +1400,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1390
1400
|
InternalServerError,
|
|
1391
1401
|
) as e:
|
|
1392
1402
|
log(
|
|
1393
|
-
|
|
1403
|
+
"Service temporarily unavailable while checking for model "
|
|
1404
|
+
f"existence of the model {model_id!r}. The error message was: {e}. "
|
|
1394
1405
|
"Retrying in 10 seconds...",
|
|
1395
|
-
level=logging.
|
|
1406
|
+
level=logging.INFO,
|
|
1396
1407
|
)
|
|
1397
1408
|
sleep(10)
|
|
1398
1409
|
except APIError as e:
|
|
@@ -1567,7 +1578,6 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1567
1578
|
|
|
1568
1579
|
return dataset
|
|
1569
1580
|
|
|
1570
|
-
@cache_arguments()
|
|
1571
1581
|
def get_generation_kwargs(self, dataset_config: DatasetConfig) -> dict[str, t.Any]:
|
|
1572
1582
|
"""Get the generation arguments for the model.
|
|
1573
1583
|
|
|
@@ -1865,6 +1875,14 @@ def clean_model_id(model_id: str, benchmark_config: BenchmarkConfig) -> str:
|
|
|
1865
1875
|
else:
|
|
1866
1876
|
prefix = "openai/"
|
|
1867
1877
|
model_id = prefix + model_id
|
|
1878
|
+
|
|
1879
|
+
# When we want to evaluate an OpenAI model on a custom inference server, such as HF
|
|
1880
|
+
# inference endpoints, LiteLLM gets confused since it's already using the `openai/`
|
|
1881
|
+
# prefix. We thus have to add it twice, and this hack here is to ensure that we
|
|
1882
|
+
# don't store the results with model ID `openai/openai/...`.
|
|
1883
|
+
elif benchmark_config.api_base is not None and model_id.startswith("openai/"):
|
|
1884
|
+
model_id = "openai/openai/" + re.sub(r"(openai/)*", "", model_id)
|
|
1885
|
+
|
|
1868
1886
|
return model_id
|
|
1869
1887
|
|
|
1870
1888
|
|