EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  """Utility functions related to the sequence-classification task group."""
2
2
 
3
+ import collections.abc as c
3
4
  import logging
4
5
  import re
5
6
  import typing as t
@@ -7,22 +8,32 @@ import typing as t
7
8
  import Levenshtein
8
9
  import numpy as np
9
10
 
11
+ from ..enums import TaskGroup
10
12
  from ..exceptions import InvalidBenchmark
11
- from ..utils import log_once, raise_if_model_output_contains_nan_values
13
+ from ..utils import (
14
+ extract_multiple_choice_labels,
15
+ log_once,
16
+ raise_if_model_output_contains_nan_values,
17
+ )
12
18
 
13
19
  if t.TYPE_CHECKING:
20
+ from datasets.arrow_dataset import Dataset
14
21
  from transformers.trainer_utils import EvalPrediction
15
22
 
16
- from ..data_models import DatasetConfig, GenerativeModelOutput
23
+ from ..data_models import (
24
+ BenchmarkConfig,
25
+ DatasetConfig,
26
+ GenerativeModelOutput,
27
+ ModelConfig,
28
+ )
17
29
  from ..types import Labels, Predictions
18
30
 
19
31
 
20
- logger = logging.getLogger("euroeval")
21
-
22
-
23
32
  def compute_metrics(
24
33
  model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
25
34
  dataset_config: "DatasetConfig",
35
+ benchmark_config: "BenchmarkConfig",
36
+ dataset: "Dataset",
26
37
  ) -> dict[str, float]:
27
38
  """Compute the metrics needed for evaluation.
28
39
 
@@ -32,6 +43,11 @@ def compute_metrics(
32
43
  contains the true labels.
33
44
  dataset_config:
34
45
  The configuration of the dataset.
46
+ benchmark_config:
47
+ The configuration of the benchmark.
48
+ dataset:
49
+ The dataset used for evaluation. This is only used in case any additional
50
+ metadata is used to compute the metrics.
35
51
 
36
52
  Returns:
37
53
  A dictionary with the names of the metrics as keys and the metric values as
@@ -73,7 +89,13 @@ def compute_metrics(
73
89
 
74
90
  results: dict[str, float] = dict()
75
91
  for metric in dataset_config.task.metrics:
76
- score: float | None = metric(predictions=predictions, references=label_ids)
92
+ score: float | None = metric(
93
+ predictions=predictions,
94
+ references=label_ids,
95
+ dataset=dataset,
96
+ dataset_config=dataset_config,
97
+ benchmark_config=benchmark_config,
98
+ )
77
99
 
78
100
  # The metric returns None if we are running on multi-GPU and the current
79
101
  # process is not the main process
@@ -87,8 +109,9 @@ def extract_labels_from_generation(
87
109
  input_batch: dict[str, list],
88
110
  model_output: "GenerativeModelOutput",
89
111
  dataset_config: "DatasetConfig",
112
+ model_config: "ModelConfig",
90
113
  first_label_token_mapping: dict[str, str] | bool,
91
- ) -> list[str]:
114
+ ) -> c.Sequence[str]:
92
115
  """Extract the predicted labels from the generated output.
93
116
 
94
117
  Args:
