EuroEval 15.15.0__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +3 -7
- euroeval/benchmark_config_factory.py +3 -7
- euroeval/benchmark_modules/base.py +35 -19
- euroeval/benchmark_modules/fresh.py +24 -19
- euroeval/benchmark_modules/hf.py +136 -154
- euroeval/benchmark_modules/litellm.py +323 -193
- euroeval/benchmark_modules/vllm.py +166 -112
- euroeval/benchmarker.py +59 -33
- euroeval/cli.py +3 -3
- euroeval/constants.py +13 -15
- euroeval/data_loading.py +33 -28
- euroeval/data_models.py +53 -7
- euroeval/dataset_configs/__init__.py +2 -0
- euroeval/dataset_configs/danish.py +38 -1
- euroeval/dataset_configs/dutch.py +38 -1
- euroeval/dataset_configs/english.py +38 -1
- euroeval/dataset_configs/estonian.py +95 -0
- euroeval/dataset_configs/faroese.py +38 -0
- euroeval/dataset_configs/finnish.py +39 -1
- euroeval/dataset_configs/french.py +38 -1
- euroeval/dataset_configs/german.py +38 -1
- euroeval/dataset_configs/icelandic.py +39 -1
- euroeval/dataset_configs/italian.py +38 -1
- euroeval/dataset_configs/latvian.py +81 -0
- euroeval/dataset_configs/norwegian.py +38 -1
- euroeval/dataset_configs/portuguese.py +38 -1
- euroeval/dataset_configs/spanish.py +38 -1
- euroeval/dataset_configs/swedish.py +38 -1
- euroeval/enums.py +0 -6
- euroeval/finetuning.py +8 -7
- euroeval/generation.py +25 -14
- euroeval/generation_utils.py +46 -14
- euroeval/languages.py +947 -187
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +76 -0
- euroeval/metrics/huggingface.py +192 -0
- euroeval/metrics/llm_as_a_judge.py +257 -0
- euroeval/metrics/pipeline.py +234 -0
- euroeval/metrics/speed.py +51 -0
- euroeval/prompt_templates/linguistic_acceptability.py +40 -2
- euroeval/prompt_templates/multiple_choice.py +23 -2
- euroeval/prompt_templates/named_entity_recognition.py +65 -2
- euroeval/prompt_templates/reading_comprehension.py +42 -2
- euroeval/prompt_templates/sentiment_classification.py +46 -2
- euroeval/prompt_templates/summarization.py +24 -4
- euroeval/scores.py +7 -2
- euroeval/speed_benchmark.py +6 -6
- euroeval/task_group_utils/multiple_choice_classification.py +17 -6
- euroeval/task_group_utils/question_answering.py +35 -28
- euroeval/task_group_utils/sequence_classification.py +96 -23
- euroeval/task_group_utils/text_to_text.py +7 -3
- euroeval/task_group_utils/token_classification.py +47 -75
- euroeval/tasks.py +31 -6
- euroeval/tokenization_utils.py +295 -207
- euroeval/utils.py +118 -34
- {euroeval-15.15.0.dist-info → euroeval-16.0.0.dist-info}/METADATA +12 -14
- euroeval-16.0.0.dist-info/RECORD +69 -0
- {euroeval-15.15.0.dist-info → euroeval-16.0.0.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -738
- euroeval/metrics.py +0 -468
- euroeval-15.15.0.dist-info/RECORD +0 -63
- {euroeval-15.15.0.dist-info → euroeval-16.0.0.dist-info}/WHEEL +0 -0
- {euroeval-15.15.0.dist-info → euroeval-16.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from ..data_models import DatasetConfig
|
|
4
4
|
from ..languages import ES
|
|
5
|
-
from ..tasks import COMMON_SENSE, KNOW, LA, MCRC, NER, RC, SENT, SUMM
|
|
5
|
+
from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
|
|
6
6
|
|
|
7
7
|
### Official datasets ###
|
|
8
8
|
|
|
@@ -66,6 +66,17 @@ HELLASWAG_ES_CONFIG = DatasetConfig(
|
|
|
66
66
|
languages=[ES],
|
|
67
67
|
)
|
|
68
68
|
|
|
69
|
+
EUROPEAN_VALUES_ES_CONFIG = DatasetConfig(
|
|
70
|
+
name="european-values-es",
|
|
71
|
+
pretty_name="the Spanish version of the European values evaluation dataset",
|
|
72
|
+
huggingface_id="EuroEval/european-values-es",
|
|
73
|
+
task=EUROPEAN_VALUES,
|
|
74
|
+
languages=[ES],
|
|
75
|
+
splits=["test"],
|
|
76
|
+
bootstrap_samples=False,
|
|
77
|
+
_instruction_prompt="{text}",
|
|
78
|
+
)
|
|
79
|
+
|
|
69
80
|
|
|
70
81
|
### Unofficial datasets ###
|
|
71
82
|
|
|
@@ -107,3 +118,29 @@ GOLDENSWAG_ES_CONFIG = DatasetConfig(
|
|
|
107
118
|
languages=[ES],
|
|
108
119
|
unofficial=True,
|
|
109
120
|
)
|
|
121
|
+
|
|
122
|
+
EUROPEAN_VALUES_SITUATIONAL_ES_CONFIG = DatasetConfig(
|
|
123
|
+
name="european-values-situational-es",
|
|
124
|
+
pretty_name="the Spanish version of the European values evaluation dataset, where "
|
|
125
|
+
"the questions are phrased in a situational way",
|
|
126
|
+
huggingface_id="EuroEval/european-values-situational-es",
|
|
127
|
+
task=EUROPEAN_VALUES,
|
|
128
|
+
languages=[ES],
|
|
129
|
+
splits=["test"],
|
|
130
|
+
bootstrap_samples=False,
|
|
131
|
+
_instruction_prompt="{text}",
|
|
132
|
+
unofficial=True,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
EUROPEAN_VALUES_COMPLETIONS_ES_CONFIG = DatasetConfig(
|
|
136
|
+
name="european-values-completions-es",
|
|
137
|
+
pretty_name="the Spanish version of the European values evaluation dataset, where "
|
|
138
|
+
"the questions are phrased as sentence completions",
|
|
139
|
+
huggingface_id="EuroEval/european-values-completions-es",
|
|
140
|
+
task=EUROPEAN_VALUES,
|
|
141
|
+
languages=[ES],
|
|
142
|
+
splits=["test"],
|
|
143
|
+
bootstrap_samples=False,
|
|
144
|
+
_instruction_prompt="{text}",
|
|
145
|
+
unofficial=True,
|
|
146
|
+
)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from ..data_models import DatasetConfig
|
|
4
4
|
from ..languages import SV
|
|
5
|
-
from ..tasks import COMMON_SENSE, KNOW, LA, MCRC, NER, RC, SENT, SUMM
|
|
5
|
+
from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
|
|
6
6
|
|
|
7
7
|
### Official datasets ###
|
|
8
8
|
|
|
@@ -67,6 +67,17 @@ HELLASWAG_SV_CONFIG = DatasetConfig(
|
|
|
67
67
|
languages=[SV],
|
|
68
68
|
)
|
|
69
69
|
|
|
70
|
+
EUROPEAN_VALUES_SV_CONFIG = DatasetConfig(
|
|
71
|
+
name="european-values-sv",
|
|
72
|
+
pretty_name="the Swedish version of the European values evaluation dataset",
|
|
73
|
+
huggingface_id="EuroEval/european-values-sv",
|
|
74
|
+
task=EUROPEAN_VALUES,
|
|
75
|
+
languages=[SV],
|
|
76
|
+
splits=["test"],
|
|
77
|
+
bootstrap_samples=False,
|
|
78
|
+
_instruction_prompt="{text}",
|
|
79
|
+
)
|
|
80
|
+
|
|
70
81
|
|
|
71
82
|
### Unofficial datasets ###
|
|
72
83
|
|
|
@@ -118,3 +129,29 @@ GOLDENSWAG_SV_CONFIG = DatasetConfig(
|
|
|
118
129
|
languages=[SV],
|
|
119
130
|
unofficial=True,
|
|
120
131
|
)
|
|
132
|
+
|
|
133
|
+
EUROPEAN_VALUES_SITUATIONAL_SV_CONFIG = DatasetConfig(
|
|
134
|
+
name="european-values-situational-sv",
|
|
135
|
+
pretty_name="the Swedish version of the European values evaluation dataset, where "
|
|
136
|
+
"the questions are phrased in a situational way",
|
|
137
|
+
huggingface_id="EuroEval/european-values-situational-sv",
|
|
138
|
+
task=EUROPEAN_VALUES,
|
|
139
|
+
languages=[SV],
|
|
140
|
+
splits=["test"],
|
|
141
|
+
bootstrap_samples=False,
|
|
142
|
+
_instruction_prompt="{text}",
|
|
143
|
+
unofficial=True,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
EUROPEAN_VALUES_COMPLETIONS_SV_CONFIG = DatasetConfig(
|
|
147
|
+
name="european-values-completions-sv",
|
|
148
|
+
pretty_name="the Swedish version of the European values evaluation dataset, where "
|
|
149
|
+
"the questions are phrased as sentence completions",
|
|
150
|
+
huggingface_id="EuroEval/european-values-completions-sv",
|
|
151
|
+
task=EUROPEAN_VALUES,
|
|
152
|
+
languages=[SV],
|
|
153
|
+
splits=["test"],
|
|
154
|
+
bootstrap_samples=False,
|
|
155
|
+
_instruction_prompt="{text}",
|
|
156
|
+
unofficial=True,
|
|
157
|
+
)
|
euroeval/enums.py
CHANGED
|
@@ -40,14 +40,11 @@ class InferenceBackend(AutoStrEnum):
|
|
|
40
40
|
VLLM library.
|
|
41
41
|
LITELLM:
|
|
42
42
|
LiteLLM library.
|
|
43
|
-
NONE:
|
|
44
|
-
No inference backend used (e.g., for human evaluation).
|
|
45
43
|
"""
|
|
46
44
|
|
|
47
45
|
TRANSFORMERS = auto()
|
|
48
46
|
VLLM = auto()
|
|
49
47
|
LITELLM = auto()
|
|
50
|
-
NONE = auto()
|
|
51
48
|
|
|
52
49
|
|
|
53
50
|
class ModelType(AutoStrEnum):
|
|
@@ -58,13 +55,10 @@ class ModelType(AutoStrEnum):
|
|
|
58
55
|
An encoder (i.e., BERT-style) model.
|
|
59
56
|
GENERATIVE:
|
|
60
57
|
A generative model. Can be either decoder or encoder-decoder (aka seq2seq).
|
|
61
|
-
HUMAN:
|
|
62
|
-
Human evaluator.
|
|
63
58
|
"""
|
|
64
59
|
|
|
65
60
|
ENCODER = auto()
|
|
66
61
|
GENERATIVE = auto()
|
|
67
|
-
HUMAN = auto()
|
|
68
62
|
|
|
69
63
|
|
|
70
64
|
class GenerativeType(AutoStrEnum):
|
euroeval/finetuning.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
import sys
|
|
5
5
|
import typing as t
|
|
6
|
+
from functools import partial
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
from tqdm.auto import tqdm
|
|
@@ -118,7 +119,7 @@ def finetune(
|
|
|
118
119
|
# NaN values can appear in the model output when using mixed precision, as
|
|
119
120
|
# the hidden states get overflowed. In this case we try to disable mixed
|
|
120
121
|
# precision and try again.
|
|
121
|
-
except NaNValueInModelOutput:
|
|
122
|
+
except NaNValueInModelOutput as e:
|
|
122
123
|
if dtype != DataType.FP32:
|
|
123
124
|
dtype = DataType.FP32
|
|
124
125
|
model_already_initialized = False
|
|
@@ -130,11 +131,11 @@ def finetune(
|
|
|
130
131
|
raise InvalidBenchmark(
|
|
131
132
|
"NaN value detected in model outputs, even with mixed "
|
|
132
133
|
"precision disabled."
|
|
133
|
-
)
|
|
134
|
+
) from e
|
|
134
135
|
|
|
135
136
|
except Exception as e:
|
|
136
137
|
if "CUDA" not in str(e) and "out of memory" not in str(e):
|
|
137
|
-
raise InvalidBenchmark(str(e))
|
|
138
|
+
raise InvalidBenchmark(str(e)) from e
|
|
138
139
|
|
|
139
140
|
if bs <= 1:
|
|
140
141
|
msg = "Could not benchmark the model, even with a batch size of 1!"
|
|
@@ -145,7 +146,7 @@ def finetune(
|
|
|
145
146
|
"environment variable set, as this removes the upper bound "
|
|
146
147
|
"on the memory usage."
|
|
147
148
|
)
|
|
148
|
-
raise InvalidBenchmark(msg)
|
|
149
|
+
raise InvalidBenchmark(msg) from e
|
|
149
150
|
|
|
150
151
|
model_already_initialized = False
|
|
151
152
|
|
|
@@ -194,11 +195,11 @@ def finetune_single_iteration(
|
|
|
194
195
|
|
|
195
196
|
trainer = model.trainer_class(
|
|
196
197
|
model=model.get_pytorch_module(),
|
|
197
|
-
processing_class=model.
|
|
198
|
+
processing_class=model.get_tokeniser(),
|
|
198
199
|
args=training_args,
|
|
199
200
|
train_dataset=dataset["train"],
|
|
200
201
|
eval_dataset=dataset["val"],
|
|
201
|
-
compute_metrics=model.compute_metrics,
|
|
202
|
+
compute_metrics=partial(model.compute_metrics, dataset=None),
|
|
202
203
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
|
|
203
204
|
data_collator=model.data_collator,
|
|
204
205
|
preprocess_logits_for_metrics=remove_extra_tensors_from_logits,
|
|
@@ -244,7 +245,7 @@ def finetune_single_iteration(
|
|
|
244
245
|
clear_memory()
|
|
245
246
|
raise e
|
|
246
247
|
except (RuntimeError, ValueError, IndexError) as e:
|
|
247
|
-
raise InvalidBenchmark(str(e))
|
|
248
|
+
raise InvalidBenchmark(str(e)) from e
|
|
248
249
|
|
|
249
250
|
return test_scores
|
|
250
251
|
|
euroeval/generation.py
CHANGED
|
@@ -6,6 +6,7 @@ import typing as t
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
|
|
8
8
|
import more_itertools as mit
|
|
9
|
+
from datasets import Dataset
|
|
9
10
|
from tqdm.auto import tqdm
|
|
10
11
|
|
|
11
12
|
from .enums import BatchingPreference, TaskGroup
|
|
@@ -15,10 +16,10 @@ from .model_cache import (
|
|
|
15
16
|
load_cached_model_outputs,
|
|
16
17
|
split_dataset_into_cached_and_non_cached,
|
|
17
18
|
)
|
|
18
|
-
from .utils import clear_memory
|
|
19
|
+
from .utils import clear_memory, log_once
|
|
19
20
|
|
|
20
21
|
if t.TYPE_CHECKING:
|
|
21
|
-
from datasets import
|
|
22
|
+
from datasets import DatasetDict
|
|
22
23
|
|
|
23
24
|
from .benchmark_modules import BenchmarkModule
|
|
24
25
|
from .data_models import (
|
|
@@ -78,7 +79,7 @@ def generate(
|
|
|
78
79
|
|
|
79
80
|
scores: list[dict[str, float]] = list()
|
|
80
81
|
for idx in tqdm(
|
|
81
|
-
iterable=range(
|
|
82
|
+
iterable=range(len(datasets)),
|
|
82
83
|
desc="Benchmarking",
|
|
83
84
|
disable=not benchmark_config.progress_bar,
|
|
84
85
|
):
|
|
@@ -89,7 +90,6 @@ def generate(
|
|
|
89
90
|
dataset_config=dataset_config,
|
|
90
91
|
benchmark_config=benchmark_config,
|
|
91
92
|
)
|
|
92
|
-
|
|
93
93
|
logger.debug(f"Test scores for iteration {idx}: {test_scores}")
|
|
94
94
|
scores.append(test_scores)
|
|
95
95
|
clear_memory()
|
|
@@ -126,10 +126,15 @@ def generate_single_iteration(
|
|
|
126
126
|
"""
|
|
127
127
|
cache.load()
|
|
128
128
|
|
|
129
|
-
# Split up the dataset into a cached and non-cached part
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
# Split up the dataset into a cached and non-cached part, unless we are not
|
|
130
|
+
# bootstrapping the samples. In that case, we just use the dataset as is.
|
|
131
|
+
if dataset_config.bootstrap_samples:
|
|
132
|
+
cached_dataset, non_cached_dataset = split_dataset_into_cached_and_non_cached(
|
|
133
|
+
dataset=dataset, cache=cache
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
cached_dataset = Dataset.from_dict({})
|
|
137
|
+
non_cached_dataset = dataset
|
|
133
138
|
|
|
134
139
|
all_preds: list[str] = list()
|
|
135
140
|
|
|
@@ -230,9 +235,12 @@ def generate_single_iteration(
|
|
|
230
235
|
cached_labels = list(cached_labels)
|
|
231
236
|
ground_truth = non_cached_labels + cached_labels
|
|
232
237
|
else:
|
|
233
|
-
|
|
234
|
-
"
|
|
238
|
+
log_once(
|
|
239
|
+
"No labels found in the dataset. We assume that this is intentional, and "
|
|
240
|
+
"will not supply any ground truth labels for evaluation.",
|
|
241
|
+
level=logging.DEBUG,
|
|
235
242
|
)
|
|
243
|
+
ground_truth = []
|
|
236
244
|
|
|
237
245
|
itr_scores: dict[str, float] = model.compute_metrics(
|
|
238
246
|
model_outputs_and_labels=(all_preds, ground_truth), dataset=dataset
|
|
@@ -293,10 +301,13 @@ def debug_log(
|
|
|
293
301
|
case (
|
|
294
302
|
TaskGroup.SEQUENCE_CLASSIFICATION | TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
|
|
295
303
|
):
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
304
|
+
if "label" in batch:
|
|
305
|
+
labels = [
|
|
306
|
+
dataset_config.prompt_label_mapping.get(label, label).lower()
|
|
307
|
+
for label in batch["label"]
|
|
308
|
+
]
|
|
309
|
+
else:
|
|
310
|
+
labels = ["N/A"] * len(extracted_labels)
|
|
300
311
|
|
|
301
312
|
case TaskGroup.QUESTION_ANSWERING:
|
|
302
313
|
extracted_labels = [
|
euroeval/generation_utils.py
CHANGED
|
@@ -8,19 +8,23 @@ import typing as t
|
|
|
8
8
|
|
|
9
9
|
from .enums import TaskGroup
|
|
10
10
|
from .exceptions import InvalidBenchmark
|
|
11
|
+
from .tokenization_utils import apply_chat_template
|
|
11
12
|
from .utils import log_once
|
|
12
13
|
|
|
13
14
|
if t.TYPE_CHECKING:
|
|
14
15
|
from datasets import DatasetDict
|
|
15
16
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
16
17
|
|
|
17
|
-
from .data_models import DatasetConfig, ModelConfig
|
|
18
|
+
from .data_models import BenchmarkConfig, DatasetConfig, ModelConfig
|
|
18
19
|
|
|
19
20
|
logger = logging.getLogger("euroeval")
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
def extract_few_shot_examples(
|
|
23
|
-
dataset: "DatasetDict",
|
|
24
|
+
dataset: "DatasetDict",
|
|
25
|
+
dataset_config: "DatasetConfig",
|
|
26
|
+
benchmark_config: "BenchmarkConfig",
|
|
27
|
+
itr_idx: int,
|
|
24
28
|
) -> list[dict[str, t.Any]]:
|
|
25
29
|
"""Extract few-shot examples from a dataset.
|
|
26
30
|
|
|
@@ -33,12 +37,32 @@ def extract_few_shot_examples(
|
|
|
33
37
|
The dataset to extract the few-shot examples from.
|
|
34
38
|
dataset_config:
|
|
35
39
|
The dataset configuration.
|
|
40
|
+
benchmark_config:
|
|
41
|
+
The benchmark configuration.
|
|
36
42
|
itr_idx:
|
|
37
43
|
The index of the dataset in the iterator.
|
|
38
44
|
|
|
39
45
|
Returns:
|
|
40
46
|
The few-shot examples.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
InvalidBenchmark:
|
|
50
|
+
If there are not enough short examples for few-shot learning.
|
|
41
51
|
"""
|
|
52
|
+
if dataset_config.task.requires_zero_shot and benchmark_config.few_shot:
|
|
53
|
+
msg = (
|
|
54
|
+
"This task only allows zero-shot evaluation, so even though you have "
|
|
55
|
+
"requested few-shot evaluation "
|
|
56
|
+
)
|
|
57
|
+
if benchmark_config.run_with_cli:
|
|
58
|
+
msg += "(by not setting the --zero-shot flag), "
|
|
59
|
+
else:
|
|
60
|
+
msg += "(by setting the default `few_shot=True` argument), "
|
|
61
|
+
msg += "we will run the evaluation in zero-shot mode."
|
|
62
|
+
benchmark_config.few_shot = False
|
|
63
|
+
log_once(msg, level=logging.DEBUG)
|
|
64
|
+
return []
|
|
65
|
+
|
|
42
66
|
random_seed = 4242 + itr_idx
|
|
43
67
|
num_few_shots = dataset_config.num_few_shot_examples
|
|
44
68
|
few_shot_examples: list[dict[str, t.Any]] = list()
|
|
@@ -63,12 +87,19 @@ def extract_few_shot_examples(
|
|
|
63
87
|
|
|
64
88
|
shuffled_train = train_with_short_examples.shuffle(seed=random_seed)
|
|
65
89
|
labels = it.cycle(dataset_config.labels)
|
|
90
|
+
labels_with_no_samples: set[str] = set()
|
|
66
91
|
while len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0:
|
|
92
|
+
if len(labels_with_no_samples) == len(dataset_config.labels):
|
|
93
|
+
raise InvalidBenchmark(
|
|
94
|
+
"Could not find enough examples for few-shot learning. "
|
|
95
|
+
"Please check the dataset and the labels."
|
|
96
|
+
)
|
|
67
97
|
label = next(labels)
|
|
68
98
|
possible_examples = shuffled_train.filter(
|
|
69
99
|
lambda x: x["label"].lower() == label.lower()
|
|
70
100
|
)
|
|
71
101
|
if len(possible_examples) == 0:
|
|
102
|
+
labels_with_no_samples.add(label)
|
|
72
103
|
continue
|
|
73
104
|
example = possible_examples.select(range(1))[0]
|
|
74
105
|
few_shot_examples.append(example)
|
|
@@ -144,7 +175,7 @@ def apply_prompt(
|
|
|
144
175
|
dataset_config: "DatasetConfig",
|
|
145
176
|
instruction_model: bool,
|
|
146
177
|
always_populate_text_field: bool,
|
|
147
|
-
|
|
178
|
+
tokeniser: "PreTrainedTokenizer | None",
|
|
148
179
|
) -> dict[str, t.Any]:
|
|
149
180
|
"""Apply prompt template to an example, potentially with few-shot examples.
|
|
150
181
|
|
|
@@ -160,16 +191,16 @@ def apply_prompt(
|
|
|
160
191
|
always_populate_text_field:
|
|
161
192
|
Whether to always populate the 'text' field in the examples, as opposed to
|
|
162
193
|
the 'messages' field.
|
|
163
|
-
|
|
164
|
-
The
|
|
194
|
+
tokeniser:
|
|
195
|
+
The tokeniser to use for the model. If None, the tokeniser is not used.
|
|
165
196
|
|
|
166
197
|
Returns:
|
|
167
198
|
The example with the few-shot examples applied.
|
|
168
199
|
"""
|
|
169
200
|
# Sanity check
|
|
170
|
-
if instruction_model and always_populate_text_field and
|
|
201
|
+
if instruction_model and always_populate_text_field and tokeniser is None:
|
|
171
202
|
raise ValueError(
|
|
172
|
-
"The `
|
|
203
|
+
"The `tokeniser` argument must be provided when the model is instruction "
|
|
173
204
|
"tuned and when we are not just returning the raw messages."
|
|
174
205
|
)
|
|
175
206
|
|
|
@@ -298,30 +329,31 @@ def apply_prompt(
|
|
|
298
329
|
examples["messages"] = messages_list
|
|
299
330
|
|
|
300
331
|
else:
|
|
301
|
-
assert
|
|
332
|
+
assert tokeniser is not None
|
|
302
333
|
|
|
303
334
|
# Pick the chat template that matches the language of the dataset, if such a
|
|
304
335
|
# template exists
|
|
305
336
|
chat_template: str | None = None
|
|
306
|
-
if
|
|
337
|
+
if hasattr(tokeniser, "chat_template") and isinstance(
|
|
338
|
+
tokeniser.chat_template, dict
|
|
339
|
+
):
|
|
307
340
|
language_codes = [
|
|
308
341
|
language.code for language in dataset_config.languages
|
|
309
342
|
]
|
|
310
|
-
for name, candidate_template in
|
|
343
|
+
for name, candidate_template in tokeniser.chat_template.items():
|
|
311
344
|
if name.lower() in language_codes:
|
|
312
345
|
chat_template = candidate_template
|
|
313
346
|
log_once(
|
|
314
|
-
f"Using the {name!r} chat template for the
|
|
347
|
+
f"Using the {name!r} chat template for the tokeniser for "
|
|
315
348
|
f"model {model_config.model_id!r}.",
|
|
316
349
|
level=logging.DEBUG,
|
|
317
350
|
)
|
|
318
351
|
break
|
|
319
352
|
|
|
320
353
|
texts = [
|
|
321
|
-
|
|
354
|
+
apply_chat_template(
|
|
322
355
|
conversation=messages,
|
|
323
|
-
|
|
324
|
-
add_generation_prompt=True,
|
|
356
|
+
tokeniser=tokeniser,
|
|
325
357
|
chat_template=chat_template,
|
|
326
358
|
)
|
|
327
359
|
for messages in messages_list
|