EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import collections.abc as c
|
|
5
|
+
import json
|
|
5
6
|
import logging
|
|
6
|
-
import os
|
|
7
7
|
import re
|
|
8
8
|
import typing as t
|
|
9
9
|
from functools import cached_property, partial
|
|
@@ -27,16 +27,23 @@ from litellm.exceptions import (
|
|
|
27
27
|
RateLimitError,
|
|
28
28
|
ServiceUnavailableError,
|
|
29
29
|
Timeout,
|
|
30
|
+
UnsupportedParamsError,
|
|
30
31
|
)
|
|
31
32
|
from litellm.llms.vertex_ai.common_utils import VertexAIError
|
|
32
33
|
from litellm.router import Router
|
|
33
|
-
from litellm.types.utils import ChoiceLogprobs
|
|
34
|
+
from litellm.types.utils import ChoiceLogprobs, Logprobs
|
|
35
|
+
from litellm.utils import supports_reasoning, supports_response_schema
|
|
34
36
|
from pydantic import conlist, create_model
|
|
35
37
|
from requests.exceptions import RequestException
|
|
36
38
|
from tqdm.asyncio import tqdm as tqdm_async
|
|
37
|
-
from tqdm.auto import tqdm
|
|
38
39
|
|
|
39
|
-
from ..
|
|
40
|
+
from ..caching_utils import cache_arguments
|
|
41
|
+
from ..constants import (
|
|
42
|
+
JSON_STRIP_CHARACTERS,
|
|
43
|
+
LITELLM_CLASSIFICATION_OUTPUT_KEY,
|
|
44
|
+
MAX_LITELLM_LOGPROBS,
|
|
45
|
+
REASONING_MAX_TOKENS,
|
|
46
|
+
)
|
|
40
47
|
from ..data_models import (
|
|
41
48
|
BenchmarkConfig,
|
|
42
49
|
DatasetConfig,
|
|
@@ -58,34 +65,40 @@ from ..exceptions import (
|
|
|
58
65
|
NeedsEnvironmentVariable,
|
|
59
66
|
NeedsExtraInstalled,
|
|
60
67
|
)
|
|
61
|
-
from ..generation_utils import
|
|
68
|
+
from ..generation_utils import (
|
|
69
|
+
apply_prompt,
|
|
70
|
+
extract_few_shot_examples,
|
|
71
|
+
raise_if_wrong_params,
|
|
72
|
+
)
|
|
73
|
+
from ..logging_utils import get_pbar, log, log_once
|
|
62
74
|
from ..task_group_utils import (
|
|
63
75
|
question_answering,
|
|
64
76
|
sequence_classification,
|
|
65
77
|
text_to_text,
|
|
66
78
|
token_classification,
|
|
67
79
|
)
|
|
68
|
-
from ..
|
|
80
|
+
from ..tasks import NER
|
|
81
|
+
from ..tokenisation_utils import get_first_label_token_mapping
|
|
69
82
|
from ..types import ExtractLabelsFunction
|
|
70
83
|
from ..utils import (
|
|
71
84
|
add_semaphore_and_catch_exception,
|
|
72
85
|
create_model_cache_dir,
|
|
73
|
-
|
|
86
|
+
get_hf_token,
|
|
74
87
|
safe_run,
|
|
88
|
+
split_model_id,
|
|
75
89
|
)
|
|
76
90
|
from .base import BenchmarkModule
|
|
77
|
-
from .hf import HuggingFaceEncoderModel, load_hf_model_config,
|
|
91
|
+
from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokeniser
|
|
78
92
|
|
|
79
93
|
if t.TYPE_CHECKING:
|
|
80
94
|
from datasets import DatasetDict
|
|
81
95
|
from litellm.types.utils import ModelResponse
|
|
82
96
|
from transformers.trainer import Trainer
|
|
83
97
|
|
|
84
|
-
logger = logging.getLogger("euroeval")
|
|
85
|
-
|
|
86
98
|
|
|
87
99
|
VOCAB_SIZE_MAPPING = {
|
|
88
100
|
# OpenAI models
|
|
101
|
+
r"gpt-5-.*": 100_256,
|
|
89
102
|
r"gpt-4-(32k)?(-[0-9]{4})?": 100_256,
|
|
90
103
|
r"gpt-4-[0-9]{4}-preview": 100_256,
|
|
91
104
|
r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
|
|
@@ -104,6 +117,7 @@ VOCAB_SIZE_MAPPING = {
|
|
|
104
117
|
|
|
105
118
|
MODEL_MAX_LENGTH_MAPPING = {
|
|
106
119
|
# OpenAI models
|
|
120
|
+
r"gpt-5-.*": 272_000,
|
|
107
121
|
r"gpt-4(-[0-9]{4})?": 8_191,
|
|
108
122
|
r"gpt-4-32k(-[0-9]{4})?": 32_767,
|
|
109
123
|
r"gpt-4-[0-9]{4}-preview": 128_000,
|
|
@@ -117,6 +131,7 @@ MODEL_MAX_LENGTH_MAPPING = {
|
|
|
117
131
|
r"gpt-4.1.*": 1_047_576,
|
|
118
132
|
# Anthropic models
|
|
119
133
|
r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
|
|
134
|
+
r"(anthropic/)?claude-(opus|sonnet|haiku)-[1-9](-[1-9])?-[0-9]{8}": 200_000,
|
|
120
135
|
# Gemini models
|
|
121
136
|
r"(gemini/)?gemini-1\.5-flash.*": 1_048_576,
|
|
122
137
|
r"(gemini/)?gemini-1\.5-pro.*": 2_097_152,
|
|
@@ -128,6 +143,7 @@ MODEL_MAX_LENGTH_MAPPING = {
|
|
|
128
143
|
|
|
129
144
|
NUM_PARAMS_MAPPING = {
|
|
130
145
|
# OpenAI models
|
|
146
|
+
r"gpt-5-.*": -1,
|
|
131
147
|
r"gpt-4.*": -1,
|
|
132
148
|
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
133
149
|
# Anthropic models
|
|
@@ -141,25 +157,32 @@ NUM_PARAMS_MAPPING = {
|
|
|
141
157
|
}
|
|
142
158
|
|
|
143
159
|
|
|
144
|
-
ALLOWED_PARAMS = {
|
|
145
|
-
# OpenAI models
|
|
146
|
-
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": ["low", "medium", "high"],
|
|
147
|
-
# Anthropic models
|
|
148
|
-
r"(anthropic/)?claude-3-7-sonnet.*": ["no-thinking", "thinking"],
|
|
149
|
-
r"(anthropic/)?claude-(sonnet|opus)-4.*": ["no-thinking", "thinking"],
|
|
150
|
-
# Gemini models
|
|
151
|
-
r"(gemini/)?gemini-2.5-flash-lite.*": ["no-thinking", "thinking"],
|
|
152
|
-
r"(gemini/)?gemini-2.5-flash-[0-9].*": ["no-thinking", "thinking"],
|
|
153
|
-
# xAI models
|
|
154
|
-
r"(xai/)?grok-3-mini(-fast)?(-beta)?": ["low", "medium", "high"],
|
|
155
|
-
}
|
|
156
|
-
|
|
157
|
-
|
|
158
160
|
REASONING_MODELS = [
|
|
159
161
|
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?",
|
|
160
162
|
r"(gemini/)?gemini.*thinking.*",
|
|
161
163
|
r"(gemini/)?gemini-2.5.*",
|
|
162
164
|
r"(xai/)?grok-3-mini.*",
|
|
165
|
+
r".*gpt-oss.*",
|
|
166
|
+
]
|
|
167
|
+
|
|
168
|
+
BASE_DECODER_MODELS = [
|
|
169
|
+
r"gpt-3.5-turbo-instruct.*",
|
|
170
|
+
r"ada-[0-9]{3}",
|
|
171
|
+
r"babbage-[0-9]{3}",
|
|
172
|
+
r"curie-[0-9]{3}",
|
|
173
|
+
r"davinci-[0-9]{3}",
|
|
174
|
+
r"text-davinci-[0-9]{3}",
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
CUSTOM_INFERENCE_API_PREFIXES = [
|
|
178
|
+
"hosted_vllm/",
|
|
179
|
+
"vllm/",
|
|
180
|
+
"ollama/",
|
|
181
|
+
"ollama_chat/",
|
|
182
|
+
"llamafile/",
|
|
183
|
+
"litellm_proxy/",
|
|
184
|
+
"lm_studio/",
|
|
185
|
+
"openai/",
|
|
163
186
|
]
|
|
164
187
|
|
|
165
188
|
|
|
@@ -169,12 +192,34 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
169
192
|
fresh_model = False
|
|
170
193
|
batching_preference = BatchingPreference.ALL_AT_ONCE
|
|
171
194
|
high_priority = False
|
|
195
|
+
allowed_params = {
|
|
196
|
+
# OpenAI models
|
|
197
|
+
re.compile(r"gpt-5-.*"): ["minimal", "low", "medium", "high"],
|
|
198
|
+
re.compile(r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?"): [
|
|
199
|
+
"low",
|
|
200
|
+
"medium",
|
|
201
|
+
"high",
|
|
202
|
+
],
|
|
203
|
+
# Anthropic models
|
|
204
|
+
re.compile(r"(anthropic/)?claude-3-7-sonnet.*"): ["no-thinking", "thinking"],
|
|
205
|
+
re.compile(r"(anthropic/)?claude-(sonnet|opus)-4.*"): [
|
|
206
|
+
"no-thinking",
|
|
207
|
+
"thinking",
|
|
208
|
+
],
|
|
209
|
+
# Gemini models
|
|
210
|
+
re.compile(r"(gemini/)?gemini-2.5-flash-lite.*"): ["no-thinking", "thinking"],
|
|
211
|
+
re.compile(r"(gemini/)?gemini-2.5-flash.*"): ["no-thinking", "thinking"],
|
|
212
|
+
# xAI models
|
|
213
|
+
re.compile(r"(xai/)?grok-3-mini(-fast)?(-beta)?"): ["low", "medium", "high"],
|
|
214
|
+
}
|
|
172
215
|
|
|
173
216
|
def __init__(
|
|
174
217
|
self,
|
|
175
218
|
model_config: ModelConfig,
|
|
176
219
|
dataset_config: DatasetConfig,
|
|
177
220
|
benchmark_config: BenchmarkConfig,
|
|
221
|
+
log_metadata: bool = True,
|
|
222
|
+
**generation_kwargs: dict[str, t.Any],
|
|
178
223
|
) -> None:
|
|
179
224
|
"""Initialise the model.
|
|
180
225
|
|
|
@@ -185,7 +230,16 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
185
230
|
The dataset configuration.
|
|
186
231
|
benchmark_config:
|
|
187
232
|
The benchmark configuration.
|
|
233
|
+
log_metadata:
|
|
234
|
+
Whether to log the model metadata.
|
|
235
|
+
generation_kwargs:
|
|
236
|
+
The generation kwargs to pass to the model. If None, default values will
|
|
237
|
+
be used.
|
|
188
238
|
"""
|
|
239
|
+
raise_if_wrong_params(
|
|
240
|
+
model_config=model_config, allowed_params=self.allowed_params
|
|
241
|
+
)
|
|
242
|
+
|
|
189
243
|
# Detect whether the model is an Ollama model, as we need to extract metadata
|
|
190
244
|
# differently for these models
|
|
191
245
|
self.is_ollama = model_config.model_id.startswith(
|
|
@@ -197,20 +251,22 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
197
251
|
else ollama.ShowResponse(model_info=None)
|
|
198
252
|
)
|
|
199
253
|
|
|
200
|
-
raise_if_wrong_params(model_config=model_config, allowed_params=ALLOWED_PARAMS)
|
|
201
|
-
|
|
202
254
|
super().__init__(
|
|
203
255
|
model_config=model_config,
|
|
204
256
|
dataset_config=dataset_config,
|
|
205
257
|
benchmark_config=benchmark_config,
|
|
258
|
+
log_metadata=log_metadata,
|
|
206
259
|
)
|
|
207
260
|
|
|
261
|
+
self.generation_kwargs = generation_kwargs
|
|
208
262
|
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
209
263
|
dataset_config=self.dataset_config,
|
|
210
264
|
model_config=self.model_config,
|
|
211
|
-
|
|
265
|
+
tokeniser=None,
|
|
212
266
|
generative_type=self.generative_type,
|
|
267
|
+
log_metadata=self.log_metadata,
|
|
213
268
|
)
|
|
269
|
+
self.buffer["max_concurrent_calls"] = 20
|
|
214
270
|
|
|
215
271
|
@property
|
|
216
272
|
def generative_type(self) -> GenerativeType | None:
|
|
@@ -219,29 +275,43 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
219
275
|
Returns:
|
|
220
276
|
The generative type of the model, or None if it has not been set yet.
|
|
221
277
|
"""
|
|
222
|
-
if self.
|
|
278
|
+
if self.benchmark_config.generative_type is not None:
|
|
279
|
+
type_ = self.benchmark_config.generative_type
|
|
280
|
+
elif self.is_ollama:
|
|
223
281
|
reasoning_model = "thinking" in (self._ollama_show.capabilities or [])
|
|
224
|
-
|
|
225
|
-
GenerativeType.REASONING
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
282
|
+
if reasoning_model:
|
|
283
|
+
type_ = GenerativeType.REASONING
|
|
284
|
+
elif self.model_config.model_id.startswith("ollama_chat/"):
|
|
285
|
+
type_ = GenerativeType.INSTRUCTION_TUNED
|
|
286
|
+
else:
|
|
287
|
+
type_ = GenerativeType.BASE
|
|
288
|
+
elif self.model_config.param in {"thinking"}:
|
|
230
289
|
type_ = GenerativeType.REASONING
|
|
231
|
-
elif self.model_config.
|
|
290
|
+
elif self.model_config.param in {"no-thinking"}:
|
|
232
291
|
type_ = GenerativeType.INSTRUCTION_TUNED
|
|
233
292
|
elif re.fullmatch(
|
|
234
|
-
pattern="|".join(REASONING_MODELS),
|
|
293
|
+
pattern="|".join(REASONING_MODELS),
|
|
294
|
+
string=self.model_config.model_id,
|
|
295
|
+
flags=re.IGNORECASE,
|
|
235
296
|
):
|
|
236
297
|
type_ = GenerativeType.REASONING
|
|
298
|
+
elif re.fullmatch(
|
|
299
|
+
pattern="|".join(BASE_DECODER_MODELS),
|
|
300
|
+
string=self.model_config.model_id,
|
|
301
|
+
flags=re.IGNORECASE,
|
|
302
|
+
):
|
|
303
|
+
type_ = GenerativeType.BASE
|
|
304
|
+
elif supports_reasoning(model=self.model_config.model_id):
|
|
305
|
+
type_ = GenerativeType.REASONING
|
|
237
306
|
else:
|
|
238
307
|
type_ = GenerativeType.INSTRUCTION_TUNED
|
|
239
308
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
309
|
+
if self.log_metadata:
|
|
310
|
+
log_once(
|
|
311
|
+
f"Detected generative type {type_.name!r} for model "
|
|
312
|
+
f"{self.model_config.model_id!r}",
|
|
313
|
+
level=logging.DEBUG,
|
|
314
|
+
)
|
|
245
315
|
return type_
|
|
246
316
|
|
|
247
317
|
def generate(self, inputs: dict) -> GenerativeModelOutput:
|
|
@@ -253,143 +323,52 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
253
323
|
|
|
254
324
|
Returns:
|
|
255
325
|
The generated model outputs.
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
InvalidBenchmark:
|
|
329
|
+
If the inputs do not contain either 'messages' or 'text' keys.
|
|
256
330
|
"""
|
|
257
|
-
|
|
258
|
-
|
|
331
|
+
model_inputs: c.Sequence[c.Sequence[litellm.AllMessageValues] | str]
|
|
332
|
+
if "messages" in inputs:
|
|
333
|
+
model_inputs = inputs["messages"]
|
|
334
|
+
elif "text" in inputs:
|
|
335
|
+
model_inputs = inputs["text"]
|
|
336
|
+
else:
|
|
337
|
+
raise InvalidBenchmark(
|
|
338
|
+
"The inputs must contain either 'messages' or 'text' keys."
|
|
339
|
+
)
|
|
259
340
|
|
|
260
341
|
# Get the mapping from labels to the first token in the label. We call this each
|
|
261
342
|
# time we generate a new dataset since the dataset config can change
|
|
262
343
|
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
263
344
|
dataset_config=self.dataset_config,
|
|
264
345
|
model_config=self.model_config,
|
|
265
|
-
|
|
346
|
+
tokeniser=None,
|
|
266
347
|
generative_type=self.generative_type,
|
|
348
|
+
log_metadata=self.log_metadata,
|
|
267
349
|
)
|
|
268
350
|
|
|
269
|
-
# Set the core generation arguments
|
|
270
|
-
generation_kwargs: dict[str, t.Any] = dict(
|
|
271
|
-
model=self.model_config.model_id,
|
|
272
|
-
max_completion_tokens=(
|
|
273
|
-
REASONING_MAX_TOKENS
|
|
274
|
-
if self.generative_type == GenerativeType.REASONING
|
|
275
|
-
else self.dataset_config.max_generated_tokens
|
|
276
|
-
),
|
|
277
|
-
stop=[],
|
|
278
|
-
temperature=0.0,
|
|
279
|
-
seed=4242,
|
|
280
|
-
api_key=self.benchmark_config.api_key,
|
|
281
|
-
api_base=self.benchmark_config.api_base,
|
|
282
|
-
api_version=self.benchmark_config.api_version,
|
|
283
|
-
max_retries=3,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
# Set up the `response_format` generation argument if we are dealing with a task
|
|
287
|
-
# using structured generation
|
|
288
|
-
if self.dataset_config.task in TASKS_USING_JSON:
|
|
289
|
-
# Sanity check that "JSON" is included in the prompt, as some models require
|
|
290
|
-
# this
|
|
291
|
-
for conversation in conversations:
|
|
292
|
-
if not conversation:
|
|
293
|
-
raise InvalidBenchmark(
|
|
294
|
-
"Encountered an empty conversation in 'messages'."
|
|
295
|
-
)
|
|
296
|
-
last_message = conversation[-1]
|
|
297
|
-
assert isinstance(last_message, dict), (
|
|
298
|
-
f"Expected dict message, got {type(last_message)}"
|
|
299
|
-
)
|
|
300
|
-
assert "content" in last_message, (
|
|
301
|
-
"Expected 'content' key in the last message of the conversation."
|
|
302
|
-
)
|
|
303
|
-
assert isinstance(last_message["content"], str), (
|
|
304
|
-
"Expected 'content' to be a string."
|
|
305
|
-
)
|
|
306
|
-
assert "json" in last_message["content"].lower(), (
|
|
307
|
-
"Prompt must contain 'json' for JSON tasks."
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
if self.generative_type == GenerativeType.REASONING:
|
|
311
|
-
log_once(
|
|
312
|
-
f"The model {self.model_config.model_id!r} is a reasoning model "
|
|
313
|
-
"and thus does not support structured generation, so we do not "
|
|
314
|
-
"enable it.",
|
|
315
|
-
level=logging.DEBUG,
|
|
316
|
-
)
|
|
317
|
-
elif litellm.utils.supports_response_schema(
|
|
318
|
-
model=self.model_config.model_id
|
|
319
|
-
):
|
|
320
|
-
ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
|
|
321
|
-
keys_and_their_types: dict[str, t.Any] = {
|
|
322
|
-
tag_name: (conlist(str, max_length=5), ...)
|
|
323
|
-
for tag_name in ner_tag_names
|
|
324
|
-
}
|
|
325
|
-
pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
|
|
326
|
-
generation_kwargs["response_format"] = pydantic_class
|
|
327
|
-
log_once(
|
|
328
|
-
"Enabling structured generation for model "
|
|
329
|
-
f"{self.model_config.model_id!r} with the JSON schema "
|
|
330
|
-
f"{pydantic_class.model_json_schema()}",
|
|
331
|
-
level=logging.DEBUG,
|
|
332
|
-
)
|
|
333
|
-
else:
|
|
334
|
-
generation_kwargs["response_format"] = dict(type="json_object")
|
|
335
|
-
log_once(
|
|
336
|
-
"Enabling structured JSON generation for model "
|
|
337
|
-
f"{self.model_config.model_id!r} with no custom JSON schema, as "
|
|
338
|
-
"the model does not support schemas.",
|
|
339
|
-
level=logging.DEBUG,
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
# If the model is an Ollama reasoning model, we ensure that thinking is enabled
|
|
343
|
-
if self.is_ollama and self.generative_type == GenerativeType.REASONING:
|
|
344
|
-
generation_kwargs["think"] = True
|
|
345
|
-
log_once(
|
|
346
|
-
"Enabling thinking mode for Ollama model "
|
|
347
|
-
f"{self.model_config.model_id!r}",
|
|
348
|
-
level=logging.DEBUG,
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
# Handle manually set parameters
|
|
352
|
-
if self.buffer["first_label_token_mapping"]:
|
|
353
|
-
generation_kwargs["logprobs"] = True
|
|
354
|
-
generation_kwargs["top_logprobs"] = MAX_LOGPROBS
|
|
355
|
-
if self.model_config.revision == "thinking":
|
|
356
|
-
generation_kwargs["thinking"] = dict(
|
|
357
|
-
type="enabled", budget_tokens=REASONING_MAX_TOKENS - 1
|
|
358
|
-
)
|
|
359
|
-
log_once(
|
|
360
|
-
f"Enabling thinking mode for model {self.model_config.model_id!r}",
|
|
361
|
-
level=logging.DEBUG,
|
|
362
|
-
)
|
|
363
|
-
elif self.model_config.revision == "no-thinking":
|
|
364
|
-
generation_kwargs["thinking"] = dict(type="disabled", budget_tokens=0)
|
|
365
|
-
log_once(
|
|
366
|
-
f"Disabling thinking mode for model {self.model_config.model_id!r}",
|
|
367
|
-
level=logging.DEBUG,
|
|
368
|
-
)
|
|
369
|
-
elif self.model_config.revision in {"low", "medium", "high"}:
|
|
370
|
-
generation_kwargs["reasoning_effort"] = self.model_config.revision
|
|
371
|
-
log_once(
|
|
372
|
-
f"Enabling reasoning effort {self.model_config.revision!r} for model "
|
|
373
|
-
f"{self.model_config.model_id!r}",
|
|
374
|
-
level=logging.DEBUG,
|
|
375
|
-
)
|
|
376
|
-
|
|
377
|
-
# Drop generation kwargs that are not supported by the model
|
|
378
|
-
litellm.drop_params = True
|
|
379
|
-
|
|
380
351
|
all_responses: dict[int, "ModelResponse"] = {}
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
)
|
|
352
|
+
inputs_to_run: c.Sequence[
|
|
353
|
+
tuple[int, c.Sequence[litellm.AllMessageValues] | str]
|
|
354
|
+
] = list(enumerate(model_inputs))
|
|
384
355
|
for attempt in range(num_attempts := 10):
|
|
385
|
-
if not
|
|
356
|
+
if not inputs_to_run:
|
|
386
357
|
break
|
|
387
358
|
|
|
388
|
-
|
|
359
|
+
generation_kwargs = self.generation_kwargs or self.get_generation_kwargs(
|
|
360
|
+
dataset_config=self.dataset_config
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
batch_indices, batch_inputs = zip(*inputs_to_run)
|
|
389
364
|
successes, failures = safe_run(
|
|
390
365
|
self._generate_async(
|
|
391
|
-
model_id=
|
|
392
|
-
|
|
366
|
+
model_id=clean_model_id(
|
|
367
|
+
model_id=self.model_config.model_id,
|
|
368
|
+
benchmark_config=self.benchmark_config,
|
|
369
|
+
),
|
|
370
|
+
inputs=list(batch_inputs),
|
|
371
|
+
max_concurrent_calls=self.buffer["max_concurrent_calls"],
|
|
393
372
|
**generation_kwargs,
|
|
394
373
|
)
|
|
395
374
|
)
|
|
@@ -401,23 +380,50 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
401
380
|
|
|
402
381
|
# If all requests were successful, break
|
|
403
382
|
if not failures:
|
|
404
|
-
|
|
383
|
+
inputs_to_run = []
|
|
405
384
|
break
|
|
406
385
|
|
|
407
386
|
# Put the failed requests back in the queue to try again
|
|
408
|
-
|
|
409
|
-
(batch_indices[idx],
|
|
387
|
+
inputs_to_run = [
|
|
388
|
+
(batch_indices[idx], model_inputs[batch_indices[idx]])
|
|
410
389
|
for idx, _ in failures
|
|
411
390
|
]
|
|
412
|
-
|
|
391
|
+
log(
|
|
413
392
|
f"Attempt {attempt + 1:,}/{num_attempts:,}: retrying "
|
|
414
|
-
f"{len(
|
|
393
|
+
f"{len(inputs_to_run):,} failed message(s). Here is the first error: "
|
|
394
|
+
f"{failures[0][1]}.",
|
|
395
|
+
level=logging.DEBUG,
|
|
415
396
|
)
|
|
416
397
|
|
|
398
|
+
# Check if any errors are due to HTTP 429 (too many requests), in which case
|
|
399
|
+
# we reduce the number of concurrent calls
|
|
400
|
+
http_429_errors = [
|
|
401
|
+
idx
|
|
402
|
+
for idx, (_, error) in enumerate(failures)
|
|
403
|
+
if isinstance(error, RateLimitError) and "Error code: 429" in str(error)
|
|
404
|
+
]
|
|
405
|
+
if http_429_errors and self.buffer["max_concurrent_calls"] > 1:
|
|
406
|
+
failures = [
|
|
407
|
+
failures[i]
|
|
408
|
+
for i in range(len(failures))
|
|
409
|
+
if i not in http_429_errors
|
|
410
|
+
]
|
|
411
|
+
self.buffer["max_concurrent_calls"] = max(
|
|
412
|
+
1, self.buffer["max_concurrent_calls"] // 2
|
|
413
|
+
)
|
|
414
|
+
log(
|
|
415
|
+
f"Reducing the maximum number of concurrent calls to "
|
|
416
|
+
f"{self.buffer['max_concurrent_calls']:,} due to rate limiting.",
|
|
417
|
+
level=logging.DEBUG,
|
|
418
|
+
)
|
|
419
|
+
continue
|
|
420
|
+
|
|
417
421
|
# Attempt to handle the exceptions, to improve the chance of getting
|
|
418
422
|
# successful generations next time around
|
|
419
423
|
for _, error in failures:
|
|
420
|
-
self._handle_exception(
|
|
424
|
+
generation_kwargs = self._handle_exception(
|
|
425
|
+
error=error, **generation_kwargs
|
|
426
|
+
)
|
|
421
427
|
|
|
422
428
|
# Sleep for a second to avoid pinging the API server too quickly
|
|
423
429
|
sleep(1)
|
|
@@ -427,22 +433,20 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
427
433
|
)
|
|
428
434
|
|
|
429
435
|
# Extract the generations from the model output
|
|
430
|
-
ordered_responses = [all_responses[i] for i in range(len(
|
|
436
|
+
ordered_responses = [all_responses[i] for i in range(len(model_inputs))]
|
|
431
437
|
model_output = self._create_model_output(
|
|
432
438
|
model_responses=ordered_responses, model_id=self.model_config.model_id
|
|
433
439
|
)
|
|
434
440
|
|
|
435
|
-
if len(
|
|
441
|
+
if len(model_inputs) != len(model_output.sequences):
|
|
436
442
|
raise InvalidBenchmark(
|
|
437
|
-
f"Number of model inputs ({len(
|
|
443
|
+
f"Number of model inputs ({len(model_inputs):,}) does not match the "
|
|
438
444
|
f"number of model outputs ({len(model_output.sequences):,})."
|
|
439
445
|
)
|
|
440
446
|
|
|
441
447
|
return model_output
|
|
442
448
|
|
|
443
|
-
def _handle_exception(
|
|
444
|
-
self, error: Exception, generation_kwargs: dict[str, t.Any]
|
|
445
|
-
) -> None:
|
|
449
|
+
def _handle_exception(self, error: Exception, **generation_kwargs) -> dict:
|
|
446
450
|
"""Handle an exception from the model.
|
|
447
451
|
|
|
448
452
|
Args:
|
|
@@ -450,26 +454,47 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
450
454
|
The exception to handle.
|
|
451
455
|
generation_kwargs:
|
|
452
456
|
The generation kwargs to pass to the model.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
The updated generation kwargs to pass to the model.
|
|
453
460
|
"""
|
|
454
461
|
error_msg = str(error).lower()
|
|
455
462
|
model_id = self.model_config.model_id
|
|
456
463
|
|
|
457
464
|
# Error messages that we want to catch and handle
|
|
458
|
-
stop_messages = [
|
|
465
|
+
stop_messages = [
|
|
466
|
+
"stop_sequences",
|
|
467
|
+
"'stop' is not supported with this model",
|
|
468
|
+
"'$.stop' is invalid",
|
|
469
|
+
]
|
|
470
|
+
stop_pattern = re.compile(r"does not support parameters: \[.*'stop'.*\]")
|
|
459
471
|
logprobs_messages = [
|
|
460
472
|
"you are not allowed to request logprobs",
|
|
461
473
|
"you've reached the maximum number of requests with logprobs",
|
|
462
474
|
"logprobs is not supported",
|
|
463
475
|
"logprobs is not enabled",
|
|
476
|
+
"Invalid value at 'generation_config.response_logprobs' (TYPE_BOOL)",
|
|
464
477
|
]
|
|
478
|
+
logprobs_pattern = re.compile(
|
|
479
|
+
r"does not support parameters: \[.*'logprobs'.*\]"
|
|
480
|
+
)
|
|
481
|
+
top_logprobs_messages = ["got an unexpected keyword argument 'top_logprobs'"]
|
|
482
|
+
top_logprobs_pattern = re.compile(
|
|
483
|
+
r"does not support parameters: \[.*'top_logprobs'.*\]"
|
|
484
|
+
)
|
|
485
|
+
max_completion_tokens_pattern = re.compile(
|
|
486
|
+
r"does not support parameters: \[.*'max_completion_tokens'.*\]"
|
|
487
|
+
)
|
|
465
488
|
temperature_messages = [
|
|
466
489
|
"'temperature' is not supported with this model.",
|
|
467
490
|
"temperature is not supported with this model",
|
|
491
|
+
r"does not support parameters: \[.*'temperature'.*\]",
|
|
468
492
|
]
|
|
469
493
|
temperature_must_be_one_messages = [
|
|
470
494
|
"`temperature` may only be set to 1",
|
|
471
495
|
"'temperature' does not support 0.0 with this model. Only the default "
|
|
472
496
|
"(1) value is supported",
|
|
497
|
+
"Only temperature=1 is supported",
|
|
473
498
|
]
|
|
474
499
|
max_items_messages = ["'maxItems' is not permitted."]
|
|
475
500
|
no_json_schema_messages = ["Property keys should match pattern"]
|
|
@@ -477,17 +502,27 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
477
502
|
r"the thinking budget [0-9]+ is invalid. please choose a value between "
|
|
478
503
|
r"[0-9]+ and ([0-9]+)\."
|
|
479
504
|
)
|
|
505
|
+
requires_thinking_disabled_messages = ["thinking.type: Field required"]
|
|
506
|
+
seed_pattern = re.compile(r"does not support parameters: \[.*'seed'.*\]")
|
|
507
|
+
response_format_messages = [
|
|
508
|
+
"got an unexpected keyword argument 'response_format'",
|
|
509
|
+
"the model returned empty outputs",
|
|
510
|
+
]
|
|
480
511
|
|
|
481
|
-
if
|
|
512
|
+
if (
|
|
513
|
+
any(msg.lower() in error_msg for msg in stop_messages)
|
|
514
|
+
or stop_pattern.search(string=error_msg) is not None
|
|
515
|
+
):
|
|
482
516
|
log_once(
|
|
483
517
|
f"The model {model_id!r} does not support "
|
|
484
518
|
"stop sequences, so disabling them.",
|
|
485
519
|
level=logging.DEBUG,
|
|
486
520
|
)
|
|
487
521
|
generation_kwargs["stop"] = None
|
|
488
|
-
return
|
|
522
|
+
return generation_kwargs
|
|
489
523
|
elif (
|
|
490
524
|
any(msg.lower() in error_msg for msg in logprobs_messages)
|
|
525
|
+
or logprobs_pattern.search(string=error_msg) is not None
|
|
491
526
|
# Special case for Vertex AI models, since they have strict rate
|
|
492
527
|
# limits on using logprobs. They also have a cap of 5 logprobs, but
|
|
493
528
|
# we ignore this since the rate limiting makes it unusable anyway.
|
|
@@ -497,9 +532,32 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
497
532
|
f"The model {model_id!r} does not support logprobs, so disabling it.",
|
|
498
533
|
level=logging.DEBUG,
|
|
499
534
|
)
|
|
535
|
+
self.buffer["first_label_token_mapping"] = False
|
|
500
536
|
generation_kwargs.pop("logprobs", None)
|
|
501
537
|
generation_kwargs.pop("top_logprobs", None)
|
|
502
|
-
|
|
538
|
+
generation_kwargs.pop("response_format", None)
|
|
539
|
+
return generation_kwargs
|
|
540
|
+
elif (
|
|
541
|
+
any(msg.lower() in error_msg for msg in top_logprobs_messages)
|
|
542
|
+
or top_logprobs_pattern.search(string=error_msg) is not None
|
|
543
|
+
):
|
|
544
|
+
log_once(
|
|
545
|
+
f"The model {model_id!r} does not support the `top_logprobs` argument, "
|
|
546
|
+
"so moving the value to `logprobs`.",
|
|
547
|
+
level=logging.DEBUG,
|
|
548
|
+
)
|
|
549
|
+
generation_kwargs["logprobs"] = generation_kwargs.pop("top_logprobs", None)
|
|
550
|
+
return generation_kwargs
|
|
551
|
+
elif max_completion_tokens_pattern.search(string=error_msg):
|
|
552
|
+
log_once(
|
|
553
|
+
f"The model {model_id!r} does not support max_completion_tokens, so "
|
|
554
|
+
"disabling it.",
|
|
555
|
+
level=logging.DEBUG,
|
|
556
|
+
)
|
|
557
|
+
generation_kwargs["max_tokens"] = generation_kwargs.pop(
|
|
558
|
+
"max_completion_tokens", None
|
|
559
|
+
)
|
|
560
|
+
return generation_kwargs
|
|
503
561
|
elif any(msg.lower() in error_msg for msg in temperature_messages):
|
|
504
562
|
log_once(
|
|
505
563
|
f"The model {model_id!r} does not support "
|
|
@@ -507,7 +565,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
507
565
|
level=logging.DEBUG,
|
|
508
566
|
)
|
|
509
567
|
generation_kwargs.pop("temperature", None)
|
|
510
|
-
return
|
|
568
|
+
return generation_kwargs
|
|
511
569
|
elif any(msg.lower() in error_msg for msg in temperature_must_be_one_messages):
|
|
512
570
|
log_once(
|
|
513
571
|
f"The model {model_id!r} requires "
|
|
@@ -515,8 +573,11 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
515
573
|
level=logging.DEBUG,
|
|
516
574
|
)
|
|
517
575
|
generation_kwargs["temperature"] = 1.0
|
|
518
|
-
return
|
|
519
|
-
elif
|
|
576
|
+
return generation_kwargs
|
|
577
|
+
elif (
|
|
578
|
+
any(msg.lower() in error_msg for msg in max_items_messages)
|
|
579
|
+
and self.dataset_config.task == NER
|
|
580
|
+
):
|
|
520
581
|
log_once(
|
|
521
582
|
f"The model {model_id!r} does not support "
|
|
522
583
|
"maxItems in the JSON schema, so disabling it.",
|
|
@@ -524,11 +585,11 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
524
585
|
)
|
|
525
586
|
ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
|
|
526
587
|
keys_and_their_types = {
|
|
527
|
-
tag_name: (
|
|
588
|
+
tag_name: (c.Sequence[str], ...) for tag_name in ner_tag_names
|
|
528
589
|
}
|
|
529
590
|
pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
|
|
530
591
|
generation_kwargs["response_format"] = pydantic_class
|
|
531
|
-
return
|
|
592
|
+
return generation_kwargs
|
|
532
593
|
elif any(msg.lower() in error_msg for msg in no_json_schema_messages):
|
|
533
594
|
log_once(
|
|
534
595
|
f"The model {self.model_config.model_id!r} does not support "
|
|
@@ -536,7 +597,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
536
597
|
level=logging.DEBUG,
|
|
537
598
|
)
|
|
538
599
|
generation_kwargs["response_format"] = dict(type="json_object")
|
|
539
|
-
return
|
|
600
|
+
return generation_kwargs
|
|
540
601
|
elif thinking_match := thinking_budget_pattern.search(string=error_msg):
|
|
541
602
|
thinking_budget = int(thinking_match.group(1))
|
|
542
603
|
if thinking_budget >= REASONING_MAX_TOKENS:
|
|
@@ -545,7 +606,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
545
606
|
f"{thinking_budget:,} tokens, which is within the limit of "
|
|
546
607
|
f"{REASONING_MAX_TOKENS:,} tokens. This should not happen. The "
|
|
547
608
|
f"error message was: {error_msg}."
|
|
548
|
-
)
|
|
609
|
+
) from error
|
|
549
610
|
log_once(
|
|
550
611
|
f"The model {model_id!r} can at most use {thinking_budget:,} tokens "
|
|
551
612
|
"for reasoning, which is less than the default of "
|
|
@@ -556,59 +617,135 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
556
617
|
generation_kwargs["thinking"] = dict(
|
|
557
618
|
type="enabled", budget_tokens=thinking_budget - 1
|
|
558
619
|
)
|
|
559
|
-
return
|
|
620
|
+
return generation_kwargs
|
|
621
|
+
elif (
|
|
622
|
+
any(msg.lower() in error_msg for msg in requires_thinking_disabled_messages)
|
|
623
|
+
and self.generative_type != GenerativeType.REASONING
|
|
624
|
+
):
|
|
625
|
+
log_once(
|
|
626
|
+
f"The model {model_id!r} requires the `thinking.type` field to be "
|
|
627
|
+
f"set to `disabled` rather than just setting `budget_tokens` to 0. "
|
|
628
|
+
"Setting `thinking.type` to `disabled`.",
|
|
629
|
+
level=logging.DEBUG,
|
|
630
|
+
)
|
|
631
|
+
generation_kwargs["thinking"] = dict(type="disabled")
|
|
632
|
+
return generation_kwargs
|
|
633
|
+
elif re.search(pattern=seed_pattern, string=error_msg):
|
|
634
|
+
log_once(
|
|
635
|
+
f"The model {model_id!r} does not support the `seed` parameter, so "
|
|
636
|
+
"disabling it.",
|
|
637
|
+
level=logging.DEBUG,
|
|
638
|
+
)
|
|
639
|
+
generation_kwargs.pop("seed", None)
|
|
640
|
+
return generation_kwargs
|
|
641
|
+
elif any(msg.lower() in error_msg for msg in response_format_messages):
|
|
642
|
+
log_once(
|
|
643
|
+
f"The model {model_id!r} does not support the `response_format` "
|
|
644
|
+
"parameter, so disabling it.",
|
|
645
|
+
level=logging.DEBUG,
|
|
646
|
+
)
|
|
647
|
+
generation_kwargs.pop("response_format", None)
|
|
648
|
+
return generation_kwargs
|
|
649
|
+
# If there are too many I/O connections, we increase the number of allowed file
|
|
650
|
+
# descriptors
|
|
651
|
+
elif "too many open files" in error_msg:
|
|
652
|
+
raise InvalidBenchmark(
|
|
653
|
+
"There are too many file descriptors running. See the current "
|
|
654
|
+
"value by running `ulimit -n`. Try increasing it by running "
|
|
655
|
+
"`ulimit -n <new-value>` and try again."
|
|
656
|
+
) from error
|
|
560
657
|
elif isinstance(
|
|
561
658
|
error, (Timeout, ServiceUnavailableError, InternalServerError, SystemError)
|
|
562
659
|
):
|
|
563
|
-
|
|
660
|
+
log(
|
|
564
661
|
f"Service temporarily unavailable. The error message was: {error}. "
|
|
565
|
-
|
|
662
|
+
"Retrying in 10 seconds...",
|
|
663
|
+
level=logging.DEBUG,
|
|
664
|
+
)
|
|
665
|
+
sleep(10)
|
|
666
|
+
return generation_kwargs
|
|
667
|
+
elif isinstance(error, UnsupportedParamsError):
|
|
668
|
+
unsupported_param_match = re.search(
|
|
669
|
+
pattern=r"(?<=does not support parameters\: \[')([^ ']+)(?='\])",
|
|
670
|
+
string=error.message,
|
|
566
671
|
)
|
|
567
|
-
|
|
568
|
-
|
|
672
|
+
if unsupported_param_match is None:
|
|
673
|
+
raise InvalidModel(error.message) from error
|
|
674
|
+
else:
|
|
675
|
+
unsupported_param = unsupported_param_match.group(0)
|
|
676
|
+
raise InvalidModel(
|
|
677
|
+
f"The model {model_id!r} does not support the parameter "
|
|
678
|
+
f"{unsupported_param!r}. Try again without this parameter. "
|
|
679
|
+
"Skipping this model."
|
|
680
|
+
) from error
|
|
569
681
|
elif isinstance(error, (APIConnectionError, OSError)):
|
|
570
|
-
# If there are too many I/O connections, we increase the number of allowed
|
|
571
|
-
# file descriptors
|
|
572
|
-
if "too many open files" in error_msg:
|
|
573
|
-
raise InvalidBenchmark(
|
|
574
|
-
"There are too many file descriptors running. See the current "
|
|
575
|
-
"value by running `ulimit -n`. Try increasing it by running "
|
|
576
|
-
"`ulimit -n <new-value>` and try again."
|
|
577
|
-
)
|
|
578
682
|
raise InvalidBenchmark(
|
|
579
683
|
f"Encountered {type(error)} during generation: {error}."
|
|
580
|
-
)
|
|
684
|
+
) from error
|
|
581
685
|
|
|
582
|
-
if isinstance(error,
|
|
686
|
+
if isinstance(error, NotFoundError):
|
|
583
687
|
raise InvalidModel(
|
|
688
|
+
f"The model {model_id!r} was not found. Please check the model ID "
|
|
689
|
+
"and try again."
|
|
690
|
+
) from error
|
|
691
|
+
|
|
692
|
+
if isinstance(error, RateLimitError):
|
|
693
|
+
log(
|
|
584
694
|
f"You have encountered your rate limit for model {model_id!r}. "
|
|
585
|
-
"
|
|
695
|
+
"Retrying in 10 seconds...",
|
|
696
|
+
level=logging.DEBUG,
|
|
697
|
+
)
|
|
698
|
+
sleep(10)
|
|
699
|
+
return generation_kwargs
|
|
700
|
+
|
|
701
|
+
if (
|
|
702
|
+
isinstance(error, BadRequestError)
|
|
703
|
+
and (
|
|
704
|
+
retry_match := re.search(
|
|
705
|
+
pattern=r"\bretry in ([0-9]+(.[0-9]+)?) ?(s|seconds)\b",
|
|
706
|
+
string=error_msg,
|
|
707
|
+
)
|
|
708
|
+
)
|
|
709
|
+
is not None
|
|
710
|
+
):
|
|
711
|
+
retry_seconds = float(retry_match.group(1))
|
|
712
|
+
log(
|
|
713
|
+
f"Bad request error encountered. Retrying in {retry_seconds:.1f} "
|
|
714
|
+
"seconds...",
|
|
715
|
+
level=logging.DEBUG,
|
|
586
716
|
)
|
|
717
|
+
sleep(retry_seconds)
|
|
718
|
+
return generation_kwargs
|
|
587
719
|
|
|
588
720
|
if isinstance(error, AuthenticationError):
|
|
589
721
|
raise NeedsAdditionalArgument(
|
|
590
722
|
cli_argument="--api-key",
|
|
591
723
|
script_argument="api_key=<your-api-key>",
|
|
592
724
|
run_with_cli=self.benchmark_config.run_with_cli,
|
|
593
|
-
)
|
|
725
|
+
) from error
|
|
594
726
|
|
|
595
727
|
raise InvalidBenchmark(
|
|
596
728
|
f"Failed to generate text. The error message was: {error}"
|
|
597
|
-
)
|
|
729
|
+
) from error
|
|
598
730
|
|
|
599
731
|
async def _generate_async(
|
|
600
732
|
self,
|
|
601
733
|
model_id: str,
|
|
602
|
-
|
|
734
|
+
inputs: c.Sequence[c.Sequence[litellm.AllMessageValues] | str],
|
|
735
|
+
max_concurrent_calls: int,
|
|
603
736
|
**generation_kwargs,
|
|
604
|
-
) -> tuple[
|
|
737
|
+
) -> tuple[
|
|
738
|
+
c.Sequence[tuple[int, "ModelResponse"]], c.Sequence[tuple[int, Exception]]
|
|
739
|
+
]:
|
|
605
740
|
"""Generate outputs from the model asynchronously.
|
|
606
741
|
|
|
607
742
|
Args:
|
|
608
743
|
model_id:
|
|
609
744
|
The ID of the model to use for generation.
|
|
610
|
-
|
|
611
|
-
The
|
|
745
|
+
inputs:
|
|
746
|
+
The inputs to pass to the model.
|
|
747
|
+
max_concurrent_calls:
|
|
748
|
+
The maximum number of concurrent calls to make to the model.
|
|
612
749
|
**generation_kwargs:
|
|
613
750
|
Additional generation arguments to pass to the model.
|
|
614
751
|
|
|
@@ -621,24 +758,67 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
621
758
|
# for all the requests, preventing "too many open files" errors
|
|
622
759
|
router = Router(
|
|
623
760
|
model_list=[
|
|
624
|
-
|
|
761
|
+
litellm.DeploymentTypedDict(
|
|
625
762
|
model_name=self.model_config.model_id,
|
|
626
|
-
litellm_params=
|
|
763
|
+
litellm_params=litellm.LiteLLMParamsTypedDict(model=model_id),
|
|
627
764
|
)
|
|
628
765
|
]
|
|
629
766
|
)
|
|
630
767
|
|
|
631
768
|
# Get the LLM generations asynchronously
|
|
632
|
-
max_concurrent_calls = 20
|
|
633
769
|
semaphore = asyncio.Semaphore(max_concurrent_calls)
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
770
|
+
if self.generative_type == GenerativeType.BASE:
|
|
771
|
+
if not all(isinstance(input_, str) for input_ in inputs):
|
|
772
|
+
raise InvalidBenchmark(
|
|
773
|
+
"For base generative models, all inputs must be strings."
|
|
774
|
+
)
|
|
775
|
+
requests = [
|
|
776
|
+
add_semaphore_and_catch_exception(
|
|
777
|
+
router.atext_completion(
|
|
778
|
+
model=clean_model_id(
|
|
779
|
+
model_id=model_id, benchmark_config=self.benchmark_config
|
|
780
|
+
),
|
|
781
|
+
prompt=input_,
|
|
782
|
+
**generation_kwargs,
|
|
783
|
+
),
|
|
784
|
+
semaphore=semaphore,
|
|
785
|
+
)
|
|
786
|
+
for input_ in inputs
|
|
787
|
+
if isinstance(input_, str)
|
|
788
|
+
]
|
|
789
|
+
else:
|
|
790
|
+
if not all(isinstance(input_, list) for input_ in inputs):
|
|
791
|
+
raise InvalidBenchmark(
|
|
792
|
+
"For instruction-tuned and reasoning generative models, all "
|
|
793
|
+
"inputs must be lists of messages."
|
|
794
|
+
)
|
|
795
|
+
requests = [
|
|
796
|
+
add_semaphore_and_catch_exception(
|
|
797
|
+
router.acompletion(
|
|
798
|
+
model=clean_model_id(
|
|
799
|
+
model_id=model_id, benchmark_config=self.benchmark_config
|
|
800
|
+
),
|
|
801
|
+
messages=input_,
|
|
802
|
+
**generation_kwargs,
|
|
803
|
+
),
|
|
804
|
+
semaphore=semaphore,
|
|
805
|
+
)
|
|
806
|
+
for input_ in inputs
|
|
807
|
+
if isinstance(input_, list)
|
|
808
|
+
]
|
|
809
|
+
responses = await tqdm_async.gather(
|
|
810
|
+
*requests, colour="yellow", ascii="—▰", leave=False
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# If the outputs are empty, convert them to exceptions
|
|
814
|
+
if all(
|
|
815
|
+
not isinstance(response, Exception)
|
|
816
|
+
and response.choices[0].message.content == "{}"
|
|
817
|
+
for response in responses
|
|
818
|
+
):
|
|
819
|
+
responses = [ValueError("The model returned empty outputs.")] * len(
|
|
820
|
+
responses
|
|
638
821
|
)
|
|
639
|
-
for conversation in conversations
|
|
640
|
-
]
|
|
641
|
-
responses = await tqdm_async.gather(*requests, leave=False)
|
|
642
822
|
|
|
643
823
|
# Separate the successful responses from the failed ones
|
|
644
824
|
successes = [
|
|
@@ -655,13 +835,18 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
655
835
|
# Close connections
|
|
656
836
|
for request in requests:
|
|
657
837
|
if hasattr(request, "close"):
|
|
658
|
-
|
|
838
|
+
try:
|
|
839
|
+
request.close()
|
|
840
|
+
except RuntimeError as e:
|
|
841
|
+
log(
|
|
842
|
+
f"RuntimeError during request.close(): {e}", level=logging.DEBUG
|
|
843
|
+
)
|
|
659
844
|
|
|
660
845
|
return successes, failures
|
|
661
846
|
|
|
662
847
|
@staticmethod
|
|
663
848
|
def _create_model_output(
|
|
664
|
-
model_responses:
|
|
849
|
+
model_responses: c.Sequence["ModelResponse"], model_id: str
|
|
665
850
|
) -> GenerativeModelOutput:
|
|
666
851
|
"""Create a GenerativeModelOutput object from a list of ModelResponse objects.
|
|
667
852
|
|
|
@@ -680,45 +865,104 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
680
865
|
for model_response in model_responses:
|
|
681
866
|
if not model_response.choices:
|
|
682
867
|
sequences.append("")
|
|
683
|
-
|
|
868
|
+
log(
|
|
684
869
|
f"The model {model_id!r} did not end up "
|
|
685
870
|
"generating any text. This is likely because the model ran "
|
|
686
|
-
"out of tokens while reasoning. Returning an empty string."
|
|
871
|
+
"out of tokens while reasoning. Returning an empty string.",
|
|
872
|
+
level=logging.WARNING,
|
|
687
873
|
)
|
|
688
874
|
continue
|
|
689
875
|
|
|
690
876
|
model_response_choices = model_response.choices[0]
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
877
|
+
|
|
878
|
+
if isinstance(model_response_choices, litellm.Choices):
|
|
879
|
+
generated_message: litellm.Message = model_response_choices.message
|
|
880
|
+
generation_output = generated_message.content or ""
|
|
881
|
+
generation_output = generation_output.strip()
|
|
882
|
+
elif isinstance(model_response_choices, litellm.litellm.TextChoices):
|
|
883
|
+
generation_output = model_response_choices.text or ""
|
|
884
|
+
else:
|
|
885
|
+
raise InvalidBenchmark(
|
|
886
|
+
"The model response choices must be of type Choices or "
|
|
887
|
+
f"TextChoices. Got {type(model_response_choices)}."
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# In the case where we're dealing with a classification task, the model is
|
|
891
|
+
# outputting a JSON dictionary, so we will extract the generated text from
|
|
892
|
+
# within the dictionary
|
|
893
|
+
generation_dct: dict[str, t.Any] | None = None
|
|
894
|
+
if LITELLM_CLASSIFICATION_OUTPUT_KEY in generation_output:
|
|
895
|
+
try:
|
|
896
|
+
generation_dct = json.loads(generation_output)
|
|
897
|
+
assert isinstance(generation_dct, dict)
|
|
898
|
+
if set(generation_dct.keys()) == {
|
|
899
|
+
LITELLM_CLASSIFICATION_OUTPUT_KEY
|
|
900
|
+
}:
|
|
901
|
+
generation_output = str(
|
|
902
|
+
generation_dct[LITELLM_CLASSIFICATION_OUTPUT_KEY]
|
|
903
|
+
).strip()
|
|
904
|
+
except json.JSONDecodeError:
|
|
905
|
+
pass
|
|
695
906
|
|
|
696
907
|
# Structure the model output as a GenerativeModelOutput object
|
|
697
908
|
sequences.append(generation_output)
|
|
698
|
-
if
|
|
909
|
+
if (
|
|
910
|
+
hasattr(model_response_choices, "logprobs")
|
|
911
|
+
and model_response_choices.logprobs is not None
|
|
912
|
+
):
|
|
699
913
|
logprobs_obj = model_response_choices.logprobs
|
|
914
|
+
|
|
915
|
+
if not isinstance(logprobs_obj, (Logprobs, ChoiceLogprobs)):
|
|
916
|
+
log_once(
|
|
917
|
+
"The logprobs object is malformed, so we won't use logprobs to "
|
|
918
|
+
"determine the labels.",
|
|
919
|
+
level=logging.WARNING,
|
|
920
|
+
)
|
|
921
|
+
continue
|
|
922
|
+
|
|
923
|
+
logprobs_list: c.Sequence[c.Sequence[tuple[str, float]]]
|
|
700
924
|
if isinstance(logprobs_obj, ChoiceLogprobs):
|
|
701
|
-
logprobs_list
|
|
925
|
+
logprobs_list = [
|
|
702
926
|
[
|
|
703
927
|
(top_logprob.token, top_logprob.logprob)
|
|
704
928
|
for top_logprob in content.top_logprobs
|
|
705
929
|
]
|
|
706
|
-
for content in
|
|
930
|
+
for content in logprobs_obj.content or list()
|
|
707
931
|
]
|
|
708
|
-
scores.append(logprobs_list)
|
|
709
932
|
else:
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
933
|
+
logprobs_list = [
|
|
934
|
+
[
|
|
935
|
+
(token, logprob)
|
|
936
|
+
for token, logprob in (top_logprobs_dct or dict()).items()
|
|
937
|
+
]
|
|
938
|
+
for top_logprobs_dct in logprobs_obj.top_logprobs or list()
|
|
939
|
+
]
|
|
940
|
+
|
|
941
|
+
# If the model outputted a JSON dictionary, we need to find the
|
|
942
|
+
# token index of the value within the dictionary, rather than the
|
|
943
|
+
# first token of the entire output
|
|
944
|
+
if generation_dct:
|
|
945
|
+
key_name = next(iter(generation_dct.keys()))
|
|
946
|
+
logprobs_list = [
|
|
947
|
+
lst
|
|
948
|
+
for lst in logprobs_list
|
|
949
|
+
if (
|
|
950
|
+
lst
|
|
951
|
+
and lst[0]
|
|
952
|
+
and (token := lst[0][0].strip(JSON_STRIP_CHARACTERS))
|
|
953
|
+
and not key_name.startswith(token)
|
|
954
|
+
)
|
|
955
|
+
]
|
|
956
|
+
|
|
957
|
+
scores.append(logprobs_list)
|
|
715
958
|
|
|
716
959
|
if not sequences:
|
|
717
|
-
|
|
960
|
+
log(
|
|
718
961
|
"No sequences were generated by the model "
|
|
719
962
|
f"{model_id!r}. This may be due to the "
|
|
720
963
|
"model running out of tokens or an issue with the input data. "
|
|
721
|
-
"Returning an empty GenerativeModelOutput."
|
|
964
|
+
"Returning an empty GenerativeModelOutput.",
|
|
965
|
+
level=logging.WARNING,
|
|
722
966
|
)
|
|
723
967
|
return GenerativeModelOutput(sequences=[], scores=None)
|
|
724
968
|
|
|
@@ -778,9 +1022,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
778
1022
|
repo_info = hf_api.model_info(
|
|
779
1023
|
repo_id=model_id,
|
|
780
1024
|
revision="main",
|
|
781
|
-
token=
|
|
782
|
-
or self.benchmark_config.api_key
|
|
783
|
-
or True,
|
|
1025
|
+
token=get_hf_token(api_key=self.benchmark_config.api_key),
|
|
784
1026
|
)
|
|
785
1027
|
except (
|
|
786
1028
|
RepositoryNotFoundError,
|
|
@@ -837,10 +1079,11 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
837
1079
|
run_with_cli=self.benchmark_config.run_with_cli,
|
|
838
1080
|
)
|
|
839
1081
|
|
|
840
|
-
|
|
1082
|
+
tokeniser = load_tokeniser(
|
|
841
1083
|
model=None,
|
|
842
1084
|
model_id=model_id,
|
|
843
1085
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
1086
|
+
model_config=self.model_config,
|
|
844
1087
|
)
|
|
845
1088
|
|
|
846
1089
|
if (
|
|
@@ -849,10 +1092,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
849
1092
|
):
|
|
850
1093
|
vocab_size = hf_config.vocab_size
|
|
851
1094
|
elif (
|
|
852
|
-
hasattr(
|
|
853
|
-
and
|
|
1095
|
+
hasattr(tokeniser, "vocab_size")
|
|
1096
|
+
and tokeniser.vocab_size is not None
|
|
854
1097
|
):
|
|
855
|
-
vocab_size =
|
|
1098
|
+
vocab_size = tokeniser.vocab_size
|
|
856
1099
|
else:
|
|
857
1100
|
vocab_size = -1
|
|
858
1101
|
return vocab_size
|
|
@@ -883,13 +1126,15 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
883
1126
|
if context_length_keys:
|
|
884
1127
|
context_length = model_info[context_length_keys[0]]
|
|
885
1128
|
if context_length is not None:
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
1129
|
+
if self.log_metadata:
|
|
1130
|
+
log_once(
|
|
1131
|
+
f"Detected context length key "
|
|
1132
|
+
f"{context_length_keys[0]!r} for Ollama model "
|
|
1133
|
+
f"{ollama_model_id!r}",
|
|
1134
|
+
level=logging.DEBUG,
|
|
1135
|
+
)
|
|
891
1136
|
return int(context_length)
|
|
892
|
-
|
|
1137
|
+
elif self.log_metadata:
|
|
893
1138
|
log_once(
|
|
894
1139
|
f"Tried to get the maximum length of the Ollama model "
|
|
895
1140
|
f"{ollama_model_id!r}, but could not find a context length. "
|
|
@@ -917,26 +1162,27 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
917
1162
|
run_with_cli=self.benchmark_config.run_with_cli,
|
|
918
1163
|
)
|
|
919
1164
|
|
|
920
|
-
|
|
1165
|
+
tokeniser = load_tokeniser(
|
|
921
1166
|
model=None,
|
|
922
1167
|
model_id=model_id,
|
|
923
1168
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
1169
|
+
model_config=self.model_config,
|
|
924
1170
|
)
|
|
925
1171
|
|
|
926
1172
|
all_max_lengths: list[int] = list()
|
|
927
1173
|
|
|
928
|
-
# Add the registered max length of the
|
|
1174
|
+
# Add the registered max length of the tokeniser
|
|
929
1175
|
if hasattr(
|
|
930
|
-
|
|
931
|
-
) and
|
|
932
|
-
all_max_lengths.append(
|
|
1176
|
+
tokeniser, "model_max_length"
|
|
1177
|
+
) and tokeniser.model_max_length < int(1e30):
|
|
1178
|
+
all_max_lengths.append(tokeniser.model_max_length)
|
|
933
1179
|
|
|
934
1180
|
# Add the max length derived from the model's input sizes
|
|
935
|
-
if hasattr(
|
|
1181
|
+
if hasattr(tokeniser, "max_model_input_sizes"):
|
|
936
1182
|
all_max_lengths.extend(
|
|
937
1183
|
[
|
|
938
1184
|
size
|
|
939
|
-
for size in
|
|
1185
|
+
for size in tokeniser.max_model_input_sizes.values()
|
|
940
1186
|
if size is not None
|
|
941
1187
|
]
|
|
942
1188
|
)
|
|
@@ -970,7 +1216,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
970
1216
|
return -1
|
|
971
1217
|
|
|
972
1218
|
@property
|
|
973
|
-
def data_collator(self) -> c.Callable[[
|
|
1219
|
+
def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
|
|
974
1220
|
"""The data collator used to prepare samples during finetuning.
|
|
975
1221
|
|
|
976
1222
|
Returns:
|
|
@@ -995,6 +1241,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
995
1241
|
return partial(
|
|
996
1242
|
sequence_classification.extract_labels_from_generation,
|
|
997
1243
|
dataset_config=self.dataset_config,
|
|
1244
|
+
model_config=self.model_config,
|
|
998
1245
|
first_label_token_mapping=self.buffer["first_label_token_mapping"],
|
|
999
1246
|
)
|
|
1000
1247
|
case TaskGroup.TEXT_TO_TEXT:
|
|
@@ -1038,13 +1285,15 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1038
1285
|
Whether the model exists, or an error describing why we cannot check
|
|
1039
1286
|
whether the model exists.
|
|
1040
1287
|
"""
|
|
1041
|
-
model_id
|
|
1288
|
+
model_id = split_model_id(model_id=model_id).model_id
|
|
1042
1289
|
if model_id in litellm.model_list:
|
|
1043
1290
|
return True
|
|
1044
1291
|
|
|
1045
1292
|
# Separate check for Ollama models
|
|
1046
1293
|
if model_id.startswith("ollama/") or model_id.startswith("ollama_chat/"):
|
|
1047
|
-
ollama_model_exists = try_download_ollama_model(
|
|
1294
|
+
ollama_model_exists = try_download_ollama_model(
|
|
1295
|
+
model_id=model_id, progress_bar=benchmark_config.progress_bar
|
|
1296
|
+
)
|
|
1048
1297
|
if ollama_model_exists:
|
|
1049
1298
|
return ollama_model_exists
|
|
1050
1299
|
|
|
@@ -1053,7 +1302,9 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1053
1302
|
try:
|
|
1054
1303
|
litellm.completion(
|
|
1055
1304
|
messages=[dict(role="user", content="X")],
|
|
1056
|
-
model=
|
|
1305
|
+
model=clean_model_id(
|
|
1306
|
+
model_id=model_id, benchmark_config=benchmark_config
|
|
1307
|
+
),
|
|
1057
1308
|
max_tokens=1,
|
|
1058
1309
|
api_key=benchmark_config.api_key,
|
|
1059
1310
|
api_base=benchmark_config.api_base,
|
|
@@ -1070,20 +1321,31 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1070
1321
|
ServiceUnavailableError,
|
|
1071
1322
|
InternalServerError,
|
|
1072
1323
|
) as e:
|
|
1073
|
-
|
|
1324
|
+
log(
|
|
1074
1325
|
f"Service temporarily unavailable. The error message was: {e}. "
|
|
1075
|
-
"Retrying in 10 seconds..."
|
|
1326
|
+
"Retrying in 10 seconds...",
|
|
1327
|
+
level=logging.DEBUG,
|
|
1076
1328
|
)
|
|
1077
|
-
sleep(
|
|
1329
|
+
sleep(10)
|
|
1078
1330
|
except APIError as e:
|
|
1079
1331
|
if "'503 Service Unavailable" not in str(e):
|
|
1080
1332
|
raise e
|
|
1081
|
-
|
|
1333
|
+
log(
|
|
1082
1334
|
f"Failed to check if model {model_id!r} exists. Retrying in 10 "
|
|
1083
|
-
"seconds..."
|
|
1335
|
+
"seconds...",
|
|
1336
|
+
level=logging.WARNING,
|
|
1084
1337
|
)
|
|
1085
1338
|
sleep(10)
|
|
1086
1339
|
except (BadRequestError, NotFoundError):
|
|
1340
|
+
# In case we're using `api_base`, try again with the `/v1` suffix
|
|
1341
|
+
if (
|
|
1342
|
+
benchmark_config.api_base is not None
|
|
1343
|
+
and not benchmark_config.api_base.endswith("/v1")
|
|
1344
|
+
):
|
|
1345
|
+
benchmark_config.api_base += "/v1"
|
|
1346
|
+
continue
|
|
1347
|
+
|
|
1348
|
+
# Check for misspelled model IDs
|
|
1087
1349
|
candidate_models = [
|
|
1088
1350
|
candidate_model_id
|
|
1089
1351
|
for candidate_model_id in litellm.model_list
|
|
@@ -1093,21 +1355,25 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1093
1355
|
case 0:
|
|
1094
1356
|
pass
|
|
1095
1357
|
case 1:
|
|
1096
|
-
|
|
1358
|
+
log(
|
|
1097
1359
|
f"Could not find the model ID {model_id!r}. Did you mean "
|
|
1098
|
-
f"{candidate_models[0]!r}?"
|
|
1360
|
+
f"{candidate_models[0]!r}?",
|
|
1361
|
+
level=logging.WARNING,
|
|
1099
1362
|
)
|
|
1100
1363
|
case _:
|
|
1101
1364
|
candidate_models_str = "', '".join(candidate_models)
|
|
1102
|
-
|
|
1365
|
+
log(
|
|
1103
1366
|
f"Could not find the model ID {model_id!r}. Did you mean "
|
|
1104
|
-
|
|
1367
|
+
"any of the following model IDs: "
|
|
1368
|
+
f"'{candidate_models_str}'?",
|
|
1369
|
+
level=logging.WARNING,
|
|
1105
1370
|
)
|
|
1106
1371
|
return False
|
|
1107
1372
|
else:
|
|
1108
|
-
|
|
1373
|
+
log(
|
|
1109
1374
|
f"Failed to check if model {model_id!r} exists after {num_attempts} "
|
|
1110
|
-
"attempts. Assuming it does not exist."
|
|
1375
|
+
"attempts. Assuming it does not exist.",
|
|
1376
|
+
level=logging.ERROR,
|
|
1111
1377
|
)
|
|
1112
1378
|
return False
|
|
1113
1379
|
|
|
@@ -1126,10 +1392,30 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1126
1392
|
Returns:
|
|
1127
1393
|
The model configuration.
|
|
1128
1394
|
"""
|
|
1129
|
-
|
|
1395
|
+
model_id_components = split_model_id(model_id=model_id)
|
|
1396
|
+
|
|
1397
|
+
# Backwards compatibility: If the revision is set but not the parameter, we
|
|
1398
|
+
# assume that the revision is actually the parameter and log this as a warning.
|
|
1399
|
+
if model_id_components.revision != "main" and model_id_components.param is None:
|
|
1400
|
+
proper_model_id = (
|
|
1401
|
+
f"{model_id_components.model_id}#{model_id_components.revision}"
|
|
1402
|
+
)
|
|
1403
|
+
log_once(
|
|
1404
|
+
f"The model ID {model_id!r} specifies a revision "
|
|
1405
|
+
f"{model_id_components.revision!r} but not a parameter. We assume "
|
|
1406
|
+
"that the revision is actually the parameter and set the revision "
|
|
1407
|
+
"to 'main'. In the future, use the new '#' syntax to specify the "
|
|
1408
|
+
f"parameter (in this case, this would be {proper_model_id!r}), as this "
|
|
1409
|
+
"will be an error in future versions of EuroEval.",
|
|
1410
|
+
level=logging.WARNING,
|
|
1411
|
+
)
|
|
1412
|
+
model_id_components.param = model_id_components.revision
|
|
1413
|
+
model_id_components.revision = "main"
|
|
1414
|
+
|
|
1130
1415
|
return ModelConfig(
|
|
1131
|
-
model_id=model_id,
|
|
1132
|
-
revision=revision,
|
|
1416
|
+
model_id=model_id_components.model_id,
|
|
1417
|
+
revision=model_id_components.revision,
|
|
1418
|
+
param=model_id_components.param,
|
|
1133
1419
|
task="text-generation",
|
|
1134
1420
|
languages=list(),
|
|
1135
1421
|
merge=False,
|
|
@@ -1184,7 +1470,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1184
1470
|
|
|
1185
1471
|
if self.benchmark_config.few_shot:
|
|
1186
1472
|
few_shot_examples = extract_few_shot_examples(
|
|
1187
|
-
dataset=dataset,
|
|
1473
|
+
dataset=dataset,
|
|
1474
|
+
dataset_config=self.dataset_config,
|
|
1475
|
+
benchmark_config=self.benchmark_config,
|
|
1476
|
+
itr_idx=itr_idx,
|
|
1188
1477
|
)
|
|
1189
1478
|
else:
|
|
1190
1479
|
few_shot_examples = list()
|
|
@@ -1195,9 +1484,9 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1195
1484
|
few_shot_examples=few_shot_examples,
|
|
1196
1485
|
model_config=self.model_config,
|
|
1197
1486
|
dataset_config=self.dataset_config,
|
|
1198
|
-
|
|
1487
|
+
generative_type=self.generative_type,
|
|
1199
1488
|
always_populate_text_field=False,
|
|
1200
|
-
|
|
1489
|
+
tokeniser=None,
|
|
1201
1490
|
),
|
|
1202
1491
|
batched=True,
|
|
1203
1492
|
load_from_cache_file=False,
|
|
@@ -1206,46 +1495,169 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1206
1495
|
|
|
1207
1496
|
return dataset
|
|
1208
1497
|
|
|
1498
|
+
@cache_arguments()
|
|
1499
|
+
def get_generation_kwargs(self, dataset_config: DatasetConfig) -> dict[str, t.Any]:
|
|
1500
|
+
"""Get the generation arguments for the model.
|
|
1209
1501
|
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1502
|
+
Args:
|
|
1503
|
+
dataset_config:
|
|
1504
|
+
The dataset configuration, which is used to determine the generative
|
|
1505
|
+
type of the model. We use this as an argument here rather than using
|
|
1506
|
+
`self.dataset_config` to ensure that that the cache is updated when the
|
|
1507
|
+
dataset configuration changes.
|
|
1214
1508
|
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1509
|
+
Returns:
|
|
1510
|
+
The generation arguments for the model.
|
|
1511
|
+
"""
|
|
1512
|
+
# Set the core generation arguments
|
|
1513
|
+
generation_kwargs: dict[str, t.Any] = dict(
|
|
1514
|
+
max_completion_tokens=(
|
|
1515
|
+
REASONING_MAX_TOKENS
|
|
1516
|
+
if self.generative_type == GenerativeType.REASONING
|
|
1517
|
+
else dataset_config.max_generated_tokens
|
|
1518
|
+
),
|
|
1519
|
+
stop=[],
|
|
1520
|
+
temperature=0.0,
|
|
1521
|
+
seed=4242,
|
|
1522
|
+
api_key=self.benchmark_config.api_key,
|
|
1523
|
+
api_base=self.benchmark_config.api_base,
|
|
1524
|
+
api_version=self.benchmark_config.api_version,
|
|
1525
|
+
max_retries=3,
|
|
1526
|
+
)
|
|
1220
1527
|
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
if param not in allowed_params_list:
|
|
1231
|
-
msg = (
|
|
1232
|
-
f"Invalid parameter {param!r} for model {model_config.model_id!r}."
|
|
1528
|
+
# Set up the `response_format` generation argument if we are dealing with a task
|
|
1529
|
+
# using structured generation
|
|
1530
|
+
if dataset_config.task.uses_structured_output:
|
|
1531
|
+
if self.generative_type == GenerativeType.REASONING:
|
|
1532
|
+
log_once(
|
|
1533
|
+
f"The model {self.model_config.model_id!r} is a reasoning model "
|
|
1534
|
+
"and thus does not support structured generation, so we do not "
|
|
1535
|
+
"enable it.",
|
|
1536
|
+
level=logging.DEBUG,
|
|
1233
1537
|
)
|
|
1234
|
-
|
|
1235
|
-
|
|
1538
|
+
elif supports_response_schema(model=self.model_config.model_id):
|
|
1539
|
+
if dataset_config.task == NER:
|
|
1540
|
+
ner_tag_names = list(dataset_config.prompt_label_mapping.values())
|
|
1541
|
+
keys_and_their_types: dict[str, t.Any] = {
|
|
1542
|
+
tag_name: (conlist(str, max_length=5), ...)
|
|
1543
|
+
for tag_name in ner_tag_names
|
|
1544
|
+
}
|
|
1545
|
+
pydantic_class = create_model(
|
|
1546
|
+
"AnswerFormat", **keys_and_their_types
|
|
1547
|
+
)
|
|
1236
1548
|
else:
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1549
|
+
raise InvalidBenchmark(
|
|
1550
|
+
"This task requires structured generation, but it has not "
|
|
1551
|
+
"been implemented for this task yet. Please open an issue "
|
|
1552
|
+
"at https://github.com/EuroEval/EuroEval/issues."
|
|
1553
|
+
)
|
|
1554
|
+
generation_kwargs["response_format"] = pydantic_class
|
|
1555
|
+
log_once(
|
|
1556
|
+
"Enabling structured generation for model "
|
|
1557
|
+
f"{self.model_config.model_id!r} with the JSON schema "
|
|
1558
|
+
f"{pydantic_class.model_json_schema()}",
|
|
1559
|
+
level=logging.DEBUG,
|
|
1560
|
+
)
|
|
1561
|
+
else:
|
|
1562
|
+
generation_kwargs["response_format"] = dict(type="json_object")
|
|
1563
|
+
log_once(
|
|
1564
|
+
"Enabling structured JSON generation for model "
|
|
1565
|
+
f"{self.model_config.model_id!r} with no custom JSON schema, as "
|
|
1566
|
+
"the model does not support schemas.",
|
|
1567
|
+
level=logging.DEBUG,
|
|
1568
|
+
)
|
|
1569
|
+
elif self.dataset_config.task.uses_logprobs and self.dataset_config.labels:
|
|
1570
|
+
localised_labels = [
|
|
1571
|
+
self.dataset_config.prompt_label_mapping[label]
|
|
1572
|
+
for label in self.dataset_config.labels
|
|
1573
|
+
]
|
|
1574
|
+
keys_and_their_types = {
|
|
1575
|
+
LITELLM_CLASSIFICATION_OUTPUT_KEY: (t.Literal[*localised_labels], ...)
|
|
1576
|
+
}
|
|
1577
|
+
pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
|
|
1578
|
+
generation_kwargs["response_format"] = pydantic_class
|
|
1240
1579
|
|
|
1580
|
+
# If the model is an Ollama reasoning model, we ensure that thinking is enabled
|
|
1581
|
+
if self.is_ollama and self.generative_type == GenerativeType.REASONING:
|
|
1582
|
+
generation_kwargs["think"] = True
|
|
1583
|
+
log_once(
|
|
1584
|
+
"Enabling thinking mode for Ollama model "
|
|
1585
|
+
f"{self.model_config.model_id!r}",
|
|
1586
|
+
level=logging.DEBUG,
|
|
1587
|
+
)
|
|
1241
1588
|
|
|
1242
|
-
|
|
1589
|
+
# Handle manually set parameters
|
|
1590
|
+
if self.buffer["first_label_token_mapping"]:
|
|
1591
|
+
generation_kwargs["logprobs"] = True
|
|
1592
|
+
generation_kwargs["top_logprobs"] = MAX_LITELLM_LOGPROBS
|
|
1593
|
+
if self.model_config.param == "thinking":
|
|
1594
|
+
generation_kwargs["thinking"] = dict(
|
|
1595
|
+
type="enabled", budget_tokens=REASONING_MAX_TOKENS - 1
|
|
1596
|
+
)
|
|
1597
|
+
log_once(
|
|
1598
|
+
f"Enabling thinking mode for model {self.model_config.model_id!r}",
|
|
1599
|
+
level=logging.DEBUG,
|
|
1600
|
+
)
|
|
1601
|
+
elif self.model_config.param == "no-thinking":
|
|
1602
|
+
generation_kwargs["thinking"] = dict(budget_tokens=0)
|
|
1603
|
+
log_once(
|
|
1604
|
+
f"Disabling thinking mode for model {self.model_config.model_id!r}",
|
|
1605
|
+
level=logging.DEBUG,
|
|
1606
|
+
)
|
|
1607
|
+
elif self.model_config.param in {"minimal", "low", "medium", "high"}:
|
|
1608
|
+
generation_kwargs["reasoning_effort"] = self.model_config.param
|
|
1609
|
+
log_once(
|
|
1610
|
+
f"Enabling reasoning effort {self.model_config.param!r} for model "
|
|
1611
|
+
f"{self.model_config.model_id!r}",
|
|
1612
|
+
level=logging.DEBUG,
|
|
1613
|
+
)
|
|
1614
|
+
|
|
1615
|
+
# First attempt is a test run with a single conversation to handle errors
|
|
1616
|
+
# quickly. We repeat this multiple times to deal with different types of
|
|
1617
|
+
# errors, and stop if we get a successful response.
|
|
1618
|
+
test_input: c.Sequence[litellm.AllMessageValues] | str
|
|
1619
|
+
if self.generative_type == GenerativeType.BASE:
|
|
1620
|
+
test_input = "Test message"
|
|
1621
|
+
else:
|
|
1622
|
+
test_input = [
|
|
1623
|
+
litellm.ChatCompletionUserMessage(role="user", content="Test message")
|
|
1624
|
+
]
|
|
1625
|
+
for _ in range(num_attempts := 10):
|
|
1626
|
+
_, failures = safe_run(
|
|
1627
|
+
self._generate_async(
|
|
1628
|
+
model_id=clean_model_id(
|
|
1629
|
+
model_id=self.model_config.model_id,
|
|
1630
|
+
benchmark_config=self.benchmark_config,
|
|
1631
|
+
),
|
|
1632
|
+
inputs=[test_input],
|
|
1633
|
+
max_concurrent_calls=1,
|
|
1634
|
+
**generation_kwargs,
|
|
1635
|
+
)
|
|
1636
|
+
)
|
|
1637
|
+
if not failures:
|
|
1638
|
+
break
|
|
1639
|
+
for _, error in failures:
|
|
1640
|
+
generation_kwargs = self._handle_exception(
|
|
1641
|
+
error=error, **generation_kwargs
|
|
1642
|
+
)
|
|
1643
|
+
else:
|
|
1644
|
+
raise InvalidModel(
|
|
1645
|
+
"Failed to get a successful response from the model "
|
|
1646
|
+
f"{self.model_config.model_id!r} after {num_attempts} attempts."
|
|
1647
|
+
)
|
|
1648
|
+
|
|
1649
|
+
return generation_kwargs
|
|
1650
|
+
|
|
1651
|
+
|
|
1652
|
+
def try_download_ollama_model(model_id: str, progress_bar: bool) -> bool:
|
|
1243
1653
|
"""Try to download an Ollama model.
|
|
1244
1654
|
|
|
1245
1655
|
Args:
|
|
1246
1656
|
model_id:
|
|
1247
1657
|
The model ID. If the model does not start with "ollama/" or "ollama_chat/"
|
|
1248
1658
|
then this function will return False.
|
|
1659
|
+
progress_bar:
|
|
1660
|
+
Whether to show a progress bar while downloading the model.
|
|
1249
1661
|
|
|
1250
1662
|
Returns:
|
|
1251
1663
|
Whether the model was downloaded successfully.
|
|
@@ -1268,16 +1680,16 @@ def try_download_ollama_model(model_id: str) -> bool:
|
|
|
1268
1680
|
)
|
|
1269
1681
|
|
|
1270
1682
|
try:
|
|
1271
|
-
downloaded_ollama_models:
|
|
1683
|
+
downloaded_ollama_models: c.Sequence[str] = [
|
|
1272
1684
|
model_obj.model
|
|
1273
1685
|
for model_obj in ollama.list().models
|
|
1274
1686
|
if model_obj.model is not None
|
|
1275
1687
|
]
|
|
1276
|
-
except ConnectionError:
|
|
1688
|
+
except ConnectionError as e:
|
|
1277
1689
|
raise InvalidModel(
|
|
1278
1690
|
"Ollama does not seem to be running, so we cannot evaluate the model "
|
|
1279
1691
|
f"{model_id!r}. Please make sure that Ollama is running and try again."
|
|
1280
|
-
)
|
|
1692
|
+
) from e
|
|
1281
1693
|
|
|
1282
1694
|
ollama_model_id = "/".join(model_id.split("/")[1:])
|
|
1283
1695
|
if ollama_model_id not in downloaded_ollama_models:
|
|
@@ -1297,7 +1709,8 @@ def try_download_ollama_model(model_id: str) -> bool:
|
|
|
1297
1709
|
f"The model {model_id!r} cannot be found on Ollama, but the "
|
|
1298
1710
|
f"model {model_id_with_prefix} *was* found, so we would "
|
|
1299
1711
|
"recommend you cancelling this run and trying the evaluation "
|
|
1300
|
-
"with that model ID instead."
|
|
1712
|
+
"with that model ID instead.",
|
|
1713
|
+
level=logging.WARNING,
|
|
1301
1714
|
)
|
|
1302
1715
|
return False
|
|
1303
1716
|
except ollama.ResponseError as inner_e:
|
|
@@ -1307,19 +1720,19 @@ def try_download_ollama_model(model_id: str) -> bool:
|
|
|
1307
1720
|
raise InvalidModel(
|
|
1308
1721
|
f"Failed to download Ollama model {ollama_model_id}. "
|
|
1309
1722
|
f"The error message was: {inner_e}"
|
|
1310
|
-
)
|
|
1723
|
+
) from inner_e
|
|
1311
1724
|
else:
|
|
1312
1725
|
raise InvalidModel(
|
|
1313
1726
|
f"Failed to download Ollama model {ollama_model_id}. "
|
|
1314
1727
|
f"The error message was: {e}"
|
|
1315
|
-
)
|
|
1728
|
+
) from e
|
|
1316
1729
|
|
|
1317
1730
|
# Download the model
|
|
1318
|
-
with
|
|
1731
|
+
with get_pbar(
|
|
1319
1732
|
desc=f"Downloading {ollama_model_id}",
|
|
1320
1733
|
unit_scale=True,
|
|
1321
1734
|
unit="B",
|
|
1322
|
-
|
|
1735
|
+
disable=not progress_bar,
|
|
1323
1736
|
) as pbar:
|
|
1324
1737
|
for status in response:
|
|
1325
1738
|
if status.total is not None:
|
|
@@ -1335,3 +1748,30 @@ def try_download_ollama_model(model_id: str) -> bool:
|
|
|
1335
1748
|
level=logging.DEBUG,
|
|
1336
1749
|
)
|
|
1337
1750
|
return True
|
|
1751
|
+
|
|
1752
|
+
|
|
1753
|
+
def clean_model_id(model_id: str, benchmark_config: BenchmarkConfig) -> str:
|
|
1754
|
+
"""Clean a model ID.
|
|
1755
|
+
|
|
1756
|
+
This adds the default `openai/` prefix to the model ID if we're benchmarking a
|
|
1757
|
+
custom API inference server and no prefix is used, just to make it more
|
|
1758
|
+
convenient for the user.
|
|
1759
|
+
|
|
1760
|
+
Args:
|
|
1761
|
+
model_id:
|
|
1762
|
+
The model ID.
|
|
1763
|
+
benchmark_config:
|
|
1764
|
+
The benchmark configuration.
|
|
1765
|
+
|
|
1766
|
+
Returns:
|
|
1767
|
+
The cleaned model ID.
|
|
1768
|
+
"""
|
|
1769
|
+
if benchmark_config.api_base is not None and not any(
|
|
1770
|
+
model_id.startswith(prefix) for prefix in CUSTOM_INFERENCE_API_PREFIXES
|
|
1771
|
+
):
|
|
1772
|
+
if benchmark_config.generative_type == GenerativeType.BASE:
|
|
1773
|
+
prefix = "text-completion-openai/"
|
|
1774
|
+
else:
|
|
1775
|
+
prefix = "openai/"
|
|
1776
|
+
model_id = prefix + model_id
|
|
1777
|
+
return model_id
|