EuroEval 15.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +72 -0
- euroeval/benchmark_config_factory.py +358 -0
- euroeval/benchmark_modules/__init__.py +7 -0
- euroeval/benchmark_modules/base.py +354 -0
- euroeval/benchmark_modules/fresh.py +286 -0
- euroeval/benchmark_modules/hf.py +1185 -0
- euroeval/benchmark_modules/litellm.py +905 -0
- euroeval/benchmark_modules/vllm.py +1171 -0
- euroeval/benchmarker.py +1074 -0
- euroeval/callbacks.py +72 -0
- euroeval/cli.py +281 -0
- euroeval/constants.py +50 -0
- euroeval/data_loading.py +96 -0
- euroeval/data_models.py +474 -0
- euroeval/dataset_configs.py +2001 -0
- euroeval/enums.py +144 -0
- euroeval/exceptions.py +191 -0
- euroeval/finetuning.py +324 -0
- euroeval/generation.py +296 -0
- euroeval/human_evaluation.py +737 -0
- euroeval/languages.py +200 -0
- euroeval/model_cache.py +253 -0
- euroeval/model_config.py +77 -0
- euroeval/model_loading.py +78 -0
- euroeval/scores.py +90 -0
- euroeval/speed_benchmark.py +124 -0
- euroeval/task_utils/__init__.py +1 -0
- euroeval/task_utils/multiple_choice_classification.py +176 -0
- euroeval/task_utils/question_answering.py +698 -0
- euroeval/task_utils/sequence_classification.py +237 -0
- euroeval/task_utils/text_to_text.py +150 -0
- euroeval/task_utils/token_classification.py +464 -0
- euroeval/tasks.py +202 -0
- euroeval/types.py +97 -0
- euroeval/utils.py +574 -0
- euroeval-15.2.0.dist-info/METADATA +234 -0
- euroeval-15.2.0.dist-info/RECORD +40 -0
- euroeval-15.2.0.dist-info/WHEEL +4 -0
- euroeval-15.2.0.dist-info/entry_points.txt +4 -0
- euroeval-15.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,905 @@
|
|
|
1
|
+
"""Generative models from an inference API, using the LiteLLM framework."""
|
|
2
|
+
|
|
3
|
+
import collections.abc as c
|
|
4
|
+
import itertools as it
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import random
|
|
9
|
+
import re
|
|
10
|
+
import typing as t
|
|
11
|
+
from functools import cached_property, partial
|
|
12
|
+
from time import sleep
|
|
13
|
+
|
|
14
|
+
import litellm
|
|
15
|
+
from datasets import DatasetDict
|
|
16
|
+
from huggingface_hub import HfApi
|
|
17
|
+
from huggingface_hub.errors import (
|
|
18
|
+
HFValidationError,
|
|
19
|
+
RepositoryNotFoundError,
|
|
20
|
+
RevisionNotFoundError,
|
|
21
|
+
)
|
|
22
|
+
from litellm.exceptions import (
|
|
23
|
+
APIConnectionError,
|
|
24
|
+
APIError,
|
|
25
|
+
AuthenticationError,
|
|
26
|
+
BadRequestError,
|
|
27
|
+
InternalServerError,
|
|
28
|
+
NotFoundError,
|
|
29
|
+
ServiceUnavailableError,
|
|
30
|
+
Timeout,
|
|
31
|
+
)
|
|
32
|
+
from litellm.types.utils import ModelResponse
|
|
33
|
+
from requests.exceptions import RequestException
|
|
34
|
+
from transformers import Trainer
|
|
35
|
+
|
|
36
|
+
from ..constants import (
|
|
37
|
+
MAX_LOGPROBS,
|
|
38
|
+
REASONING_MAX_TOKENS,
|
|
39
|
+
TASK_GROUPS_USING_LOGPROBS,
|
|
40
|
+
TASKS_USING_JSON,
|
|
41
|
+
)
|
|
42
|
+
from ..data_models import BenchmarkConfig, GenerativeModelOutput, ModelConfig, Task
|
|
43
|
+
from ..enums import (
|
|
44
|
+
BatchingPreference,
|
|
45
|
+
GenerativeType,
|
|
46
|
+
InferenceBackend,
|
|
47
|
+
ModelType,
|
|
48
|
+
TaskGroup,
|
|
49
|
+
)
|
|
50
|
+
from ..exceptions import (
|
|
51
|
+
InvalidBenchmark,
|
|
52
|
+
NeedsAdditionalArgument,
|
|
53
|
+
NeedsEnvironmentVariable,
|
|
54
|
+
NeedsExtraInstalled,
|
|
55
|
+
)
|
|
56
|
+
from ..task_utils import (
|
|
57
|
+
question_answering,
|
|
58
|
+
sequence_classification,
|
|
59
|
+
text_to_text,
|
|
60
|
+
token_classification,
|
|
61
|
+
)
|
|
62
|
+
from ..types import ExtractLabelsFunction
|
|
63
|
+
from ..utils import create_model_cache_dir
|
|
64
|
+
from .base import BenchmarkModule
|
|
65
|
+
from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokenizer
|
|
66
|
+
|
|
67
|
+
logger = logging.getLogger("euroeval")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
VOCAB_SIZE_MAPPING = {
|
|
71
|
+
# OpenAI models
|
|
72
|
+
"(text-)?(ada|babbage|curie|davinci)(-001)?": 50_257,
|
|
73
|
+
"(code|text)-davinci-00[2-9]": 50_281,
|
|
74
|
+
"gpt-3.5-turbo(-16k)?(-[0-9]{4})?": 100_256,
|
|
75
|
+
"gpt-4-(32k)?(-[0-9]{4})?": 100_256,
|
|
76
|
+
"gpt-4-[0-9]{4}-preview": 100_256,
|
|
77
|
+
"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
|
|
78
|
+
"gpt-4-(vision|turbo)(-preview)?": 100_256,
|
|
79
|
+
"gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
|
|
80
|
+
"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
|
|
81
|
+
"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
82
|
+
# Anthropic models
|
|
83
|
+
"claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
MODEL_MAX_LENGTH_MAPPING = {
|
|
88
|
+
# OpenAI models
|
|
89
|
+
"(text-)?(ada|babbage|curie|davinci)(-001)?": 2_050,
|
|
90
|
+
"text-davinci-00[2-9]": 4_098,
|
|
91
|
+
"code-davinci-00[1-9]": 8_002,
|
|
92
|
+
"gpt-3.5-turbo-0613": 4_096,
|
|
93
|
+
"gpt-3.5-turbo(-[0-9]{4})?": 16_385,
|
|
94
|
+
"gpt-3.5-turbo-16k(-[0-9]{4})?": 16_384,
|
|
95
|
+
"gpt-4(-[0-9]{4})?": 8_191,
|
|
96
|
+
"gpt-4-32k(-[0-9]{4})?": 32_767,
|
|
97
|
+
"gpt-4-[0-9]{4}-preview": 128_000,
|
|
98
|
+
"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
99
|
+
"gpt-4-(vision|turbo)(-preview)?": 128_000,
|
|
100
|
+
"gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
|
|
101
|
+
"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
102
|
+
"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
103
|
+
"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
104
|
+
"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
105
|
+
# Anthropic models
|
|
106
|
+
"claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
NUM_PARAMS_MAPPING = {
|
|
111
|
+
# OpenAI models
|
|
112
|
+
"(text-)?ada(-001)?": 350_000_000,
|
|
113
|
+
"(text-)?babbage(-001)?": 3_000_000_000,
|
|
114
|
+
"(text-)?curie(-001)?": 13_000_000_000,
|
|
115
|
+
"((text|code)-)?davinci(-00[1-9])?": 175_000_000_000,
|
|
116
|
+
"gpt-(3.5|4)-turbo-((16|32)k)?(-[0-9]{4})?": -1,
|
|
117
|
+
"gpt-4-[0-9]{4}-preview": -1,
|
|
118
|
+
"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
119
|
+
"gpt-4-(vision|turbo)(-preview)?": -1,
|
|
120
|
+
"gpt-3.5-turbo-instruct(-[0-9]{4})?": -1,
|
|
121
|
+
"gpt-4o(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
122
|
+
"gpt-4o-mini(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
123
|
+
"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
|
|
124
|
+
# Anthropic models
|
|
125
|
+
"claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": -1,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
REASONING_MODELS = ["o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?"]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class LiteLLMModel(BenchmarkModule):
|
|
133
|
+
"""A generative model from LiteLLM."""
|
|
134
|
+
|
|
135
|
+
fresh_model = False
|
|
136
|
+
batching_preference = BatchingPreference.SINGLE_SAMPLE
|
|
137
|
+
high_priority = False
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def generative_type(self) -> GenerativeType | None:
|
|
141
|
+
"""Get the generative type of the model.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
The generative type of the model, or None if it has not been set yet.
|
|
145
|
+
"""
|
|
146
|
+
if re.fullmatch(
|
|
147
|
+
pattern="|".join(REASONING_MODELS), string=self.model_config.model_id
|
|
148
|
+
):
|
|
149
|
+
return GenerativeType.REASONING
|
|
150
|
+
else:
|
|
151
|
+
return GenerativeType.INSTRUCTION_TUNED
|
|
152
|
+
|
|
153
|
+
def generate(self, inputs: dict) -> GenerativeModelOutput:
|
|
154
|
+
"""Generate outputs from the model.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
inputs:
|
|
158
|
+
A batch of inputs to pass through the model.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
The generated model outputs.
|
|
162
|
+
"""
|
|
163
|
+
assert "messages" in inputs, "The input must contain a 'messages' key."
|
|
164
|
+
assert len(inputs["messages"]) == 1, (
|
|
165
|
+
"API models only support single-sample batching."
|
|
166
|
+
)
|
|
167
|
+
messages = inputs["messages"][0]
|
|
168
|
+
|
|
169
|
+
generation_kwargs: dict[str, t.Any] = dict(
|
|
170
|
+
model=self.model_config.model_id,
|
|
171
|
+
max_completion_tokens=(
|
|
172
|
+
REASONING_MAX_TOKENS
|
|
173
|
+
if self.generative_type == GenerativeType.REASONING
|
|
174
|
+
else self.dataset_config.max_generated_tokens
|
|
175
|
+
),
|
|
176
|
+
stop=[],
|
|
177
|
+
temperature=0.0,
|
|
178
|
+
seed=4242,
|
|
179
|
+
api_key=self.benchmark_config.api_key,
|
|
180
|
+
api_base=self.benchmark_config.api_base,
|
|
181
|
+
api_version=self.benchmark_config.api_version,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
if self.dataset_config.task.task_group in TASK_GROUPS_USING_LOGPROBS:
|
|
185
|
+
generation_kwargs["logprobs"] = True
|
|
186
|
+
generation_kwargs["top_logprobs"] = MAX_LOGPROBS
|
|
187
|
+
|
|
188
|
+
if self.dataset_config.task in TASKS_USING_JSON:
|
|
189
|
+
assert "json" in messages[0]["content"].lower(), (
|
|
190
|
+
"Prompt must contain 'json' for JSON tasks."
|
|
191
|
+
)
|
|
192
|
+
generation_kwargs["response_format"] = dict(type="json_object")
|
|
193
|
+
|
|
194
|
+
# This drops generation kwargs that are not supported by the model
|
|
195
|
+
litellm.drop_params = True
|
|
196
|
+
|
|
197
|
+
# Extract the generated sequences from the model response. Some APIs cannot
|
|
198
|
+
# handle using newlines as stop sequences, so we try both.
|
|
199
|
+
num_attempts = 10
|
|
200
|
+
for _ in range(num_attempts):
|
|
201
|
+
try:
|
|
202
|
+
model_response = litellm.completion(
|
|
203
|
+
messages=messages, max_retries=3, **generation_kwargs
|
|
204
|
+
)
|
|
205
|
+
break
|
|
206
|
+
except BadRequestError as e:
|
|
207
|
+
if "stop_sequences" in str(e).lower():
|
|
208
|
+
generation_kwargs["stop"] = None
|
|
209
|
+
elif "you are not allowed to request logprobs" in str(e).lower():
|
|
210
|
+
generation_kwargs.pop("logprobs")
|
|
211
|
+
generation_kwargs.pop("top_logprobs")
|
|
212
|
+
elif (
|
|
213
|
+
"'temperature' is not supported with this model." in str(e).lower()
|
|
214
|
+
):
|
|
215
|
+
generation_kwargs.pop("temperature")
|
|
216
|
+
else:
|
|
217
|
+
raise InvalidBenchmark(
|
|
218
|
+
f"Failed to generate text. The error message was: {e}"
|
|
219
|
+
)
|
|
220
|
+
except (
|
|
221
|
+
Timeout,
|
|
222
|
+
ServiceUnavailableError,
|
|
223
|
+
APIConnectionError,
|
|
224
|
+
InternalServerError,
|
|
225
|
+
):
|
|
226
|
+
logger.debug(
|
|
227
|
+
"Service temporarily unavailable. Retrying in 5 seconds..."
|
|
228
|
+
)
|
|
229
|
+
sleep(5)
|
|
230
|
+
except APIError as e:
|
|
231
|
+
raise InvalidBenchmark(
|
|
232
|
+
f"Failed to generate text. The error message was: {e}"
|
|
233
|
+
)
|
|
234
|
+
except AuthenticationError:
|
|
235
|
+
raise NeedsAdditionalArgument(
|
|
236
|
+
cli_argument="--api-key",
|
|
237
|
+
script_argument="api_key=<your-api-key>",
|
|
238
|
+
run_with_cli=self.benchmark_config.run_with_cli,
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
raise InvalidBenchmark(
|
|
242
|
+
message=f"Failed to generate text, after {num_attempts} attempts."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
assert isinstance(model_response, ModelResponse)
|
|
246
|
+
model_response_choices = model_response.choices[0]
|
|
247
|
+
assert isinstance(model_response_choices, litellm.Choices)
|
|
248
|
+
generation_output = model_response_choices.message["content"] or ""
|
|
249
|
+
generation_output = generation_output.strip()
|
|
250
|
+
|
|
251
|
+
# Structure the model output as a GenerativeModelOutput object
|
|
252
|
+
model_output = GenerativeModelOutput(sequences=[generation_output])
|
|
253
|
+
if hasattr(model_response_choices, "logprobs"):
|
|
254
|
+
logprobs_list: list[list[tuple[str, float]]] = [
|
|
255
|
+
[
|
|
256
|
+
(top_logprob.token, top_logprob.logprob)
|
|
257
|
+
for top_logprob in content.top_logprobs
|
|
258
|
+
]
|
|
259
|
+
for content in model_response_choices.logprobs.content or list()
|
|
260
|
+
]
|
|
261
|
+
model_output.scores = [logprobs_list]
|
|
262
|
+
|
|
263
|
+
return model_output
|
|
264
|
+
|
|
265
|
+
@cached_property
|
|
266
|
+
def num_params(self) -> int:
|
|
267
|
+
"""The number of parameters in the model.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
The number of parameters in the model.
|
|
271
|
+
"""
|
|
272
|
+
for key, value in NUM_PARAMS_MAPPING.items():
|
|
273
|
+
if re.fullmatch(pattern=key, string=self.model_config.model_id) is not None:
|
|
274
|
+
return value
|
|
275
|
+
|
|
276
|
+
if self.model_config.model_id.startswith("huggingface/"):
|
|
277
|
+
model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
|
|
278
|
+
if HuggingFaceEncoderModel.model_exists(
|
|
279
|
+
model_id=model_id, benchmark_config=self.benchmark_config
|
|
280
|
+
):
|
|
281
|
+
hf_config = load_hf_model_config(
|
|
282
|
+
model_id=model_id,
|
|
283
|
+
num_labels=self.dataset_config.num_labels,
|
|
284
|
+
id2label=self.dataset_config.id2label,
|
|
285
|
+
label2id=self.dataset_config.label2id,
|
|
286
|
+
revision=self.model_config.revision,
|
|
287
|
+
model_cache_dir=self.model_config.model_cache_dir,
|
|
288
|
+
api_key=self.benchmark_config.api_key,
|
|
289
|
+
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
290
|
+
run_with_cli=self.benchmark_config.run_with_cli,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
hf_api = HfApi()
|
|
294
|
+
try:
|
|
295
|
+
repo_info = hf_api.model_info(
|
|
296
|
+
repo_id=model_id,
|
|
297
|
+
revision=self.model_config.revision,
|
|
298
|
+
token=os.getenv("HUGGINGFACE_API_KEY")
|
|
299
|
+
or self.benchmark_config.api_key
|
|
300
|
+
or True,
|
|
301
|
+
)
|
|
302
|
+
except (
|
|
303
|
+
RepositoryNotFoundError,
|
|
304
|
+
RevisionNotFoundError,
|
|
305
|
+
RequestException,
|
|
306
|
+
HFValidationError,
|
|
307
|
+
):
|
|
308
|
+
repo_info = None
|
|
309
|
+
|
|
310
|
+
if (
|
|
311
|
+
repo_info is not None
|
|
312
|
+
and hasattr(repo_info, "safetensors")
|
|
313
|
+
and repo_info.safetensors is not None
|
|
314
|
+
and "total" in repo_info.safetensors
|
|
315
|
+
):
|
|
316
|
+
return repo_info.safetensors["total"]
|
|
317
|
+
elif (
|
|
318
|
+
hasattr(hf_config, "num_params")
|
|
319
|
+
and hf_config.num_params is not None
|
|
320
|
+
):
|
|
321
|
+
return hf_config.num_params
|
|
322
|
+
|
|
323
|
+
return -1
|
|
324
|
+
|
|
325
|
+
@cached_property
|
|
326
|
+
def vocab_size(self) -> int:
|
|
327
|
+
"""The vocabulary size of the model.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
The vocabulary size of the model.
|
|
331
|
+
"""
|
|
332
|
+
for key, value in VOCAB_SIZE_MAPPING.items():
|
|
333
|
+
if re.fullmatch(pattern=key, string=self.model_config.model_id) is not None:
|
|
334
|
+
return value
|
|
335
|
+
|
|
336
|
+
if self.model_config.model_id.startswith("huggingface/"):
|
|
337
|
+
model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
|
|
338
|
+
if HuggingFaceEncoderModel.model_exists(
|
|
339
|
+
model_id=model_id, benchmark_config=self.benchmark_config
|
|
340
|
+
):
|
|
341
|
+
hf_config = load_hf_model_config(
|
|
342
|
+
model_id=model_id,
|
|
343
|
+
num_labels=self.dataset_config.num_labels,
|
|
344
|
+
id2label=self.dataset_config.id2label,
|
|
345
|
+
label2id=self.dataset_config.label2id,
|
|
346
|
+
revision=self.model_config.revision,
|
|
347
|
+
model_cache_dir=self.model_config.model_cache_dir,
|
|
348
|
+
api_key=self.benchmark_config.api_key,
|
|
349
|
+
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
350
|
+
run_with_cli=self.benchmark_config.run_with_cli,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
tokenizer = load_tokenizer(
|
|
354
|
+
model=None,
|
|
355
|
+
model_id=model_id,
|
|
356
|
+
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if (
|
|
360
|
+
hasattr(hf_config, "vocab_size")
|
|
361
|
+
and hf_config.vocab_size is not None
|
|
362
|
+
):
|
|
363
|
+
vocab_size = hf_config.vocab_size
|
|
364
|
+
elif (
|
|
365
|
+
hasattr(tokenizer, "vocab_size")
|
|
366
|
+
and tokenizer.vocab_size is not None
|
|
367
|
+
):
|
|
368
|
+
vocab_size = tokenizer.vocab_size
|
|
369
|
+
else:
|
|
370
|
+
vocab_size = -1
|
|
371
|
+
return vocab_size
|
|
372
|
+
|
|
373
|
+
return -1
|
|
374
|
+
|
|
375
|
+
@cached_property
|
|
376
|
+
def model_max_length(self) -> int:
|
|
377
|
+
"""The maximum length of the model.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
The maximum length of the model.
|
|
381
|
+
"""
|
|
382
|
+
for key, value in MODEL_MAX_LENGTH_MAPPING.items():
|
|
383
|
+
if re.fullmatch(pattern=key, string=self.model_config.model_id) is not None:
|
|
384
|
+
return value
|
|
385
|
+
|
|
386
|
+
if self.model_config.model_id.startswith("huggingface/"):
|
|
387
|
+
model_id = self.model_config.model_id.split(sep="/", maxsplit=1)[-1]
|
|
388
|
+
if HuggingFaceEncoderModel.model_exists(
|
|
389
|
+
model_id=model_id, benchmark_config=self.benchmark_config
|
|
390
|
+
):
|
|
391
|
+
hf_config = load_hf_model_config(
|
|
392
|
+
model_id=model_id,
|
|
393
|
+
num_labels=self.dataset_config.num_labels,
|
|
394
|
+
id2label=self.dataset_config.id2label,
|
|
395
|
+
label2id=self.dataset_config.label2id,
|
|
396
|
+
revision=self.model_config.revision,
|
|
397
|
+
model_cache_dir=self.model_config.model_cache_dir,
|
|
398
|
+
api_key=self.benchmark_config.api_key,
|
|
399
|
+
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
400
|
+
run_with_cli=self.benchmark_config.run_with_cli,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
tokenizer = load_tokenizer(
|
|
404
|
+
model=None,
|
|
405
|
+
model_id=model_id,
|
|
406
|
+
trust_remote_code=self.benchmark_config.trust_remote_code,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
all_max_lengths: list[int] = list()
|
|
410
|
+
|
|
411
|
+
# Add the registered max length of the tokenizer
|
|
412
|
+
if hasattr(
|
|
413
|
+
tokenizer, "model_max_length"
|
|
414
|
+
) and tokenizer.model_max_length < int(1e30):
|
|
415
|
+
all_max_lengths.append(tokenizer.model_max_length)
|
|
416
|
+
|
|
417
|
+
# Add the max length derived from the model's input sizes
|
|
418
|
+
if hasattr(tokenizer, "max_model_input_sizes"):
|
|
419
|
+
all_max_lengths.extend(
|
|
420
|
+
[
|
|
421
|
+
size
|
|
422
|
+
for size in tokenizer.max_model_input_sizes.values()
|
|
423
|
+
if size is not None
|
|
424
|
+
]
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# Add max length candidates from the model's configuration
|
|
428
|
+
candidate_config_max_lengths = [
|
|
429
|
+
"max_position_embeddings",
|
|
430
|
+
"max_sequence_length",
|
|
431
|
+
"model_max_length",
|
|
432
|
+
"sliding_window",
|
|
433
|
+
"sliding_window_size",
|
|
434
|
+
"n_positions",
|
|
435
|
+
]
|
|
436
|
+
for candidate_config_max_length in candidate_config_max_lengths:
|
|
437
|
+
if (
|
|
438
|
+
hasattr(hf_config, candidate_config_max_length)
|
|
439
|
+
and (value := getattr(hf_config, candidate_config_max_length))
|
|
440
|
+
is not None
|
|
441
|
+
):
|
|
442
|
+
all_max_lengths.append(value)
|
|
443
|
+
|
|
444
|
+
# To avoid models having artificially low max lengths, we remove any max
|
|
445
|
+
# lengths that are less than 128
|
|
446
|
+
all_max_lengths = [
|
|
447
|
+
max_length for max_length in all_max_lengths if max_length >= 128
|
|
448
|
+
]
|
|
449
|
+
|
|
450
|
+
if len(list(all_max_lengths)) > 0:
|
|
451
|
+
return min(list(all_max_lengths))
|
|
452
|
+
|
|
453
|
+
return -1
|
|
454
|
+
|
|
455
|
+
@property
|
|
456
|
+
def data_collator(self) -> c.Callable[[list[t.Any]], dict[str, t.Any]]:
|
|
457
|
+
"""The data collator used to prepare samples during finetuning.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
The data collator.
|
|
461
|
+
"""
|
|
462
|
+
raise NotImplementedError(
|
|
463
|
+
"The `data_collator` property has not been implemented for LiteLLM models."
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def extract_labels_from_generation(self) -> ExtractLabelsFunction:
|
|
468
|
+
"""The function used to extract the labels from the generated output.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
The function used to extract the labels from the generated output.
|
|
472
|
+
"""
|
|
473
|
+
match self.dataset_config.task.task_group:
|
|
474
|
+
case (
|
|
475
|
+
TaskGroup.SEQUENCE_CLASSIFICATION
|
|
476
|
+
| TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
|
|
477
|
+
):
|
|
478
|
+
return partial(
|
|
479
|
+
sequence_classification.extract_labels_from_generation,
|
|
480
|
+
dataset_config=self.dataset_config,
|
|
481
|
+
)
|
|
482
|
+
case TaskGroup.TEXT_TO_TEXT:
|
|
483
|
+
return text_to_text.extract_labels_from_generation
|
|
484
|
+
case TaskGroup.TOKEN_CLASSIFICATION:
|
|
485
|
+
return partial(
|
|
486
|
+
token_classification.extract_labels_from_generation,
|
|
487
|
+
dataset_config=self.dataset_config,
|
|
488
|
+
)
|
|
489
|
+
case TaskGroup.QUESTION_ANSWERING:
|
|
490
|
+
return question_answering.extract_labels_from_generation
|
|
491
|
+
case _:
|
|
492
|
+
raise NotImplementedError(
|
|
493
|
+
f"Unsupported task group: {self.dataset_config.task.task_group}."
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
@property
|
|
497
|
+
def trainer_class(self) -> t.Type["Trainer"]:
|
|
498
|
+
"""The Trainer class to use for finetuning.
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
The Trainer class.
|
|
502
|
+
"""
|
|
503
|
+
raise NotImplementedError(
|
|
504
|
+
"The `trainer_class` property has not been implemented for LiteLLM models."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
@classmethod
|
|
508
|
+
def model_exists(
|
|
509
|
+
cls, model_id: str, benchmark_config: BenchmarkConfig
|
|
510
|
+
) -> bool | NeedsExtraInstalled | NeedsEnvironmentVariable:
|
|
511
|
+
"""Check if a model exists.
|
|
512
|
+
|
|
513
|
+
Args:
|
|
514
|
+
model_id:
|
|
515
|
+
The model ID.
|
|
516
|
+
benchmark_config:
|
|
517
|
+
The benchmark configuration.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
Whether the model exists, or an error describing why we cannot check
|
|
521
|
+
whether the model exists.
|
|
522
|
+
"""
|
|
523
|
+
if model_id in litellm.model_list:
|
|
524
|
+
return True
|
|
525
|
+
|
|
526
|
+
num_attempts = 10
|
|
527
|
+
for _ in range(num_attempts):
|
|
528
|
+
try:
|
|
529
|
+
litellm.completion(
|
|
530
|
+
messages=[dict(role="user", content="X")],
|
|
531
|
+
model=model_id,
|
|
532
|
+
max_tokens=1,
|
|
533
|
+
api_key=benchmark_config.api_key,
|
|
534
|
+
api_base=benchmark_config.api_base,
|
|
535
|
+
api_version=benchmark_config.api_version,
|
|
536
|
+
)
|
|
537
|
+
return True
|
|
538
|
+
except APIError as e:
|
|
539
|
+
if "'503 Service Unavailable" not in str(e):
|
|
540
|
+
raise e
|
|
541
|
+
logger.warning(
|
|
542
|
+
f"Failed to check if model {model_id!r} exists. Retrying in "
|
|
543
|
+
f"{num_attempts} seconds..."
|
|
544
|
+
)
|
|
545
|
+
sleep(10)
|
|
546
|
+
except (BadRequestError, NotFoundError):
|
|
547
|
+
candidate_models = [
|
|
548
|
+
candidate_model_id
|
|
549
|
+
for candidate_model_id in litellm.model_list
|
|
550
|
+
if candidate_model_id.startswith(model_id)
|
|
551
|
+
]
|
|
552
|
+
match len(candidate_models):
|
|
553
|
+
case 0:
|
|
554
|
+
pass
|
|
555
|
+
case 1:
|
|
556
|
+
logger.warning(
|
|
557
|
+
f"Could not find the model ID {model_id!r}. Did you mean "
|
|
558
|
+
f"{candidate_models[0]!r}?"
|
|
559
|
+
)
|
|
560
|
+
case _:
|
|
561
|
+
candidate_models_str = "', '".join(candidate_models)
|
|
562
|
+
logger.warning(
|
|
563
|
+
f"Could not find the model ID {model_id!r}. Did you mean "
|
|
564
|
+
f"any of the following model IDs: '{candidate_models_str}'?"
|
|
565
|
+
)
|
|
566
|
+
return False
|
|
567
|
+
else:
|
|
568
|
+
logger.error(
|
|
569
|
+
f"Failed to check if model {model_id!r} exists after {num_attempts} "
|
|
570
|
+
"attempts. Assuming it does not exist."
|
|
571
|
+
)
|
|
572
|
+
return False
|
|
573
|
+
|
|
574
|
+
@classmethod
|
|
575
|
+
def get_model_config(
|
|
576
|
+
cls, model_id: str, benchmark_config: BenchmarkConfig
|
|
577
|
+
) -> ModelConfig:
|
|
578
|
+
"""Fetch the model configuration.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
model_id:
|
|
582
|
+
The model ID.
|
|
583
|
+
benchmark_config:
|
|
584
|
+
The benchmark configuration.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
The model configuration.
|
|
588
|
+
"""
|
|
589
|
+
return ModelConfig(
|
|
590
|
+
model_id=model_id,
|
|
591
|
+
revision="main",
|
|
592
|
+
task="text-generation",
|
|
593
|
+
languages=list(),
|
|
594
|
+
merge=False,
|
|
595
|
+
inference_backend=InferenceBackend.LITELLM,
|
|
596
|
+
model_type=ModelType.GENERATIVE,
|
|
597
|
+
fresh=False,
|
|
598
|
+
model_cache_dir=create_model_cache_dir(
|
|
599
|
+
cache_dir=benchmark_config.cache_dir, model_id=model_id
|
|
600
|
+
),
|
|
601
|
+
adapter_base_model_id=None,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
def prepare_dataset(
|
|
605
|
+
self, dataset: DatasetDict, task: Task, itr_idx: int
|
|
606
|
+
) -> DatasetDict:
|
|
607
|
+
"""Prepare the dataset for the model.
|
|
608
|
+
|
|
609
|
+
This includes things like tokenisation.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
dataset:
|
|
613
|
+
The dataset to prepare.
|
|
614
|
+
task:
|
|
615
|
+
The task to prepare the dataset for.
|
|
616
|
+
itr_idx:
|
|
617
|
+
The index of the dataset in the iterator.
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
The prepared dataset.
|
|
621
|
+
"""
|
|
622
|
+
if task.task_group == TaskGroup.QUESTION_ANSWERING:
|
|
623
|
+
dataset = dataset.map(
|
|
624
|
+
lambda examples: dict(
|
|
625
|
+
label=[
|
|
626
|
+
dict(
|
|
627
|
+
id=id,
|
|
628
|
+
answers=dict(
|
|
629
|
+
answer_start=answer_dct["answer_start"],
|
|
630
|
+
text=[
|
|
631
|
+
answer_text.lower()
|
|
632
|
+
for answer_text in answer_dct["text"]
|
|
633
|
+
],
|
|
634
|
+
),
|
|
635
|
+
)
|
|
636
|
+
for id, answer_dct in zip(examples["id"], examples["answers"])
|
|
637
|
+
]
|
|
638
|
+
),
|
|
639
|
+
batched=True,
|
|
640
|
+
load_from_cache_file=False,
|
|
641
|
+
keep_in_memory=True,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
if self.benchmark_config.few_shot:
|
|
645
|
+
few_shot_examples = self._extract_few_shot_examples(
|
|
646
|
+
dataset=dataset, task=task, itr_idx=itr_idx
|
|
647
|
+
)
|
|
648
|
+
else:
|
|
649
|
+
few_shot_examples = list()
|
|
650
|
+
|
|
651
|
+
dataset["test"] = dataset["test"].map(
|
|
652
|
+
partial(self._apply_prompt, few_shot_examples=few_shot_examples, task=task),
|
|
653
|
+
batched=True,
|
|
654
|
+
load_from_cache_file=False,
|
|
655
|
+
keep_in_memory=True,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
return dataset
|
|
659
|
+
|
|
660
|
+
def _extract_few_shot_examples(
|
|
661
|
+
self, dataset: DatasetDict, task: Task, itr_idx: int
|
|
662
|
+
) -> list[dict[str, t.Any]]:
|
|
663
|
+
"""Extract few-shot examples from a dataset.
|
|
664
|
+
|
|
665
|
+
This will always extract the examples from the training split.
|
|
666
|
+
|
|
667
|
+
We ensure that the few-shot examples are unique by picking them one at a time.
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
dataset:
|
|
671
|
+
The dataset to extract the few-shot examples from.
|
|
672
|
+
task:
|
|
673
|
+
The task that is being benchmarked.
|
|
674
|
+
itr_idx:
|
|
675
|
+
The index of the dataset in the iterator.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
The few-shot examples.
|
|
679
|
+
"""
|
|
680
|
+
random_seed = 4242 + itr_idx
|
|
681
|
+
num_few_shots = self.dataset_config.num_few_shot_examples
|
|
682
|
+
few_shot_examples: list[dict[str, t.Any]] = list()
|
|
683
|
+
shuffled_train = dataset["train"].shuffle(seed=random_seed)
|
|
684
|
+
|
|
685
|
+
match task.task_group:
|
|
686
|
+
case (
|
|
687
|
+
TaskGroup.SEQUENCE_CLASSIFICATION
|
|
688
|
+
| TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
|
|
689
|
+
):
|
|
690
|
+
labels = it.cycle(self.dataset_config.labels)
|
|
691
|
+
while (
|
|
692
|
+
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
693
|
+
):
|
|
694
|
+
label = next(labels)
|
|
695
|
+
possible_examples = shuffled_train.filter(
|
|
696
|
+
lambda x: x["label"].lower() == label.lower()
|
|
697
|
+
)
|
|
698
|
+
if len(possible_examples) == 0:
|
|
699
|
+
continue
|
|
700
|
+
example = possible_examples.select(range(1))[0]
|
|
701
|
+
few_shot_examples.append(example)
|
|
702
|
+
shuffled_train = shuffled_train.filter(
|
|
703
|
+
lambda x: x["text"] != example["text"]
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
case TaskGroup.TEXT_TO_TEXT:
|
|
707
|
+
while (
|
|
708
|
+
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
709
|
+
):
|
|
710
|
+
example = shuffled_train.select(range(1))[0]
|
|
711
|
+
few_shot_examples.append(example)
|
|
712
|
+
shuffled_train = shuffled_train.filter(
|
|
713
|
+
lambda x: x["text"] != example["text"]
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
case TaskGroup.TOKEN_CLASSIFICATION:
|
|
717
|
+
labels = it.cycle(
|
|
718
|
+
[
|
|
719
|
+
label.lower()
|
|
720
|
+
for label in self.dataset_config.labels
|
|
721
|
+
if label.lower().startswith("b-")
|
|
722
|
+
]
|
|
723
|
+
)
|
|
724
|
+
while (
|
|
725
|
+
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
726
|
+
):
|
|
727
|
+
label = next(labels)
|
|
728
|
+
possible_examples = shuffled_train.filter(
|
|
729
|
+
lambda x: label in [tag.lower() for tag in x["labels"]]
|
|
730
|
+
)
|
|
731
|
+
if len(possible_examples) == 0:
|
|
732
|
+
continue
|
|
733
|
+
example = possible_examples.select(range(1))[0]
|
|
734
|
+
few_shot_examples.append(example)
|
|
735
|
+
shuffled_train = shuffled_train.filter(
|
|
736
|
+
lambda x: x["tokens"] != example["tokens"]
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
case TaskGroup.QUESTION_ANSWERING:
|
|
740
|
+
# Locate the maximum number of tokens that constitutes a short example
|
|
741
|
+
for max_num_tokens in [512, 1024, 2048, 4096, 8192]:
|
|
742
|
+
train_with_short_examples = dataset["train"].filter(
|
|
743
|
+
lambda example: len(example["context"]) < max_num_tokens
|
|
744
|
+
)
|
|
745
|
+
num_short_examples = len(train_with_short_examples)
|
|
746
|
+
if num_short_examples >= self.dataset_config.num_few_shot_examples:
|
|
747
|
+
break
|
|
748
|
+
else:
|
|
749
|
+
raise InvalidBenchmark(
|
|
750
|
+
"Could not find enough short examples for few-shot learning."
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
shuffled_train = train_with_short_examples.shuffle(seed=random_seed)
|
|
754
|
+
while (
|
|
755
|
+
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
756
|
+
):
|
|
757
|
+
example = shuffled_train.select(range(1))[0]
|
|
758
|
+
few_shot_examples.append(example)
|
|
759
|
+
shuffled_train = shuffled_train.filter(
|
|
760
|
+
lambda x: x["context"] != example["context"]
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
case _:
|
|
764
|
+
raise NotImplementedError(f"Unsupported task group: {task.task_group}.")
|
|
765
|
+
|
|
766
|
+
random.seed(random_seed)
|
|
767
|
+
random.shuffle(few_shot_examples)
|
|
768
|
+
return few_shot_examples
|
|
769
|
+
|
|
770
|
+
def _apply_prompt(
|
|
771
|
+
self,
|
|
772
|
+
examples: dict[str, t.Any],
|
|
773
|
+
few_shot_examples: list[dict[str, t.Any]],
|
|
774
|
+
task: Task,
|
|
775
|
+
) -> dict[str, t.Any]:
|
|
776
|
+
"""Apply prompt template to an example, potentially with few-shot examples.
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
examples:
|
|
780
|
+
The examples to apply the few-shot examples to.
|
|
781
|
+
few_shot_examples:
|
|
782
|
+
The few-shot examples to apply.
|
|
783
|
+
task:
|
|
784
|
+
The task that is being benchmarked.
|
|
785
|
+
|
|
786
|
+
Returns:
|
|
787
|
+
The example with the few-shot examples applied.
|
|
788
|
+
"""
|
|
789
|
+
|
|
790
|
+
def create_prompt(**kwargs: str) -> tuple[str, str]:
|
|
791
|
+
"""Create a prompt from the given keyword arguments.
|
|
792
|
+
|
|
793
|
+
Args:
|
|
794
|
+
kwargs:
|
|
795
|
+
The keyword arguments to use in the prompt.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
A pair (prompt, label), where "label" is an empty string if the model is
|
|
799
|
+
not instruction tuned (as in this case it is included in the prompt).
|
|
800
|
+
"""
|
|
801
|
+
label_key = "label" if "label" in kwargs else "target_text"
|
|
802
|
+
label = kwargs.pop(label_key)
|
|
803
|
+
label_mapping = self.dataset_config.prompt_label_mapping
|
|
804
|
+
label = label_mapping.get(label, label)
|
|
805
|
+
prompt = self.dataset_config.instruction_prompt.format(**kwargs)
|
|
806
|
+
return prompt, label
|
|
807
|
+
|
|
808
|
+
match task.task_group:
|
|
809
|
+
case (
|
|
810
|
+
TaskGroup.SEQUENCE_CLASSIFICATION
|
|
811
|
+
| TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
|
|
812
|
+
):
|
|
813
|
+
few_shot_sections = [
|
|
814
|
+
create_prompt(
|
|
815
|
+
text=example["text"].replace("\n", " ").strip(),
|
|
816
|
+
label=example["label"].replace("\n", " ").strip(),
|
|
817
|
+
)
|
|
818
|
+
for example in few_shot_examples
|
|
819
|
+
]
|
|
820
|
+
new_sections = [
|
|
821
|
+
create_prompt(text=text.replace("\n", " ").strip(), label="")
|
|
822
|
+
for text in examples["text"]
|
|
823
|
+
]
|
|
824
|
+
|
|
825
|
+
case TaskGroup.TEXT_TO_TEXT:
|
|
826
|
+
few_shot_sections = [
|
|
827
|
+
create_prompt(
|
|
828
|
+
text=example["text"].replace("\n", " ").strip(),
|
|
829
|
+
target_text=example["target_text"].replace("\n", " ").strip(),
|
|
830
|
+
)
|
|
831
|
+
for example in few_shot_examples
|
|
832
|
+
]
|
|
833
|
+
new_sections = [
|
|
834
|
+
create_prompt(text=text.replace("\n", " ").strip(), target_text="")
|
|
835
|
+
for text in examples["text"]
|
|
836
|
+
]
|
|
837
|
+
|
|
838
|
+
case TaskGroup.TOKEN_CLASSIFICATION:
|
|
839
|
+
|
|
840
|
+
def create_label(example: dict) -> str:
|
|
841
|
+
prompt_labels = self.dataset_config.prompt_label_mapping.values()
|
|
842
|
+
labels: dict[str, list[str]] = {
|
|
843
|
+
prompt_label: list() for prompt_label in prompt_labels
|
|
844
|
+
}
|
|
845
|
+
for token, label in zip(example["tokens"], example["labels"]):
|
|
846
|
+
label = label.lower()
|
|
847
|
+
if label == "o":
|
|
848
|
+
continue
|
|
849
|
+
prompt_label = self.dataset_config.prompt_label_mapping[label]
|
|
850
|
+
if label.startswith("b-"):
|
|
851
|
+
labels[prompt_label].append(token)
|
|
852
|
+
elif label.startswith("i-"):
|
|
853
|
+
labels[prompt_label][-1] += " " + token
|
|
854
|
+
return json.dumps(labels, ensure_ascii=False)
|
|
855
|
+
|
|
856
|
+
few_shot_sections = [
|
|
857
|
+
create_prompt(
|
|
858
|
+
text=" ".join(example["tokens"]).replace("\n", " ").strip(),
|
|
859
|
+
label=create_label(example=example),
|
|
860
|
+
)
|
|
861
|
+
for example in few_shot_examples
|
|
862
|
+
]
|
|
863
|
+
new_sections = [
|
|
864
|
+
create_prompt(
|
|
865
|
+
text=" ".join(tokens).replace("\n", " ").strip(), label=""
|
|
866
|
+
)
|
|
867
|
+
for tokens in examples["tokens"]
|
|
868
|
+
]
|
|
869
|
+
|
|
870
|
+
case TaskGroup.QUESTION_ANSWERING:
|
|
871
|
+
few_shot_sections = [
|
|
872
|
+
create_prompt(
|
|
873
|
+
text=example["context"].replace("\n", " ").strip(),
|
|
874
|
+
question=example["question"].replace("\n", " ").strip(),
|
|
875
|
+
label=example["answers"]["text"][0].replace("\n", " "),
|
|
876
|
+
)
|
|
877
|
+
for example in few_shot_examples
|
|
878
|
+
]
|
|
879
|
+
new_sections = [
|
|
880
|
+
create_prompt(
|
|
881
|
+
text=context.replace("\n", " ").strip(),
|
|
882
|
+
question=question.replace("\n", " ").strip(),
|
|
883
|
+
label="",
|
|
884
|
+
)
|
|
885
|
+
for context, question in zip(
|
|
886
|
+
examples["context"], examples["question"]
|
|
887
|
+
)
|
|
888
|
+
]
|
|
889
|
+
|
|
890
|
+
case _:
|
|
891
|
+
raise NotImplementedError(f"Unsupported task group: {task.task_group}.")
|
|
892
|
+
|
|
893
|
+
few_shot_messages = [
|
|
894
|
+
dict(role=role, content=content)
|
|
895
|
+
for prompt, label in few_shot_sections
|
|
896
|
+
for role, content in [("user", prompt), ("assistant", label)]
|
|
897
|
+
]
|
|
898
|
+
|
|
899
|
+
messages_list = [
|
|
900
|
+
few_shot_messages + [dict(role="user", content=prompt)]
|
|
901
|
+
for prompt, _ in new_sections
|
|
902
|
+
]
|
|
903
|
+
|
|
904
|
+
examples["messages"] = messages_list
|
|
905
|
+
return examples
|