EuroEval 16.0.0__py3-none-any.whl → 16.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (51) hide show
  1. euroeval/__init__.py +5 -0
  2. euroeval/benchmark_config_factory.py +6 -1
  3. euroeval/benchmark_modules/base.py +2 -0
  4. euroeval/benchmark_modules/fresh.py +7 -1
  5. euroeval/benchmark_modules/hf.py +26 -21
  6. euroeval/benchmark_modules/litellm.py +258 -131
  7. euroeval/benchmark_modules/vllm.py +120 -68
  8. euroeval/benchmarker.py +11 -2
  9. euroeval/cli.py +14 -1
  10. euroeval/constants.py +7 -1
  11. euroeval/data_models.py +95 -20
  12. euroeval/dataset_configs/__init__.py +1 -0
  13. euroeval/dataset_configs/danish.py +14 -3
  14. euroeval/dataset_configs/dutch.py +14 -0
  15. euroeval/dataset_configs/english.py +22 -0
  16. euroeval/dataset_configs/estonian.py +15 -7
  17. euroeval/dataset_configs/finnish.py +14 -0
  18. euroeval/dataset_configs/french.py +14 -0
  19. euroeval/dataset_configs/german.py +23 -0
  20. euroeval/dataset_configs/italian.py +14 -0
  21. euroeval/dataset_configs/latvian.py +14 -0
  22. euroeval/dataset_configs/norwegian.py +14 -0
  23. euroeval/dataset_configs/polish.py +126 -0
  24. euroeval/dataset_configs/portuguese.py +14 -0
  25. euroeval/dataset_configs/spanish.py +14 -0
  26. euroeval/dataset_configs/swedish.py +25 -0
  27. euroeval/enums.py +12 -0
  28. euroeval/generation.py +17 -8
  29. euroeval/generation_utils.py +102 -16
  30. euroeval/metrics/pipeline.py +51 -9
  31. euroeval/model_cache.py +13 -1
  32. euroeval/prompt_templates/linguistic_acceptability.py +9 -0
  33. euroeval/prompt_templates/multiple_choice.py +27 -1
  34. euroeval/prompt_templates/named_entity_recognition.py +20 -0
  35. euroeval/prompt_templates/reading_comprehension.py +11 -0
  36. euroeval/prompt_templates/sentiment_classification.py +15 -0
  37. euroeval/prompt_templates/summarization.py +27 -1
  38. euroeval/scores.py +5 -0
  39. euroeval/task_group_utils/multiple_choice_classification.py +2 -2
  40. euroeval/task_group_utils/question_answering.py +29 -29
  41. euroeval/task_group_utils/sequence_classification.py +71 -81
  42. euroeval/task_group_utils/token_classification.py +17 -3
  43. euroeval/tasks.py +12 -10
  44. euroeval/{tokenization_utils.py → tokenisation_utils.py} +41 -25
  45. euroeval/utils.py +67 -3
  46. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/METADATA +3 -1
  47. euroeval-16.1.0.dist-info/RECORD +70 -0
  48. euroeval-16.0.0.dist-info/RECORD +0 -69
  49. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/WHEEL +0 -0
  50. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/entry_points.txt +0 -0
  51. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,126 @@
1
+ """All Polish dataset configurations used in EuroEval."""
2
+
3
+ from ..data_models import DatasetConfig
4
+ from ..enums import ModelType
5
+ from ..languages import PL
6
+ from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, NER, RC, SENT, SUMM
7
+
8
+ ### Official datasets ###
9
+
10
+ POLEMO2_CONFIG = DatasetConfig(
11
+ name="polemo2",
12
+ pretty_name="the Polish sentiment classification dataset PolEmo2",
13
+ huggingface_id="EuroEval/polemo2-mini",
14
+ task=SENT,
15
+ languages=[PL],
16
+ )
17
+
18
+ SCALA_PL_CONFIG = DatasetConfig(
19
+ name="scala-pl",
20
+ pretty_name="the Polish part of the linguistic acceptability dataset ScaLA",
21
+ huggingface_id="EuroEval/scala-pl",
22
+ task=LA,
23
+ languages=[PL],
24
+ )
25
+
26
+ KPWR_NER_CONFIG = DatasetConfig(
27
+ name="kpwr-ner",
28
+ pretty_name="the Polish entity recognition dataset KPWr-NER",
29
+ huggingface_id="EuroEval/kpwr-ner",
30
+ task=NER,
31
+ languages=[PL],
32
+ )
33
+
34
+ POQUAD_CONFIG = DatasetConfig(
35
+ name="poquad",
36
+ pretty_name="the Polish question answering dataset PoQuAD",
37
+ huggingface_id="EuroEval/poquad-mini",
38
+ task=RC,
39
+ languages=[PL],
40
+ )
41
+
42
+ PSC_CONFIG = DatasetConfig(
43
+ name="psc",
44
+ pretty_name="the Polish summarisation dataset PSC",
45
+ huggingface_id="EuroEval/psc-mini",
46
+ task=SUMM,
47
+ languages=[PL],
48
+ )
49
+
50
+ LLMZSZL_CONFIG = DatasetConfig(
51
+ name="llmzszl",
52
+ pretty_name="the Polish knowledge dataset LLMzSzŁ",
53
+ huggingface_id="EuroEval/llmzszl-mini",
54
+ task=KNOW,
55
+ languages=[PL],
56
+ )
57
+
58
+ WINOGRANDE_PL_CONFIG = DatasetConfig(
59
+ name="winogrande-pl",
60
+ pretty_name="the Polish common-sense reasoning dataset Winogrande-pl, translated "
61
+ "from the English Winogrande dataset",
62
+ huggingface_id="EuroEval/winogrande-pl",
63
+ task=COMMON_SENSE,
64
+ languages=[PL],
65
+ splits=["train", "test"],
66
+ _labels=["a", "b"],
67
+ _allowed_model_types=[ModelType.GENERATIVE],
68
+ )
69
+
70
+ EUROPEAN_VALUES_PL_CONFIG = DatasetConfig(
71
+ name="european-values-pl",
72
+ pretty_name="the Polish version of the European values evaluation dataset",
73
+ huggingface_id="EuroEval/european-values-pl",
74
+ task=EUROPEAN_VALUES,
75
+ languages=[PL],
76
+ splits=["test"],
77
+ bootstrap_samples=False,
78
+ _instruction_prompt="{text}",
79
+ )
80
+
81
+
82
+ ### Unofficial datasets ###
83
+
84
+ MULTI_WIKI_QA_PL_CONFIG = DatasetConfig(
85
+ name="multi-wiki-qa-pl",
86
+ pretty_name="the truncated version of the Polish part of the reading "
87
+ "comprehension dataset MultiWikiQA",
88
+ huggingface_id="EuroEval/multi-wiki-qa-pl-mini",
89
+ task=RC,
90
+ languages=[PL],
91
+ unofficial=True,
92
+ )
93
+
94
+ GOLDENSWAG_PL_CONFIG = DatasetConfig(
95
+ name="goldenswag-pl",
96
+ pretty_name="the truncated version of the Polish common-sense reasoning "
97
+ "dataset GoldenSwag-pl, translated from the English GoldenSwag dataset",
98
+ huggingface_id="EuroEval/goldenswag-pl-mini",
99
+ task=COMMON_SENSE,
100
+ languages=[PL],
101
+ unofficial=True,
102
+ )
103
+
104
+ EUROPEAN_VALUES_SITUATIONAL_PL_CONFIG = DatasetConfig(
105
+ name="european-values-situational-pl",
106
+ pretty_name="the Polish version of the European values evaluation dataset, where "
107
+ "the questions are phrased in a situational way",
108
+ huggingface_id="EuroEval/european-values-situational-pl",
109
+ task=EUROPEAN_VALUES,
110
+ languages=[PL],
111
+ splits=["test"],
112
+ bootstrap_samples=False,
113
+ unofficial=True,
114
+ )
115
+
116
+ EUROPEAN_VALUES_COMPLETIONS_PL_CONFIG = DatasetConfig(
117
+ name="european-values-completions-pl",
118
+ pretty_name="the Polish version of the European values evaluation dataset, where "
119
+ "the questions are phrased as sentence completions",
120
+ huggingface_id="EuroEval/european-values-completions-pl",
121
+ task=EUROPEAN_VALUES,
122
+ languages=[PL],
123
+ splits=["test"],
124
+ bootstrap_samples=False,
125
+ unofficial=True,
126
+ )
@@ -1,6 +1,7 @@
1
1
  """All Portuguese dataset configurations used in EuroEval."""
