EuroEval 15.4.2__py3-none-any.whl → 15.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +2 -2
- euroeval/benchmark_modules/base.py +3 -2
- euroeval/benchmark_modules/fresh.py +8 -6
- euroeval/benchmark_modules/hf.py +44 -33
- euroeval/benchmark_modules/litellm.py +314 -120
- euroeval/benchmark_modules/vllm.py +99 -59
- euroeval/benchmarker.py +52 -21
- euroeval/callbacks.py +2 -2
- euroeval/constants.py +9 -2
- euroeval/data_models.py +258 -44
- euroeval/dataset_configs/__init__.py +61 -0
- euroeval/dataset_configs/danish.py +120 -0
- euroeval/dataset_configs/dutch.py +123 -0
- euroeval/dataset_configs/english.py +88 -0
- euroeval/dataset_configs/faroese.py +53 -0
- euroeval/dataset_configs/french.py +83 -0
- euroeval/dataset_configs/german.py +91 -0
- euroeval/dataset_configs/icelandic.py +148 -0
- euroeval/dataset_configs/italian.py +81 -0
- euroeval/dataset_configs/norwegian.py +178 -0
- euroeval/dataset_configs/spanish.py +78 -0
- euroeval/dataset_configs/swedish.py +100 -0
- euroeval/exceptions.py +10 -10
- euroeval/finetuning.py +6 -10
- euroeval/generation.py +1 -0
- euroeval/human_evaluation.py +2 -2
- euroeval/languages.py +20 -13
- euroeval/model_cache.py +1 -1
- euroeval/model_loading.py +1 -12
- euroeval/prompt_templates/__init__.py +8 -0
- euroeval/prompt_templates/linguistic_acceptability.py +112 -0
- euroeval/prompt_templates/multiple_choice.py +97 -0
- euroeval/prompt_templates/named_entity_recognition.py +257 -0
- euroeval/prompt_templates/reading_comprehension.py +118 -0
- euroeval/prompt_templates/sentiment_classification.py +137 -0
- euroeval/prompt_templates/summarization.py +97 -0
- euroeval/speed_benchmark.py +1 -1
- euroeval/{task_utils → task_group_utils}/multiple_choice_classification.py +19 -11
- euroeval/{task_utils → task_group_utils}/question_answering.py +31 -30
- euroeval/{task_utils → task_group_utils}/sequence_classification.py +45 -10
- euroeval/{task_utils → task_group_utils}/text_to_text.py +1 -1
- euroeval/{task_utils → task_group_utils}/token_classification.py +3 -2
- euroeval/tasks.py +54 -0
- euroeval/tokenization_utils.py +343 -0
- euroeval/types.py +3 -1
- euroeval/utils.py +5 -254
- {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/METADATA +31 -9
- euroeval-15.6.0.dist-info/RECORD +59 -0
- euroeval/dataset_configs.py +0 -2408
- euroeval-15.4.2.dist-info/RECORD +0 -40
- /euroeval/{task_utils → task_group_utils}/__init__.py +0 -0
- {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/WHEEL +0 -0
- {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/entry_points.txt +0 -0
- {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -27,20 +27,17 @@ from litellm.exceptions import (
|
|
|
27
27
|
BadRequestError,
|
|
28
28
|
InternalServerError,
|
|
29
29
|
NotFoundError,
|
|
30
|
+
RateLimitError,
|
|
30
31
|
ServiceUnavailableError,
|
|
31
32
|
Timeout,
|
|
32
33
|
)
|
|
33
|
-
from litellm.
|
|
34
|
+
from litellm.llms.vertex_ai.common_utils import VertexAIError
|
|
35
|
+
from litellm.types.utils import ChoiceLogprobs, ModelResponse
|
|
34
36
|
from requests.exceptions import RequestException
|
|
35
37
|
from tqdm.auto import tqdm
|
|
36
|
-
from transformers import Trainer
|
|
38
|
+
from transformers.trainer import Trainer
|
|
37
39
|
|
|
38
|
-
from ..constants import
|
|
39
|
-
MAX_LOGPROBS,
|
|
40
|
-
REASONING_MAX_TOKENS,
|
|
41
|
-
TASK_GROUPS_USING_LOGPROBS,
|
|
42
|
-
TASKS_USING_JSON,
|
|
43
|
-
)
|
|
40
|
+
from ..constants import MAX_LOGPROBS, REASONING_MAX_TOKENS, TASKS_USING_JSON
|
|
44
41
|
from ..data_models import (
|
|
45
42
|
BenchmarkConfig,
|
|
46
43
|
DatasetConfig,
|
|
@@ -62,12 +59,13 @@ from ..exceptions import (
|
|
|
62
59
|
NeedsEnvironmentVariable,
|
|
63
60
|
NeedsExtraInstalled,
|
|
64
61
|
)
|
|
65
|
-
from ..
|
|
62
|
+
from ..task_group_utils import (
|
|
66
63
|
question_answering,
|
|
67
64
|
sequence_classification,
|
|
68
65
|
text_to_text,
|
|
69
66
|
token_classification,
|
|
70
67
|
)
|
|
68
|
+
from ..tokenization_utils import get_first_label_token_mapping
|
|
71
69
|
from ..types import ExtractLabelsFunction
|
|
72
70
|
from ..utils import create_model_cache_dir, log_once
|
|
73
71
|
from .base import BenchmarkModule
|
|
@@ -78,64 +76,80 @@ logger = logging.getLogger("euroeval")
|
|
|
78
76
|
|
|
79
77
|
VOCAB_SIZE_MAPPING = {
|
|
80
78
|
# OpenAI models
|
|
81
|
-
"(
|
|
82
|
-
"
|
|
83
|
-
"gpt-
|
|
84
|
-
"gpt-4-(
|
|
85
|
-
"gpt-
|
|
86
|
-
"gpt-
|
|
87
|
-
"
|
|
88
|
-
"gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
|
|
89
|
-
"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
|
|
90
|
-
"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
79
|
+
r"gpt-4-(32k)?(-[0-9]{4})?": 100_256,
|
|
80
|
+
r"gpt-4-[0-9]{4}-preview": 100_256,
|
|
81
|
+
r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
|
|
82
|
+
r"gpt-4-(vision|turbo)(-preview)?": 100_256,
|
|
83
|
+
r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
|
|
84
|
+
r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
|
|
85
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
91
86
|
# Anthropic models
|
|
92
|
-
"claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
|
|
87
|
+
r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
|
|
88
|
+
# Gemini models
|
|
89
|
+
r"(gemini/)?gemini-[1-9]\.[0-9]-(flash|pro).*": 256_128,
|
|
90
|
+
# xAI models
|
|
91
|
+
r"(xai/)?grok.*": -1,
|
|
93
92
|
}
|
|
94
93
|
|
|
95
94
|
|
|
96
95
|
MODEL_MAX_LENGTH_MAPPING = {
|
|
97
96
|
# OpenAI models
|
|
98
|
-
"
|
|
99
|
-
"
|
|
100
|
-
"
|
|
101
|
-
"gpt-
|
|
102
|
-
"gpt-
|
|
103
|
-
"gpt-3.5-turbo-
|
|
104
|
-
"gpt-
|
|
105
|
-
"
|
|
106
|
-
"
|
|
107
|
-
"
|
|
108
|
-
"gpt-4-(vision|turbo)(-preview)?": 128_000,
|
|
109
|
-
"gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
|
|
110
|
-
"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
111
|
-
"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
112
|
-
"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
113
|
-
"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
97
|
+
r"gpt-4(-[0-9]{4})?": 8_191,
|
|
98
|
+
r"gpt-4-32k(-[0-9]{4})?": 32_767,
|
|
99
|
+
r"gpt-4-[0-9]{4}-preview": 128_000,
|
|
100
|
+
r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
101
|
+
r"gpt-4-(vision|turbo)(-preview)?": 128_000,
|
|
102
|
+
r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
|
|
103
|
+
r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
104
|
+
r"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
105
|
+
r"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
106
|
+
r"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
114
107
|
# Anthropic models
|
|
115
|
-
"claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
|
|
108
|
+
r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
|
|
109
|
+
# Gemini models
|
|
110
|
+
r"(gemini/)?gemini-1\.5-flash.*": 1_048_576,
|
|
111
|
+
r"(gemini/)?gemini-1\.5-pro.*": 2_097_152,
|
|
112
|
+
r"(gemini/)?gemini-2\.(0|5).*": 1_048_576,
|
|
113
|
+
# xAI models
|
|
114
|
+
r"(xai/)?grok.*": 131_072,
|
|
116
115
|
}
|
|
117
116
|
|
|
118
117
|
|
|
119
118
|
NUM_PARAMS_MAPPING = {
|
|
120
119
|
# OpenAI models
|
|
121
|
-
"
|
|
122
|
-
"(
|
|
123
|
-
|
|
124
|
-
"(
|
|
125
|
-
|
|
126
|
-
"
|
|
127
|
-
"
|
|
128
|
-
"
|
|
129
|
-
|
|
130
|
-
"
|
|
131
|
-
|
|
132
|
-
|
|
120
|
+
r"gpt-4.*": -1,
|
|
121
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
122
|
+
# Anthropic models
|
|
123
|
+
r"(anthropic/)?claude-*": -1,
|
|
124
|
+
# Gemini models
|
|
125
|
+
r"(gemini/)?gemini-1.5-flash-8b": 8_000_000_000,
|
|
126
|
+
r"(gemini/)?gemini-1.5-flash-[0-9]+": -1,
|
|
127
|
+
r"(gemini/)?gemini-2.(0|5).*": -1,
|
|
128
|
+
# xAI models
|
|
129
|
+
r"(xai/)?grok.*": -1,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
ALLOWED_PARAMS = {
|
|
134
|
+
# OpenAI models
|
|
135
|
+
r"gpt-4.*": [],
|
|
136
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": ["low", "high"],
|
|
133
137
|
# Anthropic models
|
|
134
|
-
"
|
|
138
|
+
r"(anthropic/)?claude-3-.*": [],
|
|
139
|
+
r"(anthropic/)?claude-3.5-.*": [],
|
|
140
|
+
r"(anthropic/)?claude-3.7-sonnet.*": ["thinking"],
|
|
141
|
+
# Gemini models
|
|
142
|
+
r"(gemini/)?gemini-.*": [],
|
|
143
|
+
# xAI models
|
|
144
|
+
r"(xai/)?grok.*": [],
|
|
135
145
|
}
|
|
136
146
|
|
|
137
147
|
|
|
138
|
-
REASONING_MODELS = [
|
|
148
|
+
REASONING_MODELS = [
|
|
149
|
+
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?",
|
|
150
|
+
r"(gemini/)?gemini.*thinking.*",
|
|
151
|
+
r"(gemini/)?gemini-2.5-pro.*",
|
|
152
|
+
]
|
|
139
153
|
|
|
140
154
|
|
|
141
155
|
class LiteLLMModel(BenchmarkModule):
|
|
@@ -167,12 +181,18 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
167
181
|
"ollama/"
|
|
168
182
|
) or model_config.model_id.startswith("ollama_chat/")
|
|
169
183
|
|
|
184
|
+
raise_if_wrong_params(model_config=model_config, allowed_params=ALLOWED_PARAMS)
|
|
185
|
+
|
|
170
186
|
super().__init__(
|
|
171
187
|
model_config=model_config,
|
|
172
188
|
dataset_config=dataset_config,
|
|
173
189
|
benchmark_config=benchmark_config,
|
|
174
190
|
)
|
|
175
191
|
|
|
192
|
+
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
193
|
+
dataset_config=self.dataset_config, tokenizer=None
|
|
194
|
+
)
|
|
195
|
+
|
|
176
196
|
@property
|
|
177
197
|
def generative_type(self) -> GenerativeType | None:
|
|
178
198
|
"""Get the generative type of the model.
|
|
@@ -180,7 +200,9 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
180
200
|
Returns:
|
|
181
201
|
The generative type of the model, or None if it has not been set yet.
|
|
182
202
|
"""
|
|
183
|
-
if
|
|
203
|
+
if self.model_config.revision == "thinking":
|
|
204
|
+
return GenerativeType.REASONING
|
|
205
|
+
elif re.fullmatch(
|
|
184
206
|
pattern="|".join(REASONING_MODELS), string=self.model_config.model_id
|
|
185
207
|
):
|
|
186
208
|
return GenerativeType.REASONING
|
|
@@ -218,7 +240,13 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
218
240
|
api_version=self.benchmark_config.api_version,
|
|
219
241
|
)
|
|
220
242
|
|
|
221
|
-
|
|
243
|
+
# Get the mapping from labels to the first token in the label. We call this each
|
|
244
|
+
# time we generate a new dataset since the dataset config can change
|
|
245
|
+
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
246
|
+
dataset_config=self.dataset_config, tokenizer=None
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if self.buffer["first_label_token_mapping"]:
|
|
222
250
|
generation_kwargs["logprobs"] = True
|
|
223
251
|
generation_kwargs["top_logprobs"] = MAX_LOGPROBS
|
|
224
252
|
|
|
@@ -227,6 +255,27 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
227
255
|
"Prompt must contain 'json' for JSON tasks."
|
|
228
256
|
)
|
|
229
257
|
generation_kwargs["response_format"] = dict(type="json_object")
|
|
258
|
+
log_once(
|
|
259
|
+
"Enabling JSON response format for model "
|
|
260
|
+
f"{self.model_config.model_id!r}",
|
|
261
|
+
level=logging.DEBUG,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if self.model_config.revision == "thinking":
|
|
265
|
+
generation_kwargs["thinking"] = dict(
|
|
266
|
+
type="enabled", budget_tokens=REASONING_MAX_TOKENS
|
|
267
|
+
)
|
|
268
|
+
log_once(
|
|
269
|
+
f"Enabling thinking mode for model {self.model_config.model_id!r}",
|
|
270
|
+
level=logging.DEBUG,
|
|
271
|
+
)
|
|
272
|
+
elif self.model_config.revision in {"low", "high"}:
|
|
273
|
+
generation_kwargs["reasoning_effort"] = self.model_config.revision
|
|
274
|
+
log_once(
|
|
275
|
+
f"Enabling reasoning effort {self.model_config.revision!r} for model "
|
|
276
|
+
f"{self.model_config.model_id!r}",
|
|
277
|
+
level=logging.DEBUG,
|
|
278
|
+
)
|
|
230
279
|
|
|
231
280
|
# This drops generation kwargs that are not supported by the model
|
|
232
281
|
litellm.drop_params = True
|
|
@@ -235,39 +284,60 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
235
284
|
# handle using newlines as stop sequences, so we try both.
|
|
236
285
|
num_attempts = 10
|
|
237
286
|
for _ in range(num_attempts):
|
|
287
|
+
stop_messages = ["stop_sequences"]
|
|
288
|
+
logprobs_messages = [
|
|
289
|
+
"you are not allowed to request logprobs",
|
|
290
|
+
"you've reached the maximum number of requests with logprobs",
|
|
291
|
+
"logprobs is not supported",
|
|
292
|
+
"logprobs is not enabled",
|
|
293
|
+
]
|
|
294
|
+
temperature_messages = [
|
|
295
|
+
"'temperature' is not supported with this model.",
|
|
296
|
+
"temperature is not supported with this model",
|
|
297
|
+
]
|
|
238
298
|
try:
|
|
239
299
|
model_response = litellm.completion(
|
|
240
300
|
messages=messages, max_retries=3, **generation_kwargs
|
|
241
301
|
)
|
|
242
302
|
break
|
|
243
|
-
except BadRequestError as e:
|
|
244
|
-
if
|
|
303
|
+
except (BadRequestError, RateLimitError) as e:
|
|
304
|
+
if any(msg.lower() in str(e).lower() for msg in stop_messages):
|
|
245
305
|
generation_kwargs["stop"] = None
|
|
246
|
-
elif "you are not allowed to request logprobs" in str(e).lower():
|
|
247
|
-
generation_kwargs.pop("logprobs")
|
|
248
|
-
generation_kwargs.pop("top_logprobs")
|
|
249
306
|
elif (
|
|
250
|
-
|
|
307
|
+
any(msg.lower() in str(e).lower() for msg in logprobs_messages)
|
|
308
|
+
# Special case for Vertex AI models, since they have strict rate
|
|
309
|
+
# limits on using logprobs. They also have a cap of 5 logprobs, but
|
|
310
|
+
# we ignore this since the rate limiting makes it unusable anyway.
|
|
311
|
+
or (isinstance(e, VertexAIError) and "logprobs" in str(e).lower())
|
|
251
312
|
):
|
|
313
|
+
generation_kwargs.pop("logprobs")
|
|
314
|
+
generation_kwargs.pop("top_logprobs")
|
|
315
|
+
elif any(msg.lower() in str(e).lower() for msg in temperature_messages):
|
|
252
316
|
generation_kwargs.pop("temperature")
|
|
317
|
+
elif isinstance(e, RateLimitError):
|
|
318
|
+
raise InvalidModel(
|
|
319
|
+
"You have encountered your rate limit for model "
|
|
320
|
+
f"{self.model_config.model_id!r}. Skipping."
|
|
321
|
+
)
|
|
253
322
|
else:
|
|
254
323
|
raise InvalidBenchmark(
|
|
255
324
|
f"Failed to generate text. The error message was: {e}"
|
|
256
325
|
)
|
|
326
|
+
except APIError as e:
|
|
327
|
+
raise InvalidBenchmark(
|
|
328
|
+
f"Failed to generate text. The error message was: {e}"
|
|
329
|
+
)
|
|
257
330
|
except (
|
|
331
|
+
APIConnectionError,
|
|
258
332
|
Timeout,
|
|
259
333
|
ServiceUnavailableError,
|
|
260
|
-
APIConnectionError,
|
|
261
334
|
InternalServerError,
|
|
262
|
-
):
|
|
335
|
+
) as e:
|
|
263
336
|
logger.debug(
|
|
264
|
-
"Service temporarily unavailable.
|
|
337
|
+
f"Service temporarily unavailable. The error message was: {e}. "
|
|
338
|
+
f"Retrying in 5 seconds..."
|
|
265
339
|
)
|
|
266
340
|
sleep(5)
|
|
267
|
-
except APIError as e:
|
|
268
|
-
raise InvalidBenchmark(
|
|
269
|
-
f"Failed to generate text. The error message was: {e}"
|
|
270
|
-
)
|
|
271
341
|
except AuthenticationError:
|
|
272
342
|
raise NeedsAdditionalArgument(
|
|
273
343
|
cli_argument="--api-key",
|
|
@@ -280,6 +350,15 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
280
350
|
)
|
|
281
351
|
|
|
282
352
|
assert isinstance(model_response, ModelResponse)
|
|
353
|
+
if not model_response.choices:
|
|
354
|
+
# This happens for reasoning models, when they don't finish thinking and run
|
|
355
|
+
# out of tokens. Happens quite rarely, but we need to handle it.
|
|
356
|
+
logger.warning(
|
|
357
|
+
f"The model {self.model_config.model_id!r} did not end up generating "
|
|
358
|
+
"any text. This is likely because the model ran out of tokens while "
|
|
359
|
+
"reasoning. Returning an empty string."
|
|
360
|
+
)
|
|
361
|
+
return GenerativeModelOutput(sequences=[""])
|
|
283
362
|
model_response_choices = model_response.choices[0]
|
|
284
363
|
assert isinstance(model_response_choices, litellm.Choices)
|
|
285
364
|
generation_output = model_response_choices.message["content"] or ""
|
|
@@ -288,14 +367,22 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
288
367
|
# Structure the model output as a GenerativeModelOutput object
|
|
289
368
|
model_output = GenerativeModelOutput(sequences=[generation_output])
|
|
290
369
|
if hasattr(model_response_choices, "logprobs"):
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
370
|
+
logprobs_obj = model_response_choices.logprobs
|
|
371
|
+
if isinstance(logprobs_obj, ChoiceLogprobs):
|
|
372
|
+
logprobs_list: list[list[tuple[str, float]]] = [
|
|
373
|
+
[
|
|
374
|
+
(top_logprob.token, top_logprob.logprob)
|
|
375
|
+
for top_logprob in content.top_logprobs
|
|
376
|
+
]
|
|
377
|
+
for content in model_response_choices.logprobs.content or list()
|
|
295
378
|
]
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
379
|
+
model_output.scores = [logprobs_list]
|
|
380
|
+
else:
|
|
381
|
+
log_once(
|
|
382
|
+
"The logprobs object is malformed, so we won't use logprobs to "
|
|
383
|
+
"determine the labels.",
|
|
384
|
+
level=logging.WARNING,
|
|
385
|
+
)
|
|
299
386
|
|
|
300
387
|
return model_output
|
|
301
388
|
|
|
@@ -314,7 +401,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
314
401
|
# If it is an Ollama model then we can get the number of parameters from the
|
|
315
402
|
# Ollama Python SDK
|
|
316
403
|
if self.is_ollama:
|
|
317
|
-
ollama_model_id = self.model_config.model_id.split("/")[
|
|
404
|
+
ollama_model_id = "/".join(self.model_config.model_id.split("/")[1:])
|
|
318
405
|
model_info = ollama.show(ollama_model_id).modelinfo
|
|
319
406
|
if model_info is not None:
|
|
320
407
|
num_params = model_info.get("general.parameter_count")
|
|
@@ -325,7 +412,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
325
412
|
# get the number of parameters from the Hugging Face model configuration from
|
|
326
413
|
# the Hugging Face Hub
|
|
327
414
|
if self.model_config.model_id.startswith("huggingface/"):
|
|
328
|
-
model_id = self.model_config.model_id.split(sep="/"
|
|
415
|
+
model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
|
|
329
416
|
if HuggingFaceEncoderModel.model_exists(
|
|
330
417
|
model_id=model_id, benchmark_config=self.benchmark_config
|
|
331
418
|
):
|
|
@@ -334,7 +421,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
334
421
|
num_labels=self.dataset_config.num_labels,
|
|
335
422
|
id2label=self.dataset_config.id2label,
|
|
336
423
|
label2id=self.dataset_config.label2id,
|
|
337
|
-
revision=
|
|
424
|
+
revision="main",
|
|
338
425
|
model_cache_dir=self.model_config.model_cache_dir,
|
|
339
426
|
api_key=self.benchmark_config.api_key,
|
|
340
427
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
@@ -345,7 +432,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
345
432
|
try:
|
|
346
433
|
repo_info = hf_api.model_info(
|
|
347
434
|
repo_id=model_id,
|
|
348
|
-
revision=
|
|
435
|
+
revision="main",
|
|
349
436
|
token=os.getenv("HUGGINGFACE_API_KEY")
|
|
350
437
|
or self.benchmark_config.api_key
|
|
351
438
|
or True,
|
|
@@ -389,7 +476,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
389
476
|
# get the vocabulary size from the Hugging Face model configuration from the
|
|
390
477
|
# Hugging Face Hub
|
|
391
478
|
if self.model_config.model_id.startswith("huggingface/"):
|
|
392
|
-
model_id = self.model_config.model_id.split(sep="/"
|
|
479
|
+
model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
|
|
393
480
|
if HuggingFaceEncoderModel.model_exists(
|
|
394
481
|
model_id=model_id, benchmark_config=self.benchmark_config
|
|
395
482
|
):
|
|
@@ -398,7 +485,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
398
485
|
num_labels=self.dataset_config.num_labels,
|
|
399
486
|
id2label=self.dataset_config.id2label,
|
|
400
487
|
label2id=self.dataset_config.label2id,
|
|
401
|
-
revision=
|
|
488
|
+
revision="main",
|
|
402
489
|
model_cache_dir=self.model_config.model_cache_dir,
|
|
403
490
|
api_key=self.benchmark_config.api_key,
|
|
404
491
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
@@ -442,7 +529,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
442
529
|
# If it is an Ollama model then we can get the maximum length from the Ollama
|
|
443
530
|
# Python SDK
|
|
444
531
|
if self.is_ollama:
|
|
445
|
-
ollama_model_id = self.model_config.model_id.split("/")[
|
|
532
|
+
ollama_model_id = "/".join(self.model_config.model_id.split("/")[1:])
|
|
446
533
|
model_info = ollama.show(ollama_model_id).modelinfo
|
|
447
534
|
if model_info is not None:
|
|
448
535
|
context_length_keys = [
|
|
@@ -469,7 +556,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
469
556
|
# get the maximum length from the Hugging Face model configuration from the
|
|
470
557
|
# Hugging Face Hub
|
|
471
558
|
if self.model_config.model_id.startswith("huggingface/"):
|
|
472
|
-
model_id = self.model_config.model_id.split(sep="/"
|
|
559
|
+
model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
|
|
473
560
|
if HuggingFaceEncoderModel.model_exists(
|
|
474
561
|
model_id=model_id, benchmark_config=self.benchmark_config
|
|
475
562
|
):
|
|
@@ -478,7 +565,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
478
565
|
num_labels=self.dataset_config.num_labels,
|
|
479
566
|
id2label=self.dataset_config.id2label,
|
|
480
567
|
label2id=self.dataset_config.label2id,
|
|
481
|
-
revision=
|
|
568
|
+
revision="main",
|
|
482
569
|
model_cache_dir=self.model_config.model_cache_dir,
|
|
483
570
|
api_key=self.benchmark_config.api_key,
|
|
484
571
|
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
@@ -563,6 +650,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
563
650
|
return partial(
|
|
564
651
|
sequence_classification.extract_labels_from_generation,
|
|
565
652
|
dataset_config=self.dataset_config,
|
|
653
|
+
first_label_token_mapping=self.buffer["first_label_token_mapping"],
|
|
566
654
|
)
|
|
567
655
|
case TaskGroup.TEXT_TO_TEXT:
|
|
568
656
|
return text_to_text.extract_labels_from_generation
|
|
@@ -605,45 +693,15 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
605
693
|
Whether the model exists, or an error describing why we cannot check
|
|
606
694
|
whether the model exists.
|
|
607
695
|
"""
|
|
696
|
+
model_id, _ = model_id.split("@") if "@" in model_id else (model_id, "main")
|
|
608
697
|
if model_id in litellm.model_list:
|
|
609
698
|
return True
|
|
610
699
|
|
|
611
|
-
#
|
|
700
|
+
# Separate check for Ollama models
|
|
612
701
|
if model_id.startswith("ollama/") or model_id.startswith("ollama_chat/"):
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
for model_obj in ollama.list().models
|
|
617
|
-
if model_obj.model is not None
|
|
618
|
-
]
|
|
619
|
-
if ollama_model_id not in downloaded_ollama_models:
|
|
620
|
-
try:
|
|
621
|
-
response = ollama.pull(model=ollama_model_id, stream=True)
|
|
622
|
-
with tqdm(
|
|
623
|
-
desc=f"Downloading {ollama_model_id}",
|
|
624
|
-
unit_scale=True,
|
|
625
|
-
unit="B",
|
|
626
|
-
leave=False,
|
|
627
|
-
) as pbar:
|
|
628
|
-
for status in response:
|
|
629
|
-
if status.total is not None:
|
|
630
|
-
pbar.total = status.total
|
|
631
|
-
if status.completed is not None:
|
|
632
|
-
pbar.update(status.completed - pbar.n)
|
|
633
|
-
except ollama.ResponseError as e:
|
|
634
|
-
if "file does not exist" in str(e).lower():
|
|
635
|
-
return False
|
|
636
|
-
else:
|
|
637
|
-
raise InvalidModel(
|
|
638
|
-
f"Failed to download Ollama model {ollama_model_id}. The "
|
|
639
|
-
f"error message was: {e}"
|
|
640
|
-
)
|
|
641
|
-
else:
|
|
642
|
-
log_once(
|
|
643
|
-
f"Ollama model {ollama_model_id!r} already downloaded, so skipping "
|
|
644
|
-
"download.",
|
|
645
|
-
level=logging.DEBUG,
|
|
646
|
-
)
|
|
702
|
+
ollama_model_exists = try_download_ollama_model(model_id=model_id)
|
|
703
|
+
if ollama_model_exists:
|
|
704
|
+
return ollama_model_exists
|
|
647
705
|
|
|
648
706
|
num_attempts = 10
|
|
649
707
|
for _ in range(num_attempts):
|
|
@@ -657,12 +715,27 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
657
715
|
api_version=benchmark_config.api_version,
|
|
658
716
|
)
|
|
659
717
|
return True
|
|
718
|
+
# A rate limit indicates that the model *does* exist, but we are being rate
|
|
719
|
+
# limited.
|
|
720
|
+
except RateLimitError:
|
|
721
|
+
return True
|
|
722
|
+
except (
|
|
723
|
+
APIConnectionError,
|
|
724
|
+
Timeout,
|
|
725
|
+
ServiceUnavailableError,
|
|
726
|
+
InternalServerError,
|
|
727
|
+
) as e:
|
|
728
|
+
logger.debug(
|
|
729
|
+
f"Service temporarily unavailable. The error message was: {e}. "
|
|
730
|
+
"Retrying in 10 seconds..."
|
|
731
|
+
)
|
|
732
|
+
sleep(5)
|
|
660
733
|
except APIError as e:
|
|
661
734
|
if "'503 Service Unavailable" not in str(e):
|
|
662
735
|
raise e
|
|
663
736
|
logger.warning(
|
|
664
|
-
f"Failed to check if model {model_id!r} exists. Retrying in "
|
|
665
|
-
|
|
737
|
+
f"Failed to check if model {model_id!r} exists. Retrying in 10 "
|
|
738
|
+
"seconds..."
|
|
666
739
|
)
|
|
667
740
|
sleep(10)
|
|
668
741
|
except (BadRequestError, NotFoundError):
|
|
@@ -708,9 +781,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
708
781
|
Returns:
|
|
709
782
|
The model configuration.
|
|
710
783
|
"""
|
|
784
|
+
model_id, revision = model_id.split("@") if "@" in model_id else (model_id, "")
|
|
711
785
|
return ModelConfig(
|
|
712
786
|
model_id=model_id,
|
|
713
|
-
revision=
|
|
787
|
+
revision=revision,
|
|
714
788
|
task="text-generation",
|
|
715
789
|
languages=list(),
|
|
716
790
|
merge=False,
|
|
@@ -1025,3 +1099,123 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
1025
1099
|
|
|
1026
1100
|
examples["messages"] = messages_list
|
|
1027
1101
|
return examples
|
|
1102
|
+
|
|
1103
|
+
|
|
1104
|
+
def raise_if_wrong_params(
|
|
1105
|
+
model_config: ModelConfig, allowed_params: dict[str, list[str]]
|
|
1106
|
+
) -> None:
|
|
1107
|
+
"""Raise an error if the model configuration has invalid parameters.
|
|
1108
|
+
|
|
1109
|
+
Args:
|
|
1110
|
+
model_config:
|
|
1111
|
+
The model configuration.
|
|
1112
|
+
allowed_params:
|
|
1113
|
+
The allowed parameters for the model.
|
|
1114
|
+
|
|
1115
|
+
Raises:
|
|
1116
|
+
InvalidModel:
|
|
1117
|
+
If the model configuration has invalid parameters.
|
|
1118
|
+
"""
|
|
1119
|
+
param = model_config.revision
|
|
1120
|
+
if param == "":
|
|
1121
|
+
return
|
|
1122
|
+
for model_regex, allowed_params_list in allowed_params.items():
|
|
1123
|
+
if re.fullmatch(pattern=model_regex, string=model_config.model_id):
|
|
1124
|
+
if param not in allowed_params_list:
|
|
1125
|
+
msg = (
|
|
1126
|
+
f"Invalid parameter {param!r} for model {model_config.model_id!r}."
|
|
1127
|
+
)
|
|
1128
|
+
if allowed_params_list:
|
|
1129
|
+
msg += f" Allowed parameters are: {', '.join(allowed_params_list)}."
|
|
1130
|
+
else:
|
|
1131
|
+
msg += " No parameters are allowed."
|
|
1132
|
+
raise InvalidModel(msg)
|
|
1133
|
+
return
|
|
1134
|
+
|
|
1135
|
+
|
|
1136
|
+
def try_download_ollama_model(model_id: str) -> bool:
|
|
1137
|
+
"""Try to download an Ollama model.
|
|
1138
|
+
|
|
1139
|
+
Args:
|
|
1140
|
+
model_id:
|
|
1141
|
+
The model ID. If the model does not start with "ollama/" or "ollama_chat/"
|
|
1142
|
+
then this function will return False.
|
|
1143
|
+
|
|
1144
|
+
Returns:
|
|
1145
|
+
Whether the model was downloaded successfully.
|
|
1146
|
+
"""
|
|
1147
|
+
if not (model_id.startswith("ollama/") or model_id.startswith("ollama_chat/")):
|
|
1148
|
+
return False
|
|
1149
|
+
|
|
1150
|
+
if model_id.startswith("ollama/"):
|
|
1151
|
+
log_once(
|
|
1152
|
+
"You're trying to benchmark a model with the old 'ollama/' prefix, which "
|
|
1153
|
+
"probably results in bad performance, as it doesn't use the model's chat "
|
|
1154
|
+
"template. If the model is not a chat model then just disregard this "
|
|
1155
|
+
"warning, but if it is a chat model then please cancel this run and "
|
|
1156
|
+
"use the 'ollama_chat/' prefix instead.",
|
|
1157
|
+
level=logging.WARNING,
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
downloaded_ollama_models: list[str] = [
|
|
1161
|
+
model_obj.model
|
|
1162
|
+
for model_obj in ollama.list().models
|
|
1163
|
+
if model_obj.model is not None
|
|
1164
|
+
]
|
|
1165
|
+
|
|
1166
|
+
ollama_model_id = "/".join(model_id.split("/")[1:])
|
|
1167
|
+
if ollama_model_id not in downloaded_ollama_models:
|
|
1168
|
+
# Try fetching the model info
|
|
1169
|
+
try:
|
|
1170
|
+
response = ollama.pull(model=ollama_model_id, stream=True)
|
|
1171
|
+
except ollama.ResponseError as e:
|
|
1172
|
+
if "file does not exist" in str(e).lower():
|
|
1173
|
+
# Check if the model exists if we prepend "hf.co/"
|
|
1174
|
+
try:
|
|
1175
|
+
ollama_model_id_with_prefix = f"hf.co/{ollama_model_id}"
|
|
1176
|
+
model_id_with_prefix = (
|
|
1177
|
+
f"{model_id.split('/')[0]}/{ollama_model_id_with_prefix}"
|
|
1178
|
+
)
|
|
1179
|
+
ollama.pull(model=ollama_model_id_with_prefix, stream=True)
|
|
1180
|
+
log_once(
|
|
1181
|
+
f"The model {model_id!r} cannot be found on Ollama, but the "
|
|
1182
|
+
f"model {model_id_with_prefix} *was* found, so we would "
|
|
1183
|
+
"recommend you cancelling this run and trying the evaluation "
|
|
1184
|
+
"with that model ID instead."
|
|
1185
|
+
)
|
|
1186
|
+
return False
|
|
1187
|
+
except ollama.ResponseError as inner_e:
|
|
1188
|
+
if "file does not exist" in str(inner_e).lower():
|
|
1189
|
+
return False
|
|
1190
|
+
else:
|
|
1191
|
+
raise InvalidModel(
|
|
1192
|
+
f"Failed to download Ollama model {ollama_model_id}. "
|
|
1193
|
+
f"The error message was: {inner_e}"
|
|
1194
|
+
)
|
|
1195
|
+
else:
|
|
1196
|
+
raise InvalidModel(
|
|
1197
|
+
f"Failed to download Ollama model {ollama_model_id}. "
|
|
1198
|
+
f"The error message was: {e}"
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
# Download the model
|
|
1202
|
+
with tqdm(
|
|
1203
|
+
desc=f"Downloading {ollama_model_id}",
|
|
1204
|
+
unit_scale=True,
|
|
1205
|
+
unit="B",
|
|
1206
|
+
leave=False,
|
|
1207
|
+
) as pbar:
|
|
1208
|
+
for status in response:
|
|
1209
|
+
if status.total is not None:
|
|
1210
|
+
pbar.total = status.total
|
|
1211
|
+
if status.completed is not None:
|
|
1212
|
+
pbar.update(status.completed - pbar.n)
|
|
1213
|
+
return True
|
|
1214
|
+
|
|
1215
|
+
else:
|
|
1216
|
+
log_once(
|
|
1217
|
+
f"Ollama model {ollama_model_id!r} already downloaded, so skipping "
|
|
1218
|
+
"download.",
|
|
1219
|
+
level=logging.DEBUG,
|
|
1220
|
+
)
|
|
1221
|
+
return True
|