EuroEval 15.6.1__py3-none-any.whl → 15.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/benchmark_modules/litellm.py +148 -284
- euroeval/benchmark_modules/vllm.py +115 -338
- euroeval/benchmarker.py +13 -2
- euroeval/constants.py +1 -1
- euroeval/data_loading.py +48 -26
- euroeval/data_models.py +3 -9
- euroeval/dataset_configs/dutch.py +5 -16
- euroeval/dataset_configs/finnish.py +60 -0
- euroeval/generation_utils.py +346 -0
- euroeval/prompt_templates/linguistic_acceptability.py +9 -1
- euroeval/prompt_templates/multiple_choice.py +8 -1
- euroeval/prompt_templates/named_entity_recognition.py +20 -1
- euroeval/prompt_templates/reading_comprehension.py +11 -1
- euroeval/prompt_templates/sentiment_classification.py +11 -1
- euroeval/prompt_templates/summarization.py +9 -1
- euroeval/scores.py +7 -1
- euroeval/task_group_utils/sequence_classification.py +27 -32
- euroeval/task_group_utils/text_to_text.py +10 -27
- euroeval/tasks.py +1 -1
- euroeval/tokenization_utils.py +22 -6
- {euroeval-15.6.1.dist-info → euroeval-15.7.1.dist-info}/METADATA +14 -2
- {euroeval-15.6.1.dist-info → euroeval-15.7.1.dist-info}/RECORD +25 -23
- {euroeval-15.6.1.dist-info → euroeval-15.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.6.1.dist-info → euroeval-15.7.1.dist-info}/entry_points.txt +0 -0
- {euroeval-15.6.1.dist-info → euroeval-15.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
"""Generative models from an inference API, using the LiteLLM framework."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
-
import itertools as it
|
|
5
|
-
import json
|
|
6
4
|
import logging
|
|
7
5
|
import os
|
|
8
|
-
import random
|
|
9
6
|
import re
|
|
10
7
|
import typing as t
|
|
11
8
|
from functools import cached_property, partial
|
|
@@ -33,6 +30,7 @@ from litellm.exceptions import (
|
|
|
33
30
|
)
|
|
34
31
|
from litellm.llms.vertex_ai.common_utils import VertexAIError
|
|
35
32
|
from litellm.types.utils import ChoiceLogprobs, ModelResponse
|
|
33
|
+
from pydantic import conlist, create_model
|
|
36
34
|
from requests.exceptions import RequestException
|
|
37
35
|
from tqdm.auto import tqdm
|
|
38
36
|
from transformers.trainer import Trainer
|
|
@@ -59,6 +57,7 @@ from ..exceptions import (
|
|
|
59
57
|
NeedsEnvironmentVariable,
|
|
60
58
|
NeedsExtraInstalled,
|
|
61
59
|
)
|
|
60
|
+
from ..generation_utils import apply_prompt, extract_few_shot_examples
|
|
62
61
|
from ..task_group_utils import (
|
|
63
62
|
question_answering,
|
|
64
63
|
sequence_classification,
|
|
@@ -104,6 +103,7 @@ MODEL_MAX_LENGTH_MAPPING = {
|
|
|
104
103
|
r"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
|
|
105
104
|
r"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
106
105
|
r"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
|
|
106
|
+
r"gpt-4.1.*": 1_047_576,
|
|
107
107
|
# Anthropic models
|
|
108
108
|
r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
|
|
109
109
|
# Gemini models
|
|
@@ -135,20 +135,23 @@ ALLOWED_PARAMS = {
|
|
|
135
135
|
r"gpt-4.*": [],
|
|
136
136
|
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": ["low", "high"],
|
|
137
137
|
# Anthropic models
|
|
138
|
-
r"(anthropic/)?claude-3
|
|
139
|
-
r"(anthropic/)?claude-3
|
|
140
|
-
r"(anthropic/)?claude-3
|
|
138
|
+
r"(anthropic/)?claude-3-(haiku|sonnet|opus).*": [],
|
|
139
|
+
r"(anthropic/)?claude-3-5-.*": [],
|
|
140
|
+
r"(anthropic/)?claude-3-7-sonnet.*": ["thinking"],
|
|
141
141
|
# Gemini models
|
|
142
142
|
r"(gemini/)?gemini-.*": [],
|
|
143
143
|
# xAI models
|
|
144
|
-
r"(xai/)?grok.*": [],
|
|
144
|
+
r"(xai/)?grok-2.*": [],
|
|
145
|
+
r"(xai/)?grok-3(-fast)?(-beta)?": [],
|
|
146
|
+
r"(xai/)?grok-3-mini(-fast)?(-beta)?": ["low", "high"],
|
|
145
147
|
}
|
|
146
148
|
|
|
147
149
|
|
|
148
150
|
REASONING_MODELS = [
|
|
149
151
|
r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?",
|
|
150
152
|
r"(gemini/)?gemini.*thinking.*",
|
|
151
|
-
r"(gemini/)?gemini-2.5
|
|
153
|
+
r"(gemini/)?gemini-2.5.*",
|
|
154
|
+
r"(xai/)?grok-3-mini.*",
|
|
152
155
|
]
|
|
153
156
|
|
|
154
157
|
|
|
@@ -190,7 +193,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
190
193
|
)
|
|
191
194
|
|
|
192
195
|
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
193
|
-
dataset_config=self.dataset_config,
|
|
196
|
+
dataset_config=self.dataset_config,
|
|
197
|
+
model_config=self.model_config,
|
|
198
|
+
tokenizer=None,
|
|
199
|
+
generative_type=self.generative_type,
|
|
194
200
|
)
|
|
195
201
|
|
|
196
202
|
@property
|
|
@@ -201,13 +207,20 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
201
207
|
The generative type of the model, or None if it has not been set yet.
|
|
202
208
|
"""
|
|
203
209
|
if self.model_config.revision == "thinking":
|
|
204
|
-
|
|
210
|
+
type_ = GenerativeType.REASONING
|
|
205
211
|
elif re.fullmatch(
|
|
206
212
|
pattern="|".join(REASONING_MODELS), string=self.model_config.model_id
|
|
207
213
|
):
|
|
208
|
-
|
|
214
|
+
type_ = GenerativeType.REASONING
|
|
209
215
|
else:
|
|
210
|
-
|
|
216
|
+
type_ = GenerativeType.INSTRUCTION_TUNED
|
|
217
|
+
|
|
218
|
+
log_once(
|
|
219
|
+
f"Detected generative type {type_.name!r} for model "
|
|
220
|
+
f"{self.model_config.model_id!r}",
|
|
221
|
+
level=logging.DEBUG,
|
|
222
|
+
)
|
|
223
|
+
return type_
|
|
211
224
|
|
|
212
225
|
def generate(self, inputs: dict) -> GenerativeModelOutput:
|
|
213
226
|
"""Generate outputs from the model.
|
|
@@ -243,7 +256,10 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
243
256
|
# Get the mapping from labels to the first token in the label. We call this each
|
|
244
257
|
# time we generate a new dataset since the dataset config can change
|
|
245
258
|
self.buffer["first_label_token_mapping"] = get_first_label_token_mapping(
|
|
246
|
-
dataset_config=self.dataset_config,
|
|
259
|
+
dataset_config=self.dataset_config,
|
|
260
|
+
model_config=self.model_config,
|
|
261
|
+
tokenizer=None,
|
|
262
|
+
generative_type=self.generative_type,
|
|
247
263
|
)
|
|
248
264
|
|
|
249
265
|
if self.buffer["first_label_token_mapping"]:
|
|
@@ -254,16 +270,41 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
254
270
|
assert "json" in messages[0]["content"].lower(), (
|
|
255
271
|
"Prompt must contain 'json' for JSON tasks."
|
|
256
272
|
)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
273
|
+
if self.generative_type == GenerativeType.REASONING:
|
|
274
|
+
log_once(
|
|
275
|
+
f"The model {self.model_config.model_id!r} is a reasoning model "
|
|
276
|
+
"and thus does not support structured generation, so we do not "
|
|
277
|
+
"enable it.",
|
|
278
|
+
level=logging.DEBUG,
|
|
279
|
+
)
|
|
280
|
+
elif litellm.utils.supports_response_schema(
|
|
281
|
+
model=self.model_config.model_id
|
|
282
|
+
):
|
|
283
|
+
ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
|
|
284
|
+
keys_and_their_types: dict[str, t.Any] = {
|
|
285
|
+
tag_name: (conlist(str, max_length=5), ...)
|
|
286
|
+
for tag_name in ner_tag_names
|
|
287
|
+
}
|
|
288
|
+
pydantic_class = create_model("AnswerFormat", **keys_and_their_types)
|
|
289
|
+
generation_kwargs["response_format"] = pydantic_class
|
|
290
|
+
log_once(
|
|
291
|
+
"Enabling structured generation for model "
|
|
292
|
+
f"{self.model_config.model_id!r} with the JSON schema "
|
|
293
|
+
f"{pydantic_class.model_json_schema()}",
|
|
294
|
+
level=logging.DEBUG,
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
generation_kwargs["response_format"] = dict(type="json_object")
|
|
298
|
+
log_once(
|
|
299
|
+
"Enabling structured JSON generation for model "
|
|
300
|
+
f"{self.model_config.model_id!r} with no custom JSON schema, as "
|
|
301
|
+
"the model does not support schemas.",
|
|
302
|
+
level=logging.DEBUG,
|
|
303
|
+
)
|
|
263
304
|
|
|
264
305
|
if self.model_config.revision == "thinking":
|
|
265
306
|
generation_kwargs["thinking"] = dict(
|
|
266
|
-
type="enabled", budget_tokens=REASONING_MAX_TOKENS
|
|
307
|
+
type="enabled", budget_tokens=REASONING_MAX_TOKENS - 1
|
|
267
308
|
)
|
|
268
309
|
log_once(
|
|
269
310
|
f"Enabling thinking mode for model {self.model_config.model_id!r}",
|
|
@@ -280,28 +321,42 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
280
321
|
# This drops generation kwargs that are not supported by the model
|
|
281
322
|
litellm.drop_params = True
|
|
282
323
|
|
|
324
|
+
# Error messages that we want to catch and handle
|
|
325
|
+
stop_messages = ["stop_sequences", "'stop' is not supported with this model"]
|
|
326
|
+
logprobs_messages = [
|
|
327
|
+
"you are not allowed to request logprobs",
|
|
328
|
+
"you've reached the maximum number of requests with logprobs",
|
|
329
|
+
"logprobs is not supported",
|
|
330
|
+
"logprobs is not enabled",
|
|
331
|
+
]
|
|
332
|
+
temperature_messages = [
|
|
333
|
+
"'temperature' is not supported with this model.",
|
|
334
|
+
"temperature is not supported with this model",
|
|
335
|
+
]
|
|
336
|
+
temperature_must_be_one_messages = [
|
|
337
|
+
"`temperature` may only be set to 1",
|
|
338
|
+
"'temperature' does not support 0.0 with this model. Only the default "
|
|
339
|
+
"(1) value is supported",
|
|
340
|
+
]
|
|
341
|
+
max_items_messages = ["'maxItems' is not permitted."]
|
|
342
|
+
no_json_schema_messages = ["Property keys should match pattern"]
|
|
343
|
+
|
|
283
344
|
# Extract the generated sequences from the model response. Some APIs cannot
|
|
284
345
|
# handle using newlines as stop sequences, so we try both.
|
|
285
346
|
num_attempts = 10
|
|
286
347
|
for _ in range(num_attempts):
|
|
287
|
-
stop_messages = ["stop_sequences"]
|
|
288
|
-
logprobs_messages = [
|
|
289
|
-
"you are not allowed to request logprobs",
|
|
290
|
-
"you've reached the maximum number of requests with logprobs",
|
|
291
|
-
"logprobs is not supported",
|
|
292
|
-
"logprobs is not enabled",
|
|
293
|
-
]
|
|
294
|
-
temperature_messages = [
|
|
295
|
-
"'temperature' is not supported with this model.",
|
|
296
|
-
"temperature is not supported with this model",
|
|
297
|
-
]
|
|
298
348
|
try:
|
|
299
|
-
model_response = litellm.
|
|
300
|
-
messages=messages,
|
|
349
|
+
model_response = litellm.completion_with_retries(
|
|
350
|
+
messages=messages, **generation_kwargs
|
|
301
351
|
)
|
|
302
352
|
break
|
|
303
353
|
except (BadRequestError, RateLimitError) as e:
|
|
304
354
|
if any(msg.lower() in str(e).lower() for msg in stop_messages):
|
|
355
|
+
log_once(
|
|
356
|
+
f"The model {self.model_config.model_id!r} does not support "
|
|
357
|
+
"stop sequences, so disabling them.",
|
|
358
|
+
level=logging.DEBUG,
|
|
359
|
+
)
|
|
305
360
|
generation_kwargs["stop"] = None
|
|
306
361
|
elif (
|
|
307
362
|
any(msg.lower() in str(e).lower() for msg in logprobs_messages)
|
|
@@ -310,10 +365,55 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
310
365
|
# we ignore this since the rate limiting makes it unusable anyway.
|
|
311
366
|
or (isinstance(e, VertexAIError) and "logprobs" in str(e).lower())
|
|
312
367
|
):
|
|
368
|
+
log_once(
|
|
369
|
+
f"The model {self.model_config.model_id!r} does not support "
|
|
370
|
+
"logprobs, so disabling it.",
|
|
371
|
+
level=logging.DEBUG,
|
|
372
|
+
)
|
|
313
373
|
generation_kwargs.pop("logprobs")
|
|
314
374
|
generation_kwargs.pop("top_logprobs")
|
|
315
375
|
elif any(msg.lower() in str(e).lower() for msg in temperature_messages):
|
|
376
|
+
log_once(
|
|
377
|
+
f"The model {self.model_config.model_id!r} does not support "
|
|
378
|
+
"temperature, so disabling it.",
|
|
379
|
+
level=logging.DEBUG,
|
|
380
|
+
)
|
|
316
381
|
generation_kwargs.pop("temperature")
|
|
382
|
+
elif any(
|
|
383
|
+
msg.lower() in str(e).lower()
|
|
384
|
+
for msg in temperature_must_be_one_messages
|
|
385
|
+
):
|
|
386
|
+
log_once(
|
|
387
|
+
f"The model {self.model_config.model_id!r} requires "
|
|
388
|
+
"temperature to be set to 1, so setting it.",
|
|
389
|
+
level=logging.DEBUG,
|
|
390
|
+
)
|
|
391
|
+
generation_kwargs["temperature"] = 1.0
|
|
392
|
+
elif any(msg.lower() in str(e).lower() for msg in max_items_messages):
|
|
393
|
+
log_once(
|
|
394
|
+
f"The model {self.model_config.model_id!r} does not support "
|
|
395
|
+
"maxItems in the JSON schema, so disabling it.",
|
|
396
|
+
level=logging.DEBUG,
|
|
397
|
+
)
|
|
398
|
+
ner_tag_names = list(
|
|
399
|
+
self.dataset_config.prompt_label_mapping.values()
|
|
400
|
+
)
|
|
401
|
+
keys_and_their_types = {
|
|
402
|
+
tag_name: (list[str], ...) for tag_name in ner_tag_names
|
|
403
|
+
}
|
|
404
|
+
pydantic_class = create_model(
|
|
405
|
+
"AnswerFormat", **keys_and_their_types
|
|
406
|
+
)
|
|
407
|
+
generation_kwargs["response_format"] = pydantic_class
|
|
408
|
+
elif any(
|
|
409
|
+
msg.lower() in str(e).lower() for msg in no_json_schema_messages
|
|
410
|
+
):
|
|
411
|
+
log_once(
|
|
412
|
+
f"The model {self.model_config.model_id!r} does not support "
|
|
413
|
+
"JSON schemas, so using the vanilla JSON format.",
|
|
414
|
+
level=logging.DEBUG,
|
|
415
|
+
)
|
|
416
|
+
generation_kwargs["response_format"] = dict(type="json_object")
|
|
317
417
|
elif isinstance(e, RateLimitError):
|
|
318
418
|
raise InvalidModel(
|
|
319
419
|
"You have encountered your rate limit for model "
|
|
@@ -332,6 +432,7 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
332
432
|
Timeout,
|
|
333
433
|
ServiceUnavailableError,
|
|
334
434
|
InternalServerError,
|
|
435
|
+
SystemError,
|
|
335
436
|
) as e:
|
|
336
437
|
logger.debug(
|
|
337
438
|
f"Service temporarily unavailable. The error message was: {e}. "
|
|
@@ -359,9 +460,11 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
359
460
|
"reasoning. Returning an empty string."
|
|
360
461
|
)
|
|
361
462
|
return GenerativeModelOutput(sequences=[""])
|
|
463
|
+
|
|
362
464
|
model_response_choices = model_response.choices[0]
|
|
363
465
|
assert isinstance(model_response_choices, litellm.Choices)
|
|
364
|
-
|
|
466
|
+
generated_message: litellm.Message = model_response_choices.message
|
|
467
|
+
generation_output = generated_message.content or ""
|
|
365
468
|
generation_output = generation_output.strip()
|
|
366
469
|
|
|
367
470
|
# Structure the model output as a GenerativeModelOutput object
|
|
@@ -838,14 +941,22 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
838
941
|
)
|
|
839
942
|
|
|
840
943
|
if self.benchmark_config.few_shot:
|
|
841
|
-
few_shot_examples =
|
|
842
|
-
dataset=dataset,
|
|
944
|
+
few_shot_examples = extract_few_shot_examples(
|
|
945
|
+
dataset=dataset, dataset_config=self.dataset_config, itr_idx=itr_idx
|
|
843
946
|
)
|
|
844
947
|
else:
|
|
845
948
|
few_shot_examples = list()
|
|
846
949
|
|
|
847
950
|
dataset["test"] = dataset["test"].map(
|
|
848
|
-
partial(
|
|
951
|
+
partial(
|
|
952
|
+
apply_prompt,
|
|
953
|
+
few_shot_examples=few_shot_examples,
|
|
954
|
+
model_config=self.model_config,
|
|
955
|
+
dataset_config=self.dataset_config,
|
|
956
|
+
instruction_model=True,
|
|
957
|
+
always_populate_text_field=False,
|
|
958
|
+
tokenizer=None,
|
|
959
|
+
),
|
|
849
960
|
batched=True,
|
|
850
961
|
load_from_cache_file=False,
|
|
851
962
|
keep_in_memory=True,
|
|
@@ -853,253 +964,6 @@ class LiteLLMModel(BenchmarkModule):
|
|
|
853
964
|
|
|
854
965
|
return dataset
|
|
855
966
|
|
|
856
|
-
def _extract_few_shot_examples(
|
|
857
|
-
self, dataset: DatasetDict, task: Task, itr_idx: int
|
|
858
|
-
) -> list[dict[str, t.Any]]:
|
|
859
|
-
"""Extract few-shot examples from a dataset.
|
|
860
|
-
|
|
861
|
-
This will always extract the examples from the training split.
|
|
862
|
-
|
|
863
|
-
We ensure that the few-shot examples are unique by picking them one at a time.
|
|
864
|
-
|
|
865
|
-
Args:
|
|
866
|
-
dataset:
|
|
867
|
-
The dataset to extract the few-shot examples from.
|
|
868
|
-
task:
|
|
869
|
-
The task that is being benchmarked.
|
|
870
|
-
itr_idx:
|
|
871
|
-
The index of the dataset in the iterator.
|
|
872
|
-
|
|
873
|
-
Returns:
|
|
874
|
-
The few-shot examples.
|
|
875
|
-
"""
|
|
876
|
-
random_seed = 4242 + itr_idx
|
|
877
|
-
num_few_shots = self.dataset_config.num_few_shot_examples
|
|
878
|
-
few_shot_examples: list[dict[str, t.Any]] = list()
|
|
879
|
-
shuffled_train = dataset["train"].shuffle(seed=random_seed)
|
|
880
|
-
|
|
881
|
-
match task.task_group:
|
|
882
|
-
case (
|
|
883
|
-
TaskGroup.SEQUENCE_CLASSIFICATION
|
|
884
|
-
| TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
|
|
885
|
-
):
|
|
886
|
-
labels = it.cycle(self.dataset_config.labels)
|
|
887
|
-
while (
|
|
888
|
-
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
889
|
-
):
|
|
890
|
-
label = next(labels)
|
|
891
|
-
possible_examples = shuffled_train.filter(
|
|
892
|
-
lambda x: x["label"].lower() == label.lower()
|
|
893
|
-
)
|
|
894
|
-
if len(possible_examples) == 0:
|
|
895
|
-
continue
|
|
896
|
-
example = possible_examples.select(range(1))[0]
|
|
897
|
-
few_shot_examples.append(example)
|
|
898
|
-
shuffled_train = shuffled_train.filter(
|
|
899
|
-
lambda x: x["text"] != example["text"]
|
|
900
|
-
)
|
|
901
|
-
|
|
902
|
-
case TaskGroup.TEXT_TO_TEXT:
|
|
903
|
-
while (
|
|
904
|
-
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
905
|
-
):
|
|
906
|
-
example = shuffled_train.select(range(1))[0]
|
|
907
|
-
few_shot_examples.append(example)
|
|
908
|
-
shuffled_train = shuffled_train.filter(
|
|
909
|
-
lambda x: x["text"] != example["text"]
|
|
910
|
-
)
|
|
911
|
-
|
|
912
|
-
case TaskGroup.TOKEN_CLASSIFICATION:
|
|
913
|
-
labels = it.cycle(
|
|
914
|
-
[
|
|
915
|
-
label.lower()
|
|
916
|
-
for label in self.dataset_config.labels
|
|
917
|
-
if label.lower().startswith("b-")
|
|
918
|
-
]
|
|
919
|
-
)
|
|
920
|
-
while (
|
|
921
|
-
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
922
|
-
):
|
|
923
|
-
label = next(labels)
|
|
924
|
-
possible_examples = shuffled_train.filter(
|
|
925
|
-
lambda x: label in [tag.lower() for tag in x["labels"]]
|
|
926
|
-
)
|
|
927
|
-
if len(possible_examples) == 0:
|
|
928
|
-
continue
|
|
929
|
-
example = possible_examples.select(range(1))[0]
|
|
930
|
-
few_shot_examples.append(example)
|
|
931
|
-
shuffled_train = shuffled_train.filter(
|
|
932
|
-
lambda x: x["tokens"] != example["tokens"]
|
|
933
|
-
)
|
|
934
|
-
|
|
935
|
-
case TaskGroup.QUESTION_ANSWERING:
|
|
936
|
-
# Locate the maximum number of tokens that constitutes a short example
|
|
937
|
-
for max_num_tokens in [512, 1024, 2048, 4096, 8192]:
|
|
938
|
-
train_with_short_examples = dataset["train"].filter(
|
|
939
|
-
lambda example: len(example["context"]) < max_num_tokens
|
|
940
|
-
)
|
|
941
|
-
num_short_examples = len(train_with_short_examples)
|
|
942
|
-
if num_short_examples >= self.dataset_config.num_few_shot_examples:
|
|
943
|
-
break
|
|
944
|
-
else:
|
|
945
|
-
raise InvalidBenchmark(
|
|
946
|
-
"Could not find enough short examples for few-shot learning."
|
|
947
|
-
)
|
|
948
|
-
|
|
949
|
-
shuffled_train = train_with_short_examples.shuffle(seed=random_seed)
|
|
950
|
-
while (
|
|
951
|
-
len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0
|
|
952
|
-
):
|
|
953
|
-
example = shuffled_train.select(range(1))[0]
|
|
954
|
-
few_shot_examples.append(example)
|
|
955
|
-
shuffled_train = shuffled_train.filter(
|
|
956
|
-
lambda x: x["context"] != example["context"]
|
|
957
|
-
)
|
|
958
|
-
|
|
959
|
-
case _:
|
|
960
|
-
raise NotImplementedError(f"Unsupported task group: {task.task_group}.")
|
|
961
|
-
|
|
962
|
-
random.seed(random_seed)
|
|
963
|
-
random.shuffle(few_shot_examples)
|
|
964
|
-
return few_shot_examples
|
|
965
|
-
|
|
966
|
-
def _apply_prompt(
|
|
967
|
-
self,
|
|
968
|
-
examples: dict[str, t.Any],
|
|
969
|
-
few_shot_examples: list[dict[str, t.Any]],
|
|
970
|
-
task: Task,
|
|
971
|
-
) -> dict[str, t.Any]:
|
|
972
|
-
"""Apply prompt template to an example, potentially with few-shot examples.
|
|
973
|
-
|
|
974
|
-
Args:
|
|
975
|
-
examples:
|
|
976
|
-
The examples to apply the few-shot examples to.
|
|
977
|
-
few_shot_examples:
|
|
978
|
-
The few-shot examples to apply.
|
|
979
|
-
task:
|
|
980
|
-
The task that is being benchmarked.
|
|
981
|
-
|
|
982
|
-
Returns:
|
|
983
|
-
The example with the few-shot examples applied.
|
|
984
|
-
"""
|
|
985
|
-
|
|
986
|
-
def create_prompt(**kwargs: str) -> tuple[str, str]:
|
|
987
|
-
"""Create a prompt from the given keyword arguments.
|
|
988
|
-
|
|
989
|
-
Args:
|
|
990
|
-
kwargs:
|
|
991
|
-
The keyword arguments to use in the prompt.
|
|
992
|
-
|
|
993
|
-
Returns:
|
|
994
|
-
A pair (prompt, label), where "label" is an empty string if the model is
|
|
995
|
-
not instruction tuned (as in this case it is included in the prompt).
|
|
996
|
-
"""
|
|
997
|
-
label_key = "label" if "label" in kwargs else "target_text"
|
|
998
|
-
label = kwargs.pop(label_key)
|
|
999
|
-
label_mapping = self.dataset_config.prompt_label_mapping
|
|
1000
|
-
label = label_mapping.get(label, label)
|
|
1001
|
-
prompt = self.dataset_config.instruction_prompt.format(**kwargs)
|
|
1002
|
-
return prompt, label
|
|
1003
|
-
|
|
1004
|
-
match task.task_group:
|
|
1005
|
-
case (
|
|
1006
|
-
TaskGroup.SEQUENCE_CLASSIFICATION
|
|
1007
|
-
| TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
|
|
1008
|
-
):
|
|
1009
|
-
few_shot_sections = [
|
|
1010
|
-
create_prompt(
|
|
1011
|
-
text=example["text"].replace("\n", " ").strip(),
|
|
1012
|
-
label=example["label"].replace("\n", " ").strip(),
|
|
1013
|
-
)
|
|
1014
|
-
for example in few_shot_examples
|
|
1015
|
-
]
|
|
1016
|
-
new_sections = [
|
|
1017
|
-
create_prompt(text=text.replace("\n", " ").strip(), label="")
|
|
1018
|
-
for text in examples["text"]
|
|
1019
|
-
]
|
|
1020
|
-
|
|
1021
|
-
case TaskGroup.TEXT_TO_TEXT:
|
|
1022
|
-
few_shot_sections = [
|
|
1023
|
-
create_prompt(
|
|
1024
|
-
text=example["text"].replace("\n", " ").strip(),
|
|
1025
|
-
target_text=example["target_text"].replace("\n", " ").strip(),
|
|
1026
|
-
)
|
|
1027
|
-
for example in few_shot_examples
|
|
1028
|
-
]
|
|
1029
|
-
new_sections = [
|
|
1030
|
-
create_prompt(text=text.replace("\n", " ").strip(), target_text="")
|
|
1031
|
-
for text in examples["text"]
|
|
1032
|
-
]
|
|
1033
|
-
|
|
1034
|
-
case TaskGroup.TOKEN_CLASSIFICATION:
|
|
1035
|
-
|
|
1036
|
-
def create_label(example: dict) -> str:
|
|
1037
|
-
prompt_labels = self.dataset_config.prompt_label_mapping.values()
|
|
1038
|
-
labels: dict[str, list[str]] = {
|
|
1039
|
-
prompt_label: list() for prompt_label in prompt_labels
|
|
1040
|
-
}
|
|
1041
|
-
for token, label in zip(example["tokens"], example["labels"]):
|
|
1042
|
-
label = label.lower()
|
|
1043
|
-
if label == "o":
|
|
1044
|
-
continue
|
|
1045
|
-
prompt_label = self.dataset_config.prompt_label_mapping[label]
|
|
1046
|
-
if label.startswith("b-"):
|
|
1047
|
-
labels[prompt_label].append(token)
|
|
1048
|
-
elif label.startswith("i-"):
|
|
1049
|
-
labels[prompt_label][-1] += " " + token
|
|
1050
|
-
return json.dumps(labels, ensure_ascii=False)
|
|
1051
|
-
|
|
1052
|
-
few_shot_sections = [
|
|
1053
|
-
create_prompt(
|
|
1054
|
-
text=" ".join(example["tokens"]).replace("\n", " ").strip(),
|
|
1055
|
-
label=create_label(example=example),
|
|
1056
|
-
)
|
|
1057
|
-
for example in few_shot_examples
|
|
1058
|
-
]
|
|
1059
|
-
new_sections = [
|
|
1060
|
-
create_prompt(
|
|
1061
|
-
text=" ".join(tokens).replace("\n", " ").strip(), label=""
|
|
1062
|
-
)
|
|
1063
|
-
for tokens in examples["tokens"]
|
|
1064
|
-
]
|
|
1065
|
-
|
|
1066
|
-
case TaskGroup.QUESTION_ANSWERING:
|
|
1067
|
-
few_shot_sections = [
|
|
1068
|
-
create_prompt(
|
|
1069
|
-
text=example["context"].replace("\n", " ").strip(),
|
|
1070
|
-
question=example["question"].replace("\n", " ").strip(),
|
|
1071
|
-
label=example["answers"]["text"][0].replace("\n", " "),
|
|
1072
|
-
)
|
|
1073
|
-
for example in few_shot_examples
|
|
1074
|
-
]
|
|
1075
|
-
new_sections = [
|
|
1076
|
-
create_prompt(
|
|
1077
|
-
text=context.replace("\n", " ").strip(),
|
|
1078
|
-
question=question.replace("\n", " ").strip(),
|
|
1079
|
-
label="",
|
|
1080
|
-
)
|
|
1081
|
-
for context, question in zip(
|
|
1082
|
-
examples["context"], examples["question"]
|
|
1083
|
-
)
|
|
1084
|
-
]
|
|
1085
|
-
|
|
1086
|
-
case _:
|
|
1087
|
-
raise NotImplementedError(f"Unsupported task group: {task.task_group}.")
|
|
1088
|
-
|
|
1089
|
-
few_shot_messages = [
|
|
1090
|
-
dict(role=role, content=content)
|
|
1091
|
-
for prompt, label in few_shot_sections
|
|
1092
|
-
for role, content in [("user", prompt), ("assistant", label)]
|
|
1093
|
-
]
|
|
1094
|
-
|
|
1095
|
-
messages_list = [
|
|
1096
|
-
few_shot_messages + [dict(role="user", content=prompt)]
|
|
1097
|
-
for prompt, _ in new_sections
|
|
1098
|
-
]
|
|
1099
|
-
|
|
1100
|
-
examples["messages"] = messages_list
|
|
1101
|
-
return examples
|
|
1102
|
-
|
|
1103
967
|
|
|
1104
968
|
def raise_if_wrong_params(
|
|
1105
969
|
model_config: ModelConfig, allowed_params: dict[str, list[str]]
|