EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -5,8 +5,8 @@ import contextlib
5
5
  import importlib.util
6
6
  import json
7
7
  import logging
8
- import os
9
8
  import re
9
+ import shutil
10
10
  import typing as t
11
11
  from functools import partial
12
12
  from pathlib import Path
@@ -15,23 +15,22 @@ from time import sleep
15
15
  import torch
16
16
  from huggingface_hub import snapshot_download
17
17
  from pydantic import conlist, create_model
18
- from tqdm.auto import tqdm
19
18
  from transformers.models.auto.configuration_auto import AutoConfig
20
19
  from transformers.models.auto.tokenization_auto import AutoTokenizer
20
+ from transformers.tokenization_mistral_common import MistralCommonTokenizer
21
21
  from urllib3.exceptions import RequestError
22
22
 
23
23
  from ..constants import (
24
24
  CUSTOM_STOP_TOKENS,
25
25
  GENERATIVE_PIPELINE_TAGS,
26
26
  MAX_CONTEXT_LENGTH,
27
- MAX_LOGPROBS,
27
+ MAX_VLLM_LOGPROBS,
28
28
  MERGE_TAGS,
29
29
  REASONING_MAX_TOKENS,
30
30
  REASONING_TOKENS,
31
- TASKS_USING_JSON,
32
31
  VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY,
33
32
  )
34
- from ..data_models import GenerativeModelOutput, ModelConfig
33
+ from ..data_models import GenerativeModelOutput, HashableDict, ModelConfig
35
34
  from ..enums import (
36
35
  BatchingPreference,
37
36
  GenerativeType,
@@ -44,29 +43,41 @@ from ..exceptions import (
44
43
  InvalidModel,
45
44
  NeedsEnvironmentVariable,
46
45
  NeedsExtraInstalled,
46
+ NeedsSystemDependency,
47
+ )
48
+ from ..generation_utils import (
49
+ apply_prompt,
50
+ extract_few_shot_examples,
51
+ raise_if_wrong_params,
47
52
  )
48
- from ..generation_utils import apply_prompt, extract_few_shot_examples
49
53
  from ..languages import get_all_languages
54
+ from ..logging_utils import get_pbar, log, log_once, no_terminal_output
50
55
  from ..task_group_utils import (
51
56
  question_answering,
52
57
  sequence_classification,
53
58
  text_to_text,
54
59
  token_classification,
55
60
  )
56
- from ..tokenization_utils import (
61
+ from ..tokenisation_utils import (
62
+ apply_chat_template,
57
63
  get_bos_token,
58
64
  get_end_of_chat_token_ids,
59
65
  get_eos_token,
60
66
  get_first_label_token_mapping,
61
67
  get_pad_token,
68
+ has_chat_template,
62
69
  should_prompts_be_stripped,
63
70
  )
64
71
  from ..types import ExtractLabelsFunction
65
72
  from ..utils import (
66
73
  clear_memory,
67
74
  create_model_cache_dir,
75
+ flash_attention_backend,
76
+ get_hf_token,
68
77
  get_min_cuda_compute_capability,
69
- log_once,
78
+ internet_connection_available,
79
+ resolve_model_path,
80
+ split_model_id,
70
81
  )
71
82
  from .hf import HuggingFaceEncoderModel, get_model_repo_info, load_hf_model_config
72
83
 
@@ -77,13 +88,7 @@ if t.TYPE_CHECKING or importlib.util.find_spec("vllm") is not None:
77
88
  destroy_model_parallel,
78
89
  )
79
90
  from vllm.lora.request import LoRARequest
80
-
81
- if t.TYPE_CHECKING or importlib.util.find_spec("outlines") is not None:
82
- from outlines.models.vllm import adapt_tokenizer
83
- from outlines.processors.structured import JSONLogitsProcessor
84
-
85
- if t.TYPE_CHECKING or importlib.util.find_spec("ray") is not None:
86
- import ray
91
+ from vllm.sampling_params import StructuredOutputsParams
87
92
 
88
93
  if t.TYPE_CHECKING:
89
94
  from datasets import DatasetDict
@@ -92,7 +97,10 @@ if t.TYPE_CHECKING:
92
97
 
93
98
  from ..data_models import BenchmarkConfig, DatasetConfig, Task
94
99
 
95
- logger = logging.getLogger("euroeval")
100
+
101
+ MODELS_REQUIRING_FLASH_ATTENTION: list[re.Pattern] = [
102
+ re.compile(r".*gpt-oss.*", flags=re.IGNORECASE)
103
+ ]
96
104
 
97
105
 
98
106
  class VLLMModel(HuggingFaceEncoderModel):
@@ -101,12 +109,17 @@ class VLLMModel(HuggingFaceEncoderModel):
101
109
  fresh_model = False
102
110
  batching_preference = BatchingPreference.ALL_AT_ONCE
103
111
  high_priority = True
112
+ allowed_params = {
113
+ re.compile(r".*"): ["thinking", "no-thinking", "slow-tokenizer"],
114
+ re.compile(r".*gpt-oss.*", flags=re.IGNORECASE): ["low", "medium", "high"],
115
+ }
104
116
 
105
117
  def __init__(
106
118
  self,
107
119
  model_config: "ModelConfig",
108
120
  dataset_config: "DatasetConfig",
109
121
  benchmark_config: "BenchmarkConfig",
122
+ log_metadata: bool = True,
110
123
  ) -> None:
111
124
  """Initialise the vLLM model.
112
125
 
@@ -117,30 +130,40 @@ class VLLMModel(HuggingFaceEncoderModel):
117
130
  The dataset configuration.
118
131
  benchmark_config:
119
132
  The benchmark configuration.
133
+ log_metadata:
134
+ Whether to log the model and dataset metadata.
120
135
  """
121
- if (
122
- importlib.util.find_spec("vllm") is None
123
- or importlib.util.find_spec("ray") is None
124
- ):
136
+ if importlib.util.find_spec("vllm") is None:
125
137
  raise NeedsExtraInstalled(extra="generative")
126
138
 
127
- model, tokenizer = load_model_and_tokenizer(
128
- model_config=model_config, benchmark_config=benchmark_config
139
+ if shutil.which("nvcc") is None:
140
+ raise NeedsSystemDependency(
141
+ dependency="nvcc",
142
+ instructions=(
143
+ "Please install the CUDA Toolkit from "
144
+ "https://developer.nvidia.com/cuda-downloads or ensure that NVCC "
145
+ "is available in your PATH."
146
+ ),
147
+ )
148
+
149
+ raise_if_wrong_params(
150
+ model_config=model_config, allowed_params=self.allowed_params
129
151
  )
152
+
153
+ with (
154
+ no_terminal_output(disable=benchmark_config.verbose),
155
+ flash_attention_backend(
156
+ disabled=all(
157
+ not re.search(pattern=pattern, string=model_config.model_id)
158
+ for pattern in MODELS_REQUIRING_FLASH_ATTENTION
159
+ )
160
+ ),
161
+ ):
162
+ model, tokeniser = load_model_and_tokeniser(
163
+ model_config=model_config, benchmark_config=benchmark_config
164
+ )
130
165
  self._model: "LLM" = model
131
- self._tokenizer: "PreTrainedTokenizer" = tokenizer
132
- self.end_of_reasoning_token = get_end_of_reasoning_token(
133
- model=self._model, tokenizer=self._tokenizer, model_id=model_config.model_id
134
- )
135
- self.end_of_chat_token_ids = get_end_of_chat_token_ids(
136
- tokenizer=self._tokenizer
137
- )
138
- self.custom_stop_tokens = get_custom_stop_tokens(
139
- model=self._model,
140
- tokenizer=self._tokenizer,
141
- model_id=model_config.model_id,
142
- is_reasoning_model=self.end_of_reasoning_token is not None,
143
- )
166
+ self._tokeniser: "PreTrainedTokenizer" = tokeniser
144
167
 
145
168
  # We specify `HuggingFaceEncoderModel` here instead of `VLLMModel`, as we want
146
169
  # to call the `__init__` method of the `BenchmarkModule` class.
@@ -148,16 +171,30 @@ class VLLMModel(HuggingFaceEncoderModel):
148
171
  model_config=model_config,
149
172
  dataset_config=dataset_config,
150
173
  benchmark_config=benchmark_config,
174
+ log_metadata=log_metadata,
175
+ )
176
+
177
+ self.end_of_reasoning_token = get_end_of_reasoning_token(
178
+ model=self._model, tokeniser=self._tokeniser, model_config=model_config
179
+ )
180
+ self.end_of_chat_token_ids = get_end_of_chat_token_ids(
181
+ tokeniser=self._tokeniser, generative_type=self.generative_type
182
+ )
183
+ self.custom_stop_tokens = get_custom_stop_tokens(
184
+ model=self._model,
185
+ tokeniser=self._tokeniser,
186
+ model_id=model_config.model_id,
187
+ generative_type=self.generative_type,
151
188
  )
152
189
 
153
190
  self.buffer |= dict(
154
- instruction_model=self._tokenizer.chat_template is not None,
155
191
  first_label_token_mapping=get_first_label_token_mapping(
156
192
  dataset_config=self.dataset_config,
157
193
  model_config=self.model_config,
158
- tokenizer=self._tokenizer,
194
+ tokeniser=self._tokeniser,
159
195
  generative_type=self.generative_type,
160
- ),
196
+ log_metadata=self.log_metadata,
197
+ )
161
198
  )
