EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,8 +5,8 @@ import contextlib
|
|
|
5
5
|
import importlib.util
|
|
6
6
|
import json
|
|
7
7
|
import logging
|
|
8
|
-
import os
|
|
9
8
|
import re
|
|
9
|
+
import shutil
|
|
10
10
|
import typing as t
|
|
11
11
|
from functools import partial
|
|
12
12
|
from pathlib import Path
|
|
@@ -15,23 +15,22 @@ from time import sleep
|
|
|
15
15
|
import torch
|
|
16
16
|
from huggingface_hub import snapshot_download
|
|
17
17
|
from pydantic import conlist, create_model
|
|
18
|
-
from tqdm.auto import tqdm
|
|
19
18
|
from transformers.models.auto.configuration_auto import AutoConfig
|
|
20
19
|
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
20
|
+
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
|
21
21
|
from urllib3.exceptions import RequestError
|
|
22
22
|
|
|
23
23
|
from ..constants import (
|
|
24
24
|
CUSTOM_STOP_TOKENS,
|
|
25
25
|
GENERATIVE_PIPELINE_TAGS,
|
|
26
26
|
MAX_CONTEXT_LENGTH,
|
|
27
|
-
|
|
27
|
+
MAX_VLLM_LOGPROBS,
|
|
28
28
|
MERGE_TAGS,
|
|
29
29
|
REASONING_MAX_TOKENS,
|
|
30
30
|
REASONING_TOKENS,
|
|
31
|
-
TASKS_USING_JSON,
|
|
32
31
|
VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY,
|
|
33
32
|
)
|
|
34
|
-
from ..data_models import GenerativeModelOutput, ModelConfig
|
|
33
|
+
from ..data_models import GenerativeModelOutput, HashableDict, ModelConfig
|
|
35
34
|
from ..enums import (
|
|
36
35
|
BatchingPreference,
|
|
37
36
|
GenerativeType,
|
|
@@ -44,29 +43,41 @@ from ..exceptions import (
|
|
|
44
43
|
InvalidModel,
|
|
45
44
|
NeedsEnvironmentVariable,
|
|
46
45
|
NeedsExtraInstalled,
|
|
46
|
+
NeedsSystemDependency,
|
|
47
|
+
)
|
|
48
|
+
from ..generation_utils import (
|
|
49
|
+
apply_prompt,
|
|
50
|
+
extract_few_shot_examples,
|
|
51
|
+
raise_if_wrong_params,
|
|
47
52
|
)
|
|
48
|
-
from ..generation_utils import apply_prompt, extract_few_shot_examples
|
|
49
53
|
from ..languages import get_all_languages
|
|
54
|
+
from ..logging_utils import get_pbar, log, log_once, no_terminal_output
|
|
50
55
|
from ..task_group_utils import (
|
|
51
56
|
question_answering,
|
|
52
57
|
sequence_classification,
|
|
53
58
|
text_to_text,
|
|
54
59
|
token_classification,
|
|
55
60
|
)
|
|
56
|
-
from ..
|
|
61
|
+
from ..tokenisation_utils import (
|
|
62
|
+
apply_chat_template,
|
|
57
63
|
get_bos_token,
|
|
58
64
|
get_end_of_chat_token_ids,
|
|
59
65
|
get_eos_token,
|
|
60
66
|
get_first_label_token_mapping,
|
|
61
67
|
get_pad_token,
|
|
68
|
+
has_chat_template,
|
|
62
69
|
should_prompts_be_stripped,
|
|
63
70
|
)
|
|
64
71
|
from ..types import ExtractLabelsFunction
|
|
65
72
|
from ..utils import (
|
|
66
73
|
clear_memory,
|
|
67
74
|
create_model_cache_dir,
|
|
75
|
+
flash_attention_backend,
|
|
76
|
+
get_hf_token,
|
|
68
77
|
get_min_cuda_compute_capability,
|
|
69
|
-
|
|
78
|
+
internet_connection_available,
|
|
79
|
+
resolve_model_path,
|
|
80
|
+
split_model_id,
|
|
70
81
|
)
|
|
71
82
|
from .hf import HuggingFaceEncoderModel, get_model_repo_info, load_hf_model_config
|
|
72
83
|
|
|
@@ -77,13 +88,7 @@ if t.TYPE_CHECKING or importlib.util.find_spec("vllm") is not None:
|
|
|
77
88
|
destroy_model_parallel,
|
|
78
89
|
)
|
|
79
90
|
from vllm.lora.request import LoRARequest
|
|
80
|
-
|
|
81
|
-
if t.TYPE_CHECKING or importlib.util.find_spec("outlines") is not None:
|
|
82
|
-
from outlines.models.vllm import adapt_tokenizer
|
|
83
|
-
from outlines.processors.structured import JSONLogitsProcessor
|
|
84
|
-
|
|
85
|
-
if t.TYPE_CHECKING or importlib.util.find_spec("ray") is not None:
|
|
86
|
-
import ray
|
|
91
|
+
from vllm.sampling_params import StructuredOutputsParams
|
|
87
92
|
|
|
88
93
|
if t.TYPE_CHECKING:
|
|
89
94
|
from datasets import DatasetDict
|
|
@@ -92,7 +97,10 @@ if t.TYPE_CHECKING:
|
|
|
92
97
|
|
|
93
98
|
from ..data_models import BenchmarkConfig, DatasetConfig, Task
|
|
94
99
|
|
|
95
|
-
|
|
100
|
+
|
|
101
|
+
MODELS_REQUIRING_FLASH_ATTENTION: list[re.Pattern] = [
|
|
102
|
+
re.compile(r".*gpt-oss.*", flags=re.IGNORECASE)
|
|
103
|
+
]
|
|
96
104
|
|
|
97
105
|
|
|
98
106
|
class VLLMModel(HuggingFaceEncoderModel):
|
|
@@ -101,12 +109,17 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
101
109
|
fresh_model = False
|
|
102
110
|
batching_preference = BatchingPreference.ALL_AT_ONCE
|
|
103
111
|
high_priority = True
|
|
112
|
+
allowed_params = {
|
|
113
|
+
re.compile(r".*"): ["thinking", "no-thinking", "slow-tokenizer"],
|
|
114
|
+
re.compile(r".*gpt-oss.*", flags=re.IGNORECASE): ["low", "medium", "high"],
|
|
115
|
+
}
|
|
104
116
|
|
|
105
117
|
def __init__(
|
|
106
118
|
self,
|
|
107
119
|
model_config: "ModelConfig",
|
|
108
120
|
dataset_config: "DatasetConfig",
|
|
109
121
|
benchmark_config: "BenchmarkConfig",
|
|
122
|
+
log_metadata: bool = True,
|
|
110
123
|
) -> None:
|
|
111
124
|
"""Initialise the vLLM model.
|
|
112
125
|
|
|
@@ -117,30 +130,40 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
117
130
|
The dataset configuration.
|
|
118
131
|
benchmark_config:
|
|
119
132
|
The benchmark configuration.
|
|
133
|
+
log_metadata:
|
|
134
|
+
Whether to log the model and dataset metadata.
|
|
120
135
|
"""
|
|
121
|
-
if (
|
|
122
|
-
importlib.util.find_spec("vllm") is None
|
|
123
|
-
or importlib.util.find_spec("ray") is None
|
|
124
|
-
):
|
|
136
|
+
if importlib.util.find_spec("vllm") is None:
|
|
125
137
|
raise NeedsExtraInstalled(extra="generative")
|
|
126
138
|
|
|
127
|
-
|
|
128
|
-
|
|
139
|
+
if shutil.which("nvcc") is None:
|
|
140
|
+
raise NeedsSystemDependency(
|
|
141
|
+
dependency="nvcc",
|
|
142
|
+
instructions=(
|
|
143
|
+
"Please install the CUDA Toolkit from "
|
|
144
|
+
"https://developer.nvidia.com/cuda-downloads or ensure that NVCC "
|
|
145
|
+
"is available in your PATH."
|
|
146
|
+
),
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
raise_if_wrong_params(
|
|
150
|
+
model_config=model_config, allowed_params=self.allowed_params
|
|
129
151
|
)
|
|
152
|
+
|
|
153
|
+
with (
|
|
154
|
+
no_terminal_output(disable=benchmark_config.verbose),
|
|
155
|
+
flash_attention_backend(
|
|
156
|
+
disabled=all(
|
|
157
|
+
not re.search(pattern=pattern, string=model_config.model_id)
|
|
158
|
+
for pattern in MODELS_REQUIRING_FLASH_ATTENTION
|
|
159
|
+
)
|
|
160
|
+
),
|
|
161
|
+
):
|
|
162
|
+
model, tokeniser = load_model_and_tokeniser(
|
|
163
|
+
model_config=model_config, benchmark_config=benchmark_config
|
|
164
|
+
)
|
|
130
165
|
self._model: "LLM" = model
|
|
131
|
-
self.
|
|
132
|
-
self.end_of_reasoning_token = get_end_of_reasoning_token(
|
|
133
|
-
model=self._model, tokenizer=self._tokenizer, model_id=model_config.model_id
|
|
134
|
-
)
|
|
135
|
-
self.end_of_chat_token_ids = get_end_of_chat_token_ids(
|
|
136
|
-
tokenizer=self._tokenizer
|
|
137
|
-
)
|
|
138
|
-
self.custom_stop_tokens = get_custom_stop_tokens(
|
|
139
|
-
model=self._model,
|
|
140
|
-
tokenizer=self._tokenizer,
|
|
141
|
-
model_id=model_config.model_id,
|
|
142
|
-
is_reasoning_model=self.end_of_reasoning_token is not None,
|
|
143
|
-
)
|
|
166
|
+
self._tokeniser: "PreTrainedTokenizer" = tokeniser
|
|
144
167
|
|
|
145
168
|
# We specify `HuggingFaceEncoderModel` here instead of `VLLMModel`, as we want
|
|
146
169
|
# to call the `__init__` method of the `BenchmarkModule` class.
|
|
@@ -148,16 +171,30 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
148
171
|
model_config=model_config,
|
|
149
172
|
dataset_config=dataset_config,
|
|
150
173
|
benchmark_config=benchmark_config,
|
|
174
|
+
log_metadata=log_metadata,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
self.end_of_reasoning_token = get_end_of_reasoning_token(
|
|
178
|
+
model=self._model, tokeniser=self._tokeniser, model_config=model_config
|
|
179
|
+
)
|
|
180
|
+
self.end_of_chat_token_ids = get_end_of_chat_token_ids(
|
|
181
|
+
tokeniser=self._tokeniser, generative_type=self.generative_type
|
|
182
|
+
)
|
|
183
|
+
self.custom_stop_tokens = get_custom_stop_tokens(
|
|
184
|
+
model=self._model,
|
|
185
|
+
tokeniser=self._tokeniser,
|
|
186
|
+
model_id=model_config.model_id,
|
|
187
|
+
generative_type=self.generative_type,
|
|
151
188
|
)
|
|
152
189
|
|
|
153
190
|
self.buffer |= dict(
|
|
154
|
-
instruction_model=self._tokenizer.chat_template is not None,
|
|
155
191
|
first_label_token_mapping=get_first_label_token_mapping(
|
|
156
192
|
dataset_config=self.dataset_config,
|
|
157
193
|
model_config=self.model_config,
|
|
158
|
-
|
|
194
|
+
tokeniser=self._tokeniser,
|
|
159
195
|
generative_type=self.generative_type,
|
|
160
|
-
|
|
196
|
+
log_metadata=self.log_metadata,
|
|
197
|
+
)
|
|
161
198
|
)
|
|
162
199
|
if self.model_config.adapter_base_model_id is not None:
|
|
163
200
|
adapter_path = snapshot_download(
|
|
@@ -170,12 +207,16 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
170
207
|
)
|
|
171
208
|
|
|
172
209
|
def __del__(self) -> None:
|
|
173
|
-
"""Clean up the model and
|
|
174
|
-
|
|
210
|
+
"""Clean up the model and tokeniser."""
|
|
211
|
+
try:
|
|
212
|
+
if importlib.util.find_spec("vllm") is not None:
|
|
213
|
+
clear_vllm()
|
|
214
|
+
except ImportError:
|
|
215
|
+
pass
|
|
175
216
|
if hasattr(self, "_model"):
|
|
176
217
|
del self._model
|
|
177
|
-
if hasattr(self, "
|
|
178
|
-
del self.
|
|
218
|
+
if hasattr(self, "_tokeniser"):
|
|
219
|
+
del self._tokeniser
|
|
179
220
|
|
|
180
221
|
@property
|
|
181
222
|
def generative_type(self) -> GenerativeType | None:
|
|
@@ -184,17 +225,37 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
184
225
|
Returns:
|
|
185
226
|
The generative type of the model, or None if it has not been set yet.
|
|
186
227
|
"""
|
|
187
|
-
if not hasattr(self, "
|
|
228
|
+
if not hasattr(self, "_tokeniser"):
|
|
229
|
+
log_once(
|
|
230
|
+
"The generative type of the model has not been set yet as the "
|
|
231
|
+
"tokeniser has not been loaded.",
|
|
232
|
+
level=logging.DEBUG,
|
|
233
|
+
)
|
|
188
234
|
return None
|
|
189
|
-
elif self.
|
|
190
|
-
|
|
235
|
+
elif self.benchmark_config.generative_type is not None:
|
|
236
|
+
type_ = self.benchmark_config.generative_type
|
|
237
|
+
elif self.model_config.param in {"thinking"}:
|
|
238
|
+
type_ = GenerativeType.REASONING
|
|
239
|
+
elif self.model_config.param in {"no-thinking"}:
|
|
240
|
+
type_ = GenerativeType.INSTRUCTION_TUNED
|
|
241
|
+
elif (
|
|
242
|
+
hasattr(self, "end_of_reasoning_token")
|
|
243
|
+
and self.end_of_reasoning_token is not None
|
|
244
|
+
):
|
|
245
|
+
type_ = GenerativeType.REASONING
|
|
191
246
|
elif (
|
|
192
|
-
self.
|
|
247
|
+
has_chat_template(tokeniser=self._tokeniser)
|
|
193
248
|
or "instruct" in self.model_config.model_id.lower()
|
|
194
249
|
):
|
|
195
|
-
|
|
250
|
+
type_ = GenerativeType.INSTRUCTION_TUNED
|
|
196
251
|
else:
|
|
197
|
-
|
|
252
|
+
type_ = GenerativeType.BASE
|
|
253
|
+
log_once(
|
|
254
|
+
f"Detected generative type {type_.name!r} for model "
|
|
255
|
+
f"{self.model_config.model_id!r}",
|
|
256
|
+
level=logging.DEBUG,
|
|
257
|
+
)
|
|
258
|
+
return type_
|
|
198
259
|
|
|
199
260
|
@property
|
|
200
261
|
def extract_labels_from_generation(self) -> ExtractLabelsFunction:
|
|
@@ -211,6 +272,7 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
211
272
|
return partial(
|
|
212
273
|
sequence_classification.extract_labels_from_generation,
|
|
213
274
|
dataset_config=self.dataset_config,
|
|
275
|
+
model_config=self.model_config,
|
|
214
276
|
first_label_token_mapping=self.buffer["first_label_token_mapping"],
|
|
215
277
|
)
|
|
216
278
|
case TaskGroup.TEXT_TO_TEXT:
|
|
@@ -269,7 +331,10 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
269
331
|
|
|
270
332
|
if self.benchmark_config.few_shot:
|
|
271
333
|
few_shot_examples = extract_few_shot_examples(
|
|
272
|
-
dataset=dataset,
|
|
334
|
+
dataset=dataset,
|
|
335
|
+
dataset_config=self.dataset_config,
|
|
336
|
+
benchmark_config=self.benchmark_config,
|
|
337
|
+
itr_idx=itr_idx,
|
|
273
338
|
)
|
|
274
339
|
else:
|
|
275
340
|
few_shot_examples = list()
|
|
@@ -280,9 +345,9 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
280
345
|
few_shot_examples=few_shot_examples,
|
|
281
346
|
model_config=self.model_config,
|
|
282
347
|
dataset_config=self.dataset_config,
|
|
283
|
-
|
|
348
|
+
generative_type=self.generative_type,
|
|
284
349
|
always_populate_text_field=True,
|
|
285
|
-
|
|
350
|
+
tokeniser=self._tokeniser,
|
|
286
351
|
),
|
|
287
352
|
batched=True,
|
|
288
353
|
load_from_cache_file=False,
|
|
@@ -300,68 +365,111 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
300
365
|
|
|
301
366
|
Returns:
|
|
302
367
|
The generated model outputs.
|
|
368
|
+
|
|
369
|
+
Raises:
|
|
370
|
+
InvalidBenchmark:
|
|
371
|
+
If the dataset requires logprobs, but we could not get the first token
|
|
372
|
+
of each label in the dataset.
|
|
303
373
|
"""
|
|
304
374
|
# Get stopping tokens
|
|
305
375
|
stop_tokens: list[str] = self.custom_stop_tokens.copy()
|
|
306
|
-
if self.
|
|
376
|
+
if self.generative_type == GenerativeType.BASE:
|
|
307
377
|
stop_tokens.append("\n\n")
|
|
308
|
-
if self.
|
|
309
|
-
assert isinstance(self.
|
|
378
|
+
if self._tokeniser.pad_token_id is not None:
|
|
379
|
+
assert isinstance(self._tokeniser.pad_token, str), (
|
|
310
380
|
f"The pad token for the model {self.model_config.model_id!r} "
|
|
311
|
-
f"is not a string, which is unexpected: {self.
|
|
381
|
+
f"is not a string, which is unexpected: {self._tokeniser.pad_token!r}."
|
|
312
382
|
)
|
|
313
|
-
stop_tokens.append(self.
|
|
314
|
-
if self.
|
|
315
|
-
assert isinstance(self.
|
|
383
|
+
stop_tokens.append(self._tokeniser.pad_token)
|
|
384
|
+
if self._tokeniser.eos_token_id is not None:
|
|
385
|
+
assert isinstance(self._tokeniser.eos_token, str), (
|
|
316
386
|
f"The EOS token for the model {self.model_config.model_id!r} "
|
|
317
|
-
f"is not a string, which is unexpected: {self.
|
|
387
|
+
f"is not a string, which is unexpected: {self._tokeniser.eos_token!r}."
|
|
318
388
|
)
|
|
319
|
-
stop_tokens.append(self.
|
|
320
|
-
if self.
|
|
321
|
-
self.
|
|
322
|
-
self.
|
|
389
|
+
stop_tokens.append(self._tokeniser.eos_token)
|
|
390
|
+
if self._tokeniser.pad_token_id is None:
|
|
391
|
+
self._tokeniser.pad_token_id = self._tokeniser.eos_token_id
|
|
392
|
+
self._tokeniser.pad_token = self._tokeniser.eos_token
|
|
323
393
|
if self.end_of_chat_token_ids is not None:
|
|
324
|
-
end_of_chat_token = self.
|
|
394
|
+
end_of_chat_token = self._tokeniser.decode(
|
|
325
395
|
self.end_of_chat_token_ids
|
|
326
396
|
).strip()
|
|
327
397
|
if end_of_chat_token:
|
|
328
398
|
stop_tokens.append(end_of_chat_token)
|
|
329
399
|
|
|
330
|
-
logits_processor = None
|
|
331
|
-
if self.dataset_config.task in TASKS_USING_JSON:
|
|
332
|
-
if self.generative_type == GenerativeType.REASONING:
|
|
333
|
-
log_once(
|
|
334
|
-
f"The model {self.model_config.model_id!r} is a reasoning model "
|
|
335
|
-
"and thus does not support structured generation, so we do not "
|
|
336
|
-
"enable it.",
|
|
337
|
-
level=logging.DEBUG,
|
|
338
|
-
)
|
|
339
|
-
else:
|
|
340
|
-
ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
|
|
341
|
-
keys_and_their_types: dict[str, t.Any] = {
|
|
342
|
-
tag_name: (conlist(str, max_length=5), ...)
|
|
343
|
-
for tag_name in ner_tag_names
|
|
344
|
-
}
|
|
345
|
-
pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
|
|
346
|
-
logits_processor = JSONLogitsProcessor(
|
|
347
|
-
schema=pydantic_class,
|
|
348
|
-
tokenizer=adapt_tokenizer(tokenizer=self._tokenizer), # type: ignore
|
|
349
|
-
whitespace_pattern=r" ?",
|
|
350
|
-
)
|
|
351
|
-
log_once(
|
|
352
|
-
"Using structured generation with the JSON schema "
|
|
353
|
-
f"{pydantic_class.model_json_schema()}",
|
|
354
|
-
level=logging.DEBUG,
|
|
355
|
-
)
|
|
356
|
-
|
|
357
400
|
# Get the mapping from labels to the first token in the label. We call this each
|
|
358
401
|
# time we generate a new dataset since the dataset config can change
|
|
359
402
|
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
360
403
|
dataset_config=self.dataset_config,
|
|
361
404
|
model_config=self.model_config,
|
|
362
|
-
|
|
405
|
+
tokeniser=self._tokeniser,
|
|
363
406
|
generative_type=self.generative_type,
|
|
407
|
+
log_metadata=self.log_metadata,
|
|
364
408
|
)
|
|
409
|
+
if (
|
|
410
|
+
not self.buffer["first_label_token_mapping"]
|
|
411
|
+
and self.dataset_config.task.requires_logprobs
|
|
412
|
+
):
|
|
413
|
+
raise InvalidBenchmark(
|
|
414
|
+
"The dataset requires logprobs, but we encountered an error when "
|
|
415
|
+
"trying to get the first token of each label in the dataset. You can "
|
|
416
|
+
"try running this benchmark with the --verbose flag to see what the "
|
|
417
|
+
"error was. Skipping this evaluation."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
structured_generation_schema = None
|
|
421
|
+
if (
|
|
422
|
+
self.dataset_config.task.uses_structured_output
|
|
423
|
+
or (self.dataset_config.task.uses_logprobs and self.dataset_config.labels)
|
|
424
|
+
) and self.generative_type == GenerativeType.REASONING:
|
|
425
|
+
structured_outputs = None
|
|
426
|
+
log_once(
|
|
427
|
+
"The dataset uses structured output, but we are not using it as the "
|
|
428
|
+
f"model {self.model_config.model_id!r} is a reasoning model.",
|
|
429
|
+
level=logging.DEBUG,
|
|
430
|
+
)
|
|
431
|
+
elif self.dataset_config.task.uses_structured_output:
|
|
432
|
+
ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
|
|
433
|
+
keys_and_their_types: dict[str, t.Any] = {
|
|
434
|
+
tag_name: (conlist(str, max_length=5), ...)
|
|
435
|
+
for tag_name in ner_tag_names
|
|
436
|
+
}
|
|
437
|
+
answer_format_class = create_model("AnswerFormat", **keys_and_their_types)
|
|
438
|
+
structured_generation_schema = answer_format_class.model_json_schema()
|
|
439
|
+
log_once(
|
|
440
|
+
"Using structured generation with the JSON schema: "
|
|
441
|
+
f"{json.dumps(structured_generation_schema)}",
|
|
442
|
+
level=logging.DEBUG,
|
|
443
|
+
)
|
|
444
|
+
structured_outputs = StructuredOutputsParams(
|
|
445
|
+
json=structured_generation_schema
|
|
446
|
+
)
|
|
447
|
+
elif (
|
|
448
|
+
self.dataset_config.task.uses_logprobs
|
|
449
|
+
and self.dataset_config.labels
|
|
450
|
+
and self.buffer.get("first_label_token_mapping", False)
|
|
451
|
+
):
|
|
452
|
+
choice_labels = [
|
|
453
|
+
self.dataset_config.prompt_label_mapping[label]
|
|
454
|
+
for label in self.dataset_config.labels
|
|
455
|
+
]
|
|
456
|
+
if isinstance(self.buffer["first_label_token_mapping"], dict):
|
|
457
|
+
choice_labels = [
|
|
458
|
+
self.buffer["first_label_token_mapping"][label]
|
|
459
|
+
for label in choice_labels
|
|
460
|
+
]
|
|
461
|
+
structured_outputs = StructuredOutputsParams(choice=choice_labels)
|
|
462
|
+
log_once(
|
|
463
|
+
"Using structured generation with the choices: "
|
|
464
|
+
f"{structured_outputs.choice!r}.",
|
|
465
|
+
level=logging.DEBUG,
|
|
466
|
+
)
|
|
467
|
+
else:
|
|
468
|
+
structured_outputs = None
|
|
469
|
+
log_once(
|
|
470
|
+
"Not using structured generation as the dataset does not require it.",
|
|
471
|
+
level=logging.DEBUG,
|
|
472
|
+
)
|
|
365
473
|
|
|
366
474
|
# Define the parameters used for vLLM generation
|
|
367
475
|
max_tokens: int = (
|
|
@@ -371,19 +479,21 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
371
479
|
)
|
|
372
480
|
sampling_params = SamplingParams(
|
|
373
481
|
max_tokens=max_tokens,
|
|
374
|
-
logprobs=
|
|
482
|
+
logprobs=MAX_VLLM_LOGPROBS
|
|
483
|
+
if self.buffer["first_label_token_mapping"]
|
|
484
|
+
else None,
|
|
375
485
|
temperature=0.0,
|
|
376
486
|
stop=[stop_token for stop_token in stop_tokens if stop_token],
|
|
377
|
-
|
|
487
|
+
structured_outputs=structured_outputs,
|
|
378
488
|
)
|
|
379
489
|
|
|
380
490
|
# If any of the prompts are empty then we need to replace them with a BOS token
|
|
381
491
|
# so that the vLLM model can generate from them
|
|
382
|
-
prompts:
|
|
492
|
+
prompts: c.Sequence[str] = inputs["text"]
|
|
383
493
|
if any(len(prompt) == 0 for prompt in prompts):
|
|
384
|
-
|
|
494
|
+
log("Found empty prompts, replacing with BOS token.", level=logging.DEBUG)
|
|
385
495
|
prompts = [
|
|
386
|
-
prompt if len(prompt) > 0 else str(self.
|
|
496
|
+
prompt if len(prompt) > 0 else str(self._tokeniser.bos_token)
|
|
387
497
|
for prompt in prompts
|
|
388
498
|
]
|
|
389
499
|
|
|
@@ -391,10 +501,8 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
391
501
|
labels_to_be_generated = list(self.dataset_config.prompt_label_mapping.values())
|
|
392
502
|
if len(labels_to_be_generated) == 0:
|
|
393
503
|
labels_to_be_generated = ["negative", "positive"]
|
|
394
|
-
if
|
|
395
|
-
|
|
396
|
-
) and should_prompts_be_stripped(
|
|
397
|
-
labels_to_be_generated=labels_to_be_generated, tokenizer=self._tokenizer
|
|
504
|
+
if self.generative_type == GenerativeType.BASE and should_prompts_be_stripped(
|
|
505
|
+
labels_to_be_generated=labels_to_be_generated, tokeniser=self._tokeniser
|
|
398
506
|
):
|
|
399
507
|
log_once(
|
|
400
508
|
f"Stripping prompts for model {self.model_config.model_id!r}.",
|
|
@@ -402,21 +510,35 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
402
510
|
)
|
|
403
511
|
prompts = [prompt.strip() for prompt in prompts]
|
|
404
512
|
|
|
513
|
+
# Truncate the prompts if needed, but only if it's not a reasoning model
|
|
514
|
+
if self.generative_type != GenerativeType.REASONING:
|
|
515
|
+
max_tokens_per_prompt = (
|
|
516
|
+
min(self._tokeniser.model_max_length, MAX_CONTEXT_LENGTH) - max_tokens
|
|
517
|
+
)
|
|
518
|
+
tokenized_prompts = self._tokeniser(
|
|
519
|
+
text=list(prompts), truncation=True, max_length=max_tokens_per_prompt
|
|
520
|
+
)
|
|
521
|
+
prompts = self._tokeniser.batch_decode(
|
|
522
|
+
sequences=tokenized_prompts.input_ids, skip_special_tokens=True
|
|
523
|
+
)
|
|
524
|
+
|
|
405
525
|
# Generate sequences using vLLM
|
|
406
526
|
input_is_a_test = len(prompts) == 1 and len(set(prompts[0])) == 1
|
|
407
527
|
num_attempts = 3
|
|
528
|
+
truncation_attempts = 1
|
|
408
529
|
for _ in range(num_attempts):
|
|
409
530
|
try:
|
|
410
531
|
raw_outputs = self._model.generate(
|
|
411
532
|
prompts=prompts,
|
|
412
533
|
sampling_params=sampling_params,
|
|
413
|
-
use_tqdm=False if input_is_a_test else
|
|
534
|
+
use_tqdm=False if input_is_a_test else get_pbar,
|
|
414
535
|
lora_request=self.buffer.get("lora_request"),
|
|
415
536
|
)
|
|
416
537
|
break
|
|
417
538
|
except TypeError as e:
|
|
418
|
-
|
|
419
|
-
f"Encountered error during vLLM generation: {str(e)}. Retrying..."
|
|
539
|
+
log(
|
|
540
|
+
f"Encountered error during vLLM generation: {str(e)}. Retrying...",
|
|
541
|
+
level=logging.DEBUG,
|
|
420
542
|
)
|
|
421
543
|
sleep(1)
|
|
422
544
|
except ValueError as e:
|
|
@@ -428,26 +550,34 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
428
550
|
re.search(pattern, str(e), flags=re.IGNORECASE) is not None
|
|
429
551
|
for pattern in truncate_error_messages
|
|
430
552
|
):
|
|
431
|
-
|
|
432
|
-
"Prompts are too long, so truncating them and trying again..."
|
|
553
|
+
log(
|
|
554
|
+
"Prompts are too long, so truncating them and trying again...",
|
|
555
|
+
level=logging.WARNING,
|
|
433
556
|
)
|
|
434
|
-
|
|
435
|
-
|
|
557
|
+
log(f"The error message was: {str(e)}", level=logging.DEBUG)
|
|
558
|
+
|
|
559
|
+
# If we have already tried truncating the prompts a few times, then
|
|
560
|
+
# we truncate a bit more aggressively
|
|
561
|
+
extra_truncation = 50 * truncation_attempts
|
|
562
|
+
truncation_attempts += 1
|
|
563
|
+
|
|
564
|
+
tokenized_prompts = self._tokeniser(
|
|
436
565
|
text=prompts,
|
|
437
566
|
truncation=True,
|
|
438
567
|
max_length=max(
|
|
439
|
-
min(self.
|
|
440
|
-
- max_tokens
|
|
568
|
+
min(self._tokeniser.model_max_length, MAX_CONTEXT_LENGTH)
|
|
569
|
+
- max_tokens
|
|
570
|
+
- extra_truncation,
|
|
441
571
|
0,
|
|
442
572
|
),
|
|
443
573
|
)
|
|
444
|
-
prompts = self.
|
|
574
|
+
prompts = self._tokeniser.batch_decode(
|
|
445
575
|
sequences=tokenized_prompts.input_ids, skip_special_tokens=True
|
|
446
576
|
)
|
|
447
577
|
else:
|
|
448
578
|
raise InvalidBenchmark(
|
|
449
579
|
f"An error occurred during vLLM generation: {str(e)}"
|
|
450
|
-
)
|
|
580
|
+
) from e
|
|
451
581
|
else:
|
|
452
582
|
raise InvalidBenchmark(
|
|
453
583
|
f"Could not generate sequences after {num_attempts} attempts."
|
|
@@ -467,34 +597,73 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
467
597
|
f"{num_extra_outputs!r} extra outputs."
|
|
468
598
|
)
|
|
469
599
|
else:
|
|
470
|
-
|
|
600
|
+
log(
|
|
471
601
|
f"Filtered out {num_extra_outputs:,} extra outputs from the model, "
|
|
472
602
|
"which occured as we interupted the generation when we truncated "
|
|
473
|
-
"the prompts."
|
|
603
|
+
"the prompts.",
|
|
604
|
+
level=logging.DEBUG,
|
|
474
605
|
)
|
|
475
606
|
|
|
476
|
-
# Parse the raw model outputs
|
|
477
|
-
|
|
478
|
-
|
|
607
|
+
# Parse the raw model outputs. We keep the special tokens for now, as we need
|
|
608
|
+
# them to potentially remove reasoning content and stop tokens
|
|
609
|
+
completion_ids: c.Sequence[c.Sequence[int]] = [
|
|
610
|
+
list(output.outputs[0].token_ids) for output in raw_outputs
|
|
479
611
|
]
|
|
480
|
-
completions = self.
|
|
612
|
+
completions = self._tokeniser.batch_decode(
|
|
481
613
|
sequences=[
|
|
482
614
|
torch.LongTensor(completion_id) for completion_id in completion_ids
|
|
483
|
-
]
|
|
615
|
+
],
|
|
616
|
+
skip_special_tokens=False,
|
|
484
617
|
)
|
|
485
|
-
if
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
618
|
+
if (
|
|
619
|
+
self.end_of_reasoning_token is not None
|
|
620
|
+
and self.generative_type == GenerativeType.REASONING
|
|
621
|
+
):
|
|
622
|
+
num_samples_without_eor_token = 0
|
|
623
|
+
for idx in range(len(completions)):
|
|
624
|
+
if (
|
|
625
|
+
isinstance(self.end_of_reasoning_token, str)
|
|
626
|
+
and self.end_of_reasoning_token in completions[idx]
|
|
627
|
+
):
|
|
628
|
+
completions[idx] = completions[idx].split(
|
|
629
|
+
self.end_of_reasoning_token
|
|
630
|
+
)[-1]
|
|
631
|
+
elif isinstance(
|
|
632
|
+
self.end_of_reasoning_token, re.Pattern
|
|
633
|
+
) and self.end_of_reasoning_token.search(completions[idx]):
|
|
634
|
+
completions[idx] = self.end_of_reasoning_token.split(
|
|
635
|
+
completions[idx]
|
|
636
|
+
)[-1]
|
|
637
|
+
else:
|
|
638
|
+
num_samples_without_eor_token += 1
|
|
639
|
+
completions[idx] = ""
|
|
640
|
+
if num_samples_without_eor_token > 0:
|
|
641
|
+
log_once(
|
|
642
|
+
f"The model {self.model_config.model_id!r} is a reasoning "
|
|
643
|
+
"model, but the generated output did not contain the end of "
|
|
644
|
+
f"reasoning token ({self.end_of_reasoning_token!r}) in "
|
|
645
|
+
f"{num_samples_without_eor_token:,}/{len(completions):,} of "
|
|
646
|
+
"the samples. Using an empty string for all these samples "
|
|
647
|
+
"instead.",
|
|
648
|
+
level=(
|
|
649
|
+
logging.WARNING
|
|
650
|
+
if num_samples_without_eor_token / len(completions) > 0.5
|
|
651
|
+
else logging.DEBUG
|
|
652
|
+
),
|
|
653
|
+
)
|
|
490
654
|
stop_token_pattern = re.compile(
|
|
491
655
|
"|".join(re.escape(stop_token) for stop_token in stop_tokens)
|
|
492
656
|
)
|
|
493
657
|
completions = [
|
|
494
|
-
re.split(pattern=stop_token_pattern, string=completion)[0]
|
|
658
|
+
re.split(pattern=stop_token_pattern, string=completion)[0].strip()
|
|
495
659
|
for completion in completions
|
|
496
660
|
]
|
|
497
|
-
|
|
661
|
+
|
|
662
|
+
# Remove all the special tokens from the completions, if any are present
|
|
663
|
+
completion_ids = self._tokeniser(text=completions).input_ids
|
|
664
|
+
completions = self._tokeniser.batch_decode(
|
|
665
|
+
sequences=completion_ids, skip_special_tokens=True
|
|
666
|
+
)
|
|
498
667
|
|
|
499
668
|
# Sanity check
|
|
500
669
|
if len(completions) != len(prompts):
|
|
@@ -504,13 +673,13 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
504
673
|
|
|
505
674
|
# Add logprobs scores to the output
|
|
506
675
|
if self.buffer["first_label_token_mapping"]:
|
|
507
|
-
scores:
|
|
676
|
+
scores: c.Sequence[c.Sequence[c.Sequence[tuple[str, float]]]] = [
|
|
508
677
|
[
|
|
509
678
|
[
|
|
510
|
-
(obj.decoded_token, obj.logprob)
|
|
679
|
+
(obj.decoded_token or "", obj.logprob)
|
|
511
680
|
for obj in token_logprobs_dict.values()
|
|
512
681
|
]
|
|
513
|
-
for token_logprobs_dict in raw_output.outputs[0].logprobs
|
|
682
|
+
for token_logprobs_dict in raw_output.outputs[0].logprobs or list()
|
|
514
683
|
]
|
|
515
684
|
for raw_output in raw_outputs
|
|
516
685
|
]
|
|
@@ -543,11 +712,18 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
543
712
|
if using_api:
|
|
544
713
|
return False
|
|
545
714
|
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
715
|
+
model_id_components = split_model_id(model_id=model_id)
|
|
716
|
+
model_id = model_id_components.model_id
|
|
717
|
+
revision = model_id_components.revision
|
|
718
|
+
|
|
549
719
|
model_info = get_model_repo_info(
|
|
550
|
-
model_id=model_id,
|
|
720
|
+
model_id=model_id,
|
|
721
|
+
revision=revision,
|
|
722
|
+
api_key=benchmark_config.api_key,
|
|
723
|
+
cache_dir=benchmark_config.cache_dir,
|
|
724
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
725
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
726
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
551
727
|
)
|
|
552
728
|
return (
|
|
553
729
|
model_info is not None
|
|
@@ -569,11 +745,15 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
569
745
|
Returns:
|
|
570
746
|
The model configuration.
|
|
571
747
|
"""
|
|
572
|
-
|
|
573
|
-
model_id.split("@") if "@" in model_id else (model_id, "main")
|
|
574
|
-
)
|
|
748
|
+
model_id_components = split_model_id(model_id=model_id)
|
|
575
749
|
model_info = get_model_repo_info(
|
|
576
|
-
model_id=model_id,
|
|
750
|
+
model_id=model_id_components.model_id,
|
|
751
|
+
revision=model_id_components.revision,
|
|
752
|
+
api_key=benchmark_config.api_key,
|
|
753
|
+
cache_dir=benchmark_config.cache_dir,
|
|
754
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
755
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
756
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
577
757
|
)
|
|
578
758
|
if model_info is None:
|
|
579
759
|
raise InvalidModel(f"The model {model_id!r} could not be found.")
|
|
@@ -582,8 +762,9 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
582
762
|
language_codes = list(language_mapping.keys())
|
|
583
763
|
|
|
584
764
|
model_config = ModelConfig(
|
|
585
|
-
model_id=model_id,
|
|
586
|
-
revision=revision,
|
|
765
|
+
model_id=model_id_components.model_id,
|
|
766
|
+
revision=model_id_components.revision,
|
|
767
|
+
param=model_id_components.param,
|
|
587
768
|
task=model_info.pipeline_tag,
|
|
588
769
|
languages=[
|
|
589
770
|
language_mapping[tag]
|
|
@@ -603,7 +784,7 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
603
784
|
return model_config
|
|
604
785
|
|
|
605
786
|
@property
|
|
606
|
-
def data_collator(self) -> c.Callable[[
|
|
787
|
+
def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
|
|
607
788
|
"""The data collator used to prepare samples during finetuning.
|
|
608
789
|
|
|
609
790
|
Returns:
|
|
@@ -625,10 +806,10 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
625
806
|
)
|
|
626
807
|
|
|
627
808
|
|
|
628
|
-
def
|
|
809
|
+
def load_model_and_tokeniser(
|
|
629
810
|
model_config: "ModelConfig", benchmark_config: "BenchmarkConfig"
|
|
630
811
|
) -> tuple["LLM", "PreTrainedTokenizer"]:
|
|
631
|
-
"""Load the model and
|
|
812
|
+
"""Load the model and tokeniser.
|
|
632
813
|
|
|
633
814
|
Args:
|
|
634
815
|
model_config:
|
|
@@ -637,7 +818,7 @@ def load_model_and_tokenizer(
|
|
|
637
818
|
The benchmark configuration.
|
|
638
819
|
|
|
639
820
|
Returns:
|
|
640
|
-
A pair (model,
|
|
821
|
+
A pair (model, tokeniser), with the loaded model and tokeniser
|
|
641
822
|
"""
|
|
642
823
|
# Prefer base model ID if the model is an adapter - the adapter will be added on
|
|
643
824
|
# during inference in this case
|
|
@@ -649,8 +830,8 @@ def load_model_and_tokenizer(
|
|
|
649
830
|
hf_model_config = load_hf_model_config(
|
|
650
831
|
model_id=model_id,
|
|
651
832
|
num_labels=0,
|
|
652
|
-
id2label=
|
|
653
|
-
label2id=
|
|
833
|
+
id2label=HashableDict(),
|
|
834
|
+
label2id=HashableDict(),
|
|
654
835
|
revision=revision,
|
|
655
836
|
model_cache_dir=model_config.model_cache_dir,
|
|
656
837
|
api_key=benchmark_config.api_key,
|
|
@@ -675,46 +856,55 @@ def load_model_and_tokenizer(
|
|
|
675
856
|
dtype: str | torch.dtype = "auto"
|
|
676
857
|
|
|
677
858
|
# Choose bf16 over fp16 if the model is a fp32 model and the GPU supports it
|
|
678
|
-
if hf_model_config.
|
|
859
|
+
if hf_model_config.dtype == torch.float32:
|
|
679
860
|
if torch.cuda.is_bf16_supported():
|
|
680
|
-
|
|
861
|
+
log(
|
|
681
862
|
"You are loading a model with dtype FP32, which we will convert to "
|
|
682
863
|
"BF16 as FP32 is not supported by vLLM and BF16 is supported by your "
|
|
683
|
-
"GPU."
|
|
864
|
+
"GPU.",
|
|
865
|
+
level=logging.WARNING,
|
|
684
866
|
)
|
|
685
867
|
dtype = torch.bfloat16
|
|
686
868
|
else:
|
|
687
|
-
|
|
869
|
+
log(
|
|
688
870
|
"You are loading a model with dtype FP32, which we will convert to "
|
|
689
871
|
"FP16 as FP32 is not supported by vLLM and BF16 is not supported by "
|
|
690
|
-
"your GPU."
|
|
872
|
+
"your GPU.",
|
|
873
|
+
level=logging.WARNING,
|
|
691
874
|
)
|
|
692
875
|
dtype = torch.float16
|
|
693
876
|
|
|
694
|
-
# If the model is a quantized model, we need to
|
|
695
|
-
if quantization
|
|
696
|
-
|
|
877
|
+
# If the model is a quantized model, we might need to change the dtype
|
|
878
|
+
if quantization == "mxfp4" and hf_model_config.dtype is None:
|
|
879
|
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
880
|
+
log(
|
|
881
|
+
"You are loading a quantized model where `dtype` has not been set. "
|
|
882
|
+
f"Setting dtype to {dtype!r}.",
|
|
883
|
+
level=logging.DEBUG,
|
|
884
|
+
)
|
|
885
|
+
elif quantization is not None and hf_model_config.dtype != torch.float16:
|
|
886
|
+
log(
|
|
697
887
|
"You are loading a quantized model with dtype "
|
|
698
|
-
f"{hf_model_config.
|
|
699
|
-
"dtype to float16 instead."
|
|
888
|
+
f"{hf_model_config.dtype}, which vLLM does not support. Setting "
|
|
889
|
+
"dtype to float16 instead.",
|
|
890
|
+
level=logging.WARNING,
|
|
700
891
|
)
|
|
701
892
|
dtype = torch.float16
|
|
702
893
|
|
|
703
894
|
# If the model is a bf16 model, we need to check the CUDA compute capability
|
|
704
|
-
if hf_model_config.
|
|
895
|
+
if hf_model_config.dtype == torch.bfloat16:
|
|
705
896
|
min_cuda_compute_capability = get_min_cuda_compute_capability()
|
|
706
897
|
required_capability = VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY
|
|
707
898
|
|
|
708
899
|
if min_cuda_compute_capability is not None:
|
|
709
900
|
if min_cuda_compute_capability < required_capability:
|
|
710
|
-
|
|
711
|
-
"You are loading a model with "
|
|
712
|
-
|
|
713
|
-
"
|
|
714
|
-
f"
|
|
715
|
-
"
|
|
716
|
-
|
|
717
|
-
"Setting dtype to float16 instead."
|
|
901
|
+
log(
|
|
902
|
+
f"You are loading a model with dtype {hf_model_config.dtype}, "
|
|
903
|
+
"which vLLM only supports for CUDA devices with CUDA compute "
|
|
904
|
+
f"capability >={required_capability}. You are using one or more "
|
|
905
|
+
f"devices with compute capability {min_cuda_compute_capability}. "
|
|
906
|
+
"Setting dtype to float16 instead.",
|
|
907
|
+
level=logging.WARNING,
|
|
718
908
|
)
|
|
719
909
|
dtype = torch.float16
|
|
720
910
|
|
|
@@ -741,31 +931,40 @@ def load_model_and_tokenizer(
|
|
|
741
931
|
else:
|
|
742
932
|
true_max_model_len = MAX_CONTEXT_LENGTH
|
|
743
933
|
|
|
744
|
-
|
|
934
|
+
tokeniser = load_tokeniser(
|
|
745
935
|
model_id=model_config.model_id,
|
|
746
936
|
revision=model_config.revision,
|
|
747
937
|
adapter_base_model_id=model_config.adapter_base_model_id,
|
|
748
938
|
trust_remote_code=benchmark_config.trust_remote_code,
|
|
749
939
|
model_max_length=true_max_model_len,
|
|
750
|
-
|
|
751
|
-
token=benchmark_config.api_key
|
|
940
|
+
model_config=model_config,
|
|
941
|
+
token=get_hf_token(api_key=benchmark_config.api_key),
|
|
942
|
+
)
|
|
943
|
+
vllm_tokenisation_params = get_vllm_tokenisation_params(
|
|
944
|
+
tokeniser=tokeniser, model_config=model_config
|
|
752
945
|
)
|
|
753
946
|
|
|
754
947
|
clear_vllm()
|
|
755
948
|
|
|
756
949
|
try:
|
|
757
950
|
model = LLM(
|
|
758
|
-
model=
|
|
759
|
-
|
|
951
|
+
model=(
|
|
952
|
+
model_id
|
|
953
|
+
if internet_connection_available()
|
|
954
|
+
else resolve_model_path(download_dir=download_dir)
|
|
955
|
+
),
|
|
956
|
+
tokenizer=(
|
|
957
|
+
model_id
|
|
958
|
+
if internet_connection_available()
|
|
959
|
+
else resolve_model_path(download_dir=download_dir)
|
|
960
|
+
),
|
|
760
961
|
gpu_memory_utilization=benchmark_config.gpu_memory_utilization,
|
|
761
962
|
max_model_len=min(true_max_model_len, MAX_CONTEXT_LENGTH),
|
|
762
963
|
download_dir=download_dir,
|
|
763
964
|
trust_remote_code=benchmark_config.trust_remote_code,
|
|
764
965
|
revision=revision,
|
|
765
966
|
seed=4242,
|
|
766
|
-
distributed_executor_backend=
|
|
767
|
-
"ray" if torch.cuda.device_count() > 1 else "mp"
|
|
768
|
-
),
|
|
967
|
+
distributed_executor_backend="mp",
|
|
769
968
|
tensor_parallel_size=torch.cuda.device_count(),
|
|
770
969
|
disable_custom_all_reduce=True,
|
|
771
970
|
quantization=quantization,
|
|
@@ -776,38 +975,65 @@ def load_model_and_tokenizer(
|
|
|
776
975
|
enable_prefix_caching=False,
|
|
777
976
|
enable_lora=model_config.adapter_base_model_id is not None,
|
|
778
977
|
max_lora_rank=256,
|
|
978
|
+
**vllm_tokenisation_params,
|
|
779
979
|
)
|
|
780
980
|
except (RuntimeError, ValueError, OSError) as e:
|
|
781
981
|
if "awaiting a review from the repo authors" in str(e):
|
|
782
982
|
raise InvalidModel(
|
|
783
983
|
f"The model {model_id!r} is awaiting a review from the repository "
|
|
784
984
|
"authors. Please try again later."
|
|
785
|
-
)
|
|
985
|
+
) from e
|
|
786
986
|
elif "trust_remote_code" in str(e):
|
|
787
987
|
raise InvalidModel(
|
|
788
988
|
f"Loading the model {model_id!r} needs to trust remote code. "
|
|
789
989
|
"If you trust the suppliers of this model, then you can enable "
|
|
790
990
|
"this by setting the `--trust-remote-code` flag."
|
|
991
|
+
) from e
|
|
992
|
+
elif "See stack trace for root cause." in str(
|
|
993
|
+
e
|
|
994
|
+
) or "See root cause above." in str(e):
|
|
995
|
+
msg = (
|
|
996
|
+
f"The model {model_id!r} could not be loaded, but vLLM did not "
|
|
997
|
+
"mention exactly what happened. "
|
|
998
|
+
)
|
|
999
|
+
msg += (
|
|
1000
|
+
(
|
|
1001
|
+
"Since you're running in verbose mode, you might see a descriptive "
|
|
1002
|
+
"error above already. Note however that if the error message urges "
|
|
1003
|
+
"you to set the environment variable `VLLM_ATTENTION_BACKEND` to "
|
|
1004
|
+
"'FLEX_ATTENTION', please try setting it to 'FLASH_ATTN' first, as "
|
|
1005
|
+
"that often solves the issue, whereas 'FLEX_ATTENTION' usually "
|
|
1006
|
+
"doesn't. If you don't see any descriptive error above, then you "
|
|
1007
|
+
"can try "
|
|
1008
|
+
)
|
|
1009
|
+
if benchmark_config.verbose
|
|
1010
|
+
else "Try "
|
|
1011
|
+
)
|
|
1012
|
+
msg += (
|
|
1013
|
+
"re-running the benchmark with the environment variable `FULL_LOG` "
|
|
1014
|
+
"set to `1` to see the full stack trace. E.g., "
|
|
1015
|
+
f"`FULL_LOG=1 euroeval --model {model_id}`."
|
|
791
1016
|
)
|
|
1017
|
+
raise InvalidModel(msg) from e
|
|
792
1018
|
raise InvalidModel(
|
|
793
1019
|
f"The model {model_id!r} could not be loaded. The error was {e!r}."
|
|
794
|
-
)
|
|
1020
|
+
) from e
|
|
795
1021
|
|
|
796
1022
|
model.config = hf_model_config
|
|
797
1023
|
|
|
798
|
-
return model,
|
|
1024
|
+
return model, tokeniser
|
|
799
1025
|
|
|
800
1026
|
|
|
801
|
-
def
|
|
1027
|
+
def load_tokeniser(
|
|
802
1028
|
model_id: str,
|
|
803
1029
|
revision: str,
|
|
804
1030
|
adapter_base_model_id: str | None,
|
|
805
1031
|
trust_remote_code: bool,
|
|
806
1032
|
model_max_length: int,
|
|
807
|
-
|
|
1033
|
+
model_config: "ModelConfig",
|
|
808
1034
|
token: str | bool,
|
|
809
1035
|
) -> "PreTrainedTokenizer":
|
|
810
|
-
"""Load the
|
|
1036
|
+
"""Load the tokeniser.
|
|
811
1037
|
|
|
812
1038
|
Args:
|
|
813
1039
|
model_id:
|
|
@@ -821,64 +1047,97 @@ def load_tokenizer(
|
|
|
821
1047
|
Whether to trust remote code.
|
|
822
1048
|
model_max_length:
|
|
823
1049
|
The maximum length of the model.
|
|
824
|
-
|
|
825
|
-
The
|
|
1050
|
+
model_config:
|
|
1051
|
+
The model configuration.
|
|
826
1052
|
token:
|
|
827
1053
|
The Hugging Face API token.
|
|
828
1054
|
|
|
829
1055
|
Returns:
|
|
830
|
-
The loaded
|
|
1056
|
+
The loaded tokeniser.
|
|
831
1057
|
"""
|
|
832
1058
|
revision = revision if adapter_base_model_id is None else "main"
|
|
833
1059
|
config = AutoConfig.from_pretrained(
|
|
834
1060
|
adapter_base_model_id or model_id,
|
|
835
1061
|
revision=revision,
|
|
836
|
-
cache_dir=model_cache_dir,
|
|
1062
|
+
cache_dir=model_config.model_cache_dir,
|
|
837
1063
|
token=token,
|
|
838
1064
|
trust_remote_code=trust_remote_code,
|
|
1065
|
+
local_files_only=not internet_connection_available(),
|
|
839
1066
|
)
|
|
840
1067
|
num_retries = 5
|
|
841
1068
|
for _ in range(num_retries):
|
|
842
1069
|
try:
|
|
843
|
-
|
|
1070
|
+
# Mistral instruction-tuned models need a custom tokeniser
|
|
1071
|
+
if model_id.startswith("mistralai/") and "base" not in model_id.lower():
|
|
1072
|
+
tokeniser = MistralCommonTokenizer.from_pretrained(
|
|
1073
|
+
model_id,
|
|
1074
|
+
padding_side="left",
|
|
1075
|
+
truncation_side="left",
|
|
1076
|
+
model_max_length=model_max_length,
|
|
1077
|
+
token=token,
|
|
1078
|
+
)
|
|
1079
|
+
break
|
|
1080
|
+
tokeniser = AutoTokenizer.from_pretrained(
|
|
844
1081
|
model_id,
|
|
845
|
-
|
|
1082
|
+
revision=revision,
|
|
1083
|
+
use_fast=False if model_config.param == "slow-tokenizer" else True,
|
|
846
1084
|
verbose=False,
|
|
847
1085
|
trust_remote_code=trust_remote_code,
|
|
848
1086
|
padding_side="left",
|
|
849
1087
|
truncation_side="left",
|
|
850
1088
|
model_max_length=model_max_length,
|
|
1089
|
+
cache_dir=model_config.model_cache_dir,
|
|
851
1090
|
config=config,
|
|
852
1091
|
token=token,
|
|
1092
|
+
local_files_only=not internet_connection_available(),
|
|
853
1093
|
)
|
|
854
1094
|
break
|
|
855
1095
|
except (json.JSONDecodeError, OSError, TypeError) as e:
|
|
856
1096
|
if adapter_base_model_id is None or model_id == adapter_base_model_id:
|
|
857
1097
|
raise InvalidModel(
|
|
858
|
-
f"Could not load
|
|
1098
|
+
f"Could not load tokeniser for model {model_id!r}. The error was "
|
|
859
1099
|
f"{str(e)}."
|
|
860
|
-
)
|
|
861
|
-
|
|
862
|
-
f"Could not load
|
|
863
|
-
f"{adapter_base_model_id!r}."
|
|
1100
|
+
) from e
|
|
1101
|
+
log(
|
|
1102
|
+
f"Could not load tokeniser for {model_id!r}. Falling back to "
|
|
1103
|
+
f"{adapter_base_model_id!r}.",
|
|
1104
|
+
level=logging.DEBUG,
|
|
864
1105
|
)
|
|
865
1106
|
model_id = adapter_base_model_id
|
|
866
1107
|
except (TimeoutError, RequestError):
|
|
867
|
-
|
|
1108
|
+
log(
|
|
1109
|
+
f"Couldn't load tokeniser for {model_id!r}. Retrying.",
|
|
1110
|
+
level=logging.WARNING,
|
|
1111
|
+
)
|
|
868
1112
|
sleep(5)
|
|
869
1113
|
continue
|
|
1114
|
+
except (KeyError, ValueError) as e:
|
|
1115
|
+
if "mistral" in str(e).lower():
|
|
1116
|
+
tokeniser = MistralCommonTokenizer.from_pretrained(
|
|
1117
|
+
model_id,
|
|
1118
|
+
padding_side="left",
|
|
1119
|
+
truncation_side="left",
|
|
1120
|
+
model_max_length=model_max_length,
|
|
1121
|
+
token=token,
|
|
1122
|
+
)
|
|
1123
|
+
break
|
|
1124
|
+
raise InvalidModel(
|
|
1125
|
+
f"Could not load tokeniser for model {model_id!r}. The error was "
|
|
1126
|
+
f"{str(e)}."
|
|
1127
|
+
) from e
|
|
870
1128
|
else:
|
|
871
1129
|
raise InvalidModel(
|
|
872
|
-
f"Could not load
|
|
1130
|
+
f"Could not load tokeniser for model {model_id!r} after {num_retries} "
|
|
873
1131
|
"attempts."
|
|
874
1132
|
)
|
|
875
1133
|
|
|
876
1134
|
# Ensure that BOS, EOS and PAD tokens are set
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
1135
|
+
if not isinstance(tokeniser, MistralCommonTokenizer):
|
|
1136
|
+
tokeniser.bos_token, tokeniser.bos_token_id = get_bos_token(tokeniser=tokeniser)
|
|
1137
|
+
tokeniser.eos_token, tokeniser.eos_token_id = get_eos_token(tokeniser=tokeniser)
|
|
1138
|
+
tokeniser.pad_token, tokeniser.pad_token_id = get_pad_token(tokeniser=tokeniser)
|
|
880
1139
|
|
|
881
|
-
return
|
|
1140
|
+
return tokeniser
|
|
882
1141
|
|
|
883
1142
|
|
|
884
1143
|
def clear_vllm() -> None:
|
|
@@ -886,80 +1145,93 @@ def clear_vllm() -> None:
|
|
|
886
1145
|
with contextlib.suppress(ValueError):
|
|
887
1146
|
destroy_model_parallel()
|
|
888
1147
|
destroy_distributed_environment()
|
|
889
|
-
if ray.is_initialized():
|
|
890
|
-
ray.shutdown()
|
|
891
1148
|
with contextlib.suppress(AssertionError):
|
|
892
1149
|
torch.distributed.destroy_process_group()
|
|
893
|
-
if ray.is_initialized():
|
|
894
|
-
ray.shutdown()
|
|
895
1150
|
clear_memory()
|
|
896
1151
|
|
|
897
1152
|
|
|
898
1153
|
def get_end_of_reasoning_token(
|
|
899
|
-
model: "LLM",
|
|
900
|
-
) -> str | None:
|
|
1154
|
+
model: "LLM", tokeniser: "PreTrainedTokenizer", model_config: "ModelConfig"
|
|
1155
|
+
) -> str | re.Pattern | None:
|
|
901
1156
|
"""Get the end-of-reasoning token for a generative model.
|
|
902
1157
|
|
|
903
1158
|
Args:
|
|
904
1159
|
model:
|
|
905
1160
|
The vLLM model.
|
|
906
|
-
|
|
907
|
-
The
|
|
908
|
-
|
|
909
|
-
The model
|
|
1161
|
+
tokeniser:
|
|
1162
|
+
The tokeniser.
|
|
1163
|
+
model_config:
|
|
1164
|
+
The model configuration.
|
|
910
1165
|
|
|
911
1166
|
Returns:
|
|
912
1167
|
The end of reasoning token, or None if it could not be found.
|
|
913
1168
|
"""
|
|
1169
|
+
model_id = model_config.model_id
|
|
1170
|
+
|
|
914
1171
|
# Create a prompt to check if the model uses the reasoning tokens
|
|
915
1172
|
prompt = "What is your name?"
|
|
916
|
-
if
|
|
917
|
-
|
|
1173
|
+
if has_chat_template(tokeniser=tokeniser):
|
|
1174
|
+
extra_kwargs = dict()
|
|
1175
|
+
if model_config.param in {"thinking", "no-thinking"}:
|
|
1176
|
+
extra_kwargs["enable_thinking"] = model_config.param == "thinking"
|
|
1177
|
+
templated_prompt = apply_chat_template(
|
|
918
1178
|
conversation=[dict(role="user", content=prompt)],
|
|
1179
|
+
tokeniser=tokeniser,
|
|
1180
|
+
tokenise=False,
|
|
919
1181
|
add_generation_prompt=True,
|
|
920
|
-
|
|
1182
|
+
**extra_kwargs,
|
|
921
1183
|
)
|
|
922
1184
|
assert isinstance(templated_prompt, str)
|
|
923
1185
|
prompt = templated_prompt
|
|
924
1186
|
|
|
925
1187
|
# Check that the beginning-of-reasoning token is actually used by the model
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
use_tqdm=False,
|
|
931
|
-
)[0]
|
|
932
|
-
.outputs[0]
|
|
933
|
-
.text
|
|
934
|
-
)
|
|
1188
|
+
output = model.generate(
|
|
1189
|
+
prompts=[prompt], sampling_params=SamplingParams(max_tokens=10), use_tqdm=False
|
|
1190
|
+
)[0]
|
|
1191
|
+
completion = tokeniser.decode(token_ids=output.outputs[0].token_ids)
|
|
935
1192
|
bor_reasoning_matches = [
|
|
936
1193
|
(bor_token, eor_token)
|
|
937
1194
|
for bor_token, eor_token in REASONING_TOKENS
|
|
938
|
-
if
|
|
1195
|
+
if (
|
|
1196
|
+
(
|
|
1197
|
+
isinstance(bor_token, str)
|
|
1198
|
+
and (bor_token in prompt or bor_token in completion)
|
|
1199
|
+
)
|
|
1200
|
+
or (
|
|
1201
|
+
isinstance(bor_token, re.Pattern)
|
|
1202
|
+
and (
|
|
1203
|
+
bor_token.search(prompt) is not None
|
|
1204
|
+
or bor_token.search(completion) is not None
|
|
1205
|
+
)
|
|
1206
|
+
)
|
|
1207
|
+
)
|
|
939
1208
|
]
|
|
940
1209
|
if not bor_reasoning_matches:
|
|
941
1210
|
log_once(
|
|
942
1211
|
f"The model {model_id!r} did not generate any beginning-of-reasoning "
|
|
943
|
-
"tokens in the prompt or the completion. Assuming the model is not "
|
|
944
|
-
"
|
|
945
|
-
level=logging.
|
|
1212
|
+
"tokens in the prompt or the completion. Assuming the model is not a "
|
|
1213
|
+
"reasoning model.",
|
|
1214
|
+
level=logging.DEBUG,
|
|
946
1215
|
)
|
|
947
1216
|
return None
|
|
948
1217
|
|
|
949
|
-
# Check that the
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
.outputs[0]
|
|
957
|
-
.text
|
|
958
|
-
)
|
|
1218
|
+
# Check that the end-of-reasoning token is actually used by the model
|
|
1219
|
+
output = model.generate(
|
|
1220
|
+
prompts=[prompt],
|
|
1221
|
+
sampling_params=SamplingParams(max_tokens=REASONING_MAX_TOKENS),
|
|
1222
|
+
use_tqdm=False,
|
|
1223
|
+
)[0]
|
|
1224
|
+
completion = tokeniser.decode(token_ids=output.outputs[0].token_ids)
|
|
959
1225
|
eor_reasoning_matches = [
|
|
960
1226
|
(bor_token, eor_token)
|
|
961
1227
|
for bor_token, eor_token in bor_reasoning_matches
|
|
962
|
-
if
|
|
1228
|
+
if (
|
|
1229
|
+
(isinstance(eor_token, str) and eor_token in completion)
|
|
1230
|
+
or (
|
|
1231
|
+
isinstance(eor_token, re.Pattern)
|
|
1232
|
+
and eor_token.search(completion) is not None
|
|
1233
|
+
)
|
|
1234
|
+
)
|
|
963
1235
|
]
|
|
964
1236
|
if not eor_reasoning_matches:
|
|
965
1237
|
log_once(
|
|
@@ -968,7 +1240,7 @@ def get_end_of_reasoning_token(
|
|
|
968
1240
|
"the beginning-of-reasoning tokens "
|
|
969
1241
|
f"{[bor_token for bor_token, _ in bor_reasoning_matches]!r}. "
|
|
970
1242
|
"This is probably not correct, so please report this issue.",
|
|
971
|
-
level=logging.
|
|
1243
|
+
level=logging.WARNING,
|
|
972
1244
|
)
|
|
973
1245
|
return None
|
|
974
1246
|
|
|
@@ -978,14 +1250,21 @@ def get_end_of_reasoning_token(
|
|
|
978
1250
|
f"model {model_id!r}. Using {eor_reasoning_matches[0]!r} as "
|
|
979
1251
|
"the reasoning token. If this is not the correct reasoning token, "
|
|
980
1252
|
"please report this issue.",
|
|
981
|
-
level=logging.
|
|
1253
|
+
level=logging.WARNING,
|
|
982
1254
|
)
|
|
983
1255
|
|
|
984
1256
|
bor_token, eor_token = eor_reasoning_matches[0]
|
|
1257
|
+
|
|
1258
|
+
bor_token_logging: str = (
|
|
1259
|
+
bor_token if isinstance(bor_token, str) else bor_token.pattern
|
|
1260
|
+
)
|
|
1261
|
+
eor_token_logging: str = (
|
|
1262
|
+
eor_token if isinstance(eor_token, str) else eor_token.pattern
|
|
1263
|
+
)
|
|
985
1264
|
log_once(
|
|
986
|
-
f"Detected beginning-of-reasoning token {
|
|
987
|
-
f"token {
|
|
988
|
-
level=logging.
|
|
1265
|
+
f"Detected beginning-of-reasoning token {bor_token_logging!r} and "
|
|
1266
|
+
f"end-of-reasoning token {eor_token_logging!r} for model {model_id!r}.",
|
|
1267
|
+
level=logging.DEBUG,
|
|
989
1268
|
)
|
|
990
1269
|
|
|
991
1270
|
return eor_token
|
|
@@ -993,22 +1272,21 @@ def get_end_of_reasoning_token(
|
|
|
993
1272
|
|
|
994
1273
|
def get_custom_stop_tokens(
|
|
995
1274
|
model: "LLM",
|
|
996
|
-
|
|
1275
|
+
tokeniser: "PreTrainedTokenizer",
|
|
997
1276
|
model_id: str,
|
|
998
|
-
|
|
1277
|
+
generative_type: GenerativeType | None,
|
|
999
1278
|
) -> list[str]:
|
|
1000
1279
|
"""Get the stop tokens for a generative model.
|
|
1001
1280
|
|
|
1002
1281
|
Args:
|
|
1003
1282
|
model:
|
|
1004
1283
|
The vLLM model.
|
|
1005
|
-
|
|
1006
|
-
The
|
|
1284
|
+
tokeniser:
|
|
1285
|
+
The tokeniser.
|
|
1007
1286
|
model_id:
|
|
1008
1287
|
The model ID.
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
of generated tokens to allow before stopping the generation.
|
|
1288
|
+
generative_type:
|
|
1289
|
+
The generative type of the model.
|
|
1012
1290
|
|
|
1013
1291
|
Returns:
|
|
1014
1292
|
A list of stop tokens.
|
|
@@ -1016,25 +1294,26 @@ def get_custom_stop_tokens(
|
|
|
1016
1294
|
candidate_stop_tokens = CUSTOM_STOP_TOKENS
|
|
1017
1295
|
|
|
1018
1296
|
prompt = "Hello"
|
|
1019
|
-
if
|
|
1020
|
-
templated_prompt =
|
|
1297
|
+
if has_chat_template(tokeniser=tokeniser):
|
|
1298
|
+
templated_prompt = apply_chat_template(
|
|
1021
1299
|
conversation=[dict(role="user", content=prompt)],
|
|
1300
|
+
tokeniser=tokeniser,
|
|
1301
|
+
tokenise=False,
|
|
1022
1302
|
add_generation_prompt=True,
|
|
1023
|
-
|
|
1303
|
+
enable_thinking=generative_type == GenerativeType.REASONING,
|
|
1024
1304
|
)
|
|
1025
1305
|
assert isinstance(templated_prompt, str)
|
|
1026
1306
|
prompt = templated_prompt
|
|
1027
1307
|
|
|
1028
|
-
max_tokens =
|
|
1029
|
-
|
|
1030
|
-
model.generate(
|
|
1031
|
-
prompts=[prompt],
|
|
1032
|
-
sampling_params=SamplingParams(max_tokens=max_tokens, temperature=0.0),
|
|
1033
|
-
use_tqdm=False,
|
|
1034
|
-
)[0]
|
|
1035
|
-
.outputs[0]
|
|
1036
|
-
.text
|
|
1308
|
+
max_tokens = (
|
|
1309
|
+
REASONING_MAX_TOKENS if generative_type == GenerativeType.REASONING else 10
|
|
1037
1310
|
)
|
|
1311
|
+
output = model.generate(
|
|
1312
|
+
prompts=[prompt],
|
|
1313
|
+
sampling_params=SamplingParams(max_tokens=max_tokens, temperature=0.0),
|
|
1314
|
+
use_tqdm=False,
|
|
1315
|
+
)[0]
|
|
1316
|
+
completion = tokeniser.decode(token_ids=output.outputs[0].token_ids)
|
|
1038
1317
|
|
|
1039
1318
|
stop_tokens = [
|
|
1040
1319
|
stop_token
|
|
@@ -1042,27 +1321,50 @@ def get_custom_stop_tokens(
|
|
|
1042
1321
|
if stop_token in prompt or stop_token in completion
|
|
1043
1322
|
]
|
|
1044
1323
|
if stop_tokens:
|
|
1045
|
-
|
|
1324
|
+
log(
|
|
1046
1325
|
f"Found the following custom stop tokens for model {model_id!r}: "
|
|
1047
|
-
f"{stop_tokens}."
|
|
1326
|
+
f"{stop_tokens}.",
|
|
1327
|
+
level=logging.DEBUG,
|
|
1048
1328
|
)
|
|
1049
1329
|
else:
|
|
1050
|
-
|
|
1330
|
+
log(f"Found no custom stop tokens for model {model_id!r}.", level=logging.DEBUG)
|
|
1051
1331
|
|
|
1052
1332
|
return stop_tokens
|
|
1053
1333
|
|
|
1054
1334
|
|
|
1055
|
-
def
|
|
1056
|
-
""
|
|
1335
|
+
def get_vllm_tokenisation_params(
|
|
1336
|
+
tokeniser: "PreTrainedTokenizer", model_config: "ModelConfig"
|
|
1337
|
+
) -> dict[str, t.Any]:
|
|
1338
|
+
"""Get the tokenisation parameters for vLLM.
|
|
1057
1339
|
|
|
1058
1340
|
Args:
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1341
|
+
tokeniser:
|
|
1342
|
+
The tokeniser.
|
|
1343
|
+
model_config:
|
|
1344
|
+
The model configuration.
|
|
1063
1345
|
|
|
1064
1346
|
Returns:
|
|
1065
|
-
A
|
|
1347
|
+
A dictionary of tokenisation parameters to pass to vLLM.
|
|
1066
1348
|
"""
|
|
1067
|
-
|
|
1068
|
-
|
|
1349
|
+
if isinstance(tokeniser, MistralCommonTokenizer):
|
|
1350
|
+
tokeniser_mode = "mistral"
|
|
1351
|
+
elif model_config.param == "slow-tokenizer":
|
|
1352
|
+
tokeniser_mode = "slow"
|
|
1353
|
+
else:
|
|
1354
|
+
tokeniser_mode = "auto"
|
|
1355
|
+
|
|
1356
|
+
if isinstance(tokeniser, MistralCommonTokenizer):
|
|
1357
|
+
config_format = "mistral"
|
|
1358
|
+
else:
|
|
1359
|
+
config_format = "auto"
|
|
1360
|
+
|
|
1361
|
+
if isinstance(tokeniser, MistralCommonTokenizer):
|
|
1362
|
+
load_format = "mistral"
|
|
1363
|
+
else:
|
|
1364
|
+
load_format = "auto"
|
|
1365
|
+
|
|
1366
|
+
return dict(
|
|
1367
|
+
tokenizer_mode=tokeniser_mode,
|
|
1368
|
+
config_format=config_format,
|
|
1369
|
+
load_format=load_format,
|
|
1370
|
+
)
|