EuroEval 15.4.2__py3-none-any.whl → 15.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (54) hide show
  1. euroeval/__init__.py +2 -2
  2. euroeval/benchmark_modules/base.py +3 -2
  3. euroeval/benchmark_modules/fresh.py +8 -6
  4. euroeval/benchmark_modules/hf.py +44 -33
  5. euroeval/benchmark_modules/litellm.py +314 -120
  6. euroeval/benchmark_modules/vllm.py +99 -59
  7. euroeval/benchmarker.py +52 -21
  8. euroeval/callbacks.py +2 -2
  9. euroeval/constants.py +9 -2
  10. euroeval/data_models.py +258 -44
  11. euroeval/dataset_configs/__init__.py +61 -0
  12. euroeval/dataset_configs/danish.py +120 -0
  13. euroeval/dataset_configs/dutch.py +123 -0
  14. euroeval/dataset_configs/english.py +88 -0
  15. euroeval/dataset_configs/faroese.py +53 -0
  16. euroeval/dataset_configs/french.py +83 -0
  17. euroeval/dataset_configs/german.py +91 -0
  18. euroeval/dataset_configs/icelandic.py +148 -0
  19. euroeval/dataset_configs/italian.py +81 -0
  20. euroeval/dataset_configs/norwegian.py +178 -0
  21. euroeval/dataset_configs/spanish.py +78 -0
  22. euroeval/dataset_configs/swedish.py +100 -0
  23. euroeval/exceptions.py +10 -10
  24. euroeval/finetuning.py +6 -10
  25. euroeval/generation.py +1 -0
  26. euroeval/human_evaluation.py +2 -2
  27. euroeval/languages.py +20 -13
  28. euroeval/model_cache.py +1 -1
  29. euroeval/model_loading.py +1 -12
  30. euroeval/prompt_templates/__init__.py +8 -0
  31. euroeval/prompt_templates/linguistic_acceptability.py +112 -0
  32. euroeval/prompt_templates/multiple_choice.py +97 -0
  33. euroeval/prompt_templates/named_entity_recognition.py +257 -0
  34. euroeval/prompt_templates/reading_comprehension.py +118 -0
  35. euroeval/prompt_templates/sentiment_classification.py +137 -0
  36. euroeval/prompt_templates/summarization.py +97 -0
  37. euroeval/speed_benchmark.py +1 -1
  38. euroeval/{task_utils → task_group_utils}/multiple_choice_classification.py +19 -11
  39. euroeval/{task_utils → task_group_utils}/question_answering.py +31 -30
  40. euroeval/{task_utils → task_group_utils}/sequence_classification.py +45 -10
  41. euroeval/{task_utils → task_group_utils}/text_to_text.py +1 -1
  42. euroeval/{task_utils → task_group_utils}/token_classification.py +3 -2
  43. euroeval/tasks.py +54 -0
  44. euroeval/tokenization_utils.py +343 -0
  45. euroeval/types.py +3 -1
  46. euroeval/utils.py +5 -254
  47. {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/METADATA +31 -9
  48. euroeval-15.6.0.dist-info/RECORD +59 -0
  49. euroeval/dataset_configs.py +0 -2408
  50. euroeval-15.4.2.dist-info/RECORD +0 -40
  51. /euroeval/{task_utils → task_group_utils}/__init__.py +0 -0
  52. {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/WHEEL +0 -0
  53. {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/entry_points.txt +0 -0
  54. {euroeval-15.4.2.dist-info → euroeval-15.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -27,20 +27,17 @@ from litellm.exceptions import (
27
27
  BadRequestError,
28
28
  InternalServerError,
29
29
  NotFoundError,
30
+ RateLimitError,
30
31
  ServiceUnavailableError,
31
32
  Timeout,
32
33
  )
33
- from litellm.types.utils import ModelResponse
34
+ from litellm.llms.vertex_ai.common_utils import VertexAIError
35
+ from litellm.types.utils import ChoiceLogprobs, ModelResponse
34
36
  from requests.exceptions import RequestException
35
37
  from tqdm.auto import tqdm
36
- from transformers import Trainer
38
+ from transformers.trainer import Trainer
37
39
 
38
- from ..constants import (
39
- MAX_LOGPROBS,
40
- REASONING_MAX_TOKENS,
41
- TASK_GROUPS_USING_LOGPROBS,
42
- TASKS_USING_JSON,
43
- )
40
+ from ..constants import MAX_LOGPROBS, REASONING_MAX_TOKENS, TASKS_USING_JSON
44
41
  from ..data_models import (
45
42
  BenchmarkConfig,
46
43
  DatasetConfig,
@@ -62,12 +59,13 @@ from ..exceptions import (
62
59
  NeedsEnvironmentVariable,
63
60
  NeedsExtraInstalled,
64
61
  )
65
- from ..task_utils import (
62
+ from ..task_group_utils import (
66
63
  question_answering,
67
64
  sequence_classification,
68
65
  text_to_text,
69
66
  token_classification,
70
67
  )
68
+ from ..tokenization_utils import get_first_label_token_mapping
71
69
  from ..types import ExtractLabelsFunction
72
70
  from ..utils import create_model_cache_dir, log_once
73
71
  from .base import BenchmarkModule
@@ -78,64 +76,80 @@ logger = logging.getLogger("euroeval")
78
76
 
79
77
  VOCAB_SIZE_MAPPING = {
80
78
  # OpenAI models
81
- "(text-)?(ada|babbage|curie|davinci)(-001)?": 50_257,
82
- "(code|text)-davinci-00[2-9]": 50_281,
83
- "gpt-3.5-turbo(-16k)?(-[0-9]{4})?": 100_256,
84
- "gpt-4-(32k)?(-[0-9]{4})?": 100_256,
85
- "gpt-4-[0-9]{4}-preview": 100_256,
86
- "gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
87
- "gpt-4-(vision|turbo)(-preview)?": 100_256,
88
- "gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
89
- "gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
90
- "o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
79
+ r"gpt-4-(32k)?(-[0-9]{4})?": 100_256,
80
+ r"gpt-4-[0-9]{4}-preview": 100_256,
81
+ r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
82
+ r"gpt-4-(vision|turbo)(-preview)?": 100_256,
83
+ r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
84
+ r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
85
+ r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
91
86
  # Anthropic models
92
- "claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
87
+ r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
88
+ # Gemini models
89
+ r"(gemini/)?gemini-[1-9]\.[0-9]-(flash|pro).*": 256_128,
90
+ # xAI models
91
+ r"(xai/)?grok.*": -1,
93
92
  }
94
93
 
95
94
 
96
95
  MODEL_MAX_LENGTH_MAPPING = {
97
96
  # OpenAI models
98
- "(text-)?(ada|babbage|curie|davinci)(-001)?": 2_050,
99
- "text-davinci-00[2-9]": 4_098,
100
- "code-davinci-00[1-9]": 8_002,
101
- "gpt-3.5-turbo-0613": 4_096,
102
- "gpt-3.5-turbo(-[0-9]{4})?": 16_385,
103
- "gpt-3.5-turbo-16k(-[0-9]{4})?": 16_384,
104
- "gpt-4(-[0-9]{4})?": 8_191,
105
- "gpt-4-32k(-[0-9]{4})?": 32_767,
106
- "gpt-4-[0-9]{4}-preview": 128_000,
107
- "gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
108
- "gpt-4-(vision|turbo)(-preview)?": 128_000,
109
- "gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
110
- "gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
111
- "o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
112
- "o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
113
- "o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
97
+ r"gpt-4(-[0-9]{4})?": 8_191,
98
+ r"gpt-4-32k(-[0-9]{4})?": 32_767,
99
+ r"gpt-4-[0-9]{4}-preview": 128_000,
100
+ r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
101
+ r"gpt-4-(vision|turbo)(-preview)?": 128_000,
102
+ r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
103
+ r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
104
+ r"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
105
+ r"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
106
+ r"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
114
107
  # Anthropic models
115
- "claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
108
+ r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
109
+ # Gemini models
110
+ r"(gemini/)?gemini-1\.5-flash.*": 1_048_576,
111
+ r"(gemini/)?gemini-1\.5-pro.*": 2_097_152,
112
+ r"(gemini/)?gemini-2\.(0|5).*": 1_048_576,
113
+ # xAI models
114
+ r"(xai/)?grok.*": 131_072,
116
115
  }
117
116
 
118
117
 
119
118
  NUM_PARAMS_MAPPING = {
120
119
  # OpenAI models
121
- "(text-)?ada(-001)?": 350_000_000,
122
- "(text-)?babbage(-001)?": 3_000_000_000,
123
- "(text-)?curie(-001)?": 13_000_000_000,
124
- "((text|code)-)?davinci(-00[1-9])?": 175_000_000_000,
125
- "gpt-(3.5|4)-turbo-((16|32)k)?(-[0-9]{4})?": -1,
126
- "gpt-4-[0-9]{4}-preview": -1,
127
- "gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
128
- "gpt-4-(vision|turbo)(-preview)?": -1,
129
- "gpt-3.5-turbo-instruct(-[0-9]{4})?": -1,
130
- "gpt-4o(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
131
- "gpt-4o-mini(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
132
- "o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
120
+ r"gpt-4.*": -1,
121
+ r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
122
+ # Anthropic models
123
+ r"(anthropic/)?claude-*": -1,
124
+ # Gemini models
125
+ r"(gemini/)?gemini-1.5-flash-8b": 8_000_000_000,
126
+ r"(gemini/)?gemini-1.5-flash-[0-9]+": -1,
127
+ r"(gemini/)?gemini-2.(0|5).*": -1,
128
+ # xAI models
129
+ r"(xai/)?grok.*": -1,
130
+ }
131
+
132
+
133
+ ALLOWED_PARAMS = {
134
+ # OpenAI models
135
+ r"gpt-4.*": [],
136
+ r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": ["low", "high"],
133
137
  # Anthropic models
134
- "claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
138
+ r"(anthropic/)?claude-3-.*": [],
139
+ r"(anthropic/)?claude-3.5-.*": [],
140
+ r"(anthropic/)?claude-3.7-sonnet.*": ["thinking"],
141
+ # Gemini models
142
+ r"(gemini/)?gemini-.*": [],
143
+ # xAI models
144
+ r"(xai/)?grok.*": [],
135
145
  }
136
146
 
137
147
 
138
- REASONING_MODELS = ["o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?"]
148
+ REASONING_MODELS = [
149
+ r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?",
150
+ r"(gemini/)?gemini.*thinking.*",
151
+ r"(gemini/)?gemini-2.5-pro.*",
152
+ ]
139
153
 
140
154
 
141
155
  class LiteLLMModel(BenchmarkModule):
@@ -167,12 +181,18 @@ class LiteLLMModel(BenchmarkModule):
167
181
  "ollama/"
168
182
  ) or model_config.model_id.startswith("ollama_chat/")
169
183
 
184
+ raise_if_wrong_params(model_config=model_config, allowed_params=ALLOWED_PARAMS)
185
+
170
186
  super().__init__(
171
187
  model_config=model_config,
172
188
  dataset_config=dataset_config,
173
189
  benchmark_config=benchmark_config,
174
190
  )
175
191
 
192
+ self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
193
+ dataset_config=self.dataset_config, tokenizer=None
194
+ )
195
+
176
196
  @property
177
197
  def generative_type(self) -> GenerativeType | None:
178
198
  """Get the generative type of the model.
@@ -180,7 +200,9 @@ class LiteLLMModel(BenchmarkModule):
180
200
  Returns:
181
201
  The generative type of the model, or None if it has not been set yet.
182
202
  """
183
- if re.fullmatch(
203
+ if self.model_config.revision == "thinking":
204
+ return GenerativeType.REASONING
205
+ elif re.fullmatch(
184
206
  pattern="|".join(REASONING_MODELS), string=self.model_config.model_id
185
207
  ):
186
208
  return GenerativeType.REASONING
@@ -218,7 +240,13 @@ class LiteLLMModel(BenchmarkModule):
218
240
  api_version=self.benchmark_config.api_version,
219
241
  )
220
242
 
221
- if self.dataset_config.task.task_group in TASK_GROUPS_USING_LOGPROBS:
243
+ # Get the mapping from labels to the first token in the label. We call this each
244
+ # time we generate a new dataset since the dataset config can change
245
+ self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
246
+ dataset_config=self.dataset_config, tokenizer=None
247
+ )
248
+
249
+ if self.buffer["first_label_token_mapping"]:
222
250
  generation_kwargs["logprobs"] = True
223
251
  generation_kwargs["top_logprobs"] = MAX_LOGPROBS
224
252
 
@@ -227,6 +255,27 @@ class LiteLLMModel(BenchmarkModule):
227
255
  "Prompt must contain 'json' for JSON tasks."
228
256
  )
229
257
  generation_kwargs["response_format"] = dict(type="json_object")
258
+ log_once(
259
+ "Enabling JSON response format for model "
260
+ f"{self.model_config.model_id!r}",
261
+ level=logging.DEBUG,
262
+ )
263
+
264
+ if self.model_config.revision == "thinking":
265
+ generation_kwargs["thinking"] = dict(
266
+ type="enabled", budget_tokens=REASONING_MAX_TOKENS
267
+ )
268
+ log_once(
269
+ f"Enabling thinking mode for model {self.model_config.model_id!r}",
270
+ level=logging.DEBUG,
271
+ )
272
+ elif self.model_config.revision in {"low", "high"}:
273
+ generation_kwargs["reasoning_effort"] = self.model_config.revision
274
+ log_once(
275
+ f"Enabling reasoning effort {self.model_config.revision!r} for model "
276
+ f"{self.model_config.model_id!r}",
277
+ level=logging.DEBUG,
278
+ )
230
279
 
231
280
  # This drops generation kwargs that are not supported by the model
232
281
  litellm.drop_params = True
@@ -235,39 +284,60 @@ class LiteLLMModel(BenchmarkModule):
235
284
  # handle using newlines as stop sequences, so we try both.
236
285
  num_attempts = 10
237
286
  for _ in range(num_attempts):
287
+ stop_messages = ["stop_sequences"]
288
+ logprobs_messages = [
289
+ "you are not allowed to request logprobs",
290
+ "you've reached the maximum number of requests with logprobs",
291
+ "logprobs is not supported",
292
+ "logprobs is not enabled",
293
+ ]
294
+ temperature_messages = [
295
+ "'temperature' is not supported with this model.",
296
+ "temperature is not supported with this model",
297
+ ]
238
298
  try:
239
299
  model_response = litellm.completion(
240
300
  messages=messages, max_retries=3, **generation_kwargs
241
301
  )
242
302
  break
243
- except BadRequestError as e:
244
- if "stop_sequences" in str(e).lower():
303
+ except (BadRequestError, RateLimitError) as e:
304
+ if any(msg.lower() in str(e).lower() for msg in stop_messages):
245
305
  generation_kwargs["stop"] = None
246
- elif "you are not allowed to request logprobs" in str(e).lower():
247
- generation_kwargs.pop("logprobs")
248
- generation_kwargs.pop("top_logprobs")
249
306
  elif (
250
- "'temperature' is not supported with this model." in str(e).lower()
307
+ any(msg.lower() in str(e).lower() for msg in logprobs_messages)
308
+ # Special case for Vertex AI models, since they have strict rate
309
+ # limits on using logprobs. They also have a cap of 5 logprobs, but
310
+ # we ignore this since the rate limiting makes it unusable anyway.
311
+ or (isinstance(e, VertexAIError) and "logprobs" in str(e).lower())
251
312
  ):
313
+ generation_kwargs.pop("logprobs")
314
+ generation_kwargs.pop("top_logprobs")
315
+ elif any(msg.lower() in str(e).lower() for msg in temperature_messages):
252
316
  generation_kwargs.pop("temperature")
317
+ elif isinstance(e, RateLimitError):
318
+ raise InvalidModel(
319
+ "You have encountered your rate limit for model "
320
+ f"{self.model_config.model_id!r}. Skipping."
321
+ )
253
322
  else:
254
323
  raise InvalidBenchmark(
255
324
  f"Failed to generate text. The error message was: {e}"
256
325
  )
326
+ except APIError as e:
327
+ raise InvalidBenchmark(
328
+ f"Failed to generate text. The error message was: {e}"
329
+ )
257
330
  except (
331
+ APIConnectionError,
258
332
  Timeout,
259
333
  ServiceUnavailableError,
260
- APIConnectionError,
261
334
  InternalServerError,
262
- ):
335
+ ) as e:
263
336
  logger.debug(
264
- "Service temporarily unavailable. Retrying in 5 seconds..."
337
+ f"Service temporarily unavailable. The error message was: {e}. "
338
+ f"Retrying in 5 seconds..."
265
339
  )
266
340
  sleep(5)
267
- except APIError as e:
268
- raise InvalidBenchmark(
269
- f"Failed to generate text. The error message was: {e}"
270
- )
271
341
  except AuthenticationError:
272
342
  raise NeedsAdditionalArgument(
273
343
  cli_argument="--api-key",
@@ -280,6 +350,15 @@ class LiteLLMModel(BenchmarkModule):
280
350
  )
281
351
 
282
352
  assert isinstance(model_response, ModelResponse)
353
+ if not model_response.choices:
354
+ # This happens for reasoning models, when they don't finish thinking and run
355
+ # out of tokens. Happens quite rarely, but we need to handle it.
356
+ logger.warning(
357
+ f"The model {self.model_config.model_id!r} did not end up generating "
358
+ "any text. This is likely because the model ran out of tokens while "
359
+ "reasoning. Returning an empty string."
360
+ )
361
+ return GenerativeModelOutput(sequences=[""])
283
362
  model_response_choices = model_response.choices[0]
284
363
  assert isinstance(model_response_choices, litellm.Choices)
285
364
  generation_output = model_response_choices.message["content"] or ""
@@ -288,14 +367,22 @@ class LiteLLMModel(BenchmarkModule):
288
367
  # Structure the model output as a GenerativeModelOutput object
289
368
  model_output = GenerativeModelOutput(sequences=[generation_output])
290
369
  if hasattr(model_response_choices, "logprobs"):
291
- logprobs_list: list[list[tuple[str, float]]] = [
292
- [
293
- (top_logprob.token, top_logprob.logprob)
294
- for top_logprob in content.top_logprobs
370
+ logprobs_obj = model_response_choices.logprobs
371
+ if isinstance(logprobs_obj, ChoiceLogprobs):
372
+ logprobs_list: list[list[tuple[str, float]]] = [
373
+ [
374
+ (top_logprob.token, top_logprob.logprob)
375
+ for top_logprob in content.top_logprobs
376
+ ]
377
+ for content in model_response_choices.logprobs.content or list()
295
378
  ]
296
- for content in model_response_choices.logprobs.content or list()
297
- ]
298
- model_output.scores = [logprobs_list]
379
+ model_output.scores = [logprobs_list]
380
+ else:
381
+ log_once(
382
+ "The logprobs object is malformed, so we won't use logprobs to "
383
+ "determine the labels.",
384
+ level=logging.WARNING,
385
+ )
299
386
 
300
387
  return model_output
301
388
 
@@ -314,7 +401,7 @@ class LiteLLMModel(BenchmarkModule):
314
401
  # If it is an Ollama model then we can get the number of parameters from the
315
402
  # Ollama Python SDK
316
403
  if self.is_ollama:
317
- ollama_model_id = self.model_config.model_id.split("/")[-1]
404
+ ollama_model_id = "/".join(self.model_config.model_id.split("/")[1:])
318
405
  model_info = ollama.show(ollama_model_id).modelinfo
319
406
  if model_info is not None:
320
407
  num_params = model_info.get("general.parameter_count")
@@ -325,7 +412,7 @@ class LiteLLMModel(BenchmarkModule):
325
412
  # get the number of parameters from the Hugging Face model configuration from
326
413
  # the Hugging Face Hub
327
414
  if self.model_config.model_id.startswith("huggingface/"):
328
- model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
415
+ model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
329
416
  if HuggingFaceEncoderModel.model_exists(
330
417
  model_id=model_id, benchmark_config=self.benchmark_config
331
418
  ):
@@ -334,7 +421,7 @@ class LiteLLMModel(BenchmarkModule):
334
421
  num_labels=self.dataset_config.num_labels,
335
422
  id2label=self.dataset_config.id2label,
336
423
  label2id=self.dataset_config.label2id,
337
- revision=self.model_config.revision,
424
+ revision="main",
338
425
  model_cache_dir=self.model_config.model_cache_dir,
339
426
  api_key=self.benchmark_config.api_key,
340
427
  trust_remote_code=self.benchmark_config.trust_remote_code,
@@ -345,7 +432,7 @@ class LiteLLMModel(BenchmarkModule):
345
432
  try:
346
433
  repo_info = hf_api.model_info(
347
434
  repo_id=model_id,
348
- revision=self.model_config.revision,
435
+ revision="main",
349
436
  token=os.getenv("HUGGINGFACE_API_KEY")
350
437
  or self.benchmark_config.api_key
351
438
  or True,
@@ -389,7 +476,7 @@ class LiteLLMModel(BenchmarkModule):
389
476
  # get the vocabulary size from the Hugging Face model configuration from the
390
477
  # Hugging Face Hub
391
478
  if self.model_config.model_id.startswith("huggingface/"):
392
- model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
479
+ model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
393
480
  if HuggingFaceEncoderModel.model_exists(
394
481
  model_id=model_id, benchmark_config=self.benchmark_config
395
482
  ):
@@ -398,7 +485,7 @@ class LiteLLMModel(BenchmarkModule):
398
485
  num_labels=self.dataset_config.num_labels,
399
486
  id2label=self.dataset_config.id2label,
400
487
  label2id=self.dataset_config.label2id,
401
- revision=self.model_config.revision,
488
+ revision="main",
402
489
  model_cache_dir=self.model_config.model_cache_dir,
403
490
  api_key=self.benchmark_config.api_key,
404
491
  trust_remote_code=self.benchmark_config.trust_remote_code,
@@ -442,7 +529,7 @@ class LiteLLMModel(BenchmarkModule):
442
529
  # If it is an Ollama model then we can get the maximum length from the Ollama
443
530
  # Python SDK
444
531
  if self.is_ollama:
445
- ollama_model_id = self.model_config.model_id.split("/")[-1]
532
+ ollama_model_id = "/".join(self.model_config.model_id.split("/")[1:])
446
533
  model_info = ollama.show(ollama_model_id).modelinfo
447
534
  if model_info is not None:
448
535
  context_length_keys = [
@@ -469,7 +556,7 @@ class LiteLLMModel(BenchmarkModule):
469
556
  # get the maximum length from the Hugging Face model configuration from the
470
557
  # Hugging Face Hub
471
558
  if self.model_config.model_id.startswith("huggingface/"):
472
- model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
559
+ model_id = "/".join(self.model_config.model_id.split(sep="/")[-2:])
473
560
  if HuggingFaceEncoderModel.model_exists(
474
561
  model_id=model_id, benchmark_config=self.benchmark_config
475
562
  ):
@@ -478,7 +565,7 @@ class LiteLLMModel(BenchmarkModule):
478
565
  num_labels=self.dataset_config.num_labels,
479
566
  id2label=self.dataset_config.id2label,
480
567
  label2id=self.dataset_config.label2id,
481
- revision=self.model_config.revision,
568
+ revision="main",
482
569
  model_cache_dir=self.model_config.model_cache_dir,
483
570
  api_key=self.benchmark_config.api_key,
484
571
  trust_remote_code=self.benchmark_config.trust_remote_code,
@@ -563,6 +650,7 @@ class LiteLLMModel(BenchmarkModule):
563
650
  return partial(
564
651
  sequence_classification.extract_labels_from_generation,
565
652
  dataset_config=self.dataset_config,
653
+ first_label_token_mapping=self.buffer["first_label_token_mapping"],
566
654
  )
567
655
  case TaskGroup.TEXT_TO_TEXT:
568
656
  return text_to_text.extract_labels_from_generation
@@ -605,45 +693,15 @@ class LiteLLMModel(BenchmarkModule):
605
693
  Whether the model exists, or an error describing why we cannot check
606
694
  whether the model exists.
607
695
  """
696
+ model_id, _ = model_id.split("@") if "@" in model_id else (model_id, "main")
608
697
  if model_id in litellm.model_list:
609
698
  return True
610
699
 
611
- # If it is an Ollama model then try to download it
700
+ # Separate check for Ollama models
612
701
  if model_id.startswith("ollama/") or model_id.startswith("ollama_chat/"):
613
- ollama_model_id = model_id.split("/")[-1]
614
- downloaded_ollama_models: list[str] = [
615
- model_obj.model
616
- for model_obj in ollama.list().models
617
- if model_obj.model is not None
618
- ]
619
- if ollama_model_id not in downloaded_ollama_models:
620
- try:
621
- response = ollama.pull(model=ollama_model_id, stream=True)
622
- with tqdm(
623
- desc=f"Downloading {ollama_model_id}",
624
- unit_scale=True,
625
- unit="B",
626
- leave=False,
627
- ) as pbar:
628
- for status in response:
629
- if status.total is not None:
630
- pbar.total = status.total
631
- if status.completed is not None:
632
- pbar.update(status.completed - pbar.n)
633
- except ollama.ResponseError as e:
634
- if "file does not exist" in str(e).lower():
635
- return False
636
- else:
637
- raise InvalidModel(
638
- f"Failed to download Ollama model {ollama_model_id}. The "
639
- f"error message was: {e}"
640
- )
641
- else:
642
- log_once(
643
- f"Ollama model {ollama_model_id!r} already downloaded, so skipping "
644
- "download.",
645
- level=logging.DEBUG,
646
- )
702
+ ollama_model_exists = try_download_ollama_model(model_id=model_id)
703
+ if ollama_model_exists:
704
+ return ollama_model_exists
647
705
 
648
706
  num_attempts = 10
649
707
  for _ in range(num_attempts):
@@ -657,12 +715,27 @@ class LiteLLMModel(BenchmarkModule):
657
715
  api_version=benchmark_config.api_version,
658
716
  )
659
717
  return True
718
+ # A rate limit indicates that the model *does* exist, but we are being rate
719
+ # limited.
720
+ except RateLimitError:
721
+ return True
722
+ except (
723
+ APIConnectionError,
724
+ Timeout,
725
+ ServiceUnavailableError,
726
+ InternalServerError,
727
+ ) as e:
728
+ logger.debug(
729
+ f"Service temporarily unavailable. The error message was: {e}. "
730
+ "Retrying in 10 seconds..."
731
+ )
732
+ sleep(5)
660
733
  except APIError as e:
661
734
  if "'503 Service Unavailable" not in str(e):
662
735
  raise e
663
736
  logger.warning(
664
- f"Failed to check if model {model_id!r} exists. Retrying in "
665
- f"{num_attempts} seconds..."
737
+ f"Failed to check if model {model_id!r} exists. Retrying in 10 "
738
+ "seconds..."
666
739
  )
667
740
  sleep(10)
668
741
  except (BadRequestError, NotFoundError):
@@ -708,9 +781,10 @@ class LiteLLMModel(BenchmarkModule):
708
781
  Returns:
709
782
  The model configuration.
710
783
  """
784
+ model_id, revision = model_id.split("@") if "@" in model_id else (model_id, "")
711
785
  return ModelConfig(
712
786
  model_id=model_id,
713
- revision="main",
787
+ revision=revision,
714
788
  task="text-generation",
715
789
  languages=list(),
716
790
  merge=False,
@@ -1025,3 +1099,123 @@ class LiteLLMModel(BenchmarkModule):
1025
1099
 
1026
1100
  examples["messages"] = messages_list
1027
1101
  return examples
1102
+
1103
+
1104
+ def raise_if_wrong_params(
1105
+ model_config: ModelConfig, allowed_params: dict[str, list[str]]
1106
+ ) -> None:
1107
+ """Raise an error if the model configuration has invalid parameters.
1108
+
1109
+ Args:
1110
+ model_config:
1111
+ The model configuration.
1112
+ allowed_params:
1113
+ The allowed parameters for the model.
1114
+
1115
+ Raises:
1116
+ InvalidModel:
1117
+ If the model configuration has invalid parameters.
1118
+ """
1119
+ param = model_config.revision
1120
+ if param == "":
1121
+ return
1122
+ for model_regex, allowed_params_list in allowed_params.items():
1123
+ if re.fullmatch(pattern=model_regex, string=model_config.model_id):
1124
+ if param not in allowed_params_list:
1125
+ msg = (
1126
+ f"Invalid parameter {param!r} for model {model_config.model_id!r}."
1127
+ )
1128
+ if allowed_params_list:
1129
+ msg += f" Allowed parameters are: {', '.join(allowed_params_list)}."
1130
+ else:
1131
+ msg += " No parameters are allowed."
1132
+ raise InvalidModel(msg)
1133
+ return
1134
+
1135
+
1136
+ def try_download_ollama_model(model_id: str) -> bool:
1137
+ """Try to download an Ollama model.
1138
+
1139
+ Args:
1140
+ model_id:
1141
+ The model ID. If the model does not start with "ollama/" or "ollama_chat/"
1142
+ then this function will return False.
1143
+
1144
+ Returns:
1145
+ Whether the model was downloaded successfully.
1146
+ """
1147
+ if not (model_id.startswith("ollama/") or model_id.startswith("ollama_chat/")):
1148
+ return False
1149
+
1150
+ if model_id.startswith("ollama/"):
1151
+ log_once(
1152
+ "You're trying to benchmark a model with the old 'ollama/' prefix, which "
1153
+ "probably results in bad performance, as it doesn't use the model's chat "
1154
+ "template. If the model is not a chat model then just disregard this "
1155
+ "warning, but if it is a chat model then please cancel this run and "
1156
+ "use the 'ollama_chat/' prefix instead.",
1157
+ level=logging.WARNING,
1158
+ )
1159
+
1160
+ downloaded_ollama_models: list[str] = [
1161
+ model_obj.model
1162
+ for model_obj in ollama.list().models
1163
+ if model_obj.model is not None
1164
+ ]
1165
+
1166
+ ollama_model_id = "/".join(model_id.split("/")[1:])
1167
+ if ollama_model_id not in downloaded_ollama_models:
1168
+ # Try fetching the model info
1169
+ try:
1170
+ response = ollama.pull(model=ollama_model_id, stream=True)
1171
+ except ollama.ResponseError as e:
1172
+ if "file does not exist" in str(e).lower():
1173
+ # Check if the model exists if we prepend "hf.co/"
1174
+ try:
1175
+ ollama_model_id_with_prefix = f"hf.co/{ollama_model_id}"
1176
+ model_id_with_prefix = (
1177
+ f"{model_id.split('/')[0]}/{ollama_model_id_with_prefix}"
1178
+ )
1179
+ ollama.pull(model=ollama_model_id_with_prefix, stream=True)
1180
+ log_once(
1181
+ f"The model {model_id!r} cannot be found on Ollama, but the "
1182
+ f"model {model_id_with_prefix} *was* found, so we would "
1183
+ "recommend you cancelling this run and trying the evaluation "
1184
+ "with that model ID instead."
1185
+ )
1186
+ return False
1187
+ except ollama.ResponseError as inner_e:
1188
+ if "file does not exist" in str(inner_e).lower():
1189
+ return False
1190
+ else:
1191
+ raise InvalidModel(
1192
+ f"Failed to download Ollama model {ollama_model_id}. "
1193
+ f"The error message was: {inner_e}"
1194
+ )
1195
+ else:
1196
+ raise InvalidModel(
1197
+ f"Failed to download Ollama model {ollama_model_id}. "
1198
+ f"The error message was: {e}"
1199
+ )
1200
+
1201
+ # Download the model
1202
+ with tqdm(
1203
+ desc=f"Downloading {ollama_model_id}",
1204
+ unit_scale=True,
1205
+ unit="B",
1206
+ leave=False,
1207
+ ) as pbar:
1208
+ for status in response:
1209
+ if status.total is not None:
1210
+ pbar.total = status.total
1211
+ if status.completed is not None:
1212
+ pbar.update(status.completed - pbar.n)
1213
+ return True
1214
+
1215
+ else:
1216
+ log_once(
1217
+ f"Ollama model {ollama_model_id!r} already downloaded, so skipping "
1218
+ "download.",
1219
+ level=logging.DEBUG,
1220
+ )
1221
+ return True