2
2
 
3
3
  from ..data_models import DatasetConfig
4
+ from ..enums import ModelType
4
5
  from ..languages import PT
5
6
  from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
6
7
 
@@ -91,6 +92,19 @@ BOOLQ_PT_CONFIG = DatasetConfig(
91
92
  unofficial=True,
92
93
  )
93
94
 
95
+ WINOGRANDE_PT_CONFIG = DatasetConfig(
96
+ name="winogrande-pt",
97
+ pretty_name="the Portuguese common-sense reasoning dataset Winogrande-pt, "
98
+ "translated from the English Winogrande dataset",
99
+ huggingface_id="EuroEval/winogrande-pt",
100
+ task=COMMON_SENSE,
101
+ languages=[PT],
102
+ splits=["train", "test"],
103
+ _labels=["a", "b"],
104
+ _allowed_model_types=[ModelType.GENERATIVE],
105
+ unofficial=True,
106
+ )
107
+
94
108
  EUROPEAN_VALUES_SITUATIONAL_PT_CONFIG = DatasetConfig(
95
109
  name="european-values-situational-pt",
96
110
  pretty_name="the Portuguese version of the European values evaluation dataset, "
@@ -1,6 +1,7 @@
1
1
  """All Spanish dataset configurations used in EuroEval."""
2
2
 
3
3
  from ..data_models import DatasetConfig
4
+ from ..enums import ModelType
4
5
  from ..languages import ES
5
6
  from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
6
7
 
@@ -119,6 +120,19 @@ GOLDENSWAG_ES_CONFIG = DatasetConfig(
119
120
  unofficial=True,
120
121
  )
121
122
 
123
+ WINOGRANDE_ES_CONFIG = DatasetConfig(
124
+ name="winogrande-es",
125
+ pretty_name="the Spanish common-sense reasoning dataset Winogrande-es, translated "
126
+ "from the English Winogrande dataset",
127
+ huggingface_id="EuroEval/winogrande-es",
128
+ task=COMMON_SENSE,
129
+ languages=[ES],
130
+ splits=["train", "test"],
131
+ _labels=["a", "b"],
132
+ _allowed_model_types=[ModelType.GENERATIVE],
133
+ unofficial=True,
134
+ )
135
+
122
136
  EUROPEAN_VALUES_SITUATIONAL_ES_CONFIG = DatasetConfig(
123
137
  name="european-values-situational-es",
124
138
  pretty_name="the Spanish version of the European values evaluation dataset, where "
@@ -1,6 +1,7 @@
1
1
  """All Swedish dataset configurations used in EuroEval."""
2
2
 
3
3
  from ..data_models import DatasetConfig
4
+ from ..enums import ModelType
4
5
  from ..languages import SV
5
6
  from ..tasks import COMMON_SENSE, EUROPEAN_VALUES, KNOW, LA, MCRC, NER, RC, SENT, SUMM
6
7
 
@@ -130,6 +131,19 @@ GOLDENSWAG_SV_CONFIG = DatasetConfig(
130
131
  unofficial=True,
131
132
  )
132
133
 
134
+ WINOGRANDE_SV_CONFIG = DatasetConfig(
135
+ name="winogrande-sv",
136
+ pretty_name="the Swedish common-sense reasoning dataset Winogrande-sv, translated "
137
+ "from the English Winogrande dataset",
138
+ huggingface_id="EuroEval/winogrande-sv",
139
+ task=COMMON_SENSE,
140
+ languages=[SV],
141
+ splits=["train", "test"],
142
+ _labels=["a", "b"],
143
+ _allowed_model_types=[ModelType.GENERATIVE],
144
+ unofficial=True,
145
+ )
146
+
133
147
  EUROPEAN_VALUES_SITUATIONAL_SV_CONFIG = DatasetConfig(
134
148
  name="european-values-situational-sv",
135
149
  pretty_name="the Swedish version of the European values evaluation dataset, where "
@@ -155,3 +169,14 @@ EUROPEAN_VALUES_COMPLETIONS_SV_CONFIG = DatasetConfig(
155
169
  _instruction_prompt="{text}",
156
170
  unofficial=True,
157
171
  )
172
+
173
+ SKOLPROV_CONFIG = DatasetConfig(
174
+ name="skolprov",
175
+ pretty_name="the Swedish knowledge dataset Skolprov",
176
+ huggingface_id="EuroEval/skolprov",
177
+ task=KNOW,
178
+ languages=[SV],
179
+ splits=["train", "test"],
180
+ _allowed_model_types=[ModelType.GENERATIVE],
181
+ unofficial=True,
182
+ )
euroeval/enums.py CHANGED
@@ -12,6 +12,14 @@ class AutoStrEnum(str, Enum):
12
12
  ) -> str:
13
13
  return name.lower()
14
14
 
15
+ def __str__(self) -> str:
16
+ """Return the value in upper case for better readability."""
17
+ return self.value.upper()
18
+
19
+ def __repr__(self) -> str:
20
+ """Return the value in upper case for better readability."""
21
+ return self.value.upper()
22
+
15
23
 
16
24
  class Device(AutoStrEnum):
17
25
  """The compute device to use for the evaluation.
@@ -60,6 +68,10 @@ class ModelType(AutoStrEnum):
60
68
  ENCODER = auto()
61
69
  GENERATIVE = auto()
62
70
 
71
+ def __repr__(self) -> str:
72
+ """Return the value in upper case for better readability."""
73
+ return self.value.upper()
74
+
63
75
 
64
76
  class GenerativeType(AutoStrEnum):
65
77
  """The type of a generative model.
euroeval/generation.py CHANGED
@@ -307,7 +307,7 @@ def debug_log(
307
307
  for label in batch["label"]
308
308
  ]
309
309
  else:
310
- labels = ["N/A"] * len(extracted_labels)
310
+ labels = [None] * len(extracted_labels)
311
311
 
312
312
  case TaskGroup.QUESTION_ANSWERING:
313
313
  extracted_labels = [
@@ -330,12 +330,21 @@ def debug_log(
330
330
  else:
331
331
  input_texts = batch["text"]
332
332
 
333
- for input_text, raw_output, prediction, label in zip(
334
- input_texts, model_output.sequences, extracted_labels, labels
335
- ):
333
+ metadata_keys: list[str] = [
334
+ key
335
+ for key in batch.keys()
336
+ if key not in ["text", "messages", "label", "labels", "target_text"]
337
+ ]
338
+
339
+ for idx in range(len(input_texts)):
340
+ data_to_log: dict[str, t.Any] = {
341
+ "Input": input_texts[idx],
342
+ "Raw output": model_output.sequences[idx],
343
+ "Prediction": extracted_labels[idx],
344
+ }
345
+ if labels[idx]:
346
+ data_to_log["Label"] = labels[idx]
347
+ data_to_log |= {key.capitalize(): batch[key][idx] for key in metadata_keys}
336
348
  logger.info(
337
- f"Input: '{input_text}'\n"
338
- f"Raw output: '{raw_output}'\n"
339
- f"Prediction: '{prediction}'\n"
340
- f"Label: '{label}'"
349
+ "\n".join(f"{key}: {value!r}" for key, value in data_to_log.items())
341
350
  )
@@ -4,12 +4,13 @@ import itertools as it
4
4
  import json
5
5
  import logging
6
6
  import random
7
+ import re
7
8
  import typing as t
8
9
 
9
- from .enums import TaskGroup
10
- from .exceptions import InvalidBenchmark
11
- from .tokenization_utils import apply_chat_template
12
- from .utils import log_once
10
+ from .enums import GenerativeType, TaskGroup
11
+ from .exceptions import InvalidBenchmark, InvalidModel
12
+ from .tokenisation_utils import apply_chat_template
13
+ from .utils import extract_multiple_choice_labels, log_once
13
14
 
14
15
  if t.TYPE_CHECKING:
15
16
  from datasets import DatasetDict
@@ -173,7 +174,7 @@ def apply_prompt(
173
174
  few_shot_examples: list[dict[str, t.Any]],
174
175
  model_config: "ModelConfig",
175
176
  dataset_config: "DatasetConfig",
176
- instruction_model: bool,
177
+ generative_type: GenerativeType | None,
177
178
  always_populate_text_field: bool,
178
179
  tokeniser: "PreTrainedTokenizer | None",
179
180
  ) -> dict[str, t.Any]:
@@ -184,10 +185,12 @@ def apply_prompt(
184
185
  The examples to apply the few-shot examples to.
185
186
  few_shot_examples:
186
187
  The few-shot examples to apply.
188
+ model_config:
189
+ The model configuration.
187
190
  dataset_config:
188
191
  The dataset configuration.
189
- instruction_model:
190
- Whether the model is instruction-tuned.
192
+ generative_type:
193
+ The generative type of the model.
191
194
  always_populate_text_field:
192
195
  Whether to always populate the 'text' field in the examples, as opposed to
193
196
  the 'messages' field.
@@ -198,7 +201,11 @@ def apply_prompt(
198
201
  The example with the few-shot examples applied.
199
202
  """
200
203
  # Sanity check
201
- if instruction_model and always_populate_text_field and tokeniser is None:
204
+ if (
205
+ generative_type == GenerativeType.INSTRUCTION_TUNED
206
+ and always_populate_text_field
207
+ and tokeniser is None
208
+ ):
202
209
  raise ValueError(
203
210
  "The `tokeniser` argument must be provided when the model is instruction "
204
211
  "tuned and when we are not just returning the raw messages."
@@ -222,7 +229,7 @@ def apply_prompt(
222
229
  )
223
230
  label_mapping = dataset_config.prompt_label_mapping
224
231
  label = label_mapping.get(label, label)
225
- if instruction_model:
232
+ if generative_type == GenerativeType.INSTRUCTION_TUNED:
226
233
  prompt = dataset_config.instruction_prompt.format(**kwargs)
227
234
  return prompt, label
228
235
  else:
@@ -230,18 +237,49 @@ def apply_prompt(
230
237
  return dataset_config.prompt_template.format(**kwargs), ""
231
238
 
232
239
  match dataset_config.task.task_group:
233
- case (
234
- TaskGroup.SEQUENCE_CLASSIFICATION | TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
235
- ):
240
+ case TaskGroup.SEQUENCE_CLASSIFICATION:
241
+ labels_str = dataset_config.get_labels_str()
236
242
  few_shot_sections = [
237
243
  create_prompt(
238
244
  text=example["text"].replace("\n", " ").strip(),
239
245
  label=example["label"].replace("\n", " ").strip(),
246
+ labels_str=labels_str,
240
247
  )
241
248
  for example in few_shot_examples
242
249
  ]
243
250
  new_sections = [
244
- create_prompt(text=text.replace("\n", " ").strip(), label="")
251
+ create_prompt(
252
+ text=text.replace("\n", " ").strip(),
253
+ label="",
254
+ labels_str=labels_str,
255
+ )
256
+ for text in examples["text"]
257
+ ]
258
+
259
+ case TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
260
+ few_shot_sections = [
261
+ create_prompt(
262
+ text=example["text"].replace("\n", " ").strip(),
263
+ label=example["label"].replace("\n", " ").strip(),
264
+ labels_str=dataset_config.get_labels_str(
265
+ labels=extract_multiple_choice_labels(
266
+ prompt=example["text"],
267
+ candidate_labels=dataset_config.labels,
268
+ )
269
+ ),
270
+ )
271
+ for example in few_shot_examples
272
+ ]
273
+ new_sections = [
274
+ create_prompt(
275
+ text=text.replace("\n", " ").strip(),
276
+ label="",
277
+ labels_str=dataset_config.get_labels_str(
278
+ labels=extract_multiple_choice_labels(
279
+ prompt=text, candidate_labels=dataset_config.labels
280
+ )
281
+ ),
282
+ )
245
283
  for text in examples["text"]
246
284
  ]
247
285
 
@@ -259,6 +297,7 @@ def apply_prompt(
259
297
  ]
260
298
 
261
299
  case TaskGroup.TOKEN_CLASSIFICATION:
300
+ labels_str = dataset_config.get_labels_str()
262
301
 
263
302
  def create_label(example: dict) -> str:
264
303
  prompt_labels = dataset_config.prompt_label_mapping.values()
@@ -280,12 +319,15 @@ def apply_prompt(
280
319
  create_prompt(
281
320
  text=" ".join(example["tokens"]).replace("\n", " ").strip(),
282
321
  label=create_label(example=example),
322
+ labels_str=labels_str,
283
323
  )
284
324
  for example in few_shot_examples
285
325
  ]
286
326
  new_sections = [
287
327
  create_prompt(
288
- text=" ".join(tokens).replace("\n", " ").strip(), label=""
328
+ text=" ".join(tokens).replace("\n", " ").strip(),
329
+ label="",
330
+ labels_str=labels_str,
289
331
  )
290
332
  for tokens in examples["tokens"]
291
333
  ]
@@ -313,7 +355,7 @@ def apply_prompt(
313
355
  f"Unsupported task group: {dataset_config.task.task_group}."
314
356
  )
315
357
 
316
- if instruction_model:
358
+ if generative_type == GenerativeType.INSTRUCTION_TUNED:
317
359
  few_shot_messages = [
318
360
  dict(role=role, content=content)
319
361
  for prompt, label in few_shot_sections
@@ -327,7 +369,6 @@ def apply_prompt(
327
369
 
328
370
  if not always_populate_text_field:
329
371
  examples["messages"] = messages_list
330
-
331
372
  else:
332
373
  assert tokeniser is not None
333
374
 
@@ -354,6 +395,9 @@ def apply_prompt(
354
395
  apply_chat_template(
355
396
  conversation=messages,
356
397
  tokeniser=tokeniser,
398
+ tokenise=False,
399
+ add_generation_prompt=True,
400
+ enable_thinking=(generative_type == GenerativeType.REASONING),
357
401
  chat_template=chat_template,
358
402
  )
359
403
  for messages in messages_list
@@ -375,4 +419,46 @@ def apply_prompt(
375
419
  for new_prompt, _ in new_sections
376
420
  ]
377
421
 
422
+ # Always add the final prompts without few-shot examples, too, for analysis
423
+ examples["prompt"] = [new_prompt for new_prompt, _ in new_sections]
424
+
378
425
  return examples
426
+
427
+
428
+ def raise_if_wrong_params(
429
+ model_config: "ModelConfig", allowed_params: dict[re.Pattern, list[str]]
430
+ ) -> None:
431
+ """Raise an error if the model configuration has invalid parameters.
432
+
433
+ Args:
434
+ model_config:
435
+ The model configuration.
436
+ allowed_params:
437
+ The allowed parameters for the model, being a dictionary mapping a regex
438
+ pattern matching the model ID to a list of allowed parameters for those
439
+ models.
440
+
441
+ Raises:
442
+ InvalidModel:
443
+ If the model configuration has invalid parameters.
444
+ """
445
+ if model_config.param is None:
446
+ return
447
+ for model_regex, allowed_params_list in allowed_params.items():
448
+ if re.fullmatch(pattern=model_regex, string=model_config.model_id):
449
+ if model_config.param not in allowed_params_list:
450
+ msg = (
451
+ f"Invalid parameter {model_config.param!r} for model "
452
+ f"{model_config.model_id!r}."
453
+ )
454
+ if allowed_params_list:
455
+ msg += f" Allowed parameters are: {', '.join(allowed_params_list)}."
456
+ else:
457
+ msg += " No parameters are allowed."
458
+ raise InvalidModel(msg)
459
+ return
460
+ else:
461
+ raise InvalidModel(
462
+ f"The parameter {model_config.param!r} is not supported for the model "
463
+ f"{model_config.model_id!r}."
464
+ )
@@ -26,6 +26,27 @@ logger: logging.Logger = logging.getLogger("euroeval")
26
26
  T = t.TypeVar("T", bound=int | float | str | bool)
27
27
 
28
28
 
29
+ class PreprocessingFunction(t.Protocol):
30
+ """A protocol for a preprocessing function."""
31
+
32
+ def __call__(
33
+ self, predictions: c.Sequence[int], dataset: "Dataset"
34
+ ) -> c.Sequence[int]:
35
+ """Preprocess the model predictions before they are passed to the pipeline.
36
+
37
+ Args:
38
+ predictions:
39
+ The model predictions.
40
+ dataset:
41
+ The dataset used for evaluation. This is only used in case any
42
+ additional metadata is used to compute the metrics.
43
+
44
+ Returns:
45
+ The preprocessed model predictions.
46
+ """
47
+ ...
48
+
49
+
29
50
  class PipelineMetric(Metric):
30
51
  """Load a scikit-learn pipeline and use it to get scores from the predictions."""
31
52
 
@@ -36,7 +57,7 @@ class PipelineMetric(Metric):
36
57
  pipeline_repo: str,
37
58
  pipeline_scoring_function: c.Callable[["Pipeline", c.Sequence], float],
38
59
  pipeline_file_name: str = "pipeline.pkl",
39
- preprocessing_fn: c.Callable[[c.Sequence[T]], c.Sequence[T]] = lambda x: x,
60
+ preprocessing_fn: PreprocessingFunction | None = None,
40
61
  postprocessing_fn: c.Callable[[float], tuple[float, str]] | None = None,
41
62
  ) -> None:
42
63
  """Initialise the pipeline transform metric.
@@ -101,7 +122,10 @@ class PipelineMetric(Metric):
101
122
  """
102
123
  if self.pipeline is None:
103
124
  self.pipeline = self._download_pipeline()
104
- predictions = self.preprocessing_fn(predictions)
125
+ if self.preprocessing_fn is not None:
126
+ predictions = self.preprocessing_fn(
127
+ predictions=predictions, dataset=dataset
128
+ )
105
129
  return self.pipeline_scoring_function(self.pipeline, predictions)
106
130
 
107
131
  def _download_pipeline(self) -> "Pipeline":
@@ -133,13 +157,18 @@ class PipelineMetric(Metric):
133
157
  ### European Values Metric ###
134
158
 
135
159
 
136
- def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence[int]:
160
+ def european_values_preprocessing_fn(
161
+ predictions: c.Sequence[int], dataset: "Dataset"
162
+ ) -> c.Sequence[int]:
137
163
  """Preprocess the model predictions for the European Values metric.
138
164
 
139
165
  Args:
140
166
  predictions:
141
167
  The model predictions, a sequence of integers representing the predicted
142
168
  choices for each question.
169
+ dataset:
170
+ The dataset used for evaluation. This is only used in case any additional
171
+ metadata is used to compute the metrics.
143
172
 
144
173
  Returns:
145
174
  The preprocessed model predictions, a sequence of integers representing the
@@ -154,6 +183,17 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
154
183
  num_questions = 53
155
184
  num_phrasings_per_question = 5
156
185
 
186
+ # Convert the predictions to integers
187
+ integer_predictions = []
188
+ for prediction, idx_to_choice in zip(predictions, dataset["idx_to_choice"]):
189
+ idx_to_choice = {
190
+ int(idx): int(choice)
191
+ for idx, choice in idx_to_choice.items()
192
+ if choice is not None
193
+ }
194
+ integer_prediction = idx_to_choice[prediction]
195
+ integer_predictions.append(integer_prediction)
196
+
157
197
  assert len(predictions) % num_questions == 0, (
158
198
  f"The number of predictions ({len(predictions)}) is not a multiple of "
159
199
  f"{num_questions}, which is required for the European Values metric."
@@ -171,13 +211,13 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
171
211
  # Shape: (num_questions, num_phrasings_per_question)
172
212
  arr = np.array(
173
213
  [
174
- predictions[i : i + num_phrasings_per_question]
214
+ integer_predictions[i : i + num_phrasings_per_question]
175
215
  for i in range(0, len(predictions), num_phrasings_per_question)
176
216
  ]
177
217
  )
178
218
 
179
219
  # Double check that we reshaped the predictions correctly
180
- for idx, pred in enumerate(predictions):
220
+ for idx, pred in enumerate(integer_predictions):
181
221
  assert arr[idx // 5, idx % 5] == pred, (
182
222
  f"Reshaped predictions do not match the original predictions at index "
183
223
  f"{idx}: {arr[idx // 5, idx % 5]} != {pred}."
@@ -188,7 +228,7 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
188
228
  arr = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=1, arr=arr)
189
229
 
190
230
  # Convert the array to a list
191
- predictions = arr.tolist()
231
+ integer_predictions = arr.tolist()
192
232
 
193
233
  # Some of the questions are categorical and we're only interested in whether the
194
234
  # model chooses a specific choice or not. This mapping takes the question index
@@ -208,11 +248,13 @@ def european_values_preprocessing_fn(predictions: c.Sequence[int]) -> c.Sequence
208
248
  }
209
249
 
210
250
  # Map the predictions to the choices we're interested in
211
- predictions = list(predictions)
251
+ integer_predictions = list(integer_predictions)
212
252
  for question_idx, choice in question_choices.items():
213
- predictions[question_idx] = 1 if predictions[question_idx] == choice else 0
253
+ integer_predictions[question_idx] = (
254
+ 1 if integer_predictions[question_idx] == choice else 0
255
+ )
214
256
 
215
- return predictions
257
+ return integer_predictions
216
258
 
217
259
 
218
260
  def european_values_scoring_function(
euroeval/model_cache.py CHANGED
@@ -10,7 +10,9 @@ from dataclasses import asdict
10
10
 
11
11
  from tqdm.auto import tqdm
12
12
 
13
+ from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
13
14
  from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
15
+ from .utils import log_once
14
16
 
15
17
  if t.TYPE_CHECKING:
16
18
  from pathlib import Path
@@ -189,10 +191,20 @@ class ModelCache:
189
191
  # the indices of the top scores, to save space. Further, we only store
190
192
  # the scores if the generated sequence is shorter than the maximum
191
193
  # length
192
- if model_output.scores is not None and self.max_generated_tokens < 8:
194
+ if (
195
+ model_output.scores is not None
196
+ and self.max_generated_tokens
197
+ <= NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
198
+ ):
193
199
  assert model_output.scores is not None
194
200
  scores = model_output.scores[sample_idx]
195
201
  else:
202
+ if model_output.scores is not None:
203
+ log_once(
204
+ "The generated sequence is longer than the maximum "
205
+ "length for classification. Not caching the scores.",
206
+ level=logging.DEBUG,
207
+ )
196
208
  scores = None
197
209
  self[model_input] = SingleGenerativeModelOutput(
198
210
  sequence=model_output.sequences[sample_idx], scores=scores
@@ -19,6 +19,7 @@ from ..languages import (
19
19
  NL,
20
20
  NN,
21
21
  NO,
22
+ PL,
22
23
  PT,
23
24
  SV,
24
25
  )
@@ -67,6 +68,14 @@ LA_TEMPLATES: dict["Language", PromptConfig] = {
67
68
  default_instruction_prompt="Lause: {text}\n\nOtsusta, kas lause on "
68
69
  "grammatiliselt õige või mitte. Vasta {labels_str}, ja mitte midagi muud.",
69
70
  ),
71
+ PL: PromptConfig(
72
+ default_prompt_label_mapping=dict(correct="tak", incorrect="nie"),
73
+ default_prompt_prefix="Poniżej znajdują się teksty i czy są "
74
+ "gramatycznie poprawne.",
75
+ default_prompt_template="Tekst: {text}\nGramatycznie poprawny: {label}",
76
+ default_instruction_prompt="Tekst: {text}\n\nOkreśl czy tekst jest "
77
+ "gramatycznie poprawny czy nie. Odpowiedz {labels_str}, i nic więcej.",
78
+ ),
70
79
  PT: PromptConfig(
71
80
  default_prompt_label_mapping=dict(correct="sim", incorrect="não"),
72
81
  default_prompt_prefix="Seguem-se abaixo textos e se são "