162
199
  if self.model_config.adapter_base_model_id is not None:
163
200
  adapter_path = snapshot_download(
@@ -170,12 +207,16 @@ class VLLMModel(HuggingFaceEncoderModel):
170
207
  )
171
208
 
172
209
  def __del__(self) -> None:
173
- """Clean up the model and tokenizer."""
174
- clear_vllm()
210
+ """Clean up the model and tokeniser."""
211
+ try:
212
+ if importlib.util.find_spec("vllm") is not None:
213
+ clear_vllm()
214
+ except ImportError:
215
+ pass
175
216
  if hasattr(self, "_model"):
176
217
  del self._model
177
- if hasattr(self, "_tokenizer"):
178
- del self._tokenizer
218
+ if hasattr(self, "_tokeniser"):
219
+ del self._tokeniser
179
220
 
180
221
  @property
181
222
  def generative_type(self) -> GenerativeType | None:
@@ -184,17 +225,37 @@ class VLLMModel(HuggingFaceEncoderModel):
184
225
  Returns:
185
226
  The generative type of the model, or None if it has not been set yet.
186
227
  """
187
- if not hasattr(self, "_tokenizer"):
228
+ if not hasattr(self, "_tokeniser"):
229
+ log_once(
230
+ "The generative type of the model has not been set yet as the "
231
+ "tokeniser has not been loaded.",
232
+ level=logging.DEBUG,
233
+ )
188
234
  return None
189
- elif self.end_of_reasoning_token is not None:
190
- return GenerativeType.REASONING
235
+ elif self.benchmark_config.generative_type is not None:
236
+ type_ = self.benchmark_config.generative_type
237
+ elif self.model_config.param in {"thinking"}:
238
+ type_ = GenerativeType.REASONING
239
+ elif self.model_config.param in {"no-thinking"}:
240
+ type_ = GenerativeType.INSTRUCTION_TUNED
241
+ elif (
242
+ hasattr(self, "end_of_reasoning_token")
243
+ and self.end_of_reasoning_token is not None
244
+ ):
245
+ type_ = GenerativeType.REASONING
191
246
  elif (
192
- self._tokenizer.chat_template is not None
247
+ has_chat_template(tokeniser=self._tokeniser)
193
248
  or "instruct" in self.model_config.model_id.lower()
194
249
  ):
195
- return GenerativeType.INSTRUCTION_TUNED
250
+ type_ = GenerativeType.INSTRUCTION_TUNED
196
251
  else:
197
- return GenerativeType.BASE
252
+ type_ = GenerativeType.BASE
253
+ log_once(
254
+ f"Detected generative type {type_.name!r} for model "
255
+ f"{self.model_config.model_id!r}",
256
+ level=logging.DEBUG,
257
+ )
258
+ return type_
198
259
 
199
260
  @property
200
261
  def extract_labels_from_generation(self) -> ExtractLabelsFunction:
@@ -211,6 +272,7 @@ class VLLMModel(HuggingFaceEncoderModel):
211
272
  return partial(
212
273
  sequence_classification.extract_labels_from_generation,
213
274
  dataset_config=self.dataset_config,
275
+ model_config=self.model_config,
214
276
  first_label_token_mapping=self.buffer["first_label_token_mapping"],
215
277
  )
216
278
  case TaskGroup.TEXT_TO_TEXT:
@@ -269,7 +331,10 @@ class VLLMModel(HuggingFaceEncoderModel):
269
331
 
270
332
  if self.benchmark_config.few_shot:
271
333
  few_shot_examples = extract_few_shot_examples(
272
- dataset=dataset, dataset_config=self.dataset_config, itr_idx=itr_idx
334
+ dataset=dataset,
335
+ dataset_config=self.dataset_config,
336
+ benchmark_config=self.benchmark_config,
337
+ itr_idx=itr_idx,
273
338
  )
274
339
  else:
275
340
  few_shot_examples = list()
@@ -280,9 +345,9 @@ class VLLMModel(HuggingFaceEncoderModel):
280
345
  few_shot_examples=few_shot_examples,
281
346
  model_config=self.model_config,
282
347
  dataset_config=self.dataset_config,
283
- instruction_model=self.buffer["instruction_model"],
348
+ generative_type=self.generative_type,
284
349
  always_populate_text_field=True,
285
- tokenizer=self._tokenizer,
350
+ tokeniser=self._tokeniser,
286
351
  ),
287
352
  batched=True,
288
353
  load_from_cache_file=False,
@@ -300,68 +365,111 @@ class VLLMModel(HuggingFaceEncoderModel):
300
365
 
301
366
  Returns:
302
367
  The generated model outputs.
368
+
369
+ Raises:
370
+ InvalidBenchmark:
371
+ If the dataset requires logprobs, but we could not get the first token
372
+ of each label in the dataset.
303
373
  """
304
374
  # Get stopping tokens
305
375
  stop_tokens: list[str] = self.custom_stop_tokens.copy()
306
- if self.buffer["instruction_model"] is False:
376
+ if self.generative_type == GenerativeType.BASE:
307
377
  stop_tokens.append("\n\n")
