EuroEval 16.0.0__py3-none-any.whl → 16.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (51) hide show
  1. euroeval/__init__.py +5 -0
  2. euroeval/benchmark_config_factory.py +6 -1
  3. euroeval/benchmark_modules/base.py +2 -0
  4. euroeval/benchmark_modules/fresh.py +7 -1
  5. euroeval/benchmark_modules/hf.py +26 -21
  6. euroeval/benchmark_modules/litellm.py +258 -131
  7. euroeval/benchmark_modules/vllm.py +120 -68
  8. euroeval/benchmarker.py +11 -2
  9. euroeval/cli.py +14 -1
  10. euroeval/constants.py +7 -1
  11. euroeval/data_models.py +95 -20
  12. euroeval/dataset_configs/__init__.py +1 -0
  13. euroeval/dataset_configs/danish.py +14 -3
  14. euroeval/dataset_configs/dutch.py +14 -0
  15. euroeval/dataset_configs/english.py +22 -0
  16. euroeval/dataset_configs/estonian.py +15 -7
  17. euroeval/dataset_configs/finnish.py +14 -0
  18. euroeval/dataset_configs/french.py +14 -0
  19. euroeval/dataset_configs/german.py +23 -0
  20. euroeval/dataset_configs/italian.py +14 -0
  21. euroeval/dataset_configs/latvian.py +14 -0
  22. euroeval/dataset_configs/norwegian.py +14 -0
  23. euroeval/dataset_configs/polish.py +126 -0
  24. euroeval/dataset_configs/portuguese.py +14 -0
  25. euroeval/dataset_configs/spanish.py +14 -0
  26. euroeval/dataset_configs/swedish.py +25 -0
  27. euroeval/enums.py +12 -0
  28. euroeval/generation.py +17 -8
  29. euroeval/generation_utils.py +102 -16
  30. euroeval/metrics/pipeline.py +51 -9
  31. euroeval/model_cache.py +13 -1
  32. euroeval/prompt_templates/linguistic_acceptability.py +9 -0
  33. euroeval/prompt_templates/multiple_choice.py +27 -1
  34. euroeval/prompt_templates/named_entity_recognition.py +20 -0
  35. euroeval/prompt_templates/reading_comprehension.py +11 -0
  36. euroeval/prompt_templates/sentiment_classification.py +15 -0
  37. euroeval/prompt_templates/summarization.py +27 -1
  38. euroeval/scores.py +5 -0
  39. euroeval/task_group_utils/multiple_choice_classification.py +2 -2
  40. euroeval/task_group_utils/question_answering.py +29 -29
  41. euroeval/task_group_utils/sequence_classification.py +71 -81
  42. euroeval/task_group_utils/token_classification.py +17 -3
  43. euroeval/tasks.py +12 -10
  44. euroeval/{tokenization_utils.py → tokenisation_utils.py} +41 -25
  45. euroeval/utils.py +67 -3
  46. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/METADATA +3 -1
  47. euroeval-16.1.0.dist-info/RECORD +70 -0
  48. euroeval-16.0.0.dist-info/RECORD +0 -69
  49. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/WHEEL +0 -0
  50. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/entry_points.txt +0 -0
  51. {euroeval-16.0.0.dist-info → euroeval-16.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,25 @@
3
3
  import typing as t
4
4
 
5
5
  from ..data_models import PromptConfig
6
- from ..languages import DA, DE, EN, ES, ET, FI, FR, IS, IT, LV, NB, NL, NN, NO, PT, SV
6
+ from ..languages import (
7
+ DA,
8
+ DE,
9
+ EN,
10
+ ES,
11
+ ET,
12
+ FI,
13
+ FR,
14
+ IS,
15
+ IT,
16
+ LV,
17
+ NB,
18
+ NL,
19
+ NN,
20
+ NO,
21
+ PL,
22
+ PT,
23
+ SV,
24
+ )
7
25
 
8
26
  if t.TYPE_CHECKING:
9
27
  from ..data_models import Language
@@ -123,6 +141,14 @@ MULTIPLE_CHOICE_TEMPLATES: dict["Language", PromptConfig] = {
123
141
  "{labels_str}, og ikke noe annet.",
124
142
  default_prompt_label_mapping="auto",
125
143
  ),
144
+ PL: PromptConfig(
145
+ default_prompt_prefix="Poniżej znajdują się pytania wielokrotnego wyboru "
146
+ "(z odpowiedziami).",
147
+ default_prompt_template="Pytanie: {text}\nOdpowiedź: {label}",
148
+ default_instruction_prompt="Pytanie: {text}\n\nOdpowiedz na powyższe pytanie, "
149
+ "odpowiadając {labels_str}, i nic więcej.",
150
+ default_prompt_label_mapping="auto",
151
+ ),
126
152
  SV: PromptConfig(
127
153
  default_prompt_prefix="Följande är flervalsfrågor (med svar).",
128
154
  default_prompt_template="Fråga: {text}\nSvar: {label}",
@@ -19,6 +19,7 @@ from ..languages import (
19
19
  NL,
20
20
  NN,
21
21
  NO,
22
+ PL,
22
23
  PT,
23
24
  SV,
24
25
  )
@@ -336,6 +337,25 @@ NER_TEMPLATES: dict["Language", PromptConfig] = {
336
337
  "Verdiene skal være lister over de navngitte enhetene "
337
338
  "av den typen, akkurat som de vises i frasen.",
338
339
  ),
340
+ PL: PromptConfig(
341
+ default_prompt_label_mapping={
342
+ "b-per": "osoba",
343
+ "i-per": "osoba",
344
+ "b-loc": "lokalizacja",
345
+ "i-loc": "lokalizacja",
346
+ "b-org": "organizacja",
347
+ "i-org": "organizacja",
348
+ "b-misc": "różne",
349
+ "i-misc": "różne",
350
+ },
351
+ default_prompt_prefix="Poniżej znajdują się zdania i słowniki JSON z nazwanymi "
352
+ "jednostkami występującymi w danym zdaniu.",
353
+ default_prompt_template="Zdanie: {text}\nNazwane jednostki: {label}",
354
+ default_instruction_prompt="Zdanie: {text}\n\nZidentyfikuj nazwane jednostki "
355
+ "w zdaniu. Powinieneś wypisać to jako słownik JSON z kluczami "
356
+ "{labels_str}. Wartości powinny być listami nazwanych jednostek "
357
+ "tego typu, dokładnie tak jak pojawiają się w zdaniu.",
358
+ ),
339
359
  SV: PromptConfig(
340
360
  default_prompt_label_mapping={
341
361
  "b-per": "person",
@@ -19,6 +19,7 @@ from ..languages import (
19
19
  NL,
20
20
  NN,
21
21
  NO,
22
+ PL,
22
23
  PT,
23
24
  SV,
24
25
  )
@@ -157,6 +158,16 @@ RC_TEMPLATES: dict["Language", PromptConfig] = {
157
158
  "teksten ovenfor med maks 3 ord.\n\nSpørsmål: {question}",
158
159
  default_prompt_label_mapping=dict(),
159
160
  ),
161
+ PL: PromptConfig(
162
+ default_prompt_prefix=(
163
+ "Poniżej znajdują się teksty z towarzyszącymi pytaniami i odpowiedziami."
164
+ ),
165
+ default_prompt_template="Tekst: {text}\nPytanie: {question}\nOdpowiedź w "
166
+ "maksymalnie 3 słowach: {label}",
167
+ default_instruction_prompt="Tekst: {text}\n\nOdpowiedz na następujące pytanie "
168
+ "dotyczące powyższego tekstu w maksymalnie 3 słowach.\n\nPytanie: {question}",
169
+ default_prompt_label_mapping=dict(),
170
+ ),
160
171
  PT: PromptConfig(
161
172
  default_prompt_prefix="Os textos que se seguem são acompanhados de perguntas "
162
173
  "e respostas.",
@@ -19,6 +19,7 @@ from ..languages import (
19
19
  NL,
20
20
  NN,
21
21
  NO,
22
+ PL,
22
23
  PT,
23
24
  SV,
24
25
  )
@@ -78,6 +79,20 @@ SENT_TEMPLATES: dict["Language", PromptConfig] = {
78
79
  "meelestatuse järgi. Võimalikud vastused: {labels_str}. Muud vastused "
79
80
  "ei ole lubatud.",
80
81
  ),
82
+ PL: PromptConfig(
83
+ default_prompt_label_mapping=dict(
84
+ positive="pozytywny", neutral="neutralny", negative="negatywny"
85
+ ),
86
+ default_prompt_prefix=(
87
+ "Poniżej znajdują się dokumenty i ich sentyment, który może być "
88
+ "{labels_str}."
89
+ ),
90
+ default_prompt_template="Dokument: {text}\nSentyment: {label}",
91
+ default_instruction_prompt=(
92
+ "Dokument: {text}\n\nKlasyfikuj sentyment w dokumencie. "
93
+ "Odpowiedz z {labels_str}, i nic więcej."
94
+ ),
95
+ ),
81
96
  PT: PromptConfig(
82
97
  default_prompt_label_mapping=dict(
83
98
  positive="positivo", neutral="neutro", negative="negativo"
@@ -3,7 +3,25 @@
3
3
  import typing as t
4
4
 
5
5
  from ..data_models import PromptConfig
6
- from ..languages import DA, DE, EN, ES, ET, FI, FR, IS, IT, LV, NB, NL, NN, NO, PT, SV
6
+ from ..languages import (
7
+ DA,
8
+ DE,
9
+ EN,
10
+ ES,
11
+ ET,
12
+ FI,
13
+ FR,
14
+ IS,
15
+ IT,
16
+ LV,
17
+ NB,
18
+ NL,
19
+ NN,
20
+ NO,
21
+ PL,
22
+ PT,
23
+ SV,
24
+ )
7
25
 
8
26
  if t.TYPE_CHECKING:
9
27
  from ..data_models import Language
@@ -122,6 +140,14 @@ SUMM_TEMPLATES: dict["Language", PromptConfig] = {
122
140
  "dokumentet ovenfor.",
123
141
  default_prompt_label_mapping=dict(),
124
142
  ),
143
+ PL: PromptConfig(
144
+ default_prompt_prefix="Poniżej znajdują się artykuły z towarzyszącymi "
145
+ "streszczeniami.",
146
+ default_prompt_template="Artykuł: {text}\nStreszczenie: {target_text}",
147
+ default_instruction_prompt="Artykuł: {text}\n\nNapisz streszczenie "
148
+ "powyższego artykułu.",
149
+ default_prompt_label_mapping=dict(),
150
+ ),
125
151
  SV: PromptConfig(
126
152
  default_prompt_prefix="Nedan följer dokument med tillhörande sammanfattningar.",
127
153
  default_prompt_template="Dokument: {text}\nSammanfattning: {target_text}",
euroeval/scores.py CHANGED
@@ -19,6 +19,7 @@ def log_scores(
19
19
  scores: list[dict[str, float]],
20
20
  model_id: str,
21
21
  model_revision: str,
22
+ model_param: str | None,
22
23
  ) -> "ScoreDict":
23
24
  """Log the scores.
24
25
 
@@ -34,6 +35,8 @@ def log_scores(
34
35
  The model ID of the model that was evaluated.
35
36
  model_revision:
36
37
  The revision of the model.
38
+ model_param:
39
+ The model parameter, if any.
37
40
 
38
41
  Returns:
39
42
  A dictionary with keys 'raw_scores' and 'total', with 'raw_scores' being
@@ -42,6 +45,8 @@ def log_scores(
42
45
  """
43
46
  if model_revision and model_revision != "main":
44
47
  model_id += f"@{model_revision}"
48
+ if model_param is not None:
49
+ model_id += f"#{model_param}"
45
50
 
46
51
  logger.info(f"Finished evaluation of {model_id} on {dataset_name}.")
47
52
 
@@ -126,7 +126,7 @@ def prepare_examples(
126
126
  ):
127
127
  choice_idxs.append(idx)
128
128
 
129
- choices = [sections[idx] for idx in choice_idxs]
129
+ choices = [sections[idx] for idx in reversed(choice_idxs)]
130
130
 
131
131
  # Check that the choices are present, and that all of them are at the end
132
132
  assert len(choices) > 0, "No choices found in the document."
@@ -146,7 +146,7 @@ def prepare_examples(
146
146
  )
147
147
  new_examples["label"] = [
148
148
  int(choice.startswith(f"{letter}. ") and letter == examples["label"][0])
149
- for letter, choice in zip("abcde", choices)
149
+ for letter, choice in zip("abcdefghijklmnopqrstuvwxyz", choices)
150
150
  ]
151
151
  new_examples["id"] = [hashlib.md5(string=doc.encode()).hexdigest()] * len(choices)
152
152
  return new_examples
@@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
10
10
  from transformers.trainer import Trainer
11
11
 
12
12
  from ..exceptions import InvalidBenchmark
13
- from ..tokenization_utils import get_special_token_metadata
13
+ from ..tokenisation_utils import get_special_token_metadata
14
14
  from ..utils import raise_if_model_output_contains_nan_values
15
15
 
16
16
  if t.TYPE_CHECKING:
@@ -261,7 +261,7 @@ def prepare_train_examples(
261
261
  ]
262
262
  examples["context"] = [f"{c}{sep_token}" for c in examples["context"]]
263
263
 
264
- # Set the stride used during tokenization, when the context is long enough to be
264
+ # Set the stride used during tokenisation, when the context is long enough to be
265
265
  # split into several features. Since we are always keeping the question tokens, we
266
266
  # need to make sure that the stride does not exceed the resulting maximum context
267
267
  # length.
@@ -272,11 +272,11 @@ def prepare_train_examples(
272
272
  stride = min(stride, max_length - max_question_tokens - num_special_tokens)
273
273
  max_length = tokeniser.model_max_length - stride
274
274
 
275
- # Tokenize our examples with truncation and padding, but keep the overflows using a
275
+ # Tokenise our examples with truncation and padding, but keep the overflows using a
276
276
  # stride. This results in one example possible giving several features when a
277
277
  # context is long, each of those features having a context that overlaps a bit the
278
278
  # context of the previous feature.
279
- tokenized_examples = tokeniser(
279
+ tokenised_examples = tokeniser(
280
280
  text=examples["question"],
281
281
  text_pair=examples["context"],
282
282
  truncation="only_second",
@@ -290,27 +290,27 @@ def prepare_train_examples(
290
290
  # Since one example might give us several features if it has a long context, we
291
291
  # need a map from a feature to its corresponding example. This key gives us just
292
292
  # that
293
- sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
293
+ sample_mapping = tokenised_examples.pop("overflow_to_sample_mapping")
294
294
 
295
295
  # The offset mappings will give us a map from token to character position in the
296
296
  # original context. This will help us compute the start_positions and
297
297
  # end_positions.
298
- offset_mapping = tokenized_examples.pop("offset_mapping")
298
+ offset_mapping = tokenised_examples.pop("offset_mapping")
299
299
 
300
300
  # Initialise the start- and end positions of the answers
301
- tokenized_examples["start_positions"] = list()
302
- tokenized_examples["end_positions"] = list()
301
+ tokenised_examples["start_positions"] = list()
302
+ tokenised_examples["end_positions"] = list()
303
303
 
304
304
  for i, offsets in enumerate(offset_mapping):
305
305
  # Get the input IDs for the current example
306
- input_ids = tokenized_examples.input_ids[i]
306
+ input_ids = tokenised_examples.input_ids[i]
307
307
 
308
308
  # We will label impossible answers with the index of the CLS token
309
309
  cls_index = input_ids.index(cls_token_id)
310
310
 
311
311
  # Grab the sequence corresponding to that example (to know what is the context
312
312
  # and what is the question).
313
- sequence_ids = tokenized_examples.sequence_ids(i)
313
+ sequence_ids = tokenised_examples.sequence_ids(i)
314
314
 
315
315
  # Manually ensure that the special tokens are set to None in `sequence_ids`
316
316
  for special_token in tokeniser.special_tokens_map.keys():
@@ -329,8 +329,8 @@ def prepare_train_examples(
329
329
 
330
330
  # If no answers are given, set the cls_index as answer.
331
331
  if len(answers["answer_start"]) == 0:
332
- tokenized_examples.start_positions.append(cls_index)
333
- tokenized_examples.end_positions.append(cls_index)
332
+ tokenised_examples.start_positions.append(cls_index)
333
+ tokenised_examples.end_positions.append(cls_index)
334
334
 
335
335
  else:
336
336
  # Start/end character index of the answer in the text.
@@ -353,8 +353,8 @@ def prepare_train_examples(
353
353
  offsets[token_start_index][0] <= start_char
354
354
  and offsets[token_end_index][1] >= end_char
355
355
  ):
356
- tokenized_examples.start_positions.append(cls_index)
357
- tokenized_examples.end_positions.append(cls_index)
356
+ tokenised_examples.start_positions.append(cls_index)
357
+ tokenised_examples.end_positions.append(cls_index)
358
358
 
359
359
  # Otherwise move the token_start_index and token_end_index to the two ends
360
360
  # of the answer. Note: we could go after the last offset if the answer is
@@ -366,17 +366,17 @@ def prepare_train_examples(
366
366
  ):
367
367
  token_start_index += 1
368
368
  token_start_index -= 1
369
- tokenized_examples.start_positions.append(token_start_index)
369
+ tokenised_examples.start_positions.append(token_start_index)
370
370
  while (
371
371
  token_start_index <= token_end_index
372
372
  and offsets[token_end_index][1] >= end_char
373
373
  ):
374
374
  token_end_index -= 1
375
375
  token_end_index += 1
376
- tokenized_examples.end_positions.append(token_end_index)
376
+ tokenised_examples.end_positions.append(token_end_index)
377
377
  assert token_end_index >= token_start_index
378
378
 
379
- return tokenized_examples
379
+ return tokenised_examples
380
380
 
381
381
 
382
382
  def prepare_test_examples(
@@ -394,7 +394,7 @@ def prepare_test_examples(
394
394
  The prepared test examples.
395
395
  """
396
396
  # Some of the questions have lots of whitespace on the left, which is not useful
397
- # and will make the truncation of the context fail (the tokenized question will
397
+ # and will make the truncation of the context fail (the tokenised question will
398
398
  # take a lots of space). So we remove that left whitespace
399
399
  examples["question"] = [q.lstrip() for q in examples["question"]]
400
400
 
@@ -412,7 +412,7 @@ def prepare_test_examples(
412
412
  ]
413
413
  examples["context"] = [f"{c}{sep_token}" for c in examples["context"]]
414
414
 
415
- # Set the stride used during tokenization, when the context is long enough to be
415
+ # Set the stride used during tokenisation, when the context is long enough to be
416
416
  # split into several features. Since we are always keeping the question tokens, we
417
417
  # need to make sure that the stride does not exceed the resulting maximum context
418
418
  # length.
@@ -423,11 +423,11 @@ def prepare_test_examples(
423
423
  stride = min(stride, max_length - max_question_tokens - num_special_tokens)
424
424
  max_length = tokeniser.model_max_length - stride
425
425
 
426
- # Tokenize our examples with truncation and maybe padding, but keep the overflows
426
+ # Tokenise our examples with truncation and maybe padding, but keep the overflows
427
427
  # using a stride. This results in one example possible giving several features when
428
428
  # a context is long, each of those features having a context that overlaps a bit
429
429
  # the context of the previous feature.
430
- tokenized_examples = tokeniser(
430
+ tokenised_examples = tokeniser(
431
431
  text=examples["question"],
432
432
  text_pair=examples["context"],
433
433
  truncation="only_second",
@@ -441,30 +441,30 @@ def prepare_test_examples(
441
441
  # Since one example might give us several features if it has a long context, we
442
442
  # need a map from a feature to its corresponding example. This key gives us just
443
443
  # that.
444
- sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
444
+ sample_mapping = tokenised_examples.pop("overflow_to_sample_mapping")
445
445
 
446
446
  # We keep the id that gave us this feature and we will store the offset mappings.
447
- tokenized_examples["id"] = list()
447
+ tokenised_examples["id"] = list()
448
448
 
449
- for i in range(len(tokenized_examples.input_ids)):
449
+ for i in range(len(tokenised_examples.input_ids)):
450
450
  # Grab the sequence corresponding to that example (to know what is the context
451
451
  # and what is the question).
452
- sequence_ids = tokenized_examples.sequence_ids(i)
452
+ sequence_ids = tokenised_examples.sequence_ids(i)
453
453
  context_index = 1
454
454
 
455
455
  # One example can give several spans, this is the index of the example
456
456
  # containing this span of text.
457
457
  sample_index = sample_mapping[i]
458
- tokenized_examples.id.append(examples["id"][sample_index])
458
+ tokenised_examples.id.append(examples["id"][sample_index])
459
459
 
460
460
  # Set to (-1, -1) the offset_mapping that are not part of the context so it's
461
461
  # easy to determine if a token position is part of the context or not.
462
- tokenized_examples.offset_mapping[i] = [
462
+ tokenised_examples.offset_mapping[i] = [
463
463
  (o if sequence_ids[k] == context_index else (-1, -1))
464
- for k, o in enumerate(tokenized_examples.offset_mapping[i])
464
+ for k, o in enumerate(tokenised_examples.offset_mapping[i])
465
465
  ]
466
466
 
467
- return tokenized_examples
467
+ return tokenised_examples
468
468
 
469
469
 
470
470
  def postprocess_predictions_and_labels(
@@ -9,7 +9,11 @@ import numpy as np
9
9
 
10
10
  from ..enums import TaskGroup
11
11
  from ..exceptions import InvalidBenchmark
12
- from ..utils import log_once, raise_if_model_output_contains_nan_values
12
+ from ..utils import (
13
+ extract_multiple_choice_labels,
14
+ log_once,
15
+ raise_if_model_output_contains_nan_values,
16
+ )
13
17
 
14
18
  if t.TYPE_CHECKING:
15
19
  from datasets.arrow_dataset import Dataset
@@ -128,6 +132,21 @@ def extract_labels_from_generation(
128
132
  or if the model outputted log probabilities but the first label token
129
133
  mapping is not provided.
130
134
  """
135
+ # Get the candidate labels, which are the labels that the model can predict
136
+ default_labels = [
137
+ dataset_config.prompt_label_mapping[lbl]
138
+ for lbl in dataset_config.id2label.values()
139
+ ]
140
+ if dataset_config.task.task_group == TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
141
+ sample_candidate_labels = [
142
+ extract_multiple_choice_labels(
143
+ prompt=prompt, candidate_labels=default_labels
144
+ )
145
+ for prompt in input_batch["prompt"]
146
+ ]
147
+ else:
148
+ sample_candidate_labels = [default_labels] * len(input_batch["prompt"])
149
+
131
150
  if model_output.scores is not None:
132
151
  if first_label_token_mapping is False:
133
152
  raise InvalidBenchmark(
@@ -136,8 +155,8 @@ def extract_labels_from_generation(
136
155
  )
137
156
  labels = get_closest_logprobs_labels(
138
157
  generation_logprobs=model_output.scores,
139
- dataset_config=dataset_config,
140
158
  first_label_token_mapping=first_label_token_mapping,
159
+ candidate_labels=sample_candidate_labels,
141
160
  )
142
161
  if labels is not None:
143
162
  return labels
@@ -147,31 +166,8 @@ def extract_labels_from_generation(
147
166
  "does not seem to be able to do that. Skipping the evaluation."
148
167
  )
149
168
 
150
- # Get the candidate labels, which are the labels that the model can predict
151
- candidate_labels = [
152
- dataset_config.prompt_label_mapping[lbl]
153
- for lbl in dataset_config.id2label.values()
154
- ]
155
-
156
169
  new_predicted_labels: list[str] = list()
157
170
  for idx, predicted_label in enumerate(model_output.sequences):
158
- # Special case if we are doing multiple choice classification: we in this case
159
- # dynamically change the candidate labels to the labels mentioned in the prompt
160
- if dataset_config.task.task_group == TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
161
- prompt = input_batch["text"][idx]
162
- sample_candidate_labels = [
163
- candidate_label
164
- for candidate_label in candidate_labels
165
- if re.search(
166
- pattern=rf"\b{candidate_label}. ",
167
- string=prompt,
168
- flags=re.IGNORECASE,
169
- )
170
- is not None
171
- ]
172
- else:
173
- sample_candidate_labels = candidate_labels
174
-
175
171
  # If the prediction includes a boxed answer, use that instead of the full
176
172
  # generation
177
173
  if (m := re.search(r"boxed\{(.*?)\}", predicted_label)) is not None:
@@ -192,22 +188,43 @@ def extract_labels_from_generation(
192
188
  s2=candidate_label.lower(),
193
189
  weights=(insertion_weight, deletion_weight, substitution_weight),
194
190
  )
195
- for candidate_label in sample_candidate_labels
191
+ for candidate_label in sample_candidate_labels[idx]
196
192
  ]
197
193
 
198
- # If no candidate labels were found, we assume that something is wrong with the
199
- # model output, and we raise an error
194
+ best_candidate_label = sample_candidate_labels[idx][
195
+ np.argmin(edit_distances).item()
196
+ ]
197
+
198
+ # If no candidate labels were found, we either pick the label with the smallest
199
+ # word edit distance to the predicted label (if invalid model outputs are
200
+ # allowed), or we raise an error
200
201
  if min(edit_distances) > 100:
201
- raise InvalidBenchmark(
202
- "No candidate labels found for the predicted label "
203
- f"{predicted_label!r}, out of the candidate labels "
204
- f"{sample_candidate_labels}. This likely means that the model output "
205
- "is completely off, and we cannot extract any labels from it. Please "
206
- "check the model output and the candidate labels."
207
- )
202
+ if dataset_config.allow_invalid_model_outputs:
203
+ logger.warning(
204
+ "No candidate labels found for the predicted label "
205
+ f"{predicted_label!r}, out of the candidate labels "
206
+ f"{sample_candidate_labels[idx]}. This likely means that the model "
207
+ "output is completely off, but since invalid model outputs are "
208
+ "allowed for this task, we will use the closest candidate label "
209
+ f"({best_candidate_label})) as the output label. If you see this "
210
+ "warning very often, please report this issue to the EuroEval "
211
+ "team at github.com/EuroEval/EuroEval/issues."
212
+ )
213
+ logger.debug(
214
+ "The candidate labels were extracted from the prompt: "
215
+ f"{input_batch['text'][idx]!r}."
216
+ )
217
+ else:
218
+ raise InvalidBenchmark(
219
+ "No candidate labels found for the predicted label "
220
+ f"{predicted_label!r}, out of the candidate labels "
221
+ f"{sample_candidate_labels[idx]}. This likely means that the model "
222
+ "output is completely off, and we cannot extract any labels from "
223
+ "it. Please check the model output and the candidate labels. The "
224
+ "candidate labels were extracted from the prompt: "
225
+ f"{input_batch['text'][idx]!r}."
226
+ )
208
227
 
209
- # Pick the label with the smallest word edit distance to the predicted label
210
- best_candidate_label = sample_candidate_labels[np.argmin(edit_distances).item()]
211
228
  new_predicted_labels.append(best_candidate_label)
212
229
 
213
230
  return new_predicted_labels
@@ -215,8 +232,8 @@ def extract_labels_from_generation(
215
232
 
216
233
  def get_closest_logprobs_labels(
217
234
  generation_logprobs: list[list[list[tuple[str, float]]]],
218
- dataset_config: "DatasetConfig",
219
235
  first_label_token_mapping: dict[str, str] | t.Literal[True],
236
+ candidate_labels: list[list[str]],
220
237
  ) -> list[str] | None:
221
238
  """Get the labels with the highest predicted logprob value.
222
239
 
@@ -229,11 +246,11 @@ def get_closest_logprobs_labels(
229
246
  generation_logprobs:
230
247
  The logprobs of the generated tokens, for all samples in the batch. Of shape
231
248
  (batch_size, num_tokens, num_logprobs).
232
- dataset_config:
233
- The configuration of the dataset.
234
249
  first_label_token_mapping:
235
250
  A mapping from labels to the first token in each label, or alternatively a
236
251
  `True` value indicating that the model should output logprobs.
252
+ candidate_labels:
253
+ The candidate labels for each sample in the batch.
237
254
 
238
255
  Returns:
239
256
  The predicted labels, or None if labels could not be extracted.
@@ -242,12 +259,8 @@ def get_closest_logprobs_labels(
242
259
  InvalidBenchmark:
243
260
  If no candidate label can be found for any of the generated labels.
244
261
  """
245
- english_labels = list(dataset_config.id2label.values())
246
- english2local = dataset_config.prompt_label_mapping
247
- candidate_labels = [english2local[lbl].lower() for lbl in english_labels]
248
-
249
262
  output_labels: list[str] = list()
250
- for sample in generation_logprobs:
263
+ for idx, sample in enumerate(generation_logprobs):
251
264
  for logprob_list in sample:
252
265
  generated_labels = [
253
266
  re.sub(pattern=r"^[^a-zæøåüöä0-9]+$", repl="", string=label.lower())
@@ -265,7 +278,7 @@ def get_closest_logprobs_labels(
265
278
  if isinstance(first_label_token_mapping, dict):
266
279
  if any(
267
280
  candidate_label not in first_label_token_mapping
268
- for candidate_label in candidate_labels
281
+ for candidate_label in candidate_labels[idx]
269
282
  ):
270
283
  raise InvalidBenchmark(
271
284
  "There is a label not present in the first label token "
@@ -276,26 +289,14 @@ def get_closest_logprobs_labels(
276
289
 
277
290
  candidate_output_labels = {
278
291
  candidate_label
279
- for candidate_label in candidate_labels
292
+ for candidate_label in candidate_labels[idx]
280
293
  if generated_label == first_label_token_mapping[candidate_label]
281
294
  }
282
295
  else:
283
296
  candidate_output_labels = {
284
297
  candidate_label
285
- for candidate_label in candidate_labels
286
- if candidate_label.startswith(generated_label)
287
- }
288
-
289
- # If the generated label is a numeral (e.g., "1", "2", "3") and there is
290
- # a matching candidate label, we only keep the full match
291
- if re.match(r"^\d+$", generated_label) and any(
292
- candidate_label == generated_label
293
- for candidate_label in candidate_output_labels
294
- ):
295
- candidate_output_labels = {
296
- candidate_label
297
- for candidate_label in candidate_output_labels
298
- if candidate_label == generated_label
298
+ for candidate_label in candidate_labels[idx]
299
+ if candidate_label.startswith(generated_label.strip())
299
300
  }
300
301
 
301
302
  # If we can uniquely determine the output label, we break the loop.
@@ -328,7 +329,7 @@ def get_closest_logprobs_labels(
328
329
  elif len(candidate_output_labels) == 0:
329
330
  candidate_output_labels_starting_with_generated_label = [
330
331
  candidate_label
331
- for candidate_label in candidate_labels
332
+ for candidate_label in candidate_labels[idx]
332
333
  if candidate_label.startswith(generated_label)
333
334
  ]
334
335
  if candidate_output_labels_starting_with_generated_label:
@@ -344,19 +345,6 @@ def get_closest_logprobs_labels(
344
345
  )
345
346
  return None
346
347
 
347
- # If we did not find any candidate label for any of the generated labels, we
348
- # assume that something is wrong with the model output, and we fall back to
349
- # using word edit distance to extract the labels
350
- else:
351
- log_once(
352
- f"No candidate label found for any of the generated labels "
353
- f"{generated_labels}. This means that using logprobs to extract "
354
- "the labels is not reliable, and we will instead fall back to "
355
- "extracting the labels using word edit distance.",
356
- level=logging.DEBUG,
357
- )
358
- return None
359
-
360
348
  if output_label is not None:
361
349
  output_labels.append(output_label)
362
350
  break
@@ -364,18 +352,20 @@ def get_closest_logprobs_labels(
364
352
  if len(sample) == 0:
365
353
  log_once(
366
354
  "The model outputted an empty string, so no candidate labels could "
367
- f"be determined. Using the first label, {candidate_labels[0]!r}, "
368
- "as the output label.",
355
+ "be determined. This means that using logprobs to extract the "
356
+ "labels is not reliable, and we will instead fall back to "
357
+ "extracting the labels using word edit distance.",
369
358
  level=logging.INFO,
370
359
  )
371
360
  else:
372
361
  log_once(
373
- "Could not find a candidate label for any of the generated "
374
- f"labels in the sample {sample}. Using the first label, "
375
- f"{candidate_labels[0]!r}, as the output label.",
362
+ "No candidate label found for any of the generated labels, which "
363
+ "means that using logprobs to extract the labels is not reliable, "
364
+ "and we will instead fall back to extracting the labels using "
365
+ "word edit distance.",
376
366
  level=logging.INFO,
377
367
  )
378
- output_labels.append(candidate_labels[0])
368
+ return None
379
369
 
380
370
  assert len(output_labels) == len(generation_logprobs)
381
371
  return output_labels