EuroEval 16.0.0__py3-none-any.whl → 16.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +5 -0
- euroeval/benchmark_config_factory.py +6 -1
- euroeval/benchmark_modules/base.py +2 -0
- euroeval/benchmark_modules/fresh.py +7 -1
- euroeval/benchmark_modules/hf.py +26 -21
- euroeval/benchmark_modules/litellm.py +258 -131
- euroeval/benchmark_modules/vllm.py +120 -68
- euroeval/benchmarker.py +11 -2
- euroeval/cli.py +14 -1
- euroeval/constants.py +7 -1
- euroeval/data_models.py +95 -20
- euroeval/dataset_configs/__init__.py +1 -0
- euroeval/dataset_configs/danish.py +14 -3
- euroeval/dataset_configs/dutch.py +14 -0
- euroeval/dataset_configs/english.py +22 -0
- euroeval/dataset_configs/estonian.py +15 -7
- euroeval/dataset_configs/finnish.py +14 -0
- euroeval/dataset_configs/french.py +14 -0
- euroeval/dataset_configs/german.py +23 -0
- euroeval/dataset_configs/italian.py +14 -0
- euroeval/dataset_configs/latvian.py +14 -0
- euroeval/dataset_configs/norwegian.py +14 -0
- euroeval/dataset_configs/polish.py +126 -0
- euroeval/dataset_configs/portuguese.py +14 -0
- euroeval/dataset_configs/spanish.py +14 -0
- euroeval/dataset_configs/swedish.py +25 -0
- euroeval/enums.py +12 -0
- euroeval/generation.py +17 -8
- euroeval/generation_utils.py +102 -16
- euroeval/metrics/pipeline.py +51 -9
- euroeval/model_cache.py +13 -1
- euroeval/prompt_templates/linguistic_acceptability.py +9 -0
- euroeval/prompt_templates/multiple_choice.py +27 -1
- euroeval/prompt_templates/named_entity_recognition.py +20 -0
- euroeval/prompt_templates/reading_comprehension.py +11 -0
- euroeval/prompt_templates/sentiment_classification.py +15 -0
- euroeval/prompt_templates/summarization.py +27 -1
- euroeval/scores.py +5 -0
- euroeval/task_group_utils/multiple_choice_classification.py +2 -2
- euroeval/task_group_utils/question_answering.py +29 -29
- euroeval/task_group_utils/sequence_classification.py +71 -81
- euroeval/task_group_utils/token_classification.py +17 -3
- euroeval/tasks.py +12 -10
- euroeval/{tokenization_utils.py → tokenisation_utils.py} +41 -25
- euroeval/utils.py +67 -3
- {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/METADATA +3 -1
- euroeval-16.1.0.dist-info/RECORD +70 -0
- euroeval-16.0.0.dist-info/RECORD +0 -69
- {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/WHEEL +0 -0
- {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""All Polish dataset configurations used in EuroEval."""
|
|
2
|
+
|
|
3
|
+
from ..data_models import DatasetConfig
|
|
4
|
+
from ..enums import ModelType
|
|
5
|
+
from ..languages import PL
|
|
6
|
+
from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, NER, RC, SENT, SUMM
|
|
7
|
+
|
|
8
|
+
### Official datasets ###
|
|
9
|
+
|
|
10
|
+
POLEMO2_CONFIG = DatasetConfig(
|
|
11
|
+
name="polemo2",
|
|
12
|
+
pretty_name="the Polish sentiment classification dataset PolEmo2",
|
|
13
|
+
huggingface_id="EuroEval/polemo2-mini",
|
|
14
|
+
task=SENT,
|
|
15
|
+
languages=[PL],
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
SCALA_PL_CONFIG = DatasetConfig(
|
|
19
|
+
name="scala-pl",
|
|
20
|
+
pretty_name="the Polish part of the linguistic acceptability dataset ScaLA",
|
|
21
|
+
huggingface_id="EuroEval/scala-pl",
|
|
22
|
+
task=LA,
|
|
23
|
+
languages=[PL],
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
KPWR_NER_CONFIG = DatasetConfig(
|
|
27
|
+
name="kpwr-ner",
|
|
28
|
+
pretty_name="the Polish entity recognition dataset KPWr-NER",
|
|
29
|
+
huggingface_id="EuroEval/kpwr-ner",
|
|
30
|
+
task=NER,
|
|
31
|
+
languages=[PL],
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
POQUAD_CONFIG = DatasetConfig(
|
|
35
|
+
name="poquad",
|
|
36
|
+
pretty_name="the Polish question answering dataset PoQuAD",
|
|
37
|
+
huggingface_id="EuroEval/poquad-mini",
|
|
38
|
+
task=RC,
|
|
39
|
+
languages=[PL],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
PSC_CONFIG = DatasetConfig(
|
|
43
|
+
name="psc",
|
|
44
|
+
pretty_name="the Polish summarisation dataset PSC",
|
|
45
|
+
huggingface_id="EuroEval/psc-mini",
|
|
46
|
+
task=SUMM,
|
|
47
|
+
languages=[PL],
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
LLMZSZL_CONFIG = DatasetConfig(
|
|
51
|
+
name="llmzszl",
|
|
52
|
+
pretty_name="the Polish knowledge dataset LLMzSzŁ",
|
|
53
|
+
huggingface_id="EuroEval/llmzszl-mini",
|
|
54
|
+
task=KNOW,
|
|
55
|
+
languages=[PL],
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
WINOGRANDE_PL_CONFIG = DatasetConfig(
|
|
59
|
+
name="winogrande-pl",
|
|
60
|
+
pretty_name="the Polish common-sense reasoning dataset Winogrande-pl, translated "
|
|
61
|
+
"from the English Winogrande dataset",
|
|
62
|
+
huggingface_id="EuroEval/winogrande-pl",
|
|
63
|
+
task=COMMON_SENSE,
|
|
64
|
+
languages=[PL],
|
|
65
|
+
splits=["train", "test"],
|
|
66
|
+
_labels=["a", "b"],
|
|
67
|
+
_allowed_model_types=[ModelType.GENERATIVE],
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
EUROPEAN_VALUES_PL_CONFIG = DatasetConfig(
|
|
71
|
+
name="european-values-pl",
|
|
72
|
+
pretty_name="the Polish version of the European values evaluation dataset",
|
|
73
|
+
huggingface_id="EuroEval/european-values-pl",
|
|
74
|
+
task=EUROPEAN_VALUES,
|
|
75
|
+
languages=[PL],
|
|
76
|
+
splits=["test"],
|
|
77
|
+
bootstrap_samples=False,
|
|
78
|
+
_instruction_prompt="{text}",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
### Unofficial datasets ###
|
|
83
|
+
|
|
84
|
+
MULTI_WIKI_QA_PL_CONFIG = DatasetConfig(
|
|
85
|
+
name="multi-wiki-qa-pl",
|
|
86
|
+
pretty_name="the truncated version of the Polish part of the reading "
|
|
87
|
+
"comprehension dataset MultiWikiQA",
|
|
88
|
+
huggingface_id="EuroEval/multi-wiki-qa-pl-mini",
|
|
89
|
+
task=RC,
|
|
90
|
+
languages=[PL],
|
|
91
|
+
unofficial=True,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
GOLDENSWAG_PL_CONFIG = DatasetConfig(
|
|
95
|
+
name="goldenswag-pl",
|
|
96
|
+
pretty_name="the truncated version of the Polish common-sense reasoning "
|
|
97
|
+
"dataset GoldenSwag-pl, translated from the English GoldenSwag dataset",
|
|
98
|
+
huggingface_id="EuroEval/goldenswag-pl-mini",
|
|
99
|
+
task=COMMON_SENSE,
|
|
100
|
+
languages=[PL],
|
|
101
|
+
unofficial=True,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
EUROPEAN_VALUES_SITUATIONAL_PL_CONFIG = DatasetConfig(
|
|
105
|
+
name="european-values-situational-pl",
|
|
106
|
+
pretty_name="the Polish version of the European values evaluation dataset, where "
|
|
107
|
+
"the questions are phrased in a situational way",
|
|
108
|
+
huggingface_id="EuroEval/european-values-situational-pl",
|
|
109
|
+
task=EUROPEAN_VALUES,
|
|
110
|
+
languages=[PL],
|
|
111
|
+
splits=["test"],
|
|
112
|
+
bootstrap_samples=False,
|
|
113
|
+
unofficial=True,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
EUROPEAN_VALUES_COMPLETIONS_PL_CONFIG = DatasetConfig(
|
|
117
|
+
name="european-values-completions-pl",
|
|
118
|
+
pretty_name="the Polish version of the European values evaluation dataset, where "
|
|
119
|
+
"the questions are phrased as sentence completions",
|
|
120
|
+
huggingface_id="EuroEval/european-values-completions-pl",
|
|
121
|
+
task=EUROPEAN_VALUES,
|
|
122
|
+
languages=[PL],
|
|
123
|
+
splits=["test"],
|
|
124
|
+
bootstrap_samples=False,
|
|
125
|
+
unofficial=True,
|
|
126
|
+
)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""All Portuguese dataset configurations used in EuroEval."""
|
|
2
2
|
|
|
3
3
|
from ..data_models import DatasetConfig
|
|
4
|
+
from ..enums import ModelType
|
|
4
5
|
from ..languages import PT
|
|
5
6
|
from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
|
|
6
7
|
|
|
@@ -91,6 +92,19 @@ BOOLQ_PT_CONFIG = DatasetConfig(
|
|
|
91
92
|
unofficial=True,
|
|
92
93
|
)
|
|
93
94
|
|
|
95
|
+
WINOGRANDE_PT_CONFIG = DatasetConfig(
|
|
96
|
+
name="winogrande-pt",
|
|
97
|
+
pretty_name="the Portuguese common-sense reasoning dataset Winogrande-pt, "
|
|
98
|
+
"translated from the English Winogrande dataset",
|
|
99
|
+
huggingface_id="EuroEval/winogrande-pt",
|
|
100
|
+
task=COMMON_SENSE,
|
|
101
|
+
languages=[PT],
|
|
102
|
+
splits=["train", "test"],
|
|
103
|
+
_labels=["a", "b"],
|
|
104
|
+
_allowed_model_types=[ModelType.GENERATIVE],
|
|
105
|
+
unofficial=True,
|
|
106
|
+
)
|
|
107
|
+
|
|
94
108
|
EUROPEAN_VALUES_SITUATIONAL_PT_CONFIG = DatasetConfig(
|
|
95
109
|
name="european-values-situational-pt",
|
|
96
110
|
pretty_name="the Portuguese version of the European values evaluation dataset, "
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""All Spanish dataset configurations used in EuroEval."""
|
|
2
2
|
|
|
3
3
|
from ..data_models import DatasetConfig
|
|
4
|
+
from ..enums import ModelType
|
|
4
5
|
from ..languages import ES
|
|
5
6
|
from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
|
|
6
7
|
|
|
@@ -119,6 +120,19 @@ GOLDENSWAG_ES_CONFIG = DatasetConfig(
|
|
|
119
120
|
unofficial=True,
|
|
120
121
|
)
|
|
121
122
|
|
|
123
|
+
WINOGRANDE_ES_CONFIG = DatasetConfig(
|
|
124
|
+
name="winogrande-es",
|
|
125
|
+
pretty_name="the Spanish common-sense reasoning dataset Winogrande-es, translated "
|
|
126
|
+
"from the English Winogrande dataset",
|
|
127
|
+
huggingface_id="EuroEval/winogrande-es",
|
|
128
|
+
task=COMMON_SENSE,
|
|
129
|
+
languages=[ES],
|
|
130
|
+
splits=["train", "test"],
|
|
131
|
+
_labels=["a", "b"],
|
|
132
|
+
_allowed_model_types=[ModelType.GENERATIVE],
|
|
133
|
+
unofficial=True,
|
|
134
|
+
)
|
|
135
|
+
|
|
122
136
|
EUROPEAN_VALUES_SITUATIONAL_ES_CONFIG = DatasetConfig(
|
|
123
137
|
name="european-values-situational-es",
|
|
124
138
|
pretty_name="the Spanish version of the European values evaluation dataset, where "
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""All Swedish dataset configurations used in EuroEval."""
|
|
2
2
|
|
|
3
3
|
from ..data_models import DatasetConfig
|
|
4
|
+
from ..enums import ModelType
|
|
4
5
|
from ..languages import SV
|
|
5
6
|
from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
|
|
6
7
|
|
|
@@ -130,6 +131,19 @@ GOLDENSWAG_SV_CONFIG = DatasetConfig(
|
|
|
130
131
|
unofficial=True,
|
|
131
132
|
)
|
|
132
133
|
|
|
134
|
+
WINOGRANDE_SV_CONFIG = DatasetConfig(
|
|
135
|
+
name="winogrande-sv",
|
|
136
|
+
pretty_name="the Swedish common-sense reasoning dataset Winogrande-sv, translated "
|
|
137
|
+
"from the English Winogrande dataset",
|
|
138
|
+
huggingface_id="EuroEval/winogrande-sv",
|
|
139
|
+
task=COMMON_SENSE,
|
|
140
|
+
languages=[SV],
|
|
141
|
+
splits=["train", "test"],
|
|
142
|
+
_labels=["a", "b"],
|
|
143
|
+
_allowed_model_types=[ModelType.GENERATIVE],
|
|
144
|
+
unofficial=True,
|
|
145
|
+
)
|
|
146
|
+
|
|
133
147
|
EUROPEAN_VALUES_SITUATIONAL_SV_CONFIG = DatasetConfig(
|
|
134
148
|
name="european-values-situational-sv",
|
|
135
149
|
pretty_name="the Swedish version of the European values evaluation dataset, where "
|
|
@@ -155,3 +169,14 @@ EUROPEAN_VALUES_COMPLETIONS_SV_CONFIG = DatasetConfig(
|
|
|
155
169
|
_instruction_prompt="{text}",
|
|
156
170
|
unofficial=True,
|
|
157
171
|
)
|
|
172
|
+
|
|
173
|
+
SKOLPROV_CONFIG = DatasetConfig(
|
|
174
|
+
name="skolprov",
|
|
175
|
+
pretty_name="the Swedish knowledge dataset Skolprov",
|
|
176
|
+
huggingface_id="EuroEval/skolprov",
|
|
177
|
+
task=KNOW,
|
|
178
|
+
languages=[SV],
|
|
179
|
+
splits=["train", "test"],
|
|
180
|
+
_allowed_model_types=[ModelType.GENERATIVE],
|
|
181
|
+
unofficial=True,
|
|
182
|
+
)
|
euroeval/enums.py
CHANGED
|
@@ -12,6 +12,14 @@ class AutoStrEnum(str, Enum):
|
|
|
12
12
|
) -> str:
|
|
13
13
|
return name.lower()
|
|
14
14
|
|
|
15
|
+
def __str__(self) -> str:
|
|
16
|
+
"""Return the value in upper case for better readability."""
|
|
17
|
+
return self.value.upper()
|
|
18
|
+
|
|
19
|
+
def __repr__(self) -> str:
|
|
20
|
+
"""Return the value in upper case for better readability."""
|
|
21
|
+
return self.value.upper()
|
|
22
|
+
|
|
15
23
|
|
|
16
24
|
class Device(AutoStrEnum):
|
|
17
25
|
"""The compute device to use for the evaluation.
|
|
@@ -60,6 +68,10 @@ class ModelType(AutoStrEnum):
|
|
|
60
68
|
ENCODER = auto()
|
|
61
69
|
GENERATIVE = auto()
|
|
62
70
|
|
|
71
|
+
def __repr__(self) -> str:
|
|
72
|
+
"""Return the value in upper case for better readability."""
|
|
73
|
+
return self.value.upper()
|
|
74
|
+
|
|
63
75
|
|
|
64
76
|
class GenerativeType(AutoStrEnum):
|
|
65
77
|
"""The type of a generative model.
|
euroeval/generation.py
CHANGED
|
@@ -307,7 +307,7 @@ def debug_log(
|
|
|
307
307
|
for label in batch["label"]
|
|
308
308
|
]
|
|
309
309
|
else:
|
|
310
|
-
labels = [
|
|
310
|
+
labels = [None] * len(extracted_labels)
|
|
311
311
|
|
|
312
312
|
case TaskGroup.QUESTION_ANSWERING:
|
|
313
313
|
extracted_labels = [
|
|
@@ -330,12 +330,21 @@ def debug_log(
|
|
|
330
330
|
else:
|
|
331
331
|
input_texts = batch["text"]
|
|
332
332
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
333
|
+
metadata_keys: list[str] = [
|
|
334
|
+
key
|
|
335
|
+
for key in batch.keys()
|
|
336
|
+
if key not in ["text", "messages", "label", "labels", "target_text"]
|
|
337
|
+
]
|
|
338
|
+
|
|
339
|
+
for idx in range(len(input_texts)):
|
|
340
|
+
data_to_log: dict[str, t.Any] = {
|
|
341
|
+
"Input": input_texts[idx],
|
|
342
|
+
"Raw output": model_output.sequences[idx],
|
|
343
|
+
"Prediction": extracted_labels[idx],
|
|
344
|
+
}
|
|
345
|
+
if labels[idx]:
|
|
346
|
+
data_to_log["Label"] = labels[idx]
|
|
347
|
+
data_to_log |= {key.capitalize(): batch[key][idx] for key in metadata_keys}
|
|
336
348
|
logger.info(
|
|
337
|
-
f"
|
|
338
|
-
f"Raw output: '{raw_output}'\n"
|
|
339
|
-
f"Prediction: '{prediction}'\n"
|
|
340
|
-
f"Label: '{label}'"
|
|
349
|
+
"\n".join(f"{key}: {value!r}" for key, value in data_to_log.items())
|
|
341
350
|
)
|
euroeval/generation_utils.py
CHANGED
|
@@ -4,12 +4,13 @@ import itertools as it
|
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
6
|
import random
|
|
7
|
+
import re
|
|
7
8
|
import typing as t
|
|
8
9
|
|
|
9
|
-
from .enums import TaskGroup
|
|
10
|
-
from .exceptions import InvalidBenchmark
|
|
11
|
-
from .
|
|
12
|
-
from .utils import log_once
|
|
10
|
+
from .enums import GenerativeType, TaskGroup
|
|
11
|
+
from .exceptions import InvalidBenchmark, InvalidModel
|
|
12
|
+
from .tokenisation_utils import apply_chat_template
|
|
13
|
+
from .utils import extract_multiple_choice_labels, log_once
|
|
13
14
|
|
|
14
15
|
if t.TYPE_CHECKING:
|
|
15
16
|
from datasets import DatasetDict
|
|
@@ -173,7 +174,7 @@ def apply_prompt(
|
|
|
173
174
|
few_shot_examples: list[dict[str, t.Any]],
|
|
174
175
|
model_config: "ModelConfig",
|
|
175
176
|
dataset_config: "DatasetConfig",
|
|
176
|
-
|
|
177
|
+
generative_type: GenerativeType | None,
|
|
177
178
|
always_populate_text_field: bool,
|
|
178
179
|
tokeniser: "PreTrainedTokenizer | None",
|
|
179
180
|
) -> dict[str, t.Any]:
|
|
@@ -184,10 +185,12 @@ def apply_prompt(
|
|
|
184
185
|
The examples to apply the few-shot examples to.
|
|
185
186
|
few_shot_examples:
|
|
186
187
|
The few-shot examples to apply.
|
|
188
|
+
model_config:
|
|
189
|
+
The model configuration.
|
|
187
190
|
dataset_config:
|
|
188
191
|
The dataset configuration.
|
|
189
|
-
|
|
190
|
-
|
|
192
|
+
generative_type:
|
|
193
|
+
The generative type of the model.
|
|
191
194
|
always_populate_text_field:
|
|
192
195
|
Whether to always populate the 'text' field in the examples, as opposed to
|
|
193
196
|
the 'messages' field.
|
|
@@ -198,7 +201,11 @@ def apply_prompt(
|
|
|
198
201
|
The example with the few-shot examples applied.
|
|
199
202
|
"""
|
|
200
203
|
# Sanity check
|
|
201
|
-
if
|
|
204
|
+
if (
|
|
205
|
+
generative_type == GenerativeType.INSTRUCTION_TUNED
|
|
206
|
+
and always_populate_text_field
|
|
207
|
+
and tokeniser is None
|
|
208
|
+
):
|
|
202
209
|
raise ValueError(
|
|
203
210
|
"The `tokeniser` argument must be provided when the model is instruction "
|
|
204
211
|
"tuned and when we are not just returning the raw messages."
|
|
@@ -222,7 +229,7 @@ def apply_prompt(
|
|
|
222
229
|
)
|
|
223
230
|
label_mapping = dataset_config.prompt_label_mapping
|
|
224
231
|
label = label_mapping.get(label, label)
|
|
225
|
-
if
|
|
232
|
+
if generative_type == GenerativeType.INSTRUCTION_TUNED:
|
|
226
233
|
prompt = dataset_config.instruction_prompt.format(**kwargs)
|
|
227
234
|
return prompt, label
|
|
228
235
|
else:
|
|
@@ -230,18 +237,49 @@ def apply_prompt(
|
|
|
230
237
|
return dataset_config.prompt_template.format(**kwargs), ""
|
|
231
238
|
|
|
232
239
|
match dataset_config.task.task_group:
|
|
233
|
-
case
|
|
234
|
-
|
|
235
|
-
):
|
|
240
|
+
case TaskGroup.SEQUENCE_CLASSIFICATION:
|
|
241
|
+
labels_str = dataset_config.get_labels_str()
|
|
236
242
|
few_shot_sections = [
|
|
237
243
|
create_prompt(
|
|
238
244
|
text=example["text"].replace("\n", " ").strip(),
|
|
239
245
|
label=example["label"].replace("\n", " ").strip(),
|
|
246
|
+
labels_str=labels_str,
|
|
240
247
|
)
|
|
241
248
|
for example in few_shot_examples
|
|
242
249
|
]
|
|
243
250
|
new_sections = [
|
|
244
|
-
create_prompt(
|
|
251
|
+
create_prompt(
|
|
252
|
+
text=text.replace("\n", " ").strip(),
|
|
253
|
+
label="",
|
|
254
|
+
labels_str=labels_str,
|
|
255
|
+
)
|
|
256
|
+
for text in examples["text"]
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
case TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
|
|
260
|
+
few_shot_sections = [
|
|
261
|
+
create_prompt(
|
|
262
|
+
text=example["text"].replace("\n", " ").strip(),
|
|
263
|
+
label=example["label"].replace("\n", " ").strip(),
|
|
264
|
+
labels_str=dataset_config.get_labels_str(
|
|
265
|
+
labels=extract_multiple_choice_labels(
|
|
266
|
+
prompt=example["text"],
|
|
267
|
+
candidate_labels=dataset_config.labels,
|
|
268
|
+
)
|
|
269
|
+
),
|
|
270
|
+
)
|
|
271
|
+
for example in few_shot_examples
|
|
272
|
+
]
|
|
273
|
+
new_sections = [
|
|
274
|
+
create_prompt(
|
|
275
|
+
text=text.replace("\n", " ").strip(),
|
|
276
|
+
label="",
|
|
277
|
+
labels_str=dataset_config.get_labels_str(
|
|
278
|
+
labels=extract_multiple_choice_labels(
|
|
279
|
+
prompt=text, candidate_labels=dataset_config.labels
|
|
280
|
+
)
|
|
281
|
+
),
|
|
282
|
+
)
|
|
245
283
|
for text in examples["text"]
|
|
246
284
|
]
|
|
247
285
|
|
|
@@ -259,6 +297,7 @@ def apply_prompt(
|
|
|
259
297
|
]
|
|
260
298
|
|
|
261
299
|
case TaskGroup.TOKEN_CLASSIFICATION:
|
|
300
|
+
labels_str = dataset_config.get_labels_str()
|
|
262
301
|
|
|
263
302
|
def create_label(example: dict) -> str:
|
|
264
303
|
prompt_labels = dataset_config.prompt_label_mapping.values()
|
|
@@ -280,12 +319,15 @@ def apply_prompt(
|
|
|
280
319
|
create_prompt(
|
|
281
320
|
text=" ".join(example["tokens"]).replace("\n", " ").strip(),
|
|
282
321
|
label=create_label(example=example),
|
|
322
|
+
labels_str=labels_str,
|
|
283
323
|
)
|
|
284
324
|
for example in few_shot_examples
|
|
285
325
|
]
|
|
286
326
|
new_sections = [
|
|
287
327
|
create_prompt(
|
|
288
|
-
text=" ".join(tokens).replace("\n", " ").strip(),
|
|
328
|
+
text=" ".join(tokens).replace("\n", " ").strip(),
|
|
329
|
+
label="",
|
|
330
|
+
labels_str=labels_str,
|
|
289
331
|
)
|
|
290
332
|
for tokens in examples["tokens"]
|
|
291
333
|
]
|
|
@@ -313,7 +355,7 @@ def apply_prompt(
|
|
|
313
355
|
f"Unsupported task group: {dataset_config.task.task_group}."
|
|
314
356
|
)
|
|
315
357
|
|
|
316
|
-
if
|
|
358
|
+
if generative_type == GenerativeType.INSTRUCTION_TUNED:
|
|
317
359
|
few_shot_messages = [
|
|
318
360
|
dict(role=role, content=content)
|
|
319
361
|
for prompt, label in few_shot_sections
|
|
@@ -327,7 +369,6 @@ def apply_prompt(
|
|
|
327
369
|
|
|
328
370
|
if not always_populate_text_field:
|
|
329
371
|
examples["messages"] = messages_list
|
|
330
|
-
|
|
331
372
|
else:
|
|
332
373
|
assert tokeniser is not None
|
|
333
374
|
|
|
@@ -354,6 +395,9 @@ def apply_prompt(
|
|
|
354
395
|
apply_chat_template(
|
|
355
396
|
conversation=messages,
|
|
356
397
|
tokeniser=tokeniser,
|
|
398
|
+
tokenise=False,
|
|
399
|
+
add_generation_prompt=True,
|
|
400
|
+
enable_thinking=(generative_type == GenerativeType.REASONING),
|
|
357
401
|
chat_template=chat_template,
|
|
358
402
|
)
|
|
359
403
|
for messages in messages_list
|
|
@@ -375,4 +419,46 @@ def apply_prompt(
|
|
|
375
419
|
for new_prompt, _ in new_sections
|
|
376
420
|
]
|
|
377
421
|
|
|
422
|
+
# Always add the final prompts without few-shot examples, too, for analysis
|
|
423
|
+
examples["prompt"] = [new_prompt for new_prompt, _ in new_sections]
|
|
424
|
+
|
|
378
425
|
return examples
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def raise_if_wrong_params(
|
|
429
|
+
model_config: "ModelConfig", allowed_params: dict[re.Pattern, list[str]]
|
|
430
|
+
) -> None:
|
|
431
|
+
"""Raise an error if the model configuration has invalid parameters.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
model_config:
|
|
435
|
+
The model configuration.
|
|
436
|
+
allowed_params:
|
|
437
|
+
The allowed parameters for the model, being a dictionary mapping a regex
|
|
438
|
+
pattern matching the model ID to a list of allowed parameters for those
|
|
439
|
+
models.
|
|
440
|
+
|
|
441
|
+
Raises:
|
|
442
|
+
InvalidModel:
|
|
443
|
+
If the model configuration has invalid parameters.
|
|
444
|
+
"""
|
|
445
|
+
if model_config.param is None:
|
|
446
|
+
return
|
|
447
|
+
for model_regex, allowed_params_list in allowed_params.items():
|
|
448
|
+
if re.fullmatch(pattern=model_regex, string=model_config.model_id):
|
|
449
|
+
if model_config.param not in allowed_params_list:
|
|
450
|
+
msg = (
|
|
451
|
+
f"Invalid parameter {model_config.param!r} for model "
|
|
452
|
+
f"{model_config.model_id!r}."
|
|
453
|
+
)
|
|
454
|
+
if allowed_params_list:
|
|
455
|
+
msg += f" Allowed parameters are: {', '.join(allowed_params_list)}."
|
|
456
|
+
else:
|
|
457
|
+
msg += " No parameters are allowed."
|
|
458
|
+
raise InvalidModel(msg)
|
|
459
|
+
return
|
|
460
|
+
else:
|
|
461
|
+
raise InvalidModel(
|
|
462
|
+
f"The parameter {model_config.param!r} is not supported for the model "
|
|
463
|
+
f"{model_config.model_id!r}."
|
|
464
|
+
)
|
euroeval/metrics/pipeline.py
CHANGED
|
@@ -26,6 +26,27 @@ logger: logging.Logger = logging.getLogger("euroeval")
|
|
|
26
26
|
T = t.TypeVar("T", bound=int | float | str | bool)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
class PreprocessingFunction(t.Protocol):
|
|
30
|
+
"""A protocol for a preprocessing function."""
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self, predictions: c.Sequence[int], dataset: "Dataset"
|
|
34
|
+
) -> c.Sequence[int]:
|
|
35
|
+
"""Preprocess the model predictions before they are passed to the pipeline.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
predictions:
|
|
39
|
+
The model predictions.
|
|
40
|
+
dataset:
|
|
41
|
+
The dataset used for evaluation. This is only used in case any
|
|
42
|
+
additional metadata is used to compute the metrics.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The preprocessed model predictions.
|
|
46
|
+
"""
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
|
|
29
50
|
class PipelineMetric(Metric):
|
|
30
51
|
"""Load a scikit-learn pipeline and use it to get scores from the predictions."""
|
|
31
52
|
|
|
@@ -36,7 +57,7 @@ class PipelineMetric(Metric):
|
|
|
36
57
|
pipeline_repo: str,
|
|
37
58
|
pipeline_scoring_function: c.Callable[["Pipeline", c.Sequence], float],
|
|
38
59
|
pipeline_file_name: str = "pipeline.pkl",
|
|
39
|
-
preprocessing_fn:
|
|
60
|
+
preprocessing_fn: PreprocessingFunction | None = None,
|
|
40
61
|
postprocessing_fn: c.Callable[[float], tuple[float, str]] | None = None,
|
|
41
62
|
) -> None:
|
|
42
63
|
"""Initialise the pipeline transform metric.
|
|
@@ -101,7 +122,10 @@ class PipelineMetric(Metric):
|
|
|
101
122
|
"""
|
|
102
123
|
if self.pipeline is None:
|
|
103
124
|
self.pipeline = self._download_pipeline()
|
|
104
|
-
|
|
125
|
+
if self.preprocessing_fn is not None:
|
|
126
|
+
predictions = self.preprocessing_fn(
|
|
127
|
+
predictions=predictions, dataset=dataset
|
|
128
|
+
)
|
|
105
129
|
return self.pipeline_scoring_function(self.pipeline, predictions)
|
|
106
130
|
|
|
107
131
|
def _download_pipeline(self) -> "Pipeline":
|
|
@@ -133,13 +157,18 @@ class PipelineMetric(Metric):
|
|
|
133
157
|
### European Values Metric ###
|
|
134
158
|
|
|
135
159
|
|
|
136
|
-
def european_values_preprocessing_fn(
|
|
160
|
+
def european_values_preprocessing_fn(
|
|
161
|
+
predictions: c.Sequence[int], dataset: "Dataset"
|
|
162
|
+
) -> c.Sequence[int]:
|
|
137
163
|
"""Preprocess the model predictions for the European Values metric.
|
|
138
164
|
|
|
139
165
|
Args:
|
|
140
166
|
predictions:
|
|
141
167
|
The model predictions, a sequence of integers representing the predicted
|
|
142
168
|
choices for each question.
|
|
169
|
+
dataset:
|
|
170
|
+
The dataset used for evaluation. This is only used in case any additional
|
|
171
|
+
metadata is used to compute the metrics.
|
|
143
172
|
|
|
144
173
|
Returns:
|
|
145
174
|
The preprocessed model predictions, a sequence of integers representing the
|
|
@@ -154,6 +183,17 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
|
|
|
154
183
|
num_questions = 53
|
|
155
184
|
num_phrasings_per_question = 5
|
|
156
185
|
|
|
186
|
+
# Convert the predictions to integers
|
|
187
|
+
integer_predictions = []
|
|
188
|
+
for prediction, idx_to_choice in zip(predictions, dataset["idx_to_choice"]):
|
|
189
|
+
idx_to_choice = {
|
|
190
|
+
int(idx): int(choice)
|
|
191
|
+
for idx, choice in idx_to_choice.items()
|
|
192
|
+
if choice is not None
|
|
193
|
+
}
|
|
194
|
+
integer_prediction = idx_to_choice[prediction]
|
|
195
|
+
integer_predictions.append(integer_prediction)
|
|
196
|
+
|
|
157
197
|
assert len(predictions) % num_questions == 0, (
|
|
158
198
|
f"The number of predictions ({len(predictions)}) is not a multiple of "
|
|
159
199
|
f"{num_questions}, which is required for the European Values metric."
|
|
@@ -171,13 +211,13 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
|
|
|
171
211
|
# Shape: (num_questions, num_phrasings_per_question)
|
|
172
212
|
arr = np.array(
|
|
173
213
|
[
|
|
174
|
-
|
|
214
|
+
integer_predictions[i : i + num_phrasings_per_question]
|
|
175
215
|
for i in range(0, len(predictions), num_phrasings_per_question)
|
|
176
216
|
]
|
|
177
217
|
)
|
|
178
218
|
|
|
179
219
|
# Double check that we reshaped the predictions correctly
|
|
180
|
-
for idx, pred in enumerate(
|
|
220
|
+
for idx, pred in enumerate(integer_predictions):
|
|
181
221
|
assert arr[idx // 5, idx % 5] == pred, (
|
|
182
222
|
f"Reshaped predictions do not match the original predictions at index "
|
|
183
223
|
f"{idx}: {arr[idx // 5, idx % 5]} != {pred}."
|
|
@@ -188,7 +228,7 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
|
|
|
188
228
|
arr = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=1, arr=arr)
|
|
189
229
|
|
|
190
230
|
# Convert the array to a list
|
|
191
|
-
|
|
231
|
+
integer_predictions = arr.tolist()
|
|
192
232
|
|
|
193
233
|
# Some of the questions are categorical and we're only interested in whether the
|
|
194
234
|
# model chooses a specific choice or not. This mapping takes the question index
|
|
@@ -208,11 +248,13 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
|
|
|
208
248
|
}
|
|
209
249
|
|
|
210
250
|
# Map the predictions to the choices we're interested in
|
|
211
|
-
|
|
251
|
+
integer_predictions = list(integer_predictions)
|
|
212
252
|
for question_idx, choice in question_choices.items():
|
|
213
|
-
|
|
253
|
+
integer_predictions[question_idx] = (
|
|
254
|
+
1 if integer_predictions[question_idx] == choice else 0
|
|
255
|
+
)
|
|
214
256
|
|
|
215
|
-
return
|
|
257
|
+
return integer_predictions
|
|
216
258
|
|
|
217
259
|
|
|
218
260
|
def european_values_scoring_function(
|
euroeval/model_cache.py
CHANGED
|
@@ -10,7 +10,9 @@ from dataclasses import asdict
|
|
|
10
10
|
|
|
11
11
|
from tqdm.auto import tqdm
|
|
12
12
|
|
|
13
|
+
from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
|
|
13
14
|
from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
|
|
15
|
+
from .utils import log_once
|
|
14
16
|
|
|
15
17
|
if t.TYPE_CHECKING:
|
|
16
18
|
from pathlib import Path
|
|
@@ -189,10 +191,20 @@ class ModelCache:
|
|
|
189
191
|
# the indices of the top scores, to save space. Further, we only store
|
|
190
192
|
# the scores if the generated sequence is shorter than the maximum
|
|
191
193
|
# length
|
|
192
|
-
if
|
|
194
|
+
if (
|
|
195
|
+
model_output.scores is not None
|
|
196
|
+
and self.max_generated_tokens
|
|
197
|
+
<= NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
|
|
198
|
+
):
|
|
193
199
|
assert model_output.scores is not None
|
|
194
200
|
scores = model_output.scores[sample_idx]
|
|
195
201
|
else:
|
|
202
|
+
if model_output.scores is not None:
|
|
203
|
+
log_once(
|
|
204
|
+
"The generated sequence is longer than the maximum "
|
|
205
|
+
"length for classification. Not caching the scores.",
|
|
206
|
+
level=logging.DEBUG,
|
|
207
|
+
)
|
|
196
208
|
scores = None
|
|
197
209
|
self[model_input] = SingleGenerativeModelOutput(
|
|
198
210
|
sequence=model_output.sequences[sample_idx], scores=scores
|
|
@@ -19,6 +19,7 @@ from ..languages import (
|
|
|
19
19
|
NL,
|
|
20
20
|
NN,
|
|
21
21
|
NO,
|
|
22
|
+
PL,
|
|
22
23
|
PT,
|
|
23
24
|
SV,
|
|
24
25
|
)
|
|
@@ -67,6 +68,14 @@ LA_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
67
68
|
default_instruction_prompt="Lause: {text}\n\nOtsusta, kas lause on "
|
|
68
69
|
"grammatiliselt õige või mitte. Vasta {labels_str}, ja mitte midagi muud.",
|
|
69
70
|
),
|
|
71
|
+
PL: PromptConfig(
|
|
72
|
+
default_prompt_label_mapping=dict(correct="tak", incorrect="nie"),
|
|
73
|
+
default_prompt_prefix="Poniżej znajdują się teksty i czy są "
|
|
74
|
+
"gramatycznie poprawne.",
|
|
75
|
+
default_prompt_template="Tekst: {text}\nGramatycznie poprawny: {label}",
|
|
76
|
+
default_instruction_prompt="Tekst: {text}\n\nOkreśl czy tekst jest "
|
|
77
|
+
"gramatycznie poprawny czy nie. Odpowiedz {labels_str}, i nic więcej.",
|
|
78
|
+
),
|
|
70
79
|
PT: PromptConfig(
|
|
71
80
|
default_prompt_label_mapping=dict(correct="sim", incorrect="não"),
|
|
72
81
|
default_prompt_prefix="Seguem-se abaixo textos e se são "
|