308
- if self._tokenizer.pad_token_id is not None:
309
- assert isinstance(self._tokenizer.pad_token, str), (
378
+ if self._tokeniser.pad_token_id is not None:
379
+ assert isinstance(self._tokeniser.pad_token, str), (
310
380
  f"The pad token for the model {self.model_config.model_id!r} "
311
- f"is not a string, which is unexpected: {self._tokenizer.pad_token!r}."
381
+ f"is not a string, which is unexpected: {self._tokeniser.pad_token!r}."
312
382
  )
313
- stop_tokens.append(self._tokenizer.pad_token)
314
- if self._tokenizer.eos_token_id is not None:
315
- assert isinstance(self._tokenizer.eos_token, str), (
383
+ stop_tokens.append(self._tokeniser.pad_token)
384
+ if self._tokeniser.eos_token_id is not None:
385
+ assert isinstance(self._tokeniser.eos_token, str), (
316
386
  f"The EOS token for the model {self.model_config.model_id!r} "
317
- f"is not a string, which is unexpected: {self._tokenizer.eos_token!r}."
387
+ f"is not a string, which is unexpected: {self._tokeniser.eos_token!r}."
318
388
  )
319
- stop_tokens.append(self._tokenizer.eos_token)
320
- if self._tokenizer.pad_token_id is None:
321
- self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
322
- self._tokenizer.pad_token = self._tokenizer.eos_token
389
+ stop_tokens.append(self._tokeniser.eos_token)
390
+ if self._tokeniser.pad_token_id is None:
391
+ self._tokeniser.pad_token_id = self._tokeniser.eos_token_id
392
+ self._tokeniser.pad_token = self._tokeniser.eos_token
323
393
  if self.end_of_chat_token_ids is not None:
324
- end_of_chat_token = self._tokenizer.decode(
394
+ end_of_chat_token = self._tokeniser.decode(
325
395
  self.end_of_chat_token_ids
326
396
  ).strip()
327
397
  if end_of_chat_token:
328
398
  stop_tokens.append(end_of_chat_token)
329
399
 
330
- logits_processor = None
331
- if self.dataset_config.task in TASKS_USING_JSON:
332
- if self.generative_type == GenerativeType.REASONING:
333
- log_once(
334
- f"The model {self.model_config.model_id!r} is a reasoning model "
335
- "and thus does not support structured generation, so we do not "
336
- "enable it.",
337
- level=logging.DEBUG,
338
- )
339
- else:
340
- ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
341
- keys_and_their_types: dict[str, t.Any] = {
342
- tag_name: (conlist(str, max_length=5), ...)
343
- for tag_name in ner_tag_names
344
- }
345
- pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
346
- logits_processor = JSONLogitsProcessor(
347
- schema=pydantic_class,
348
- tokenizer=adapt_tokenizer(tokenizer=self._tokenizer), # type: ignore
349
- whitespace_pattern=r" ?",
350
- )
351
- log_once(
352
- "Using structured generation with the JSON schema "
353
- f"{pydantic_class.model_json_schema()}",
354
- level=logging.DEBUG,
355
- )
356
-
357
400
  # Get the mapping from labels to the first token in the label. We call this each
358
401
  # time we generate a new dataset since the dataset config can change
359
402
  self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
360
403
  dataset_config=self.dataset_config,
361
404
  model_config=self.model_config,
362
- tokenizer=self._tokenizer,
405
+ tokeniser=self._tokeniser,
363
406
  generative_type=self.generative_type,
407
+ log_metadata=self.log_metadata,
364
408
  )
409
+ if (
410
+ not self.buffer["first_label_token_mapping"]
411
+ and self.dataset_config.task.requires_logprobs
412
+ ):
413
+ raise InvalidBenchmark(
414
+ "The dataset requires logprobs, but we encountered an error when "
415
+ "trying to get the first token of each label in the dataset. You can "
416
+ "try running this benchmark with the --verbose flag to see what the "
417
+ "error was. Skipping this evaluation."
418
+ )
419
+
420
+ structured_generation_schema = None
421
+ if (
422
+ self.dataset_config.task.uses_structured_output
423
+ or (self.dataset_config.task.uses_logprobs and self.dataset_config.labels)
424
+ ) and self.generative_type == GenerativeType.REASONING:
425
+ structured_outputs = None
426
+ log_once(
427
+ "The dataset uses structured output, but we are not using it as the "
428
+ f"model {self.model_config.model_id!r} is a reasoning model.",
429
+ level=logging.DEBUG,
430
+ )
431
+ elif self.dataset_config.task.uses_structured_output:
432
+ ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
433
+ keys_and_their_types: dict[str, t.Any] = {
434
+ tag_name: (conlist(str, max_length=5), ...)
435
+ for tag_name in ner_tag_names
436
+ }
437
+ answer_format_class = create_model("AnswerFormat", **keys_and_their_types)
438
+ structured_generation_schema = answer_format_class.model_json_schema()
439
+ log_once(
440
+ "Using structured generation with the JSON schema: "
441
+ f"{json.dumps(structured_generation_schema)}",
442
+ level=logging.DEBUG,
443
+ )
444
+ structured_outputs = StructuredOutputsParams(
445
+ json=structured_generation_schema
446
+ )
447
+ elif (
448
+ self.dataset_config.task.uses_logprobs
449
+ and self.dataset_config.labels
450
+ and self.buffer.get("first_label_token_mapping", False)
451
+ ):
452
+ choice_labels = [
453
+ self.dataset_config.prompt_label_mapping[label]
454
+ for label in self.dataset_config.labels
455
+ ]
456
+ if isinstance(self.buffer["first_label_token_mapping"], dict):
457
+ choice_labels = [
458
+ self.buffer["first_label_token_mapping"][label]
459
+ for label in choice_labels
460
+ ]
461
+ structured_outputs = StructuredOutputsParams(choice=choice_labels)
462
+ log_once(
463
+ "Using structured generation with the choices: "
464
+ f"{structured_outputs.choice!r}.",
465
+ level=logging.DEBUG,
466
+ )
467
+ else:
468
+ structured_outputs = None
469
+ log_once(
470
+ "Not using structured generation as the dataset does not require it.",
471
+ level=logging.DEBUG,
472
+ )
365
473
 
366
474
  # Define the parameters used for vLLM generation
367
475
  max_tokens: int = (
@@ -371,19 +479,21 @@ class VLLMModel(HuggingFaceEncoderModel):
371
479
  )
372
480
  sampling_params = SamplingParams(
373
481
  max_tokens=max_tokens,
374
- logprobs=MAX_LOGPROBS if self.buffer["first_label_token_mapping"] else None,
482
+ logprobs=MAX_VLLM_LOGPROBS
483
+ if self.buffer["first_label_token_mapping"]
484
+ else None,
375
485
  temperature=0.0,
376
486
  stop=[stop_token for stop_token in stop_tokens if stop_token],
377
- logits_processors=[logits_processor] if logits_processor else None,
487
+ structured_outputs=structured_outputs,
378
488
  )
379
489
 
380
490
  # If any of the prompts are empty then we need to replace them with a BOS token
381
491
  # so that the vLLM model can generate from them
382
- prompts: list[str] = inputs["text"]
492
+ prompts: c.Sequence[str] = inputs["text"]
383
493
  if any(len(prompt) == 0 for prompt in prompts):
384
- logger.debug("Found empty prompts, replacing with BOS token.")
494
+ log("Found empty prompts, replacing with BOS token.", level=logging.DEBUG)
385
495
  prompts = [
386
- prompt if len(prompt) > 0 else str(self._tokenizer.bos_token)
496
+ prompt if len(prompt) > 0 else str(self._tokeniser.bos_token)
387
497
  for prompt in prompts
388
498
  ]
389
499
 
@@ -391,10 +501,8 @@ class VLLMModel(HuggingFaceEncoderModel):
391
501
  labels_to_be_generated = list(self.dataset_config.prompt_label_mapping.values())
392
502
  if len(labels_to_be_generated) == 0:
393
503
  labels_to_be_generated = ["negative", "positive"]
394
- if not self.buffer.get(
395
- "instruction_model", False
396
- ) and should_prompts_be_stripped(
397
- labels_to_be_generated=labels_to_be_generated, tokenizer=self._tokenizer
504
+ if self.generative_type == GenerativeType.BASE and should_prompts_be_stripped(
505
+ labels_to_be_generated=labels_to_be_generated, tokeniser=self._tokeniser
398
506
  ):
399
507
  log_once(
400
508
  f"Stripping prompts for model {self.model_config.model_id!r}.",
@@ -402,21 +510,35 @@ class VLLMModel(HuggingFaceEncoderModel):
402
510
  )
403
511
  prompts = [prompt.strip() for prompt in prompts]
404
512
 
513
+ # Truncate the prompts if needed, but only if it's not a reasoning model
514
+ if self.generative_type != GenerativeType.REASONING:
515
+ max_tokens_per_prompt = (
516
+ min(self._tokeniser.model_max_length, MAX_CONTEXT_LENGTH) - max_tokens
517
+ )
518
+ tokenized_prompts = self._tokeniser(
519
+ text=list(prompts), truncation=True, max_length=max_tokens_per_prompt
520
+ )
521
+ prompts = self._tokeniser.batch_decode(
522
+ sequences=tokenized_prompts.input_ids, skip_special_tokens=True
523
+ )
524
+
405
525
  # Generate sequences using vLLM
406
526
  input_is_a_test = len(prompts) == 1 and len(set(prompts[0])) == 1
407
527
  num_attempts = 3
528
+ truncation_attempts = 1
408
529
  for _ in range(num_attempts):
409
530
  try:
410
531
  raw_outputs = self._model.generate(
411
532
  prompts=prompts,
412
533
  sampling_params=sampling_params,
413
- use_tqdm=False if input_is_a_test else get_pbar_without_leave,
534
+ use_tqdm=False if input_is_a_test else get_pbar,
414
535
  lora_request=self.buffer.get("lora_request"),
415
536
  )
416
537
  break
417
538
  except TypeError as e:
418
- logger.debug(
419
- f"Encountered error during vLLM generation: {str(e)}. Retrying..."
539
+ log(
540
+ f"Encountered error during vLLM generation: {str(e)}. Retrying...",
541
+ level=logging.DEBUG,
420
542
  )
421
543
  sleep(1)
422
544
  except ValueError as e:
@@ -428,26 +550,34 @@ class VLLMModel(HuggingFaceEncoderModel):
428
550
  re.search(pattern, str(e), flags=re.IGNORECASE) is not None
429
551
  for pattern in truncate_error_messages
430
552
  ):
431
- logger.info(
432
- "Prompts are too long, so truncating them and trying again..."
553
+ log(
554
+ "Prompts are too long, so truncating them and trying again...",
555
+ level=logging.WARNING,
433
556
  )
434
- logger.debug(f"The error message was: {str(e)}")
435
- tokenized_prompts = self._tokenizer(
557
+ log(f"The error message was: {str(e)}", level=logging.DEBUG)
558
+
559
+ # If we have already tried truncating the prompts a few times, then
560
+ # we truncate a bit more aggressively
561
+ extra_truncation = 50 * truncation_attempts
562
+ truncation_attempts += 1
563
+
564
+ tokenized_prompts = self._tokeniser(
436
565
  text=prompts,
437
566
  truncation=True,
438
567
  max_length=max(
439
- min(self._tokenizer.model_max_length, MAX_CONTEXT_LENGTH)
440
- - max_tokens,
568
+ min(self._tokeniser.model_max_length, MAX_CONTEXT_LENGTH)
569
+ - max_tokens
570
+ - extra_truncation,
441
571
  0,
442
572
  ),
443
573
  )
444
- prompts = self._tokenizer.batch_decode(
574
+ prompts = self._tokeniser.batch_decode(
445
575
  sequences=tokenized_prompts.input_ids, skip_special_tokens=True
446
576
  )
447
577
  else:
448
578
  raise InvalidBenchmark(
449
579
  f"An error occurred during vLLM generation: {str(e)}"
450
- )
580
+ ) from e
451
581
  else:
452
582
  raise InvalidBenchmark(
453
583
  f"Could not generate sequences after {num_attempts} attempts."
@@ -467,34 +597,73 @@ class VLLMModel(HuggingFaceEncoderModel):
467
597
  f"{num_extra_outputs!r} extra outputs."
468
598
  )
469
599
  else:
470
- logger.debug(
600
+ log(
471
601
  f"Filtered out {num_extra_outputs:,} extra outputs from the model, "
472
602
  "which occured as we interupted the generation when we truncated "
473
- "the prompts."
603
+ "the prompts.",
604
+ level=logging.DEBUG,
474
605
  )
475
606
 
476
- # Parse the raw model outputs
477
- completion_ids: list[list[int]] = [
478
- output.outputs[0].token_ids for output in raw_outputs
607
+ # Parse the raw model outputs. We keep the special tokens for now, as we need
608
+ # them to potentially remove reasoning content and stop tokens
609
+ completion_ids: c.Sequence[c.Sequence[int]] = [
610
+ list(output.outputs[0].token_ids) for output in raw_outputs
479
611
  ]
480
- completions = self._tokenizer.batch_decode(
612
+ completions = self._tokeniser.batch_decode(
481
613
  sequences=[
482
614
  torch.LongTensor(completion_id) for completion_id in completion_ids
483
- ]
615
+ ],
616
+ skip_special_tokens=False,
484
617
  )
485
- if self.end_of_reasoning_token is not None:
486
- completions = [
487
- completion.split(self.end_of_reasoning_token)[-1]
488
- for completion in completions
489
- ]
618
+ if (
619
+ self.end_of_reasoning_token is not None
620
+ and self.generative_type == GenerativeType.REASONING
621
+ ):
622
+ num_samples_without_eor_token = 0
623
+ for idx in range(len(completions)):
624
+ if (
625
+ isinstance(self.end_of_reasoning_token, str)
626
+ and self.end_of_reasoning_token in completions[idx]
627
+ ):
628
+ completions[idx] = completions[idx].split(
629
+ self.end_of_reasoning_token
630
+ )[-1]
631
+ elif isinstance(
632
+ self.end_of_reasoning_token, re.Pattern
633
+ ) and self.end_of_reasoning_token.search(completions[idx]):
634
+ completions[idx] = self.end_of_reasoning_token.split(
635
+ completions[idx]
636
+ )[-1]
637
+ else:
638
+ num_samples_without_eor_token += 1
639
+ completions[idx] = ""
640
+ if num_samples_without_eor_token > 0:
641
+ log_once(
642
+ f"The model {self.model_config.model_id!r} is a reasoning "
643
+ "model, but the generated output did not contain the end of "
644
+ f"reasoning token ({self.end_of_reasoning_token!r}) in "
645
+ f"{num_samples_without_eor_token:,}/{len(completions):,} of "
646
+ "the samples. Using an empty string for all these samples "
647
+ "instead.",
648
+ level=(
649
+ logging.WARNING
650
+ if num_samples_without_eor_token / len(completions) > 0.5
651
+ else logging.DEBUG
652
+ ),
653
+ )
490
654
  stop_token_pattern = re.compile(
491
655
  "|".join(re.escape(stop_token) for stop_token in stop_tokens)
492
656
  )
493
657
  completions = [
494
- re.split(pattern=stop_token_pattern, string=completion)[0]
658
+ re.split(pattern=stop_token_pattern, string=completion)[0].strip()
495
659
  for completion in completions
496
660
  ]
497
- completions = [completion.strip() for completion in completions]
661
+
662
+ # Remove all the special tokens from the completions, if any are present
663
+ completion_ids = self._tokeniser(text=completions).input_ids
664
+ completions = self._tokeniser.batch_decode(
665
+ sequences=completion_ids, skip_special_tokens=True
666
+ )
498
667
 
499
668
  # Sanity check
500
669
  if len(completions) != len(prompts):
@@ -504,13 +673,13 @@ class VLLMModel(HuggingFaceEncoderModel):
504
673
 
505
674
  # Add logprobs scores to the output
506
675
  if self.buffer["first_label_token_mapping"]:
507
- scores: list[list[list[tuple[str, float]]]] = [
676
+ scores: c.Sequence[c.Sequence[c.Sequence[tuple[str, float]]]] = [
508
677
  [
509
678
  [
510
- (obj.decoded_token, obj.logprob)
679
+ (obj.decoded_token or "", obj.logprob)
511
680
  for obj in token_logprobs_dict.values()
512
681
  ]
513
- for token_logprobs_dict in raw_output.outputs[0].logprobs
682
+ for token_logprobs_dict in raw_output.outputs[0].logprobs or list()
514
683
  ]
515
684
  for raw_output in raw_outputs
516
685
  ]
@@ -543,11 +712,18 @@ class VLLMModel(HuggingFaceEncoderModel):
543
712
  if using_api:
544
713
  return False
545
714
 
546
- model_id, revision = (
547
- model_id.split("@") if "@" in model_id else (model_id, "main")
548
- )
715
+ model_id_components = split_model_id(model_id=model_id)
716
+ model_id = model_id_components.model_id
717
+ revision = model_id_components.revision
718
+
549
719
  model_info = get_model_repo_info(
550
- model_id=model_id, revision=revision, benchmark_config=benchmark_config
720
+ model_id=model_id,
721
+ revision=revision,
722
+ api_key=benchmark_config.api_key,
723
+ cache_dir=benchmark_config.cache_dir,
724
+ trust_remote_code=benchmark_config.trust_remote_code,
725
+ requires_safetensors=benchmark_config.requires_safetensors,
726
+ run_with_cli=benchmark_config.run_with_cli,
551
727
  )
552
728
  return (
553
729
  model_info is not None
@@ -569,11 +745,15 @@ class VLLMModel(HuggingFaceEncoderModel):
569
745
  Returns:
570
746
  The model configuration.
571
747
  """
572
- model_id, revision = (
573
- model_id.split("@") if "@" in model_id else (model_id, "main")
574
- )
748
+ model_id_components = split_model_id(model_id=model_id)
575
749
  model_info = get_model_repo_info(
576
- model_id=model_id, revision=revision, benchmark_config=benchmark_config
750
+ model_id=model_id_components.model_id,
751
+ revision=model_id_components.revision,
752
+ api_key=benchmark_config.api_key,
753
+ cache_dir=benchmark_config.cache_dir,
754
+ trust_remote_code=benchmark_config.trust_remote_code,
755
+ requires_safetensors=benchmark_config.requires_safetensors,
756
+ run_with_cli=benchmark_config.run_with_cli,
577
757
  )
578
758
  if model_info is None:
579
759
  raise InvalidModel(f"The model {model_id!r} could not be found.")
@@ -582,8 +762,9 @@ class VLLMModel(HuggingFaceEncoderModel):
582
762
  language_codes = list(language_mapping.keys())
583
763
 
584
764
  model_config = ModelConfig(
585
- model_id=model_id,
586
- revision=revision,
765
+ model_id=model_id_components.model_id,
766
+ revision=model_id_components.revision,
767
+ param=model_id_components.param,
587
768
  task=model_info.pipeline_tag,
588
769
  languages=[
589
770
  language_mapping[tag]
@@ -603,7 +784,7 @@ class VLLMModel(HuggingFaceEncoderModel):
603
784
  return model_config
604
785
 
605
786
  @property
606
- def data_collator(self) -> c.Callable[[list[t.Any]], dict[str, t.Any]]:
787
+ def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
607
788
  """The data collator used to prepare samples during finetuning.
608
789
 
609
790
  Returns:
@@ -625,10 +806,10 @@ class VLLMModel(HuggingFaceEncoderModel):
625
806
  )
626
807
 
627
808
 
628
- def load_model_and_tokenizer(
809
+ def load_model_and_tokeniser(
629
810
  model_config: "ModelConfig", benchmark_config: "BenchmarkConfig"
630
811
  ) -> tuple["LLM", "PreTrainedTokenizer"]:
631
- """Load the model and tokenizer.
812
+ """Load the model and tokeniser.
632
813
 
633
814
  Args:
634
815
  model_config:
@@ -637,7 +818,7 @@ def load_model_and_tokenizer(
637
818
  The benchmark configuration.
638
819
 
639
820
  Returns:
640
- A pair (model, tokenizer), with the loaded model and tokenizer
821
+ A pair (model, tokeniser), with the loaded model and tokeniser
641
822
  """
642
823
  # Prefer base model ID if the model is an adapter - the adapter will be added on
643
824
  # during inference in this case
@@ -649,8 +830,8 @@ def load_model_and_tokenizer(
649
830
  hf_model_config = load_hf_model_config(
650
831
  model_id=model_id,
651
832
  num_labels=0,
652
- id2label=dict(),
653
- label2id=dict(),
833
+ id2label=HashableDict(),
834
+ label2id=HashableDict(),
654
835
  revision=revision,
655
836
  model_cache_dir=model_config.model_cache_dir,
656
837
  api_key=benchmark_config.api_key,
@@ -675,46 +856,55 @@ def load_model_and_tokenizer(
675
856
  dtype: str | torch.dtype = "auto"
676
857
 
677
858
  # Choose bf16 over fp16 if the model is a fp32 model and the GPU supports it
678
- if hf_model_config.torch_dtype == torch.float32:
859
+ if hf_model_config.dtype == torch.float32:
679
860
  if torch.cuda.is_bf16_supported():
680
- logger.info(
861
+ log(
681
862
  "You are loading a model with dtype FP32, which we will convert to "
682
863
  "BF16 as FP32 is not supported by vLLM and BF16 is supported by your "
683
- "GPU."
864
+ "GPU.",
865
+ level=logging.WARNING,
684
866
  )
685
867
  dtype = torch.bfloat16
686
868
  else:
687
- logger.info(
869
+ log(
688
870
  "You are loading a model with dtype FP32, which we will convert to "
689
871
  "FP16 as FP32 is not supported by vLLM and BF16 is not supported by "
690
- "your GPU."
872
+ "your GPU.",
873
+ level=logging.WARNING,
691
874
  )
692
875
  dtype = torch.float16
693
876
 
694
- # If the model is a quantized model, we need to set the dtype to float16
695
- if quantization is not None and hf_model_config.torch_dtype != torch.float16:
696
- logger.info(
877
+ # If the model is a quantized model, we might need to change the dtype
878
+ if quantization == "mxfp4" and hf_model_config.dtype is None:
879
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
880
+ log(
881
+ "You are loading a quantized model where `dtype` has not been set. "
882
+ f"Setting dtype to {dtype!r}.",
883
+ level=logging.DEBUG,
884
+ )
885
+ elif quantization is not None and hf_model_config.dtype != torch.float16:
886
+ log(
697
887
  "You are loading a quantized model with dtype "
698
- f"{hf_model_config.torch_dtype}, which vLLM does not support. Setting "
699
- "dtype to float16 instead."
888
+ f"{hf_model_config.dtype}, which vLLM does not support. Setting "
889
+ "dtype to float16 instead.",
890
+ level=logging.WARNING,
700
891
  )
701
892
  dtype = torch.float16
702
893
 
703
894
  # If the model is a bf16 model, we need to check the CUDA compute capability
704
- if hf_model_config.torch_dtype == torch.bfloat16:
895
+ if hf_model_config.dtype == torch.bfloat16:
705
896
  min_cuda_compute_capability = get_min_cuda_compute_capability()
706
897
  required_capability = VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY
707
898
 
708
899
  if min_cuda_compute_capability is not None:
709
900
  if min_cuda_compute_capability < required_capability:
710
- logger.info(
711
- "You are loading a model with "
712
- f"dtype {hf_model_config.torch_dtype}, "
713
- "which vLLM only supports for CUDA devices with"
714
- f"CUDA compute capability >={required_capability}. "
715
- "You are using one or more devices with "
716
- f"compute capability {min_cuda_compute_capability}. "
717
- "Setting dtype to float16 instead."
901
+ log(
902
+ f"You are loading a model with dtype {hf_model_config.dtype}, "
903
+ "which vLLM only supports for CUDA devices with CUDA compute "
904
+ f"capability >={required_capability}. You are using one or more "
905
+ f"devices with compute capability {min_cuda_compute_capability}. "
906
+ "Setting dtype to float16 instead.",
907
+ level=logging.WARNING,
718
908
  )
719
909
  dtype = torch.float16
720
910
 
@@ -741,31 +931,40 @@ def load_model_and_tokenizer(
741
931
  else:
742
932
  true_max_model_len = MAX_CONTEXT_LENGTH
743
933
 
744
- tokenizer = load_tokenizer(
934
+ tokeniser = load_tokeniser(
745
935
  model_id=model_config.model_id,
746
936
  revision=model_config.revision,
747
937
  adapter_base_model_id=model_config.adapter_base_model_id,
748
938
  trust_remote_code=benchmark_config.trust_remote_code,
749
939
  model_max_length=true_max_model_len,
750
- model_cache_dir=model_config.model_cache_dir,
751
- token=benchmark_config.api_key or os.getenv("HUGGINGFACE_API_KEY") or True,
940
+ model_config=model_config,
941
+ token=get_hf_token(api_key=benchmark_config.api_key),
942
+ )
943
+ vllm_tokenisation_params = get_vllm_tokenisation_params(
944
+ tokeniser=tokeniser, model_config=model_config
752
945
  )
753
946
 
754
947
  clear_vllm()
755
948
 
756
949
  try:
757
950
  model = LLM(
758
- model=model_id,
759
- tokenizer=model_id,
951
+ model=(
952
+ model_id
953
+ if internet_connection_available()
954
+ else resolve_model_path(download_dir=download_dir)
955
+ ),
956
+ tokenizer=(
957
+ model_id
958
+ if internet_connection_available()
959
+ else resolve_model_path(download_dir=download_dir)
960
+ ),
760
961
  gpu_memory_utilization=benchmark_config.gpu_memory_utilization,
761
962
  max_model_len=min(true_max_model_len, MAX_CONTEXT_LENGTH),
762
963
  download_dir=download_dir,
763
964
  trust_remote_code=benchmark_config.trust_remote_code,
764
965
  revision=revision,
765
966
  seed=4242,
766
- distributed_executor_backend=(
767
- "ray" if torch.cuda.device_count() > 1 else "mp"
768
- ),
967
+ distributed_executor_backend="mp",
769
968
  tensor_parallel_size=torch.cuda.device_count(),
770
969
  disable_custom_all_reduce=True,
771
970
  quantization=quantization,
@@ -776,38 +975,65 @@ def load_model_and_tokenizer(
776
975
  enable_prefix_caching=False,
777
976
  enable_lora=model_config.adapter_base_model_id is not None,
778
977
  max_lora_rank=256,
978
+ **vllm_tokenisation_params,
779
979
  )
780
980
  except (RuntimeError, ValueError, OSError) as e:
781
981
  if "awaiting a review from the repo authors" in str(e):
782
982
  raise InvalidModel(
783
983
  f"The model {model_id!r} is awaiting a review from the repository "
784
984
  "authors. Please try again later."
785
- )
985
+ ) from e
786
986
  elif "trust_remote_code" in str(e):
787
987
  raise InvalidModel(
788
988
  f"Loading the model {model_id!r} needs to trust remote code. "
789
989
  "If you trust the suppliers of this model, then you can enable "
790
990
  "this by setting the `--trust-remote-code` flag."
991
+ ) from e
992
+ elif "See stack trace for root cause." in str(
993
+ e
994
+ ) or "See root cause above." in str(e):
995
+ msg = (
996
+ f"The model {model_id!r} could not be loaded, but vLLM did not "
997
+ "mention exactly what happened. "
998
+ )
999
+ msg += (
1000
+ (
1001
+ "Since you're running in verbose mode, you might see a descriptive "
1002
+ "error above already. Note however that if the error message urges "
1003
+ "you to set the environment variable `VLLM_ATTENTION_BACKEND` to "
1004
+ "'FLEX_ATTENTION', please try setting it to 'FLASH_ATTN' first, as "
1005
+ "that often solves the issue, whereas 'FLEX_ATTENTION' usually "
1006
+ "doesn't. If you don't see any descriptive error above, then you "
1007
+ "can try "
1008
+ )
1009
+ if benchmark_config.verbose
1010
+ else "Try "
1011
+ )
1012
+ msg += (
1013
+ "re-running the benchmark with the environment variable `FULL_LOG` "
1014
+ "set to `1` to see the full stack trace. E.g., "
1015
+ f"`FULL_LOG=1 euroeval --model {model_id}`."
791
1016
  )
1017
+ raise InvalidModel(msg) from e
792
1018
  raise InvalidModel(
793
1019
  f"The model {model_id!r} could not be loaded. The error was {e!r}."
794
- )
1020
+ ) from e
795
1021
 
796
1022
  model.config = hf_model_config
797
1023
 
798
- return model, tokenizer
1024
+ return model, tokeniser
799
1025
 
800
1026
 
801
- def load_tokenizer(
1027
+ def load_tokeniser(
802
1028
  model_id: str,
803
1029
  revision: str,
804
1030
  adapter_base_model_id: str | None,
805
1031
  trust_remote_code: bool,
806
1032
  model_max_length: int,
807
- model_cache_dir: str,
1033
+ model_config: "ModelConfig",
808
1034
  token: str | bool,
809
1035
  ) -> "PreTrainedTokenizer":
810
- """Load the tokenizer.
1036
+ """Load the tokeniser.
811
1037
 
812
1038
  Args:
813
1039
  model_id:
@@ -821,64 +1047,97 @@ def load_tokenizer(
821
1047
  Whether to trust remote code.
822
1048
  model_max_length:
823
1049
  The maximum length of the model.
824
- model_cache_dir:
825
- The cache directory for the model.
1050
+ model_config:
1051
+ The model configuration.
826
1052
  token:
827
1053
  The Hugging Face API token.
828
1054
 
829
1055
  Returns:
830
- The loaded tokenizer.
1056
+ The loaded tokeniser.
831
1057
  """
832
1058
  revision = revision if adapter_base_model_id is None else "main"
833
1059
  config = AutoConfig.from_pretrained(
834
1060
  adapter_base_model_id or model_id,
835
1061
  revision=revision,
836
- cache_dir=model_cache_dir,
1062
+ cache_dir=model_config.model_cache_dir,
837
1063
  token=token,
838
1064
  trust_remote_code=trust_remote_code,
1065
+ local_files_only=not internet_connection_available(),
839
1066
  )
840
1067
  num_retries = 5
841
1068
  for _ in range(num_retries):
842
1069
  try:
843
- tokenizer = AutoTokenizer.from_pretrained(
1070
+ # Mistral instruction-tuned models need a custom tokeniser
1071
+ if model_id.startswith("mistralai/") and "base" not in model_id.lower():
1072
+ tokeniser = MistralCommonTokenizer.from_pretrained(
1073
+ model_id,
1074
+ padding_side="left",
1075
+ truncation_side="left",
1076
+ model_max_length=model_max_length,
1077
+ token=token,
1078
+ )
1079
+ break
1080
+ tokeniser = AutoTokenizer.from_pretrained(
844
1081
  model_id,
845
- use_fast=True,
1082
+ revision=revision,
1083
+ use_fast=False if model_config.param == "slow-tokenizer" else True,
846
1084
  verbose=False,
847
1085
  trust_remote_code=trust_remote_code,
848
1086
  padding_side="left",
849
1087
  truncation_side="left",
850
1088
  model_max_length=model_max_length,
1089
+ cache_dir=model_config.model_cache_dir,
851
1090
  config=config,
852
1091
  token=token,
1092
+ local_files_only=not internet_connection_available(),
853
1093
  )
854
1094
  break
855
1095
  except (json.JSONDecodeError, OSError, TypeError) as e:
856
1096
  if adapter_base_model_id is None or model_id == adapter_base_model_id:
857
1097
  raise InvalidModel(
858
- f"Could not load tokenizer for model {model_id!r}. The error was "
1098
+ f"Could not load tokeniser for model {model_id!r}. The error was "
859
1099
  f"{str(e)}."
860
- )
861
- logger.debug(
862
- f"Could not load tokenizer for {model_id!r}. Falling back to "
863
- f"{adapter_base_model_id!r}."
1100
+ ) from e
1101
+ log(
1102
+ f"Could not load tokeniser for {model_id!r}. Falling back to "
1103
+ f"{adapter_base_model_id!r}.",
1104
+ level=logging.DEBUG,
864
1105
  )
865
1106
  model_id = adapter_base_model_id
866
1107
  except (TimeoutError, RequestError):
867
- logger.info(f"Couldn't load tokenizer for {model_id!r}. Retrying.")
1108
+ log(
1109
+ f"Couldn't load tokeniser for {model_id!r}. Retrying.",
1110
+ level=logging.WARNING,
1111
+ )
868
1112
  sleep(5)
869
1113
  continue
1114
+ except (KeyError, ValueError) as e:
1115
+ if "mistral" in str(e).lower():
1116
+ tokeniser = MistralCommonTokenizer.from_pretrained(
1117
+ model_id,
1118
+ padding_side="left",
1119
+ truncation_side="left",
1120
+ model_max_length=model_max_length,
1121
+ token=token,
1122
+ )
1123
+ break
1124
+ raise InvalidModel(
1125
+ f"Could not load tokeniser for model {model_id!r}. The error was "
1126
+ f"{str(e)}."
1127
+ ) from e
870
1128
  else:
871
1129
  raise InvalidModel(
872
- f"Could not load tokenizer for model {model_id!r} after {num_retries} "
1130
+ f"Could not load tokeniser for model {model_id!r} after {num_retries} "
873
1131
  "attempts."
874
1132
  )
875
1133
 
876
1134
  # Ensure that BOS, EOS and PAD tokens are set
877
- tokenizer.bos_token, tokenizer.bos_token_id = get_bos_token(tokenizer=tokenizer)
878
- tokenizer.eos_token, tokenizer.eos_token_id = get_eos_token(tokenizer=tokenizer)
879
- tokenizer.pad_token, tokenizer.pad_token_id = get_pad_token(tokenizer=tokenizer)
1135
+ if not isinstance(tokeniser, MistralCommonTokenizer):
1136
+ tokeniser.bos_token, tokeniser.bos_token_id = get_bos_token(tokeniser=tokeniser)
1137
+ tokeniser.eos_token, tokeniser.eos_token_id = get_eos_token(tokeniser=tokeniser)
1138
+ tokeniser.pad_token, tokeniser.pad_token_id = get_pad_token(tokeniser=tokeniser)
880
1139
 
881
- return tokenizer
1140
+ return tokeniser
882
1141
 
883
1142
 
884
1143
  def clear_vllm() -> None:
@@ -886,80 +1145,93 @@ def clear_vllm() -> None:
886
1145
  with contextlib.suppress(ValueError):
887
1146
  destroy_model_parallel()
888
1147
  destroy_distributed_environment()
889
- if ray.is_initialized():
890
- ray.shutdown()
891
1148
  with contextlib.suppress(AssertionError):
892
1149
  torch.distributed.destroy_process_group()
893
- if ray.is_initialized():
894
- ray.shutdown()
895
1150
  clear_memory()
896
1151
 
897
1152
 
898
1153
  def get_end_of_reasoning_token(
899
- model: "LLM", tokenizer: "PreTrainedTokenizer", model_id: str
900
- ) -> str | None:
1154
+ model: "LLM", tokeniser: "PreTrainedTokenizer", model_config: "ModelConfig"
1155
+ ) -> str | re.Pattern | None:
901
1156
  """Get the end-of-reasoning token for a generative model.
902
1157
 
903
1158
  Args:
904
1159
  model:
905
1160
  The vLLM model.
906
- tokenizer:
907
- The tokenizer.
908
- model_id:
909
- The model ID.
1161
+ tokeniser:
1162
+ The tokeniser.
1163
+ model_config:
1164
+ The model configuration.
910
1165
 
911
1166
  Returns:
912
1167
  The end of reasoning token, or None if it could not be found.
913
1168
  """
1169
+ model_id = model_config.model_id
1170
+
914
1171
  # Create a prompt to check if the model uses the reasoning tokens
915
1172
  prompt = "What is your name?"
916
- if tokenizer.chat_template is not None:
917
- templated_prompt = tokenizer.apply_chat_template(
1173
+ if has_chat_template(tokeniser=tokeniser):
1174
+ extra_kwargs = dict()
1175
+ if model_config.param in {"thinking", "no-thinking"}:
1176
+ extra_kwargs["enable_thinking"] = model_config.param == "thinking"
1177
+ templated_prompt = apply_chat_template(
918
1178
  conversation=[dict(role="user", content=prompt)],
1179
+ tokeniser=tokeniser,
1180
+ tokenise=False,
919
1181
  add_generation_prompt=True,
920
- tokenize=False,
1182
+ **extra_kwargs,
921
1183
  )
922
1184
  assert isinstance(templated_prompt, str)
923
1185
  prompt = templated_prompt
924
1186
 
925
1187
  # Check that the beginning-of-reasoning token is actually used by the model
926
- completion = (
927
- model.generate(
928
- prompts=[prompt],
929
- sampling_params=SamplingParams(max_tokens=10),
930
- use_tqdm=False,
931
- )[0]
932
- .outputs[0]
933
- .text
934
- )
1188
+ output = model.generate(
1189
+ prompts=[prompt], sampling_params=SamplingParams(max_tokens=10), use_tqdm=False
1190
+ )[0]
1191
+ completion = tokeniser.decode(token_ids=output.outputs[0].token_ids)
935
1192
  bor_reasoning_matches = [
936
1193
  (bor_token, eor_token)
937
1194
  for bor_token, eor_token in REASONING_TOKENS
938
- if bor_token in prompt or bor_token in completion
1195
+ if (
1196
+ (
1197
+ isinstance(bor_token, str)
1198
+ and (bor_token in prompt or bor_token in completion)
1199
+ )
1200
+ or (
1201
+ isinstance(bor_token, re.Pattern)
1202
+ and (
1203
+ bor_token.search(prompt) is not None
1204
+ or bor_token.search(completion) is not None
1205
+ )
1206
+ )
1207
+ )
939
1208
  ]
940
1209
  if not bor_reasoning_matches:
941
1210
  log_once(
942
1211
  f"The model {model_id!r} did not generate any beginning-of-reasoning "
943
- "tokens in the prompt or the completion. Assuming the model is not "
944
- "a reasoning model.",
945
- level=logging.INFO,
1212
+ "tokens in the prompt or the completion. Assuming the model is not a "
1213
+ "reasoning model.",
1214
+ level=logging.DEBUG,
946
1215
  )
947
1216
  return None
948
1217
 
949
- # Check that the beginning-of-reasoning token is actually used by the model
950
- completion = (
951
- model.generate(
952
- prompts=[prompt],
953
- sampling_params=SamplingParams(max_tokens=REASONING_MAX_TOKENS),
954
- use_tqdm=False,
955
- )[0]
956
- .outputs[0]
957
- .text
958
- )
1218
+ # Check that the end-of-reasoning token is actually used by the model
1219
+ output = model.generate(
1220
+ prompts=[prompt],
1221
+ sampling_params=SamplingParams(max_tokens=REASONING_MAX_TOKENS),
1222
+ use_tqdm=False,
1223
+ )[0]
1224
+ completion = tokeniser.decode(token_ids=output.outputs[0].token_ids)
959
1225
  eor_reasoning_matches = [
960
1226
  (bor_token, eor_token)
961
1227
  for bor_token, eor_token in bor_reasoning_matches
962
- if eor_token in completion
1228
+ if (
1229
+ (isinstance(eor_token, str) and eor_token in completion)
1230
+ or (
1231
+ isinstance(eor_token, re.Pattern)
1232
+ and eor_token.search(completion) is not None
1233
+ )
1234
+ )
963
1235
  ]
964
1236
  if not eor_reasoning_matches:
965
1237
  log_once(
@@ -968,7 +1240,7 @@ def get_end_of_reasoning_token(
968
1240
  "the beginning-of-reasoning tokens "
969
1241
  f"{[bor_token for bor_token, _ in bor_reasoning_matches]!r}. "
970
1242
  "This is probably not correct, so please report this issue.",
971
- level=logging.INFO,
1243
+ level=logging.WARNING,
972
1244
  )
973
1245
  return None
974
1246
 
@@ -978,14 +1250,21 @@ def get_end_of_reasoning_token(
978
1250
  f"model {model_id!r}. Using {eor_reasoning_matches[0]!r} as "
979
1251
  "the reasoning token. If this is not the correct reasoning token, "
980
1252
  "please report this issue.",
981
- level=logging.INFO,
1253
+ level=logging.WARNING,
982
1254
  )
983
1255
 
984
1256
  bor_token, eor_token = eor_reasoning_matches[0]
1257
+
1258
+ bor_token_logging: str = (
1259
+ bor_token if isinstance(bor_token, str) else bor_token.pattern
1260
+ )
1261
+ eor_token_logging: str = (
1262
+ eor_token if isinstance(eor_token, str) else eor_token.pattern
1263
+ )
985
1264
  log_once(
986
- f"Detected beginning-of-reasoning token {bor_token!r} and end-of-reasoning "
987
- f"token {eor_token!r} for model {model_id!r}.",
988
- level=logging.INFO,
1265
+ f"Detected beginning-of-reasoning token {bor_token_logging!r} and "
1266
+ f"end-of-reasoning token {eor_token_logging!r} for model {model_id!r}.",
1267
+ level=logging.DEBUG,
989
1268
  )
990
1269
 
991
1270
  return eor_token
@@ -993,22 +1272,21 @@ def get_end_of_reasoning_token(
993
1272
 
994
1273
  def get_custom_stop_tokens(
995
1274
  model: "LLM",
996
- tokenizer: "PreTrainedTokenizer",
1275
+ tokeniser: "PreTrainedTokenizer",
997
1276
  model_id: str,
998
- is_reasoning_model: bool,
1277
+ generative_type: GenerativeType | None,
999
1278
  ) -> list[str]:
1000
1279
  """Get the stop tokens for a generative model.
1001
1280
 
1002
1281
  Args:
1003
1282
  model:
1004
1283
  The vLLM model.
1005
- tokenizer:
1006
- The tokenizer.
1284
+ tokeniser:
1285
+ The tokeniser.
1007
1286
  model_id:
1008
1287
  The model ID.
1009
- is_reasoning_model:
1010
- Whether the model is a reasoning model. This is used to determine the number
1011
- of generated tokens to allow before stopping the generation.
1288
+ generative_type:
1289
+ The generative type of the model.
1012
1290
 
1013
1291
  Returns:
1014
1292
  A list of stop tokens.
@@ -1016,25 +1294,26 @@ def get_custom_stop_tokens(
1016
1294
  candidate_stop_tokens = CUSTOM_STOP_TOKENS
1017
1295
 
1018
1296
  prompt = "Hello"
1019
- if tokenizer.chat_template is not None:
1020
- templated_prompt = tokenizer.apply_chat_template(
1297
+ if has_chat_template(tokeniser=tokeniser):
1298
+ templated_prompt = apply_chat_template(
1021
1299
  conversation=[dict(role="user", content=prompt)],
1300
+ tokeniser=tokeniser,
1301
+ tokenise=False,
1022
1302
  add_generation_prompt=True,
1023
- tokenize=False,
1303
+ enable_thinking=generative_type == GenerativeType.REASONING,
1024
1304
  )
1025
1305
  assert isinstance(templated_prompt, str)
1026
1306
  prompt = templated_prompt
1027
1307
 
1028
- max_tokens = REASONING_MAX_TOKENS if is_reasoning_model else 10
1029
- completion = (
1030
- model.generate(
1031
- prompts=[prompt],
1032
- sampling_params=SamplingParams(max_tokens=max_tokens, temperature=0.0),
1033
- use_tqdm=False,
1034
- )[0]
1035
- .outputs[0]
1036
- .text
1308
+ max_tokens = (
1309
+ REASONING_MAX_TOKENS if generative_type == GenerativeType.REASONING else 10
1037
1310
  )
1311
+ output = model.generate(
1312
+ prompts=[prompt],
1313
+ sampling_params=SamplingParams(max_tokens=max_tokens, temperature=0.0),
1314
+ use_tqdm=False,
1315
+ )[0]
1316
+ completion = tokeniser.decode(token_ids=output.outputs[0].token_ids)
1038
1317
 
1039
1318
  stop_tokens = [
1040
1319
  stop_token
@@ -1042,27 +1321,50 @@ def get_custom_stop_tokens(
1042
1321
  if stop_token in prompt or stop_token in completion
1043
1322
  ]
1044
1323
  if stop_tokens:
1045
- logger.debug(
1324
+ log(
1046
1325
  f"Found the following custom stop tokens for model {model_id!r}: "
1047
- f"{stop_tokens}."
1326
+ f"{stop_tokens}.",
1327
+ level=logging.DEBUG,
1048
1328
  )
1049
1329
  else:
1050
- logger.debug(f"Found no custom stop tokens for model {model_id!r}.")
1330
+ log(f"Found no custom stop tokens for model {model_id!r}.", level=logging.DEBUG)
1051
1331
 
1052
1332
  return stop_tokens
1053
1333
 
1054
1334
 
1055
- def get_pbar_without_leave(*tqdm_args, **tqdm_kwargs) -> tqdm:
1056
- """Get a progress bar for vLLM which disappears after completion.
1335
+ def get_vllm_tokenisation_params(
1336
+ tokeniser: "PreTrainedTokenizer", model_config: "ModelConfig"
1337
+ ) -> dict[str, t.Any]:
1338
+ """Get the tokenisation parameters for vLLM.
1057
1339
 
1058
1340
  Args:
1059
- *tqdm_args:
1060
- Positional arguments to pass to tqdm.
1061
- **tqdm_kwargs:
1062
- Additional keyword arguments to pass to tqdm.
1341
+ tokeniser:
1342
+ The tokeniser.
1343
+ model_config:
1344
+ The model configuration.
1063
1345
 
1064
1346
  Returns:
1065
- A tqdm progress bar.
1347
+ A dictionary of tokenisation parameters to pass to vLLM.
1066
1348
  """
1067
- tqdm_kwargs.pop("leave", None) # Remove the 'leave' key if it exists
1068
- return tqdm(*tqdm_args, leave=False, **tqdm_kwargs)
1349
+ if isinstance(tokeniser, MistralCommonTokenizer):
1350
+ tokeniser_mode = "mistral"
1351
+ elif model_config.param == "slow-tokenizer":
1352
+ tokeniser_mode = "slow"
1353
+ else:
1354
+ tokeniser_mode = "auto"
1355
+
1356
+ if isinstance(tokeniser, MistralCommonTokenizer):
1357
+ config_format = "mistral"
1358
+ else:
1359
+ config_format = "auto"
1360
+
1361
+ if isinstance(tokeniser, MistralCommonTokenizer):
1362
+ load_format = "mistral"
1363
+ else:
1364
+ load_format = "auto"
1365
+
1366
+ return dict(
1367
+ tokenizer_mode=tokeniser_mode,
1368
+ config_format=config_format,
1369
+ load_format=load_format,
1370
+ )