EuroEval 15.4.1__py3-none-any.whl → 15.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +2 -2
- euroeval/benchmark_modules/hf.py +79 -39
- euroeval/benchmark_modules/litellm.py +204 -74
- euroeval/benchmark_modules/vllm.py +106 -42
- euroeval/benchmarker.py +35 -6
- euroeval/constants.py +11 -1
- euroeval/data_models.py +6 -2
- euroeval/dataset_configs.py +6 -6
- euroeval/task_utils/sequence_classification.py +70 -30
- euroeval/types.py +3 -3
- euroeval/utils.py +131 -32
- {euroeval-15.4.1.dist-info → euroeval-15.5.0.dist-info}/METADATA +6 -4
- {euroeval-15.4.1.dist-info → euroeval-15.5.0.dist-info}/RECORD +16 -16
- {euroeval-15.4.1.dist-info → euroeval-15.5.0.dist-info}/WHEEL +0 -0
- {euroeval-15.4.1.dist-info → euroeval-15.5.0.dist-info}/entry_points.txt +0 -0
- {euroeval-15.4.1.dist-info → euroeval-15.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -27,20 +27,17 @@ from litellm.exceptions import (
|
|
|
27
27
|
BadRequestError,
|
|
28
28
|
InternalServerError,
|
|
29
29
|
NotFoundError,
|
|
30
|
+
RateLimitError,
|
|
30
31
|
ServiceUnavailableError,
|
|
31
32
|
Timeout,
|
|
32
33
|
)
|
|
34
|
+
from litellm.llms.vertex_ai.common_utils import VertexAIError
|
|
33
35
|
from litellm.types.utils import ModelResponse
|
|
34
36
|
from requests.exceptions import RequestException
|
|
35
37
|
from tqdm.auto import tqdm
|
|
36
38
|
from transformers import Trainer
|
|
37
39
|
|
|
38
|
-
from ..constants import
|
|
39
|
-
MAX_LOGPROBS,
|
|
40
|
-
REASONING_MAX_TOKENS,
|
|
41
|
-
TASK_GROUPS_USING_LOGPROBS,
|
|
42
|
-
TASKS_USING_JSON,
|
|
43
|
-
)
|
|
40
|
+
from ..constants import MAX_LOGPROBS, REASONING_MAX_TOKENS, TASKS_USING_JSON
|
|
44
41
|
from ..data_models import (
|
|
45
42
|
BenchmarkConfig,
|
|
46
43
|
DatasetConfig,
|
|
@@ -69,7 +66,7 @@ from ..task_utils import (
|
|
|
69
66
|
token_classification,
|
|
70
67
|
)
|
|
71
68
|
from ..types import ExtractLabelsFunction
|
|
72
|
-
from ..utils import create_model_cache_dir, log_once
|
|
69
|
+
from ..utils import create_model_cache_dir, get_first_label_token_mapping, log_once
|
|
73
70
|
from .base import BenchmarkModule
|
|
74
71
|
from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokenizer
|
|
75
72
|
|
|
@@ -78,64 +75,80 @@ logger = logging.getLogger("euroeval")
|
|
|
78
75
|
|
|
79
76
|
VOCAB_SIZE_MAPPING = {
|
|
80
77
|
# OpenAI models
|
|
81
|
-
"(
|
|
82
|
-
"
|
|
83
|
-
"gpt-
|
|
84
|
-
"gpt-4-(
|
|
85
|
-
"gpt-
|
|
86
|
-
"gpt-
|
|
87
|
-
"
|
|
88
|
-
"gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
|
|
89
|
-
"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
|
|
90
|
-
"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
78
|
+
r"gpt-4-(32k)?(-[0-9]{4})?": 100_256,
|
|
79
|
+
r"gpt-4-[0-9]{4}-preview": 100_256,
|
|
80
|
+
r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
|
|
81
|
+
r"gpt-4-(vision|turbo)(-preview)?": 100_256,
|
|
82
|
+
r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
|
|
83
|
+
r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
|
|
84
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
91
85
|
# Anthropic models
|
|
92
|
-
"claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
|
|
86
|
+
r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
|
|
87
|
+
# Gemini models
|
|
88
|
+
r"(gemini/)?gemini-[1-9]\.[0-9]-(flash|pro).*": 256_128,
|
|
89
|
+
# xAI models
|
|
90
|
+
r"(xai/)?grok.*": -1,
|
|
93
91
|
}
|
|
94
92
|
|
|
95
93
|
|
|
96
94
|
MODEL_MAX_LENGTH_MAPPING = {
|
|
97
95
|
# OpenAI models
|
|
98
|
-
"
|
|
99
|
-
"
|
|
100
|
-
"
|
|
101
|
-
"gpt-
|
|
102
|
-
"gpt-
|
|
103
|
-
"gpt-3.5-turbo-
|
|
104
|
-
"gpt-
|
|
105
|
-
"
|
|
106
|
-
"
|
|
107
|
-
"
|
|
108
|
-
"gpt-4-(vision|turbo)(-preview)?": 128_000,
|
|
109
|
-
"gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
|
|
110
|
-
"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
111
|
-
"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
112
|
-
"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
113
|
-
"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
96
|
+
r"gpt-4(-[0-9]{4})?": 8_191,
|
|
97
|
+
r"gpt-4-32k(-[0-9]{4})?": 32_767,
|
|
98
|
+
r"gpt-4-[0-9]{4}-preview": 128_000,
|
|
99
|
+
r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
100
|
+
r"gpt-4-(vision|turbo)(-preview)?": 128_000,
|
|
101
|
+
r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
|
|
102
|
+
r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
103
|
+
r"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
104
|
+
r"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
105
|
+
r"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
114
106
|
# Anthropic models
|
|
115
|
-
"claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
|
|
107
|
+
r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
|
|
108
|
+
# Gemini models
|
|
109
|
+
r"(gemini/)?gemini-1\.5-flash.*": 1_048_576,
|
|
110
|
+
r"(gemini/)?gemini-1\.5-pro.*": 2_097_152,
|
|
111
|
+
r"(gemini/)?gemini-2\.(0|5).*": 1_048_576,
|
|
112
|
+
# xAI models
|
|
113
|
+
r"(xai/)?grok.*": 131_072,
|
|
116
114
|
}
|
|
117
115
|
|
|
118
116
|
|
|
119
117
|
NUM_PARAMS_MAPPING = {
|
|
120
118
|
# OpenAI models
|
|
121
|
-
"
|
|
122
|
-
"(
|
|
123
|
-
"(text-)?curie(-001)?": 13_000_000_000,
|
|
124
|
-
"((text|code)-)?davinci(-00[1-9])?": 175_000_000_000,
|
|
125
|
-
"gpt-(3.5|4)-turbo-((16|32)k)?(-[0-9]{4})?": -1,
|
|
126
|
-
"gpt-4-[0-9]{4}-preview": -1,
|
|
127
|
-
"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
128
|
-
"gpt-4-(vision|turbo)(-preview)?": -1,
|
|
129
|
-
"gpt-3.5-turbo-instruct(-[0-9]{4})?": -1,
|
|
130
|
-
"gpt-4o(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
131
|
-
"gpt-4o-mini(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
132
|
-
"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
119
|
+
r"gpt-4.*": -1,
|
|
120
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
133
121
|
# Anthropic models
|
|
134
|
-
"
|
|
122
|
+
r"(anthropic/)?claude-*": -1,
|
|
123
|
+
# Gemini models
|
|
124
|
+
r"(gemini/)?gemini-1.5-flash-8b": 8_000_000_000,
|
|
125
|
+
r"(gemini/)?gemini-1.5-flash-[0-9]+": -1,
|
|
126
|
+
r"(gemini/)?gemini-2.(0|5).*": -1,
|
|
127
|
+
# xAI models
|
|
128
|
+
r"(xai/)?grok.*": -1,
|
|
135
129
|
}
|
|
136
130
|
|
|
137
131
|
|
|
138
|
-
|
|
132
|
+
ALLOWED_PARAMS = {
|
|
133
|
+
# OpenAI models
|
|
134
|
+
r"gpt-4.*": [],
|
|
135
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": ["low", "high"],
|
|
136
|
+
# Anthropic models
|
|
137
|
+
r"(anthropic/)?claude-3-.*": [],
|
|
138
|
+
r"(anthropic/)?claude-3.5-.*": [],
|
|
139
|
+
r"(anthropic/)?claude-3.7-sonnet.*": ["thinking"],
|
|
140
|
+
# Gemini models
|
|
141
|
+
r"(gemini/)?gemini-.*": [],
|
|
142
|
+
# xAI models
|
|
143
|
+
r"(xai/)?grok.*": [],
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
REASONING_MODELS = [
|
|
148
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?",
|
|
149
|
+
r"(gemini/)?gemini.*thinking.*",
|
|
150
|
+
r"(gemini/)?gemini-2.5-pro.*",
|
|
151
|
+
]
|
|
139
152
|
|
|
140
153
|
|
|
141
154
|
class LiteLLMModel(BenchmarkModule):
|
|
@@ -167,12 +180,18 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
167
180
|
"ollama/"
|
|
168
181
|
) or model_config.model_id.startswith("ollama_chat/")
|
|
169
182
|
|
|
183
|
+
raise_if_wrong_params(model_config=model_config, allowed_params=ALLOWED_PARAMS)
|
|
184
|
+
|
|
170
185
|
super().__init__(
|
|
171
186
|
model_config=model_config,
|
|
172
187
|
dataset_config=dataset_config,
|
|
173
188
|
benchmark_config=benchmark_config,
|
|
174
189
|
)
|
|
175
190
|
|
|
191
|
+
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
192
|
+
dataset_config=self.dataset_config, tokenizer=None
|
|
193
|
+
)
|
|
194
|
+
|
|
176
195
|
@property
|
|
177
196
|
def generative_type(self) -> GenerativeType | None:
|
|
178
197
|
"""Get the generative type of the model.
|
|
@@ -180,7 +199,9 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
180
199
|
Returns:
|
|
181
200
|
The generative type of the model, or None if it has not been set yet.
|
|
182
201
|
"""
|
|
183
|
-
if
|
|
202
|
+
if self.model_config.revision == "thinking":
|
|
203
|
+
return GenerativeType.REASONING
|
|
204
|
+
elif re.fullmatch(
|
|
184
205
|
pattern="|".join(REASONING_MODELS), string=self.model_config.model_id
|
|
185
206
|
):
|
|
186
207
|
return GenerativeType.REASONING
|
|
@@ -218,7 +239,13 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
218
239
|
api_version=self.benchmark_config.api_version,
|
|
219
240
|
)
|
|
220
241
|
|
|
221
|
-
|
|
242
|
+
# Get the mapping from labels to the first token in the label. We call this each
|
|
243
|
+
# time we generate a new dataset since the dataset config can change
|
|
244
|
+
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
245
|
+
dataset_config=self.dataset_config, tokenizer=None
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
if self.buffer["first_label_token_mapping"]:
|
|
222
249
|
generation_kwargs["logprobs"] = True
|
|
223
250
|
generation_kwargs["top_logprobs"] = MAX_LOGPROBS
|
|
224
251
|
|
|
@@ -227,6 +254,27 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
227
254
|
"Prompt must contain 'json' for JSON tasks."
|
|
228
255
|
)
|
|
229
256
|
generation_kwargs["response_format"] = dict(type="json_object")
|
|
257
|
+
log_once(
|
|
258
|
+
"Enabling JSON response format for model "
|
|
259
|
+
f"{self.model_config.model_id!r}",
|
|
260
|
+
level=logging.DEBUG,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if self.model_config.revision == "thinking":
|
|
264
|
+
generation_kwargs["thinking"] = dict(
|
|
265
|
+
type="enabled", budget_tokens=REASONING_MAX_TOKENS
|
|
266
|
+
)
|
|
267
|
+
log_once(
|
|
268
|
+
f"Enabling thinking mode for model {self.model_config.model_id!r}",
|
|
269
|
+
level=logging.DEBUG,
|
|
270
|
+
)
|
|
271
|
+
elif self.model_config.revision in {"low", "high"}:
|
|
272
|
+
generation_kwargs["reasoning_effort"] = self.model_config.revision
|
|
273
|
+
log_once(
|
|
274
|
+
f"Enabling reasoning effort {self.model_config.revision!r} for model "
|
|
275
|
+
f"{self.model_config.model_id!r}",
|
|
276
|
+
level=logging.DEBUG,
|
|
277
|
+
)
|
|
230
278
|
|
|
231
279
|
# This drops generation kwargs that are not supported by the model
|
|
232
280
|
litellm.drop_params = True
|
|
@@ -235,39 +283,60 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
235
283
|
# handle using newlines as stop sequences, so we try both.
|
|
236
284
|
num_attempts = 10
|
|
237
285
|
for _ in range(num_attempts):
|
|
286
|
+
stop_messages = ["stop_sequences"]
|
|
287
|
+
logprobs_messages = [
|
|
288
|
+
"you are not allowed to request logprobs",
|
|
289
|
+
"you've reached the maximum number of requests with logprobs",
|
|
290
|
+
"logprobs is not supported",
|
|
291
|
+
"logprobs is not enabled",
|
|
292
|
+
]
|
|
293
|
+
temperature_messages = [
|
|
294
|
+
"'temperature' is not supported with this model.",
|
|
295
|
+
"temperature is not supported with this model",
|
|
296
|
+
]
|
|
238
297
|
try:
|
|
239
298
|
model_response = litellm.completion(
|
|
240
299
|
messages=messages, max_retries=3, **generation_kwargs
|
|
241
300
|
)
|
|
242
301
|
break
|
|
243
|
-
except BadRequestError as e:
|
|
244
|
-
if
|
|
302
|
+
except (BadRequestError, RateLimitError) as e:
|
|
303
|
+
if any(msg.lower() in str(e).lower() for msg in stop_messages):
|
|
245
304
|
generation_kwargs["stop"] = None
|
|
246
|
-
elif "you are not allowed to request logprobs" in str(e).lower():
|
|
247
|
-
generation_kwargs.pop("logprobs")
|
|
248
|
-
generation_kwargs.pop("top_logprobs")
|
|
249
305
|
elif (
|
|
250
|
-
|
|
306
|
+
any(msg.lower() in str(e).lower() for msg in logprobs_messages)
|
|
307
|
+
# Special case for Vertex AI models, since they have strict rate
|
|
308
|
+
# limits on using logprobs. They also have a cap of 5 logprobs, but
|
|
309
|
+
# we ignore this since the rate limiting makes it unusable anyway.
|
|
310
|
+
or (isinstance(e, VertexAIError) and "logprobs" in str(e).lower())
|
|
251
311
|
):
|
|
312
|
+
generation_kwargs.pop("logprobs")
|
|
313
|
+
generation_kwargs.pop("top_logprobs")
|
|
314
|
+
elif any(msg.lower() in str(e).lower() for msg in temperature_messages):
|
|
252
315
|
generation_kwargs.pop("temperature")
|
|
316
|
+
elif isinstance(e, RateLimitError):
|
|
317
|
+
raise InvalidModel(
|
|
318
|
+
"You have encountered your rate limit for model "
|
|
319
|
+
f"{self.model_config.model_id!r}. The error message was: {e}"
|
|
320
|
+
)
|
|
253
321
|
else:
|
|
254
322
|
raise InvalidBenchmark(
|
|
255
323
|
f"Failed to generate text. The error message was: {e}"
|
|
256
324
|
)
|
|
325
|
+
except APIError as e:
|
|
326
|
+
raise InvalidBenchmark(
|
|
327
|
+
f"Failed to generate text. The error message was: {e}"
|
|
328
|
+
)
|
|
257
329
|
except (
|
|
330
|
+
APIConnectionError,
|
|
258
331
|
Timeout,
|
|
259
332
|
ServiceUnavailableError,
|
|
260
|
-
APIConnectionError,
|
|
261
333
|
InternalServerError,
|
|
262
|
-
):
|
|
334
|
+
) as e:
|
|
263
335
|
logger.debug(
|
|
264
|
-
"Service temporarily unavailable.
|
|
336
|
+
f"Service temporarily unavailable. The error message was: {e}. "
|
|
337
|
+
f"Retrying in 5 seconds..."
|
|
265
338
|
)
|
|
266
339
|
sleep(5)
|
|
267
|
-
except APIError as e:
|
|
268
|
-
raise InvalidBenchmark(
|
|
269
|
-
f"Failed to generate text. The error message was: {e}"
|
|
270
|
-
)
|
|
271
340
|
except AuthenticationError:
|
|
272
341
|
raise NeedsAdditionalArgument(
|
|
273
342
|
cli_argument="--api-key",
|
|
@@ -280,6 +349,15 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
280
349
|
)
|
|
281
350
|
|
|
282
351
|
assert isinstance(model_response, ModelResponse)
|
|
352
|
+
if not model_response.choices:
|
|
353
|
+
# This happens for reasoning models, when they don't finish thinking and run
|
|
354
|
+
# out of tokens. Happens quite rarely, but we need to handle it.
|
|
355
|
+
logger.warning(
|
|
356
|
+
f"The model {self.model_config.model_id!r} did not end up generating "
|
|
357
|
+
"any text. This is likely because the model ran out of tokens while "
|
|
358
|
+
"reasoning. Returning an empty string."
|
|
359
|
+
)
|
|
360
|
+
return GenerativeModelOutput(sequences=[""])
|
|
283
361
|
model_response_choices = model_response.choices[0]
|
|
284
362
|
assert isinstance(model_response_choices, litellm.Choices)
|
|
285
363
|
generation_output = model_response_choices.message["content"] or ""
|
|
@@ -314,7 +392,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
314
392
|
# If it is an Ollama model then we can get the number of parameters from the
|
|
315
393
|
# Ollama Python SDK
|
|
316
394
|
if self.is_ollama:
|
|
317
|
-
ollama_model_id = self.model_config.model_id.split("/")[
|
|
395
|
+
ollama_model_id = "/".join(self.model_config.model_id.split("/")[1:])
|
|
318
396
|
model_info = ollama.show(ollama_model_id).modelinfo
|
|
319
397
|
if model_info is not None:
|
|
320
398
|
num_params = model_info.get("general.parameter_count")
|
|
@@ -334,7 +412,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
334
412
|
num_labels=self.dataset_config.num_labels,
|
|
335
413
|
id2label=self.dataset_config.id2label,
|
|
336
414
|
label2id=self.dataset_config.label2id,
|
|
337
|
-
revision=
|
|
415
|
+
revision="main",
|
|
338
416
|
model_cache_dir=self.model_config.model_cache_dir,
|
|
339
417
|
api_key=self.benchmark_config.api_key,
|
|
340
418
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
@@ -345,7 +423,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
345
423
|
try:
|
|
346
424
|
repo_info = hf_api.model_info(
|
|
347
425
|
repo_id=model_id,
|
|
348
|
-
revision=
|
|
426
|
+
revision="main",
|
|
349
427
|
token=os.getenv("HUGGINGFACE_API_KEY")
|
|
350
428
|
or self.benchmark_config.api_key
|
|
351
429
|
or True,
|
|
@@ -398,7 +476,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
398
476
|
num_labels=self.dataset_config.num_labels,
|
|
399
477
|
id2label=self.dataset_config.id2label,
|
|
400
478
|
label2id=self.dataset_config.label2id,
|
|
401
|
-
revision=
|
|
479
|
+
revision="main",
|
|
402
480
|
model_cache_dir=self.model_config.model_cache_dir,
|
|
403
481
|
api_key=self.benchmark_config.api_key,
|
|
404
482
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
@@ -442,7 +520,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
442
520
|
# If it is an Ollama model then we can get the maximum length from the Ollama
|
|
443
521
|
# Python SDK
|
|
444
522
|
if self.is_ollama:
|
|
445
|
-
ollama_model_id = self.model_config.model_id.split("/")[
|
|
523
|
+
ollama_model_id = "/".join(self.model_config.model_id.split("/")[1:])
|
|
446
524
|
model_info = ollama.show(ollama_model_id).modelinfo
|
|
447
525
|
if model_info is not None:
|
|
448
526
|
context_length_keys = [
|
|
@@ -478,7 +556,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
478
556
|
num_labels=self.dataset_config.num_labels,
|
|
479
557
|
id2label=self.dataset_config.id2label,
|
|
480
558
|
label2id=self.dataset_config.label2id,
|
|
481
|
-
revision=
|
|
559
|
+
revision="main",
|
|
482
560
|
model_cache_dir=self.model_config.model_cache_dir,
|
|
483
561
|
api_key=self.benchmark_config.api_key,
|
|
484
562
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
@@ -563,6 +641,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
563
641
|
return partial(
|
|
564
642
|
sequence_classification.extract_labels_from_generation,
|
|
565
643
|
dataset_config=self.dataset_config,
|
|
644
|
+
first_label_token_mapping=self.buffer["first_label_token_mapping"],
|
|
566
645
|
)
|
|
567
646
|
case TaskGroup.TEXT_TO_TEXT:
|
|
568
647
|
return text_to_text.extract_labels_from_generation
|
|
@@ -605,12 +684,13 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
605
684
|
Whether the model exists, or an error describing why we cannot check
|
|
606
685
|
whether the model exists.
|
|
607
686
|
"""
|
|
687
|
+
model_id, _ = model_id.split("@") if "@" in model_id else (model_id, "main")
|
|
608
688
|
if model_id in litellm.model_list:
|
|
609
689
|
return True
|
|
610
690
|
|
|
611
691
|
# If it is an Ollama model then try to download it
|
|
612
692
|
if model_id.startswith("ollama/") or model_id.startswith("ollama_chat/"):
|
|
613
|
-
ollama_model_id = model_id.split("/")[
|
|
693
|
+
ollama_model_id = "/".join(model_id.split("/")[1:])
|
|
614
694
|
downloaded_ollama_models: list[str] = [
|
|
615
695
|
model_obj.model
|
|
616
696
|
for model_obj in ollama.list().models
|
|
@@ -657,12 +737,29 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
657
737
|
api_version=benchmark_config.api_version,
|
|
658
738
|
)
|
|
659
739
|
return True
|
|
740
|
+
except (
|
|
741
|
+
APIConnectionError,
|
|
742
|
+
Timeout,
|
|
743
|
+
ServiceUnavailableError,
|
|
744
|
+
InternalServerError,
|
|
745
|
+
) as e:
|
|
746
|
+
logger.debug(
|
|
747
|
+
f"Service temporarily unavailable. The error message was: {e}. "
|
|
748
|
+
"Retrying in 10 seconds..."
|
|
749
|
+
)
|
|
750
|
+
sleep(5)
|
|
751
|
+
except RateLimitError:
|
|
752
|
+
logger.warning(
|
|
753
|
+
f"Rate limit exceeded for model {model_id!r}. Retrying in 10 "
|
|
754
|
+
"seconds..."
|
|
755
|
+
)
|
|
756
|
+
sleep(10)
|
|
660
757
|
except APIError as e:
|
|
661
758
|
if "'503 Service Unavailable" not in str(e):
|
|
662
759
|
raise e
|
|
663
760
|
logger.warning(
|
|
664
|
-
f"Failed to check if model {model_id!r} exists. Retrying in "
|
|
665
|
-
|
|
761
|
+
f"Failed to check if model {model_id!r} exists. Retrying in 10 "
|
|
762
|
+
"seconds..."
|
|
666
763
|
)
|
|
667
764
|
sleep(10)
|
|
668
765
|
except (BadRequestError, NotFoundError):
|
|
@@ -708,9 +805,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
708
805
|
Returns:
|
|
709
806
|
The model configuration.
|
|
710
807
|
"""
|
|
808
|
+
model_id, revision = model_id.split("@") if "@" in model_id else (model_id, "")
|
|
711
809
|
return ModelConfig(
|
|
712
810
|
model_id=model_id,
|
|
713
|
-
revision=
|
|
811
|
+
revision=revision,
|
|
714
812
|
task="text-generation",
|
|
715
813
|
languages=list(),
|
|
716
814
|
merge=False,
|
|
@@ -1025,3 +1123,35 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1025
1123
|
|
|
1026
1124
|
examples["messages"] = messages_list
|
|
1027
1125
|
return examples
|
|
1126
|
+
|
|
1127
|
+
|
|
1128
|
+
def raise_if_wrong_params(
|
|
1129
|
+
model_config: ModelConfig, allowed_params: dict[str, list[str]]
|
|
1130
|
+
) -> None:
|
|
1131
|
+
"""Raise an error if the model configuration has invalid parameters.
|
|
1132
|
+
|
|
1133
|
+
Args:
|
|
1134
|
+
model_config:
|
|
1135
|
+
The model configuration.
|
|
1136
|
+
allowed_params:
|
|
1137
|
+
The allowed parameters for the model.
|
|
1138
|
+
|
|
1139
|
+
Raises:
|
|
1140
|
+
InvalidModel:
|
|
1141
|
+
If the model configuration has invalid parameters.
|
|
1142
|
+
"""
|
|
1143
|
+
param = model_config.revision
|
|
1144
|
+
if param == "":
|
|
1145
|
+
return
|
|
1146
|
+
for model_regex, allowed_params_list in allowed_params.items():
|
|
1147
|
+
if re.fullmatch(pattern=model_regex, string=model_config.model_id):
|
|
1148
|
+
if param not in allowed_params_list:
|
|
1149
|
+
msg = (
|
|
1150
|
+
f"Invalid parameter {param!r} for model {model_config.model_id!r}."
|
|
1151
|
+
)
|
|
1152
|
+
if allowed_params_list:
|
|
1153
|
+
msg += f" Allowed parameters are: {', '.join(allowed_params_list)}."
|
|
1154
|
+
else:
|
|
1155
|
+
msg += " No parameters are allowed."
|
|
1156
|
+
raise InvalidModel(msg)
|
|
1157
|
+
return
|