@@ -99,6 +122,8 @@ def extract_labels_from_generation(
99
122
  The raw generated output of the model.
100
123
  dataset_config:
101
124
  The configuration of the dataset.
125
+ model_config:
126
+ The configuration of the model.
102
127
  first_label_token_mapping:
103
128
  A mapping from labels to the first token in each label, or alternatively a
104
129
  Boolean value indicating whether the model should output scores (if the
@@ -106,7 +131,28 @@ def extract_labels_from_generation(
106
131
 
107
132
  Returns:
108
133
  The predicted labels.
134
+
135
+ Raises:
136
+ InvalidBenchmark:
137
+ If the task requires log probabilities, but the model did not output them,
138
+ or if the model outputted log probabilities but the first label token
139
+ mapping is not provided.
109
140
  """
141
+ # Get the candidate labels, which are the labels that the model can predict
142
+ default_labels = [
143
+ dataset_config.prompt_label_mapping[lbl]
144
+ for lbl in dataset_config.id2label.values()
145
+ ]
146
+ if dataset_config.task.task_group == TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
147
+ sample_candidate_labels = [
148
+ extract_multiple_choice_labels(
149
+ prompt=prompt, candidate_labels=default_labels
150
+ )
151
+ for prompt in input_batch["prompt"]
152
+ ]
153
+ else:
154
+ sample_candidate_labels = [default_labels] * len(input_batch["prompt"])
155
+
110
156
  if model_output.scores is not None:
111
157
  if first_label_token_mapping is False:
112
158
  raise InvalidBenchmark(
@@ -115,39 +161,93 @@ def extract_labels_from_generation(
115
161
  )
116
162
  labels = get_closest_logprobs_labels(
117
163
  generation_logprobs=model_output.scores,
118
- dataset_config=dataset_config,
119
164
  first_label_token_mapping=first_label_token_mapping,
165
+ candidate_labels=sample_candidate_labels,
120
166
  )
121
167
  if labels is not None:
122
168
  return labels
169
+ elif dataset_config.task.requires_logprobs:
170
+ raise InvalidBenchmark(
171
+ "This task requires the model to output logprobs, and this model "
172
+ "does not seem to be able to do that. Skipping the evaluation."
173
+ )
123
174
 
124
- candidate_labels = [
125
- dataset_config.prompt_label_mapping[lbl]
126
- for lbl in dataset_config.id2label.values()
127
- ]
128
175
  new_predicted_labels: list[str] = list()
129
- for predicted_label in model_output.sequences:
176
+ num_predictions_being_very_off = 0
177
+ for idx, predicted_label in enumerate(model_output.sequences):
130
178
  # If the prediction includes a boxed answer, use that instead of the full
131
179
  # generation
132
180
  if (m := re.search(r"boxed\{(.*?)\}", predicted_label)) is not None:
133
181
  predicted_label = m.group(1)
134
182
 
135
- # Pick the label with the smallest word edit distance to the predicted label
183
+ # We set the word edit distance weights such that we heavily penalise insertions
184
+ # and substitutions, so that we don't just insert the correct label, but that we
185
+ # want the model to have included the correct label in its output.
186
+ insertion_weight = 1000
187
+ deletion_weight = 1
188
+ substitution_weight = 1000
189
+
190
+ # Compute the word edit distances between the predicted label and all candidate
191
+ # labels
136
192
  edit_distances = [
137
- Levenshtein.distance(s1=predicted_label.lower(), s2=candidate_label.lower())
138
- for candidate_label in candidate_labels
193
+ Levenshtein.distance(
194
+ s1=predicted_label.lower(),
195
+ s2=candidate_label.lower(),
196
+ weights=(insertion_weight, deletion_weight, substitution_weight),
197
+ )
198
+ for candidate_label in sample_candidate_labels[idx]
139
199
  ]
140
- predicted_label = candidate_labels[np.argmin(edit_distances).item()]
141
- new_predicted_labels.append(predicted_label)
200
+
201
+ best_candidate_label = sample_candidate_labels[idx][
202
+ np.argmin(edit_distances).item()
203
+ ]
204
+
205
+ # If no candidate labels were found, we either pick the label with the smallest
206
+ # word edit distance to the predicted label (if invalid model outputs are
207
+ # allowed), or we raise an error
208
+ if min(edit_distances) >= 1000:
209
+ num_predictions_being_very_off += 1
210
+
211
+ new_predicted_labels.append(best_candidate_label)
212
+
213
+ if num_predictions_being_very_off > 0:
214
+ if dataset_config.allow_invalid_model_outputs:
215
+ log_msg = (
216
+ "No candidate labels found for the predicted label in "
217
+ f"{num_predictions_being_very_off:,}/{len(model_output.sequences):,} "
218
+ f"of the samples with the model {model_config.model_id!r}. This "
219
+ "likely means that the model were completely off in these cases, "
220
+ "but since invalid model outputs are allowed for this task, we used "
221
+ "the closest candidate labels as the output labels."
222
+ )
223
+ level = logging.DEBUG
224
+ if num_predictions_being_very_off / len(model_output.sequences) > 0.5:
225
+ log_msg += (
226
+ " Since this happened for most of the model's predictions, please "
227
+ "report this issue to the EuroEval team at "
228
+ "github.com/EuroEval/EuroEval/issues."
229
+ )
230
+ level = logging.WARNING
231
+ log_once(log_msg, level=level)
232
+ else:
233
+ raise InvalidBenchmark(
234
+ "No candidate labels found for the predicted label in "
235
+ f"{num_predictions_being_very_off:,}/{len(model_output.sequences):,} "
236
+ "of the samples. This likely means that the model were completely "
237
+ "off in these cases. Since this task does not allow invalid model "
238
+ "outputs, we have to abort the evaluation. Please re-run the "
239
+ "evaluation with the `--debug` flag (or `debug=True` if you're using "
240
+ "the `Benchmarker` API) to see the precise model outputs."
241
+ )
142
242
 
143
243
  return new_predicted_labels
144
244
 
145
245
 
146
246
  def get_closest_logprobs_labels(
147
- generation_logprobs: list[list[list[tuple[str, float]]]],
148
- dataset_config: "DatasetConfig",
247
+ generation_logprobs: c.Sequence[c.Sequence[c.Sequence[tuple[str, float]]]],
149
248
  first_label_token_mapping: dict[str, str] | t.Literal[True],
150
- ) -> list[str] | None:
249
+ candidate_labels: c.Sequence[c.Sequence[str]],
250
+ ) -> c.Sequence[str] | None:
151
251
  """Get the labels with the highest predicted logprob value.
152
252
 
153
253
  In case a candidate label is split into multiple tokens, we only use the first
@@ -159,11 +259,11 @@ def get_closest_logprobs_labels(
159
259
  generation_logprobs:
160
260
  The logprobs of the generated tokens, for all samples in the batch. Of shape
161
261
  (batch_size, num_tokens, num_logprobs).
162
- dataset_config:
163
- The configuration of the dataset.
164
262
  first_label_token_mapping:
165
263
  A mapping from labels to the first token in each label, or alternatively a
166
264
  `True` value indicating that the model should output logprobs.
265
+ candidate_labels:
266
+ The candidate labels for each sample in the batch.
167
267
 
168
268
  Returns:
169
269
  The predicted labels, or None if labels could not be extracted.
@@ -172,19 +272,11 @@ def get_closest_logprobs_labels(
172
272
  InvalidBenchmark:
173
273
  If no candidate label can be found for any of the generated labels.
174
274
  """
175
- english_labels = list(dataset_config.id2label.values())
176
- english2local = dataset_config.prompt_label_mapping
177
- candidate_labels = [english2local[lbl].lower() for lbl in english_labels]
178
-
179
275
  output_labels: list[str] = list()
180
- for sample in generation_logprobs:
276
+ for idx, sample in enumerate(generation_logprobs):
181
277
  for logprob_list in sample:
182
278
  generated_labels = [
183
- re.sub(
184
- pattern=r"^[^a-zæøåüöä]+|[^a-zæøåüöä]+$",
185
- repl="",
186
- string=label.lower(),
187
- )
279
+ re.sub(pattern=r"^[^a-zæøåüöä0-9]+$", repl="", string=label.lower())
188
280
  for label, _ in logprob_list
189
281
  ]
190
282
  generated_labels = [label for label in generated_labels if label != ""]
@@ -199,7 +291,7 @@ def get_closest_logprobs_labels(
199
291
  if isinstance(first_label_token_mapping, dict):
200
292
  if any(
201
293
  candidate_label not in first_label_token_mapping
202
- for candidate_label in candidate_labels
294
+ for candidate_label in candidate_labels[idx]
203
295
  ):
204
296
  raise InvalidBenchmark(
205
297
  "There is a label not present in the first label token "
@@ -210,14 +302,14 @@ def get_closest_logprobs_labels(
210
302
 
211
303
  candidate_output_labels = {
212
304
  candidate_label
213
- for candidate_label in candidate_labels
305
+ for candidate_label in candidate_labels[idx]
214
306
  if generated_label == first_label_token_mapping[candidate_label]
215
307
  }
216
308
  else:
217
309
  candidate_output_labels = {
218
310
  candidate_label
219
- for candidate_label in candidate_labels
220
- if candidate_label.startswith(generated_label)
311
+ for candidate_label in candidate_labels[idx]
312
+ if candidate_label.startswith(generated_label.strip())
221
313
  }
222
314
 
223
315
  # If we can uniquely determine the output label, we break the loop.
@@ -250,33 +342,22 @@ def get_closest_logprobs_labels(
250
342
  elif len(candidate_output_labels) == 0:
251
343
  candidate_output_labels_starting_with_generated_label = [
252
344
  candidate_label
253
- for candidate_label in candidate_labels
345
+ for candidate_label in candidate_labels[idx]
254
346
  if candidate_label.startswith(generated_label)
255
347
  ]
256
348
  if candidate_output_labels_starting_with_generated_label:
257
349
  log_once(
258
350
  f"No candidate label found for the generated label "
259
- f"{generated_label!r}. This means that using logprobs to "
260
- "extract the labels is not reliable, and we will instead "
261
- "fall back to extracting the labels using word edit "
262
- "distance.",
351
+ f"{generated_label!r}, but there are candidate labels "
352
+ f"starting with it: "
353
+ f"{candidate_output_labels_starting_with_generated_label}. "
354
+ "This means that the first label token mapping is not "
355
+ "reliable, and we will instead fall back to extracting "
356
+ "the labels using word edit distance.",
263
357
  level=logging.DEBUG,
264
358
  )
265
359
  return None
266
360
 
267
- # If we did not find any candidate label for any of the generated labels, we
268
- # assume that something is wrong with the model output, and we fall back to
269
- # using word edit distance to extract the labels
270
- else:
271
- log_once(
272
- f"No candidate label found for any of the generated labels "
273
- f"{generated_labels}. This means that using logprobs to extract "
274
- "the labels is not reliable, and we will instead fall back to "
275
- "extracting the labels using word edit distance.",
276
- level=logging.DEBUG,
277
- )
278
- return None
279
-
280
361
  if output_label is not None:
281
362
  output_labels.append(output_label)
282
363
  break
@@ -284,18 +365,20 @@ def get_closest_logprobs_labels(
284
365
  if len(sample) == 0:
285
366
  log_once(
286
367
  "The model outputted an empty string, so no candidate labels could "
287
- f"be determined. Using {candidate_labels[0]!r} as the output "
288
- "label.",
368
+ "be determined. This means that using logprobs to extract the "
369
+ "labels is not reliable, and we will instead fall back to "
370
+ "extracting the labels using word edit distance.",
289
371
  level=logging.DEBUG,
290
372
  )
291
373
  else:
292
374
  log_once(
293
- "Could not find a candidate label for any of the generated "
294
- f"labels in the sample {sample}. Using {candidate_labels[0]!r} "
295
- "as the output label.",
375
+ "No candidate label found for any of the generated labels, which "
376
+ "means that using logprobs to extract the labels is not reliable, "
377
+ "and we will instead fall back to extracting the labels using "
378
+ "word edit distance.",
296
379
  level=logging.DEBUG,
297
380
  )
298
- output_labels.append(candidate_labels[0])
381
+ return None
299
382
 
300
383
  assert len(output_labels) == len(generation_logprobs)
301
384
  return output_labels
@@ -1,5 +1,6 @@
1
1
  """Utility functions related to the text-to-text task group."""
2
2
 
3
+ import collections.abc as c
3
4
  import logging
4
5
  import typing as t
5
6
 
@@ -7,23 +8,23 @@ import numpy as np
7
8
 
8
9
  from ..constants import METRIC_ATTRIBUTES_TAKING_UP_MEMORY
9
10
  from ..exceptions import InvalidBenchmark
11
+ from ..logging_utils import log
10
12
  from ..metrics import HuggingFaceMetric
11
13
  from ..utils import raise_if_model_output_contains_nan_values
12
14
 
13
15
  if t.TYPE_CHECKING:
16
+ from datasets.arrow_dataset import Dataset
14
17
  from transformers.trainer_utils import EvalPrediction
15
18
 
16
19
  from ..data_models import BenchmarkConfig, DatasetConfig, GenerativeModelOutput
17
20
  from ..types import Labels, Predictions
18
21
 
19
22
 
20
- logger = logging.getLogger("euroeval")
21
-
22
-
23
23
  def compute_metrics(
24
24
  model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
25
25
  dataset_config: "DatasetConfig",
26
26
  benchmark_config: "BenchmarkConfig",
27
+ dataset: "Dataset",
27
28
  ) -> dict[str, float]:
28
29
  """Compute the metrics needed for evaluation.
29
30
 
@@ -35,10 +36,17 @@ def compute_metrics(
35
36
  The configuration of the dataset.
36
37
  benchmark_config:
37
38
  The configuration of the benchmark.
39
+ dataset:
40
+ The dataset used for evaluation. This is only used in case any additional
41
+ metadata is used to compute the metrics.
38
42
 
39
43
  Returns:
40
44
  A dictionary with the names of the metrics as keys and the metric values as
41
45
  values.
46
+
47
+ Raises:
48
+ InvalidBenchmark:
49
+ If the metric computation fails.
42
50
  """
43
51
  model_outputs, labels = model_outputs_and_labels
44
52
 
@@ -67,9 +75,15 @@ def compute_metrics(
67
75
  ):
68
76
  metric.compute_kwargs["device"] = benchmark_config.device.type
69
77
 
70
- while True:
78
+ for _ in range(num_attempts := 5):
71
79
  try:
72
- score: float | None = metric(predictions=predictions, references=labels)
80
+ score: float | None = metric(
81
+ predictions=predictions,
82
+ references=labels,
83
+ dataset=dataset,
84
+ dataset_config=dataset_config,
85
+ benchmark_config=benchmark_config,
86
+ )
73
87
  break
74
88
  except Exception as e:
75
89
  oom_error = [
@@ -78,28 +92,35 @@ def compute_metrics(
78
92
  "MPS backend out of memory",
79
93
  ]
80
94
  if not any(error in str(e) for error in oom_error):
81
- raise InvalidBenchmark(str(e))
95
+ raise InvalidBenchmark(str(e)) from e
82
96
 
83
97
  if (
84
98
  isinstance(metric, HuggingFaceMetric)
85
99
  and metric.compute_kwargs.get("device", "cpu") != "cpu"
86
100
  ):
87
101
  metric.compute_kwargs["device"] = "cpu"
88
- logger.debug(
102
+ log(
89
103
  "Out of memory error occurred during the computation of "
90
104
  f"the metric {metric.pretty_name}. Moving the computation to "
91
- "the CPU."
105
+ "the CPU.",
106
+ level=logging.DEBUG,
92
107
  )
93
108
  else:
94
- raise InvalidBenchmark(str(e))
109
+ raise InvalidBenchmark(str(e)) from e
95
110
  finally:
96
111
  for attribute in METRIC_ATTRIBUTES_TAKING_UP_MEMORY:
97
112
  if hasattr(metric, attribute):
98
- logger.debug(
113
+ log(
99
114
  f"Deleting the {attribute!r} attribute of the metric "
100
- f"{metric.pretty_name} to free up memory."
115
+ f"{metric.pretty_name} to free up memory.",
116
+ level=logging.DEBUG,
101
117
  )
102
118
  delattr(metric, attribute)
119
+ else:
120
+ raise InvalidBenchmark(
121
+ f"Could not compute the metric {metric.pretty_name} after "
122
+ f"{num_attempts} attempts due to out of memory errors."
123
+ )
103
124
 
104
125
  # The metric returns None if we are running on multi-GPU and the current
105
126
  # process is not the main process
@@ -111,7 +132,7 @@ def compute_metrics(
111
132
 
112
133
  def extract_labels_from_generation(
113
134
  input_batch: dict[str, list], model_output: "GenerativeModelOutput"
114
- ) -> list[t.Any]:
135
+ ) -> c.Sequence[t.Any]:
115
136
  """Extract the predicted labels from the generated output.
116
137
 
117
138
  Args: