EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -2,8 +2,8 @@
2
2
 
3
3
  import asyncio
4
4
  import collections.abc as c
5
+ import json
5
6
  import logging
6
- import os
7
7
  import re
8
8
  import typing as t
9
9
  from functools import cached_property, partial
@@ -27,16 +27,23 @@ from litellm.exceptions import (
27
27
  RateLimitError,
28
28
  ServiceUnavailableError,
29
29
  Timeout,
30
+ UnsupportedParamsError,
30
31
  )
31
32
  from litellm.llms.vertex_ai.common_utils import VertexAIError
32
33
  from litellm.router import Router
33
- from litellm.types.utils import ChoiceLogprobs
34
+ from litellm.types.utils import ChoiceLogprobs, Logprobs
35
+ from litellm.utils import supports_reasoning, supports_response_schema
34
36
  from pydantic import conlist, create_model
35
37
  from requests.exceptions import RequestException
36
38
  from tqdm.asyncio import tqdm as tqdm_async
37
- from tqdm.auto import tqdm
38
39
 
39
- from ..constants import MAX_LOGPROBS, REASONING_MAX_TOKENS, TASKS_USING_JSON
40
+ from ..caching_utils import cache_arguments
41
+ from ..constants import (
42
+ JSON_STRIP_CHARACTERS,
43
+ LITELLM_CLASSIFICATION_OUTPUT_KEY,
44
+ MAX_LITELLM_LOGPROBS,
45
+ REASONING_MAX_TOKENS,
46
+ )
40
47
  from ..data_models import (
41
48
  BenchmarkConfig,
42
49
  DatasetConfig,
@@ -58,34 +65,40 @@ from ..exceptions import (
58
65
  NeedsEnvironmentVariable,
59
66
  NeedsExtraInstalled,
60
67
  )
61
- from ..generation_utils import apply_prompt, extract_few_shot_examples
68
+ from ..generation_utils import (
69
+ apply_prompt,
70
+ extract_few_shot_examples,
71
+ raise_if_wrong_params,
72
+ )
73
+ from ..logging_utils import get_pbar, log, log_once
62
74
  from ..task_group_utils import (
63
75
  question_answering,
64
76
  sequence_classification,
65
77
  text_to_text,
66
78
  token_classification,
67
79
  )
68
- from ..tokenization_utils import get_first_label_token_mapping
80
+ from ..tasks import NER
81
+ from ..tokenisation_utils import get_first_label_token_mapping
69
82
  from ..types import ExtractLabelsFunction
70
83
  from ..utils import (
71
84
  add_semaphore_and_catch_exception,
72
85
  create_model_cache_dir,
73
- log_once,
86
+ get_hf_token,
74
87
  safe_run,
88
+ split_model_id,
75
89
  )
76
90
  from .base import BenchmarkModule
77
- from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokenizer
91
+ from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokeniser
78
92
 
79
93
  if t.TYPE_CHECKING:
80
94
  from datasets import DatasetDict
81
95
  from litellm.types.utils import ModelResponse
82
96
  from transformers.trainer import Trainer
83
97
 
84
- logger = logging.getLogger("euroeval")
85
-
86
98
 
87
99
  VOCAB_SIZE_MAPPING = {
88
100
  # OpenAI models
101
+ r"gpt-5-.*": 100_256,
89
102
  r"gpt-4-(32k)?(-[0-9]{4})?": 100_256,
90
103
  r"gpt-4-[0-9]{4}-preview": 100_256,
91
104
  r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
@@ -104,6 +117,7 @@ VOCAB_SIZE_MAPPING = {
104
117
 
105
118
  MODEL_MAX_LENGTH_MAPPING = {
106
119
  # OpenAI models
120
+ r"gpt-5-.*": 272_000,
107
121
  r"gpt-4(-[0-9]{4})?": 8_191,
108
122
  r"gpt-4-32k(-[0-9]{4})?": 32_767,
109
123
  r"gpt-4-[0-9]{4}-preview": 128_000,
@@ -117,6 +131,7 @@ MODEL_MAX_LENGTH_MAPPING = {
117
131
  r"gpt-4.1.*": 1_047_576,
118
132
  # Anthropic models
119
133
  r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
134
+ r"(anthropic/)?claude-(opus|sonnet|haiku)-[1-9](-[1-9])?-[0-9]{8}": 200_000,
120
135
  # Gemini models
121
136
  r"(gemini/)?gemini-1\.5-flash.*": 1_048_576,
122
137
  r"(gemini/)?gemini-1\.5-pro.*": 2_097_152,
@@ -128,6 +143,7 @@ MODEL_MAX_LENGTH_MAPPING = {
128
143
 
129
144
  NUM_PARAMS_MAPPING = {
130
145
  # OpenAI models
146
+ r"gpt-5-.*": -1,
131
147
  r"gpt-4.*": -1,
132
148
  r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
133
149
  # Anthropic models
@@ -141,25 +157,32 @@ NUM_PARAMS_MAPPING = {
141
157
  }
142
158
 
143
159
 
144
- ALLOWED_PARAMS = {
145
- # OpenAI models
146
- r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": ["low", "medium", "high"],
147
- # Anthropic models
148
- r"(anthropic/)?claude-3-7-sonnet.*": ["no-thinking", "thinking"],
149
- r"(anthropic/)?claude-(sonnet|opus)-4.*": ["no-thinking", "thinking"],
150
- # Gemini models
151
- r"(gemini/)?gemini-2.5-flash-lite.*": ["no-thinking", "thinking"],
152
- r"(gemini/)?gemini-2.5-flash-[0-9].*": ["no-thinking", "thinking"],
153
- # xAI models
154
- r"(xai/)?grok-3-mini(-fast)?(-beta)?": ["low", "medium", "high"],
155
- }
156
-
157
-
158
160
  REASONING_MODELS = [
159
161
  r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?",
160
162
  r"(gemini/)?gemini.*thinking.*",
161
163
  r"(gemini/)?gemini-2.5.*",
162
164
  r"(xai/)?grok-3-mini.*",
165
+ r".*gpt-oss.*",
166
+ ]
167
+
168
+ BASE_DECODER_MODELS = [
169
+ r"gpt-3.5-turbo-instruct.*",
170
+ r"ada-[0-9]{3}",
171
+ r"babbage-[0-9]{3}",
172
+ r"curie-[0-9]{3}",
173
+ r"davinci-[0-9]{3}",
174
+ r"text-davinci-[0-9]{3}",
175
+ ]
176
+
177
+ CUSTOM_INFERENCE_API_PREFIXES = [
178
+ "hosted_vllm/",
179
+ "vllm/",
180
+ "ollama/",
181
+ "ollama_chat/",
182
+ "llamafile/",
183
+ "litellm_proxy/",
184
+ "lm_studio/",
185
+ "openai/",
163
186
  ]
164
187
 
165
188
 
@@ -169,12 +192,34 @@ class LiteLLMModel(BenchmarkModule):
169
192
  fresh_model = False
170
193
  batching_preference = BatchingPreference.ALL_AT_ONCE
171
194
  high_priority = False
195
+ allowed_params = {
196
+ # OpenAI models
197
+ re.compile(r"gpt-5-.*"): ["minimal", "low", "medium", "high"],
198
+ re.compile(r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?"): [
199
+ "low",
200
+ "medium",
201
+ "high",
202
+ ],
203
+ # Anthropic models
204
+ re.compile(r"(anthropic/)?claude-3-7-sonnet.*"): ["no-thinking", "thinking"],
205
+ re.compile(r"(anthropic/)?claude-(sonnet|opus)-4.*"): [
206
+ "no-thinking",
207
+ "thinking",
208
+ ],
209
+ # Gemini models
210
+ re.compile(r"(gemini/)?gemini-2.5-flash-lite.*"): ["no-thinking", "thinking"],
211
+ re.compile(r"(gemini/)?gemini-2.5-flash.*"): ["no-thinking", "thinking"],
212
+ # xAI models
213
+ re.compile(r"(xai/)?grok-3-mini(-fast)?(-beta)?"): ["low", "medium", "high"],
214
+ }
172
215
 
173
216
  def __init__(
174
217
  self,
175
218
  model_config: ModelConfig,
176
219
  dataset_config: DatasetConfig,
177
220
  benchmark_config: BenchmarkConfig,
221
+ log_metadata: bool = True,
222
+ **generation_kwargs: dict[str, t.Any],
178
223
  ) -> None:
179
224
  """Initialise the model.
180
225
 
@@ -185,7 +230,16 @@ class LiteLLMModel(BenchmarkModule):
185
230
  The dataset configuration.
186
231
  benchmark_config:
187
232
  The benchmark configuration.
233
+ log_metadata:
234
+ Whether to log the model metadata.
235
+ generation_kwargs:
236
+ The generation kwargs to pass to the model. If None, default values will
237
+ be used.
188
238
  """
239
+ raise_if_wrong_params(
240
+ model_config=model_config, allowed_params=self.allowed_params
241
+ )
242
+
189
243
  # Detect whether the model is an Ollama model, as we need to extract metadata
190
244
  # differently for these models
191
245
  self.is_ollama = model_config.model_id.startswith(
@@ -197,20 +251,22 @@ class LiteLLMModel(BenchmarkModule):
197
251
  else ollama.ShowResponse(model_info=None)
198
252
  )
199
253
 
200
- raise_if_wrong_params(model_config=model_config, allowed_params=ALLOWED_PARAMS)
201
-
202
254
  super().__init__(
203
255
  model_config=model_config,
204
256
  dataset_config=dataset_config,
205
257
  benchmark_config=benchmark_config,
258
+ log_metadata=log_metadata,
206
259
  )
207
260
 
261
+ self.generation_kwargs = generation_kwargs
208
262
  self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
209
263
  dataset_config=self.dataset_config,
210
264
  model_config=self.model_config,
211
- tokenizer=None,
265
+ tokeniser=None,
212
266
  generative_type=self.generative_type,
267
+ log_metadata=self.log_metadata,
213
268
  )
269
+ self.buffer["max_concurrent_calls"] = 20
214
270
 
215
271
  @property
216
272
  def generative_type(self) -> GenerativeType | None:
@@ -219,29 +275,43 @@ class LiteLLMModel(BenchmarkModule):
219
275
  Returns:
220
276
  The generative type of the model, or None if it has not been set yet.
221
277
  """
222
- if self.is_ollama:
278
+ if self.benchmark_config.generative_type is not None:
279
+ type_ = self.benchmark_config.generative_type
280
+ elif self.is_ollama:
223
281
  reasoning_model = "thinking" in (self._ollama_show.capabilities or [])
224
- type_ = (
225
- GenerativeType.REASONING
226
- if reasoning_model
227
- else GenerativeType.INSTRUCTION_TUNED
228
- )
229
- elif self.model_config.revision in {"thinking"}:
282
+ if reasoning_model:
283
+ type_ = GenerativeType.REASONING
284
+ elif self.model_config.model_id.startswith("ollama_chat/"):
285
+ type_ = GenerativeType.INSTRUCTION_TUNED
286
+ else:
287
+ type_ = GenerativeType.BASE
288
+ elif self.model_config.param in {"thinking"}:
230
289
  type_ = GenerativeType.REASONING
231
- elif self.model_config.revision in {"no-thinking"}:
290
+ elif self.model_config.param in {"no-thinking"}:
232
291
  type_ = GenerativeType.INSTRUCTION_TUNED
233
292
  elif re.fullmatch(
234
- pattern="|".join(REASONING_MODELS), string=self.model_config.model_id
293
+ pattern="|".join(REASONING_MODELS),
294
+ string=self.model_config.model_id,
295
+ flags=re.IGNORECASE,
235
296
  ):
236
297
  type_ = GenerativeType.REASONING
298
+ elif re.fullmatch(
299
+ pattern="|".join(BASE_DECODER_MODELS),
300
+ string=self.model_config.model_id,
301
+ flags=re.IGNORECASE,
302
+ ):
303
+ type_ = GenerativeType.BASE
304
+ elif supports_reasoning(model=self.model_config.model_id):
305
+ type_ = GenerativeType.REASONING
237
306
  else:
238
307
  type_ = GenerativeType.INSTRUCTION_TUNED
239
308
 
240
- log_once(
241
- f"Detected generative type {type_.name!r} for model "
242
- f"{self.model_config.model_id!r}",
243
- level=logging.DEBUG,
244
- )
309
+ if self.log_metadata:
310
+ log_once(
311
+ f"Detected generative type {type_.name!r} for model "
312
+ f"{self.model_config.model_id!r}",
313
+ level=logging.DEBUG,
314
+ )
245
315
  return type_
246
316
 
247
317
  def generate(self, inputs: dict) -> GenerativeModelOutput:
@@ -253,143 +323,52 @@ class LiteLLMModel(BenchmarkModule):
253
323
 
254
324
  Returns:
255
325
  The generated model outputs.
326
+
327
+ Raises:
328
+ InvalidBenchmark:
329
+ If the inputs do not contain either 'messages' or 'text' keys.
256
330
  """
257
- assert "messages" in inputs, "The input must contain a 'messages' key."
258
- conversations: list[list[litellm.AllMessageValues]] = inputs["messages"]
331
+ model_inputs: c.Sequence[c.Sequence[litellm.AllMessageValues] | str]
332
+ if "messages" in inputs:
333
+ model_inputs = inputs["messages"]
334
+ elif "text" in inputs:
335
+ model_inputs = inputs["text"]
336
+ else:
337
+ raise InvalidBenchmark(
338
+ "The inputs must contain either 'messages' or 'text' keys."
339
+ )
259
340
 
260
341
  # Get the mapping from labels to the first token in the label. We call this each
261
342
  # time we generate a new dataset since the dataset config can change
262
343
  self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
263
344
  dataset_config=self.dataset_config,
264
345
  model_config=self.model_config,
265
- tokenizer=None,
346
+ tokeniser=None,
266
347
  generative_type=self.generative_type,
348
+ log_metadata=self.log_metadata,
267
349
  )
268
350
 
269
- # Set the core generation arguments
270
- generation_kwargs: dict[str, t.Any] = dict(
271
- model=self.model_config.model_id,
272
- max_completion_tokens=(
273
- REASONING_MAX_TOKENS
274
- if self.generative_type == GenerativeType.REASONING
275
- else self.dataset_config.max_generated_tokens
276
- ),
277
- stop=[],
278
- temperature=0.0,
279
- seed=4242,
280
- api_key=self.benchmark_config.api_key,
281
- api_base=self.benchmark_config.api_base,
282
- api_version=self.benchmark_config.api_version,
283
- max_retries=3,
284
- )
285
-
286
- # Set up the `response_format` generation argument if we are dealing with a task
287
- # using structured generation
288
- if self.dataset_config.task in TASKS_USING_JSON:
289
- # Sanity check that "JSON" is included in the prompt, as some models require
290
- # this
291
- for conversation in conversations:
292
- if not conversation:
293
- raise InvalidBenchmark(
294
- "Encountered an empty conversation in 'messages'."
295
- )
296
- last_message = conversation[-1]
297
- assert isinstance(last_message, dict), (
298
- f"Expected dict message, got {type(last_message)}"
299
- )
300
- assert "content" in last_message, (
301
- "Expected 'content' key in the last message of the conversation."
302
- )
303
- assert isinstance(last_message["content"], str), (
304
- "Expected 'content' to be a string."
305
- )
306
- assert "json" in last_message["content"].lower(), (
307
- "Prompt must contain 'json' for JSON tasks."
308
- )
309
-
310
- if self.generative_type == GenerativeType.REASONING:
311
- log_once(
312
- f"The model {self.model_config.model_id!r} is a reasoning model "
313
- "and thus does not support structured generation, so we do not "
314
- "enable it.",
315
- level=logging.DEBUG,
316
- )
317
- elif litellm.utils.supports_response_schema(
318
- model=self.model_config.model_id
319
- ):
320
- ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
321
- keys_and_their_types: dict[str, t.Any] = {
322
- tag_name: (conlist(str, max_length=5), ...)
323
- for tag_name in ner_tag_names
324
- }
325
- pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
326
- generation_kwargs["response_format"] = pydantic_class
327
- log_once(
328
- "Enabling structured generation for model "
329
- f"{self.model_config.model_id!r} with the JSON schema "
330
- f"{pydantic_class.model_json_schema()}",
331
- level=logging.DEBUG,
332
- )
333
- else:
334
- generation_kwargs["response_format"] = dict(type="json_object")
335
- log_once(
336
- "Enabling structured JSON generation for model "
337
- f"{self.model_config.model_id!r} with no custom JSON schema, as "
338
- "the model does not support schemas.",
339
- level=logging.DEBUG,
340
- )
341
-
342
- # If the model is an Ollama reasoning model, we ensure that thinking is enabled
343
- if self.is_ollama and self.generative_type == GenerativeType.REASONING:
344
- generation_kwargs["think"] = True
345
- log_once(
346
- "Enabling thinking mode for Ollama model "
347
- f"{self.model_config.model_id!r}",
348
- level=logging.DEBUG,
349
- )
350
-
351
- # Handle manually set parameters
352
- if self.buffer["first_label_token_mapping"]:
353
- generation_kwargs["logprobs"] = True
354
- generation_kwargs["top_logprobs"] = MAX_LOGPROBS
355
- if self.model_config.revision == "thinking":
356
- generation_kwargs["thinking"] = dict(
357
- type="enabled", budget_tokens=REASONING_MAX_TOKENS - 1
358
- )
359
- log_once(
360
- f"Enabling thinking mode for model {self.model_config.model_id!r}",
361
- level=logging.DEBUG,
362
- )
363
- elif self.model_config.revision == "no-thinking":
364
- generation_kwargs["thinking"] = dict(type="disabled", budget_tokens=0)
365
- log_once(
366
- f"Disabling thinking mode for model {self.model_config.model_id!r}",
367
- level=logging.DEBUG,
368
- )
369
- elif self.model_config.revision in {"low", "medium", "high"}:
370
- generation_kwargs["reasoning_effort"] = self.model_config.revision
371
- log_once(
372
- f"Enabling reasoning effort {self.model_config.revision!r} for model "
373
- f"{self.model_config.model_id!r}",
374
- level=logging.DEBUG,
375
- )
376
-
377
- # Drop generation kwargs that are not supported by the model
378
- litellm.drop_params = True
379
-
380
351
  all_responses: dict[int, "ModelResponse"] = {}
381
- conversations_to_run: list[tuple[int, list[litellm.AllMessageValues]]] = list(
382
- enumerate(conversations)
383
- )
352
+ inputs_to_run: c.Sequence[
353
+ tuple[int, c.Sequence[litellm.AllMessageValues] | str]
354
+ ] = list(enumerate(model_inputs))
384
355
  for attempt in range(num_attempts := 10):
385
- if not conversations_to_run:
356
+ if not inputs_to_run:
386
357
  break
387
358
 
388
- batch_indices, batch_conversations = zip(*conversations_to_run)
359
+ generation_kwargs = self.generation_kwargs or self.get_generation_kwargs(
360
+ dataset_config=self.dataset_config
361
+ )
362
+
363
+ batch_indices, batch_inputs = zip(*inputs_to_run)
389
364
  successes, failures = safe_run(
390
365
  self._generate_async(
391
- model_id=self.model_config.model_id,
392
- conversations=list(batch_conversations),
366
+ model_id=clean_model_id(
367
+ model_id=self.model_config.model_id,
368
+ benchmark_config=self.benchmark_config,
369
+ ),
370
+ inputs=list(batch_inputs),
371
+ max_concurrent_calls=self.buffer["max_concurrent_calls"],
393
372
  **generation_kwargs,
394
373
  )
395
374
  )
@@ -401,23 +380,50 @@ class LiteLLMModel(BenchmarkModule):
401
380
 
402
381
  # If all requests were successful, break
403
382
  if not failures:
404
- conversations_to_run = []
383
+ inputs_to_run = []
405
384
  break
406
385
 
407
386
  # Put the failed requests back in the queue to try again
408
- conversations_to_run = [
409
- (batch_indices[idx], conversations[batch_indices[idx]])
387
+ inputs_to_run = [
388
+ (batch_indices[idx], model_inputs[batch_indices[idx]])
410
389
  for idx, _ in failures
411
390
  ]
412
- logger.debug(
391
+ log(
413
392
  f"Attempt {attempt + 1:,}/{num_attempts:,}: retrying "
414
- f"{len(conversations_to_run):,} failed message(s)"
393
+ f"{len(inputs_to_run):,} failed message(s). Here is the first error: "
394
+ f"{failures[0][1]}.",
395
+ level=logging.DEBUG,
415
396
  )
416
397
 
398
+ # Check if any errors are due to HTTP 429 (too many requests), in which case
399
+ # we reduce the number of concurrent calls
400
+ http_429_errors = [
401
+ idx
402
+ for idx, (_, error) in enumerate(failures)
403
+ if isinstance(error, RateLimitError) and "Error code: 429" in str(error)
404
+ ]
405
+ if http_429_errors and self.buffer["max_concurrent_calls"] > 1:
406
+ failures = [
407
+ failures[i]
408
+ for i in range(len(failures))
409
+ if i not in http_429_errors
410
+ ]
411
+ self.buffer["max_concurrent_calls"] = max(
412
+ 1, self.buffer["max_concurrent_calls"] // 2
413
+ )
414
+ log(
415
+ f"Reducing the maximum number of concurrent calls to "
416
+ f"{self.buffer['max_concurrent_calls']:,} due to rate limiting.",
417
+ level=logging.DEBUG,
418
+ )
419
+ continue
420
+
417
421
  # Attempt to handle the exceptions, to improve the chance of getting
418
422
  # successful generations next time around
419
423
  for _, error in failures:
420
- self._handle_exception(error=error, generation_kwargs=generation_kwargs)
424
+ generation_kwargs = self._handle_exception(
425
+ error=error, **generation_kwargs
426
+ )
421
427
 
422
428
  # Sleep for a second to avoid pinging the API server too quickly
423
429
  sleep(1)
@@ -427,22 +433,20 @@ class LiteLLMModel(BenchmarkModule):
427
433
  )
428
434
 
429
435
  # Extract the generations from the model output
430
- ordered_responses = [all_responses[i] for i in range(len(conversations))]
436
+ ordered_responses = [all_responses[i] for i in range(len(model_inputs))]
431
437
  model_output = self._create_model_output(
432
438
  model_responses=ordered_responses, model_id=self.model_config.model_id
433
439
  )
434
440
 
435
- if len(conversations) != len(model_output.sequences):
441
+ if len(model_inputs) != len(model_output.sequences):
436
442
  raise InvalidBenchmark(
437
- f"Number of model inputs ({len(conversations):,}) does not match the "
443
+ f"Number of model inputs ({len(model_inputs):,}) does not match the "
438
444
  f"number of model outputs ({len(model_output.sequences):,})."
439
445
  )
440
446
 
441
447
  return model_output
442
448
 
443
- def _handle_exception(
444
- self, error: Exception, generation_kwargs: dict[str, t.Any]
445
- ) -> None:
449
+ def _handle_exception(self, error: Exception, **generation_kwargs) -> dict:
446
450
  """Handle an exception from the model.
447
451
 
448
452
  Args:
@@ -450,26 +454,47 @@ class LiteLLMModel(BenchmarkModule):
450
454
  The exception to handle.
451
455
  generation_kwargs:
452
456
  The generation kwargs to pass to the model.
457
+
458
+ Returns:
459
+ The updated generation kwargs to pass to the model.
453
460
  """
454
461
  error_msg = str(error).lower()
455
462
  model_id = self.model_config.model_id
456
463
 
457
464
  # Error messages that we want to catch and handle
458
- stop_messages = ["stop_sequences", "'stop' is not supported with this model"]
465
+ stop_messages = [
466
+ "stop_sequences",
467
+ "'stop' is not supported with this model",
468
+ "'$.stop' is invalid",
469
+ ]
470
+ stop_pattern = re.compile(r"does not support parameters: \[.*'stop'.*\]")
459
471
  logprobs_messages = [
460
472
  "you are not allowed to request logprobs",
461
473
  "you've reached the maximum number of requests with logprobs",
462
474
  "logprobs is not supported",
463
475
  "logprobs is not enabled",
476
+ "Invalid value at 'generation_config.response_logprobs' (TYPE_BOOL)",
464
477
  ]
478
+ logprobs_pattern = re.compile(
479
+ r"does not support parameters: \[.*'logprobs'.*\]"
480
+ )
481
+ top_logprobs_messages = ["got an unexpected keyword argument 'top_logprobs'"]
482
+ top_logprobs_pattern = re.compile(
483
+ r"does not support parameters: \[.*'top_logprobs'.*\]"
484
+ )
485
+ max_completion_tokens_pattern = re.compile(
486
+ r"does not support parameters: \[.*'max_completion_tokens'.*\]"
487
+ )
465
488
  temperature_messages = [
466
489
  "'temperature' is not supported with this model.",
467
490
  "temperature is not supported with this model",
491
+ r"does not support parameters: \[.*'temperature'.*\]",
468
492
  ]
469
493
  temperature_must_be_one_messages = [
470
494
  "`temperature` may only be set to 1",
471
495
  "'temperature' does not support 0.0 with this model. Only the default "
472
496
  "(1) value is supported",
497
+ "Only temperature=1 is supported",
473
498
  ]
474
499
  max_items_messages = ["'maxItems' is not permitted."]
475
500
  no_json_schema_messages = ["Property keys should match pattern"]
@@ -477,17 +502,27 @@ class LiteLLMModel(BenchmarkModule):
477
502
  r"the thinking budget [0-9]+ is invalid. please choose a value between "
478
503
  r"[0-9]+ and ([0-9]+)\."
479
504
  )
505
+ requires_thinking_disabled_messages = ["thinking.type: Field required"]
506
+ seed_pattern = re.compile(r"does not support parameters: \[.*'seed'.*\]")
507
+ response_format_messages = [
508
+ "got an unexpected keyword argument 'response_format'",
509
+ "the model returned empty outputs",
510
+ ]
480
511
 
481
- if any(msg.lower() in error_msg for msg in stop_messages):
512
+ if (
513
+ any(msg.lower() in error_msg for msg in stop_messages)
514
+ or stop_pattern.search(string=error_msg) is not None
515
+ ):
482
516
  log_once(
483
517
  f"The model {model_id!r} does not support "
484
518
  "stop sequences, so disabling them.",
485
519
  level=logging.DEBUG,
486
520
  )
487
521
  generation_kwargs["stop"] = None
488
- return
522
+ return generation_kwargs
489
523
  elif (
490
524
  any(msg.lower() in error_msg for msg in logprobs_messages)
525
+ or logprobs_pattern.search(string=error_msg) is not None
491
526
  # Special case for Vertex AI models, since they have strict rate
492
527
  # limits on using logprobs. They also have a cap of 5 logprobs, but
493
528
  # we ignore this since the rate limiting makes it unusable anyway.
@@ -497,9 +532,32 @@ class LiteLLMModel(BenchmarkModule):
497
532
  f"The model {model_id!r} does not support logprobs, so disabling it.",
498
533
  level=logging.DEBUG,
499
534
  )
535
+ self.buffer["first_label_token_mapping"] = False
500
536
  generation_kwargs.pop("logprobs", None)
501
537
  generation_kwargs.pop("top_logprobs", None)
502
- return
538
+ generation_kwargs.pop("response_format", None)
539
+ return generation_kwargs
540
+ elif (
541
+ any(msg.lower() in error_msg for msg in top_logprobs_messages)
542
+ or top_logprobs_pattern.search(string=error_msg) is not None
543
+ ):
544
+ log_once(
545
+ f"The model {model_id!r} does not support the `top_logprobs` argument, "
546
+ "so moving the value to `logprobs`.",
547
+ level=logging.DEBUG,
548
+ )
549
+ generation_kwargs["logprobs"] = generation_kwargs.pop("top_logprobs", None)
550
+ return generation_kwargs
551
+ elif max_completion_tokens_pattern.search(string=error_msg):
552
+ log_once(
553
+ f"The model {model_id!r} does not support max_completion_tokens, so "
554
+ "disabling it.",
555
+ level=logging.DEBUG,
556
+ )
557
+ generation_kwargs["max_tokens"] = generation_kwargs.pop(
558
+ "max_completion_tokens", None
559
+ )
560
+ return generation_kwargs
503
561
  elif any(msg.lower() in error_msg for msg in temperature_messages):
504
562
  log_once(
505
563
  f"The model {model_id!r} does not support "
@@ -507,7 +565,7 @@ class LiteLLMModel(BenchmarkModule):
507
565
  level=logging.DEBUG,
508
566
  )
509
567
  generation_kwargs.pop("temperature", None)
510
- return
568
+ return generation_kwargs
511
569
  elif any(msg.lower() in error_msg for msg in temperature_must_be_one_messages):
512
570
  log_once(
513
571
  f"The model {model_id!r} requires "
@@ -515,8 +573,11 @@ class LiteLLMModel(BenchmarkModule):
515
573
  level=logging.DEBUG,
516
574
  )
517
575
  generation_kwargs["temperature"] = 1.0
518
- return
519
- elif any(msg.lower() in error_msg for msg in max_items_messages):
576
+ return generation_kwargs
577
+ elif (
578
+ any(msg.lower() in error_msg for msg in max_items_messages)
579
+ and self.dataset_config.task == NER
580
+ ):
520
581
  log_once(
521
582
  f"The model {model_id!r} does not support "
522
583
  "maxItems in the JSON schema, so disabling it.",
@@ -524,11 +585,11 @@ class LiteLLMModel(BenchmarkModule):
524
585
  )
525
586
  ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
526
587
  keys_and_their_types = {
527
- tag_name: (list[str], ...) for tag_name in ner_tag_names
588
+ tag_name: (c.Sequence[str], ...) for tag_name in ner_tag_names
528
589
  }
529
590
  pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
530
591
  generation_kwargs["response_format"] = pydantic_class
531
- return
592
+ return generation_kwargs
532
593
  elif any(msg.lower() in error_msg for msg in no_json_schema_messages):
533
594
  log_once(
534
595
  f"The model {self.model_config.model_id!r} does not support "
@@ -536,7 +597,7 @@ class LiteLLMModel(BenchmarkModule):
536
597
  level=logging.DEBUG,
537
598
  )
538
599
  generation_kwargs["response_format"] = dict(type="json_object")
539
- return
600
+ return generation_kwargs
540
601
  elif thinking_match := thinking_budget_pattern.search(string=error_msg):
541
602
  thinking_budget = int(thinking_match.group(1))
542
603
  if thinking_budget >= REASONING_MAX_TOKENS:
@@ -545,7 +606,7 @@ class LiteLLMModel(BenchmarkModule):
545
606
  f"{thinking_budget:,} tokens, which is within the limit of "
546
607
  f"{REASONING_MAX_TOKENS:,} tokens. This should not happen. The "
547
608
  f"error message was: {error_msg}."
548
- )
609
+ ) from error
549
610
  log_once(
550
611
  f"The model {model_id!r} can at most use {thinking_budget:,} tokens "
551
612
  "for reasoning, which is less than the default of "
@@ -556,59 +617,135 @@ class LiteLLMModel(BenchmarkModule):
556
617
  generation_kwargs["thinking"] = dict(
557
618
  type="enabled", budget_tokens=thinking_budget - 1
558
619
  )
559
- return
620
+ return generation_kwargs
621
+ elif (
622
+ any(msg.lower() in error_msg for msg in requires_thinking_disabled_messages)
623
+ and self.generative_type != GenerativeType.REASONING
624
+ ):
625
+ log_once(
626
+ f"The model {model_id!r} requires the `thinking.type` field to be "
627
+ f"set to `disabled` rather than just setting `budget_tokens` to 0. "
628
+ "Setting `thinking.type` to `disabled`.",
629
+ level=logging.DEBUG,
630
+ )
631
+ generation_kwargs["thinking"] = dict(type="disabled")
632
+ return generation_kwargs
633
+ elif re.search(pattern=seed_pattern, string=error_msg):
634
+ log_once(
635
+ f"The model {model_id!r} does not support the `seed` parameter, so "
636
+ "disabling it.",
637
+ level=logging.DEBUG,
638
+ )
639
+ generation_kwargs.pop("seed", None)
640
+ return generation_kwargs
641
+ elif any(msg.lower() in error_msg for msg in response_format_messages):
642
+ log_once(
643
+ f"The model {model_id!r} does not support the `response_format` "
644
+ "parameter, so disabling it.",
645
+ level=logging.DEBUG,
646
+ )
647
+ generation_kwargs.pop("response_format", None)
648
+ return generation_kwargs
649
+ # If there are too many I/O connections, we increase the number of allowed file
650
+ # descriptors
651
+ elif "too many open files" in error_msg:
652
+ raise InvalidBenchmark(
653
+ "There are too many file descriptors running. See the current "
654
+ "value by running `ulimit -n`. Try increasing it by running "
655
+ "`ulimit -n <new-value>` and try again."
656
+ ) from error
560
657
  elif isinstance(
561
658
  error, (Timeout, ServiceUnavailableError, InternalServerError, SystemError)
562
659
  ):
563
- logger.debug(
660
+ log(
564
661
  f"Service temporarily unavailable. The error message was: {error}. "
565
- f"Retrying in 5 seconds..."
662
+ "Retrying in 10 seconds...",
663
+ level=logging.DEBUG,
664
+ )
665
+ sleep(10)
666
+ return generation_kwargs
667
+ elif isinstance(error, UnsupportedParamsError):
668
+ unsupported_param_match = re.search(
669
+ pattern=r"(?<=does not support parameters\: \[')([^ ']+)(?='\])",
670
+ string=error.message,
566
671
  )
567
- sleep(5)
568
- return
672
+ if unsupported_param_match is None:
673
+ raise InvalidModel(error.message) from error
674
+ else:
675
+ unsupported_param = unsupported_param_match.group(0)
676
+ raise InvalidModel(
677
+ f"The model {model_id!r} does not support the parameter "
678
+ f"{unsupported_param!r}. Try again without this parameter. "
679
+ "Skipping this model."
680
+ ) from error
569
681
  elif isinstance(error, (APIConnectionError, OSError)):
570
- # If there are too many I/O connections, we increase the number of allowed
571
- # file descriptors
572
- if "too many open files" in error_msg:
573
- raise InvalidBenchmark(
574
- "There are too many file descriptors running. See the current "
575
- "value by running `ulimit -n`. Try increasing it by running "
576
- "`ulimit -n <new-value>` and try again."
577
- )
578
682
  raise InvalidBenchmark(
579
683
  f"Encountered {type(error)} during generation: {error}."
580
- )
684
+ ) from error
581
685
 
582
- if isinstance(error, RateLimitError):
686
+ if isinstance(error, NotFoundError):
583
687
  raise InvalidModel(
688
+ f"The model {model_id!r} was not found. Please check the model ID "
689
+ "and try again."
690
+ ) from error
691
+
692
+ if isinstance(error, RateLimitError):
693
+ log(
584
694
  f"You have encountered your rate limit for model {model_id!r}. "
585
- "Skipping."
695
+ "Retrying in 10 seconds...",
696
+ level=logging.DEBUG,
697
+ )
698
+ sleep(10)
699
+ return generation_kwargs
700
+
701
+ if (
702
+ isinstance(error, BadRequestError)
703
+ and (
704
+ retry_match := re.search(
705
+ pattern=r"\bretry in ([0-9]+(.[0-9]+)?) ?(s|seconds)\b",
706
+ string=error_msg,
707
+ )
708
+ )
709
+ is not None
710
+ ):
711
+ retry_seconds = float(retry_match.group(1))
712
+ log(
713
+ f"Bad request error encountered. Retrying in {retry_seconds:.1f} "
714
+ "seconds...",
715
+ level=logging.DEBUG,
586
716
  )
717
+ sleep(retry_seconds)
718
+ return generation_kwargs
587
719
 
588
720
  if isinstance(error, AuthenticationError):
589
721
  raise NeedsAdditionalArgument(
590
722
  cli_argument="--api-key",
591
723
  script_argument="api_key=<your-api-key>",
592
724
  run_with_cli=self.benchmark_config.run_with_cli,
593
- )
725
+ ) from error
594
726
 
595
727
  raise InvalidBenchmark(
596
728
  f"Failed to generate text. The error message was: {error}"
597
- )
729
+ ) from error
598
730
 
599
731
  async def _generate_async(
600
732
  self,
601
733
  model_id: str,
602
- conversations: list[list[litellm.AllMessageValues]],
734
+ inputs: c.Sequence[c.Sequence[litellm.AllMessageValues] | str],
735
+ max_concurrent_calls: int,
603
736
  **generation_kwargs,
604
- ) -> tuple[list[tuple[int, "ModelResponse"]], list[tuple[int, Exception]]]:
737
+ ) -> tuple[
738
+ c.Sequence[tuple[int, "ModelResponse"]], c.Sequence[tuple[int, Exception]]
739
+ ]:
605
740
  """Generate outputs from the model asynchronously.
606
741
 
607
742
  Args:
608
743
  model_id:
609
744
  The ID of the model to use for generation.
610
- conversations:
611
- The conversations to pass to the model.
745
+ inputs:
746
+ The inputs to pass to the model.
747
+ max_concurrent_calls:
748
+ The maximum number of concurrent calls to make to the model.
612
749
  **generation_kwargs:
613
750
  Additional generation arguments to pass to the model.
614
751
 
@@ -621,24 +758,67 @@ class LiteLLMModel(BenchmarkModule):
621
758
  # for all the requests, preventing "too many open files" errors
622
759
  router = Router(
623
760
  model_list=[
624
- dict(
761
+ litellm.DeploymentTypedDict(
625
762
  model_name=self.model_config.model_id,
626
- litellm_params=generation_kwargs,
763
+ litellm_params=litellm.LiteLLMParamsTypedDict(model=model_id),
627
764
  )
628
765
  ]
629
766
  )
630
767
 
631
768
  # Get the LLM generations asynchronously
632
- max_concurrent_calls = 20
633
769
  semaphore = asyncio.Semaphore(max_concurrent_calls)
634
- requests = [
635
- add_semaphore_and_catch_exception(
636
- router.acompletion(model=model_id, messages=conversation),
637
- semaphore=semaphore,
770
+ if self.generative_type == GenerativeType.BASE:
771
+ if not all(isinstance(input_, str) for input_ in inputs):
772
+ raise InvalidBenchmark(
773
+ "For base generative models, all inputs must be strings."
774
+ )
775
+ requests = [
776
+ add_semaphore_and_catch_exception(
777
+ router.atext_completion(
778
+ model=clean_model_id(
779
+ model_id=model_id, benchmark_config=self.benchmark_config
780
+ ),
781
+ prompt=input_,
782
+ **generation_kwargs,
783
+ ),
784
+ semaphore=semaphore,
785
+ )
786
+ for input_ in inputs
787
+ if isinstance(input_, str)
788
+ ]
789
+ else:
790
+ if not all(isinstance(input_, list) for input_ in inputs):
791
+ raise InvalidBenchmark(
792
+ "For instruction-tuned and reasoning generative models, all "
793
+ "inputs must be lists of messages."
794
+ )
795
+ requests = [
796
+ add_semaphore_and_catch_exception(
797
+ router.acompletion(
798
+ model=clean_model_id(
799
+ model_id=model_id, benchmark_config=self.benchmark_config
800
+ ),
801
+ messages=input_,
802
+ **generation_kwargs,
803
+ ),
804
+ semaphore=semaphore,
805
+ )
806
+ for input_ in inputs
807
+ if isinstance(input_, list)
808
+ ]
809
+ responses = await tqdm_async.gather(
810
+ *requests, colour="yellow", ascii="—▰", leave=False
811
+ )
812
+
813
+ # If the outputs are empty, convert them to exceptions
814
+ if all(
815
+ not isinstance(response, Exception)
816
+ and response.choices[0].message.content == "{}"
817
+ for response in responses
818
+ ):
819
+ responses = [ValueError("The model returned empty outputs.")] * len(
820
+ responses
638
821
  )
639
- for conversation in conversations
640
- ]
641
- responses = await tqdm_async.gather(*requests, leave=False)
642
822
 
643
823
  # Separate the successful responses from the failed ones
644
824
  successes = [
@@ -655,13 +835,18 @@ class LiteLLMModel(BenchmarkModule):
655
835
  # Close connections
656
836
  for request in requests:
657
837
  if hasattr(request, "close"):
658
- request.close()
838
+ try:
839
+ request.close()
840
+ except RuntimeError as e:
841
+ log(
842
+ f"RuntimeError during request.close(): {e}", level=logging.DEBUG
843
+ )
659
844
 
660
845
  return successes, failures
661
846
 
662
847
  @staticmethod
663
848
  def _create_model_output(
664
- model_responses: list["ModelResponse"], model_id: str
849
+ model_responses: c.Sequence["ModelResponse"], model_id: str
665
850
  ) -> GenerativeModelOutput:
666
851
  """Create a GenerativeModelOutput object from a list of ModelResponse objects.
667
852
 
@@ -680,45 +865,104 @@ class LiteLLMModel(BenchmarkModule):
680
865
  for model_response in model_responses:
681
866
  if not model_response.choices:
682
867
  sequences.append("")
683
- logger.warning(
868
+ log(
684
869
  f"The model {model_id!r} did not end up "
685
870
  "generating any text. This is likely because the model ran "
686
- "out of tokens while reasoning. Returning an empty string."
871
+ "out of tokens while reasoning. Returning an empty string.",
872
+ level=logging.WARNING,
687
873
  )
688
874
  continue
689
875
 
690
876
  model_response_choices = model_response.choices[0]
691
- assert isinstance(model_response_choices, litellm.Choices)
692
- generated_message: litellm.Message = model_response_choices.message
693
- generation_output = generated_message.content or ""
694
- generation_output = generation_output.strip()
877
+
878
+ if isinstance(model_response_choices, litellm.Choices):
879
+ generated_message: litellm.Message = model_response_choices.message
880
+ generation_output = generated_message.content or ""
881
+ generation_output = generation_output.strip()
882
+ elif isinstance(model_response_choices, litellm.litellm.TextChoices):
883
+ generation_output = model_response_choices.text or ""
884
+ else:
885
+ raise InvalidBenchmark(
886
+ "The model response choices must be of type Choices or "
887
+ f"TextChoices. Got {type(model_response_choices)}."
888
+ )
889
+
890
+ # In the case where we're dealing with a classification task, the model is
891
+ # outputting a JSON dictionary, so we will extract the generated text from
892
+ # within the dictionary
893
+ generation_dct: dict[str, t.Any] | None = None
894
+ if LITELLM_CLASSIFICATION_OUTPUT_KEY in generation_output:
895
+ try:
896
+ generation_dct = json.loads(generation_output)
897
+ assert isinstance(generation_dct, dict)
898
+ if set(generation_dct.keys()) == {
899
+ LITELLM_CLASSIFICATION_OUTPUT_KEY
900
+ }:
901
+ generation_output = str(
902
+ generation_dct[LITELLM_CLASSIFICATION_OUTPUT_KEY]
903
+ ).strip()
904
+ except json.JSONDecodeError:
905
+ pass
695
906
 
696
907
  # Structure the model output as a GenerativeModelOutput object
697
908
  sequences.append(generation_output)
698
- if hasattr(model_response_choices, "logprobs"):
909
+ if (
910
+ hasattr(model_response_choices, "logprobs")
911
+ and model_response_choices.logprobs is not None
912
+ ):
699
913
  logprobs_obj = model_response_choices.logprobs
914
+
915
+ if not isinstance(logprobs_obj, (Logprobs, ChoiceLogprobs)):
916
+ log_once(
917
+ "The logprobs object is malformed, so we won't use logprobs to "
918
+ "determine the labels.",
919
+ level=logging.WARNING,
920
+ )
921
+ continue
922
+
923
+ logprobs_list: c.Sequence[c.Sequence[tuple[str, float]]]
700
924
  if isinstance(logprobs_obj, ChoiceLogprobs):
701
- logprobs_list: list[list[tuple[str, float]]] = [
925
+ logprobs_list = [
702
926
  [
703
927
  (top_logprob.token, top_logprob.logprob)
704
928
  for top_logprob in content.top_logprobs
705
929
  ]
706
- for content in model_response_choices.logprobs.content or list()
930
+ for content in logprobs_obj.content or list()
707
931
  ]
708
- scores.append(logprobs_list)
709
932
  else:
710
- log_once(
711
- "The logprobs object is malformed, so we won't use logprobs to "
712
- "determine the labels.",
713
- level=logging.WARNING,
714
- )
933
+ logprobs_list = [
934
+ [
935
+ (token, logprob)
936
+ for token, logprob in (top_logprobs_dct or dict()).items()
937
+ ]
938
+ for top_logprobs_dct in logprobs_obj.top_logprobs or list()
939
+ ]
940
+
941
+ # If the model outputted a JSON dictionary, we need to find the
942
+ # token index of the value within the dictionary, rather than the
943
+ # first token of the entire output
944
+ if generation_dct:
945
+ key_name = next(iter(generation_dct.keys()))
946
+ logprobs_list = [
947
+ lst
948
+ for lst in logprobs_list
949
+ if (
950
+ lst
951
+ and lst[0]
952
+ and (token := lst[0][0].strip(JSON_STRIP_CHARACTERS))
953
+ and not key_name.startswith(token)
954
+ )
955
+ ]
956
+
957
+ scores.append(logprobs_list)
715
958
 
716
959
  if not sequences:
717
- logger.warning(
960
+ log(
718
961
  "No sequences were generated by the model "
719
962
  f"{model_id!r}. This may be due to the "
720
963
  "model running out of tokens or an issue with the input data. "
721
- "Returning an empty GenerativeModelOutput."
964
+ "Returning an empty GenerativeModelOutput.",
965
+ level=logging.WARNING,
722
966
  )
723
967
  return GenerativeModelOutput(sequences=[], scores=None)
724
968
 
@@ -778,9 +1022,7 @@ class LiteLLMModel(BenchmarkModule):
778
1022
  repo_info = hf_api.model_info(
779
1023
  repo_id=model_id,
780
1024
  revision="main",
781
- token=os.getenv("HUGGINGFACE_API_KEY")
782
- or self.benchmark_config.api_key
783
- or True,
1025
+ token=get_hf_token(api_key=self.benchmark_config.api_key),
784
1026
  )
785
1027
  except (
786
1028
  RepositoryNotFoundError,
@@ -837,10 +1079,11 @@ class LiteLLMModel(BenchmarkModule):
837
1079
  run_with_cli=self.benchmark_config.run_with_cli,
838
1080
  )
839
1081
 
840
- tokenizer = load_tokenizer(
1082
+ tokeniser = load_tokeniser(
841
1083
  model=None,
842
1084
  model_id=model_id,
843
1085
  trust_remote_code=self.benchmark_config.trust_remote_code,
1086
+ model_config=self.model_config,
844
1087
  )
845
1088
 
846
1089
  if (
@@ -849,10 +1092,10 @@ class LiteLLMModel(BenchmarkModule):
849
1092
  ):
850
1093
  vocab_size = hf_config.vocab_size
851
1094
  elif (
852
- hasattr(tokenizer, "vocab_size")
853
- and tokenizer.vocab_size is not None
1095
+ hasattr(tokeniser, "vocab_size")
1096
+ and tokeniser.vocab_size is not None
854
1097
  ):
855
- vocab_size = tokenizer.vocab_size
1098
+ vocab_size = tokeniser.vocab_size
856
1099
  else:
857
1100
  vocab_size = -1
858
1101
  return vocab_size
@@ -883,13 +1126,15 @@ class LiteLLMModel(BenchmarkModule):
883
1126
  if context_length_keys:
884
1127
  context_length = model_info[context_length_keys[0]]
885
1128
  if context_length is not None:
886
- log_once(
887
- f"Detected context length key {context_length_keys[0]!r} "
888
- f"for Ollama model {ollama_model_id!r}",
889
- level=logging.DEBUG,
890
- )
1129
+ if self.log_metadata:
1130
+ log_once(
1131
+ f"Detected context length key "
1132
+ f"{context_length_keys[0]!r} for Ollama model "
1133
+ f"{ollama_model_id!r}",
1134
+ level=logging.DEBUG,
1135
+ )
891
1136
  return int(context_length)
892
- else:
1137
+ elif self.log_metadata:
893
1138
  log_once(
894
1139
  f"Tried to get the maximum length of the Ollama model "
895
1140
  f"{ollama_model_id!r}, but could not find a context length. "
@@ -917,26 +1162,27 @@ class LiteLLMModel(BenchmarkModule):
917
1162
  run_with_cli=self.benchmark_config.run_with_cli,
918
1163
  )
919
1164
 
920
- tokenizer = load_tokenizer(
1165
+ tokeniser = load_tokeniser(
921
1166
  model=None,
922
1167
  model_id=model_id,
923
1168
  trust_remote_code=self.benchmark_config.trust_remote_code,
1169
+ model_config=self.model_config,
924
1170
  )
925
1171
 
926
1172
  all_max_lengths: list[int] = list()
927
1173
 
928
- # Add the registered max length of the tokenizer
1174
+ # Add the registered max length of the tokeniser
929
1175
  if hasattr(
930
- tokenizer, "model_max_length"
931
- ) and tokenizer.model_max_length < int(1e30):
932
- all_max_lengths.append(tokenizer.model_max_length)
1176
+ tokeniser, "model_max_length"
1177
+ ) and tokeniser.model_max_length < int(1e30):
1178
+ all_max_lengths.append(tokeniser.model_max_length)
933
1179
 
934
1180
  # Add the max length derived from the model's input sizes
935
- if hasattr(tokenizer, "max_model_input_sizes"):
1181
+ if hasattr(tokeniser, "max_model_input_sizes"):
936
1182
  all_max_lengths.extend(
937
1183
  [
938
1184
  size
939
- for size in tokenizer.max_model_input_sizes.values()
1185
+ for size in tokeniser.max_model_input_sizes.values()
940
1186
  if size is not None
941
1187
  ]
942
1188
  )
@@ -970,7 +1216,7 @@ class LiteLLMModel(BenchmarkModule):
970
1216
  return -1
971
1217
 
972
1218
  @property
973
- def data_collator(self) -> c.Callable[[list[t.Any]], dict[str, t.Any]]:
1219
+ def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
974
1220
  """The data collator used to prepare samples during finetuning.
975
1221
 
976
1222
  Returns:
@@ -995,6 +1241,7 @@ class LiteLLMModel(BenchmarkModule):
995
1241
  return partial(
996
1242
  sequence_classification.extract_labels_from_generation,
997
1243
  dataset_config=self.dataset_config,
1244
+ model_config=self.model_config,
998
1245
  first_label_token_mapping=self.buffer["first_label_token_mapping"],
999
1246
  )
1000
1247
  case TaskGroup.TEXT_TO_TEXT:
@@ -1038,13 +1285,15 @@ class LiteLLMModel(BenchmarkModule):
1038
1285
  Whether the model exists, or an error describing why we cannot check
1039
1286
  whether the model exists.
1040
1287
  """
1041
- model_id, _ = model_id.split("@") if "@" in model_id else (model_id, "main")
1288
+ model_id = split_model_id(model_id=model_id).model_id
1042
1289
  if model_id in litellm.model_list:
1043
1290
  return True
1044
1291
 
1045
1292
  # Separate check for Ollama models
1046
1293
  if model_id.startswith("ollama/") or model_id.startswith("ollama_chat/"):
1047
- ollama_model_exists = try_download_ollama_model(model_id=model_id)
1294
+ ollama_model_exists = try_download_ollama_model(
1295
+ model_id=model_id, progress_bar=benchmark_config.progress_bar
1296
+ )
1048
1297
  if ollama_model_exists:
1049
1298
  return ollama_model_exists
1050
1299
 
@@ -1053,7 +1302,9 @@ class LiteLLMModel(BenchmarkModule):
1053
1302
  try:
1054
1303
  litellm.completion(
1055
1304
  messages=[dict(role="user", content="X")],
1056
- model=model_id,
1305
+ model=clean_model_id(
1306
+ model_id=model_id, benchmark_config=benchmark_config
1307
+ ),
1057
1308
  max_tokens=1,
1058
1309
  api_key=benchmark_config.api_key,
1059
1310
  api_base=benchmark_config.api_base,
@@ -1070,20 +1321,31 @@ class LiteLLMModel(BenchmarkModule):
1070
1321
  ServiceUnavailableError,
1071
1322
  InternalServerError,
1072
1323
  ) as e:
1073
- logger.debug(
1324
+ log(
1074
1325
  f"Service temporarily unavailable. The error message was: {e}. "
1075
- "Retrying in 10 seconds..."
1326
+ "Retrying in 10 seconds...",
1327
+ level=logging.DEBUG,
1076
1328
  )
1077
- sleep(5)
1329
+ sleep(10)
1078
1330
  except APIError as e:
1079
1331
  if "'503 Service Unavailable" not in str(e):
1080
1332
  raise e
1081
- logger.warning(
1333
+ log(
1082
1334
  f"Failed to check if model {model_id!r} exists. Retrying in 10 "
1083
- "seconds..."
1335
+ "seconds...",
1336
+ level=logging.WARNING,
1084
1337
  )
1085
1338
  sleep(10)
1086
1339
  except (BadRequestError, NotFoundError):
1340
+ # In case we're using `api_base`, try again with the `/v1` suffix
1341
+ if (
1342
+ benchmark_config.api_base is not None
1343
+ and not benchmark_config.api_base.endswith("/v1")
1344
+ ):
1345
+ benchmark_config.api_base += "/v1"
1346
+ continue
1347
+
1348
+ # Check for misspelled model IDs
1087
1349
  candidate_models = [
1088
1350
  candidate_model_id
1089
1351
  for candidate_model_id in litellm.model_list
@@ -1093,21 +1355,25 @@ class LiteLLMModel(BenchmarkModule):
1093
1355
  case 0:
1094
1356
  pass
1095
1357
  case 1:
1096
- logger.warning(
1358
+ log(
1097
1359
  f"Could not find the model ID {model_id!r}. Did you mean "
1098
- f"{candidate_models[0]!r}?"
1360
+ f"{candidate_models[0]!r}?",
1361
+ level=logging.WARNING,
1099
1362
  )
1100
1363
  case _:
1101
1364
  candidate_models_str = "', '".join(candidate_models)
1102
- logger.warning(
1365
+ log(
1103
1366
  f"Could not find the model ID {model_id!r}. Did you mean "
1104
- f"any of the following model IDs: '{candidate_models_str}'?"
1367
+ "any of the following model IDs: "
1368
+ f"'{candidate_models_str}'?",
1369
+ level=logging.WARNING,
1105
1370
  )
1106
1371
  return False
1107
1372
  else:
1108
- logger.error(
1373
+ log(
1109
1374
  f"Failed to check if model {model_id!r} exists after {num_attempts} "
1110
- "attempts. Assuming it does not exist."
1375
+ "attempts. Assuming it does not exist.",
1376
+ level=logging.ERROR,
1111
1377
  )
1112
1378
  return False
1113
1379
 
@@ -1126,10 +1392,30 @@ class LiteLLMModel(BenchmarkModule):
1126
1392
  Returns:
1127
1393
  The model configuration.
1128
1394
  """
1129
- model_id, revision = model_id.split("@") if "@" in model_id else (model_id, "")
1395
+ model_id_components = split_model_id(model_id=model_id)
1396
+
1397
+ # Backwards compatibility: If the revision is set but not the parameter, we
1398
+ # assume that the revision is actually the parameter and log this as a warning.
1399
+ if model_id_components.revision != "main" and model_id_components.param is None:
1400
+ proper_model_id = (
1401
+ f"{model_id_components.model_id}#{model_id_components.revision}"
1402
+ )
1403
+ log_once(
1404
+ f"The model ID {model_id!r} specifies a revision "
1405
+ f"{model_id_components.revision!r} but not a parameter. We assume "
1406
+ "that the revision is actually the parameter and set the revision "
1407
+ "to 'main'. In the future, use the new '#' syntax to specify the "
1408
+ f"parameter (in this case, this would be {proper_model_id!r}), as this "
1409
+ "will be an error in future versions of EuroEval.",
1410
+ level=logging.WARNING,
1411
+ )
1412
+ model_id_components.param = model_id_components.revision
1413
+ model_id_components.revision = "main"
1414
+
1130
1415
  return ModelConfig(
1131
- model_id=model_id,
1132
- revision=revision,
1416
+ model_id=model_id_components.model_id,
1417
+ revision=model_id_components.revision,
1418
+ param=model_id_components.param,
1133
1419
  task="text-generation",
1134
1420
  languages=list(),
1135
1421
  merge=False,
@@ -1184,7 +1470,10 @@ class LiteLLMModel(BenchmarkModule):
1184
1470
 
1185
1471
  if self.benchmark_config.few_shot:
1186
1472
  few_shot_examples = extract_few_shot_examples(
1187
- dataset=dataset, dataset_config=self.dataset_config, itr_idx=itr_idx
1473
+ dataset=dataset,
1474
+ dataset_config=self.dataset_config,
1475
+ benchmark_config=self.benchmark_config,
1476
+ itr_idx=itr_idx,
1188
1477
  )
1189
1478
  else:
1190
1479
  few_shot_examples = list()
@@ -1195,9 +1484,9 @@ class LiteLLMModel(BenchmarkModule):
1195
1484
  few_shot_examples=few_shot_examples,
1196
1485
  model_config=self.model_config,
1197
1486
  dataset_config=self.dataset_config,
1198
- instruction_model=True,
1487
+ generative_type=self.generative_type,
1199
1488
  always_populate_text_field=False,
1200
- tokenizer=None,
1489
+ tokeniser=None,
1201
1490
  ),
1202
1491
  batched=True,
1203
1492
  load_from_cache_file=False,
@@ -1206,46 +1495,169 @@ class LiteLLMModel(BenchmarkModule):
1206
1495
 
1207
1496
  return dataset
1208
1497
 
1498
+ @cache_arguments()
1499
+ def get_generation_kwargs(self, dataset_config: DatasetConfig) -> dict[str, t.Any]:
1500
+ """Get the generation arguments for the model.
1209
1501
 
1210
- def raise_if_wrong_params(
1211
- model_config: ModelConfig, allowed_params: dict[str, list[str]]
1212
- ) -> None:
1213
- """Raise an error if the model configuration has invalid parameters.
1502
+ Args:
1503
+ dataset_config:
1504
+ The dataset configuration, which is used to determine the generative
1505
+ type of the model. We use this as an argument here rather than using
1506
+ `self.dataset_config` to ensure that that the cache is updated when the
1507
+ dataset configuration changes.
1214
1508
 
1215
- Args:
1216
- model_config:
1217
- The model configuration.
1218
- allowed_params:
1219
- The allowed parameters for the model.
1509
+ Returns:
1510
+ The generation arguments for the model.
1511
+ """
1512
+ # Set the core generation arguments
1513
+ generation_kwargs: dict[str, t.Any] = dict(
1514
+ max_completion_tokens=(
1515
+ REASONING_MAX_TOKENS
1516
+ if self.generative_type == GenerativeType.REASONING
1517
+ else dataset_config.max_generated_tokens
1518
+ ),
1519
+ stop=[],
1520
+ temperature=0.0,
1521
+ seed=4242,
1522
+ api_key=self.benchmark_config.api_key,
1523
+ api_base=self.benchmark_config.api_base,
1524
+ api_version=self.benchmark_config.api_version,
1525
+ max_retries=3,
1526
+ )
1220
1527
 
1221
- Raises:
1222
- InvalidModel:
1223
- If the model configuration has invalid parameters.
1224
- """
1225
- param = model_config.revision
1226
- if param == "":
1227
- return
1228
- for model_regex, allowed_params_list in allowed_params.items():
1229
- if re.fullmatch(pattern=model_regex, string=model_config.model_id):
1230
- if param not in allowed_params_list:
1231
- msg = (
1232
- f"Invalid parameter {param!r} for model {model_config.model_id!r}."
1528
+ # Set up the `response_format` generation argument if we are dealing with a task
1529
+ # using structured generation
1530
+ if dataset_config.task.uses_structured_output:
1531
+ if self.generative_type == GenerativeType.REASONING:
1532
+ log_once(
1533
+ f"The model {self.model_config.model_id!r} is a reasoning model "
1534
+ "and thus does not support structured generation, so we do not "
1535
+ "enable it.",
1536
+ level=logging.DEBUG,
1233
1537
  )
1234
- if allowed_params_list:
1235
- msg += f" Allowed parameters are: {', '.join(allowed_params_list)}."
1538
+ elif supports_response_schema(model=self.model_config.model_id):
1539
+ if dataset_config.task == NER:
1540
+ ner_tag_names = list(dataset_config.prompt_label_mapping.values())
1541
+ keys_and_their_types: dict[str, t.Any] = {
1542
+ tag_name: (conlist(str, max_length=5), ...)
1543
+ for tag_name in ner_tag_names
1544
+ }
1545
+ pydantic_class = create_model(
1546
+ "AnswerFormat", **keys_and_their_types
1547
+ )
1236
1548
  else:
1237
- msg += " No parameters are allowed."
1238
- raise InvalidModel(msg)
1239
- return
1549
+ raise InvalidBenchmark(
1550
+ "This task requires structured generation, but it has not "
1551
+ "been implemented for this task yet. Please open an issue "
1552
+ "at https://github.com/EuroEval/EuroEval/issues."
1553
+ )
1554
+ generation_kwargs["response_format"] = pydantic_class
1555
+ log_once(
1556
+ "Enabling structured generation for model "
1557
+ f"{self.model_config.model_id!r} with the JSON schema "
1558
+ f"{pydantic_class.model_json_schema()}",
1559
+ level=logging.DEBUG,
1560
+ )
1561
+ else:
1562
+ generation_kwargs["response_format"] = dict(type="json_object")
1563
+ log_once(
1564
+ "Enabling structured JSON generation for model "
1565
+ f"{self.model_config.model_id!r} with no custom JSON schema, as "
1566
+ "the model does not support schemas.",
1567
+ level=logging.DEBUG,
1568
+ )
1569
+ elif self.dataset_config.task.uses_logprobs and self.dataset_config.labels:
1570
+ localised_labels = [
1571
+ self.dataset_config.prompt_label_mapping[label]
1572
+ for label in self.dataset_config.labels
1573
+ ]
1574
+ keys_and_their_types = {
1575
+ LITELLM_CLASSIFICATION_OUTPUT_KEY: (t.Literal[*localised_labels], ...)
1576
+ }
1577
+ pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
1578
+ generation_kwargs["response_format"] = pydantic_class
1240
1579
 
1580
+ # If the model is an Ollama reasoning model, we ensure that thinking is enabled
1581
+ if self.is_ollama and self.generative_type == GenerativeType.REASONING:
1582
+ generation_kwargs["think"] = True
1583
+ log_once(
1584
+ "Enabling thinking mode for Ollama model "
1585
+ f"{self.model_config.model_id!r}",
1586
+ level=logging.DEBUG,
1587
+ )
1241
1588
 
1242
- def try_download_ollama_model(model_id: str) -> bool:
1589
+ # Handle manually set parameters
1590
+ if self.buffer["first_label_token_mapping"]:
1591
+ generation_kwargs["logprobs"] = True
1592
+ generation_kwargs["top_logprobs"] = MAX_LITELLM_LOGPROBS
1593
+ if self.model_config.param == "thinking":
1594
+ generation_kwargs["thinking"] = dict(
1595
+ type="enabled", budget_tokens=REASONING_MAX_TOKENS - 1
1596
+ )
1597
+ log_once(
1598
+ f"Enabling thinking mode for model {self.model_config.model_id!r}",
1599
+ level=logging.DEBUG,
1600
+ )
1601
+ elif self.model_config.param == "no-thinking":
1602
+ generation_kwargs["thinking"] = dict(budget_tokens=0)
1603
+ log_once(
1604
+ f"Disabling thinking mode for model {self.model_config.model_id!r}",
1605
+ level=logging.DEBUG,
1606
+ )
1607
+ elif self.model_config.param in {"minimal", "low", "medium", "high"}:
1608
+ generation_kwargs["reasoning_effort"] = self.model_config.param
1609
+ log_once(
1610
+ f"Enabling reasoning effort {self.model_config.param!r} for model "
1611
+ f"{self.model_config.model_id!r}",
1612
+ level=logging.DEBUG,
1613
+ )
1614
+
1615
+ # First attempt is a test run with a single conversation to handle errors
1616
+ # quickly. We repeat this multiple times to deal with different types of
1617
+ # errors, and stop if we get a successful response.
1618
+ test_input: c.Sequence[litellm.AllMessageValues] | str
1619
+ if self.generative_type == GenerativeType.BASE:
1620
+ test_input = "Test message"
1621
+ else:
1622
+ test_input = [
1623
+ litellm.ChatCompletionUserMessage(role="user", content="Test message")
1624
+ ]
1625
+ for _ in range(num_attempts := 10):
1626
+ _, failures = safe_run(
1627
+ self._generate_async(
1628
+ model_id=clean_model_id(
1629
+ model_id=self.model_config.model_id,
1630
+ benchmark_config=self.benchmark_config,
1631
+ ),
1632
+ inputs=[test_input],
1633
+ max_concurrent_calls=1,
1634
+ **generation_kwargs,
1635
+ )
1636
+ )
1637
+ if not failures:
1638
+ break
1639
+ for _, error in failures:
1640
+ generation_kwargs = self._handle_exception(
1641
+ error=error, **generation_kwargs
1642
+ )
1643
+ else:
1644
+ raise InvalidModel(
1645
+ "Failed to get a successful response from the model "
1646
+ f"{self.model_config.model_id!r} after {num_attempts} attempts."
1647
+ )
1648
+
1649
+ return generation_kwargs
1650
+
1651
+
1652
+ def try_download_ollama_model(model_id: str, progress_bar: bool) -> bool:
1243
1653
  """Try to download an Ollama model.
1244
1654
 
1245
1655
  Args:
1246
1656
  model_id:
1247
1657
  The model ID. If the model does not start with "ollama/" or "ollama_chat/"
1248
1658
  then this function will return False.
1659
+ progress_bar:
1660
+ Whether to show a progress bar while downloading the model.
1249
1661
 
1250
1662
  Returns:
1251
1663
  Whether the model was downloaded successfully.
@@ -1268,16 +1680,16 @@ def try_download_ollama_model(model_id: str) -> bool:
1268
1680
  )
1269
1681
 
1270
1682
  try:
1271
- downloaded_ollama_models: list[str] = [
1683
+ downloaded_ollama_models: c.Sequence[str] = [
1272
1684
  model_obj.model
1273
1685
  for model_obj in ollama.list().models
1274
1686
  if model_obj.model is not None
1275
1687
  ]
1276
- except ConnectionError:
1688
+ except ConnectionError as e:
1277
1689
  raise InvalidModel(
1278
1690
  "Ollama does not seem to be running, so we cannot evaluate the model "
1279
1691
  f"{model_id!r}. Please make sure that Ollama is running and try again."
1280
- )
1692
+ ) from e
1281
1693
 
1282
1694
  ollama_model_id = "/".join(model_id.split("/")[1:])
1283
1695
  if ollama_model_id not in downloaded_ollama_models:
@@ -1297,7 +1709,8 @@ def try_download_ollama_model(model_id: str) -> bool:
1297
1709
  f"The model {model_id!r} cannot be found on Ollama, but the "
1298
1710
  f"model {model_id_with_prefix} *was* found, so we would "
1299
1711
  "recommend you cancelling this run and trying the evaluation "
1300
- "with that model ID instead."
1712
+ "with that model ID instead.",
1713
+ level=logging.WARNING,
1301
1714
  )
1302
1715
  return False
1303
1716
  except ollama.ResponseError as inner_e:
@@ -1307,19 +1720,19 @@ def try_download_ollama_model(model_id: str) -> bool:
1307
1720
  raise InvalidModel(
1308
1721
  f"Failed to download Ollama model {ollama_model_id}. "
1309
1722
  f"The error message was: {inner_e}"
1310
- )
1723
+ ) from inner_e
1311
1724
  else:
1312
1725
  raise InvalidModel(
1313
1726
  f"Failed to download Ollama model {ollama_model_id}. "
1314
1727
  f"The error message was: {e}"
1315
- )
1728
+ ) from e
1316
1729
 
1317
1730
  # Download the model
1318
- with tqdm(
1731
+ with get_pbar(
1319
1732
  desc=f"Downloading {ollama_model_id}",
1320
1733
  unit_scale=True,
1321
1734
  unit="B",
1322
- leave=False,
1735
+ disable=not progress_bar,
1323
1736
  ) as pbar:
1324
1737
  for status in response:
1325
1738
  if status.total is not None:
@@ -1335,3 +1748,30 @@ def try_download_ollama_model(model_id: str) -> bool:
1335
1748
  level=logging.DEBUG,
1336
1749
  )
1337
1750
  return True
1751
+
1752
+
1753
+ def clean_model_id(model_id: str, benchmark_config: BenchmarkConfig) -> str:
1754
+ """Clean a model ID.
1755
+
1756
+ This adds the default `openai/` prefix to the model ID if we're benchmarking a
1757
+ custom API inference server and no prefix is used, just to make it more
1758
+ convenient for the user.
1759
+
1760
+ Args:
1761
+ model_id:
1762
+ The model ID.
1763
+ benchmark_config:
1764
+ The benchmark configuration.
1765
+
1766
+ Returns:
1767
+ The cleaned model ID.
1768
+ """
1769
+ if benchmark_config.api_base is not None and not any(
1770
+ model_id.startswith(prefix) for prefix in CUSTOM_INFERENCE_API_PREFIXES
1771
+ ):
1772
+ if benchmark_config.generative_type == GenerativeType.BASE:
1773
+ prefix = "text-completion-openai/"
1774
+ else:
1775
+ prefix = "openai/"
1776
+ model_id = prefix + model_id
1777
+ return model_id