EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Utility functions related to the sequence-classification task group."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import logging
|
|
4
5
|
import re
|
|
5
6
|
import typing as t
|
|
@@ -7,22 +8,32 @@ import typing as t
|
|
|
7
8
|
import Levenshtein
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
11
|
+
from ..enums import TaskGroup
|
|
10
12
|
from ..exceptions import InvalidBenchmark
|
|
11
|
-
from ..utils import
|
|
13
|
+
from ..utils import (
|
|
14
|
+
extract_multiple_choice_labels,
|
|
15
|
+
log_once,
|
|
16
|
+
raise_if_model_output_contains_nan_values,
|
|
17
|
+
)
|
|
12
18
|
|
|
13
19
|
if t.TYPE_CHECKING:
|
|
20
|
+
from datasets.arrow_dataset import Dataset
|
|
14
21
|
from transformers.trainer_utils import EvalPrediction
|
|
15
22
|
|
|
16
|
-
from ..data_models import
|
|
23
|
+
from ..data_models import (
|
|
24
|
+
BenchmarkConfig,
|
|
25
|
+
DatasetConfig,
|
|
26
|
+
GenerativeModelOutput,
|
|
27
|
+
ModelConfig,
|
|
28
|
+
)
|
|
17
29
|
from ..types import Labels, Predictions
|
|
18
30
|
|
|
19
31
|
|
|
20
|
-
logger = logging.getLogger("euroeval")
|
|
21
|
-
|
|
22
|
-
|
|
23
32
|
def compute_metrics(
|
|
24
33
|
model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
|
|
25
34
|
dataset_config: "DatasetConfig",
|
|
35
|
+
benchmark_config: "BenchmarkConfig",
|
|
36
|
+
dataset: "Dataset",
|
|
26
37
|
) -> dict[str, float]:
|
|
27
38
|
"""Compute the metrics needed for evaluation.
|
|
28
39
|
|
|
@@ -32,6 +43,11 @@ def compute_metrics(
|
|
|
32
43
|
contains the true labels.
|
|
33
44
|
dataset_config:
|
|
34
45
|
The configuration of the dataset.
|
|
46
|
+
benchmark_config:
|
|
47
|
+
The configuration of the benchmark.
|
|
48
|
+
dataset:
|
|
49
|
+
The dataset used for evaluation. This is only used in case any additional
|
|
50
|
+
metadata is used to compute the metrics.
|
|
35
51
|
|
|
36
52
|
Returns:
|
|
37
53
|
A dictionary with the names of the metrics as keys and the metric values as
|
|
@@ -73,7 +89,13 @@ def compute_metrics(
|
|
|
73
89
|
|
|
74
90
|
results: dict[str, float] = dict()
|
|
75
91
|
for metric in dataset_config.task.metrics:
|
|
76
|
-
score: float | None = metric(
|
|
92
|
+
score: float | None = metric(
|
|
93
|
+
predictions=predictions,
|
|
94
|
+
references=label_ids,
|
|
95
|
+
dataset=dataset,
|
|
96
|
+
dataset_config=dataset_config,
|
|
97
|
+
benchmark_config=benchmark_config,
|
|
98
|
+
)
|
|
77
99
|
|
|
78
100
|
# The metric returns None if we are running on multi-GPU and the current
|
|
79
101
|
# process is not the main process
|
|
@@ -87,8 +109,9 @@ def extract_labels_from_generation(
|
|
|
87
109
|
input_batch: dict[str, list],
|
|
88
110
|
model_output: "GenerativeModelOutput",
|
|
89
111
|
dataset_config: "DatasetConfig",
|
|
112
|
+
model_config: "ModelConfig",
|
|
90
113
|
first_label_token_mapping: dict[str, str] | bool,
|
|
91
|
-
) ->
|
|
114
|
+
) -> c.Sequence[str]:
|
|
92
115
|
"""Extract the predicted labels from the generated output.
|
|
93
116
|
|
|
94
117
|
Args:
|
|
@@ -99,6 +122,8 @@ def extract_labels_from_generation(
|
|
|
99
122
|
The raw generated output of the model.
|
|
100
123
|
dataset_config:
|
|
101
124
|
The configuration of the dataset.
|
|
125
|
+
model_config:
|
|
126
|
+
The configuration of the model.
|
|
102
127
|
first_label_token_mapping:
|
|
103
128
|
A mapping from labels to the first token in each label, or alternatively a
|
|
104
129
|
Boolean value indicating whether the model should output scores (if the
|
|
@@ -106,7 +131,28 @@ def extract_labels_from_generation(
|
|
|
106
131
|
|
|
107
132
|
Returns:
|
|
108
133
|
The predicted labels.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
InvalidBenchmark:
|
|
137
|
+
If the task requires log probabilities, but the model did not output them,
|
|
138
|
+
or if the model outputted log probabilities but the first label token
|
|
139
|
+
mapping is not provided.
|
|
109
140
|
"""
|
|
141
|
+
# Get the candidate labels, which are the labels that the model can predict
|
|
142
|
+
default_labels = [
|
|
143
|
+
dataset_config.prompt_label_mapping[lbl]
|
|
144
|
+
for lbl in dataset_config.id2label.values()
|
|
145
|
+
]
|
|
146
|
+
if dataset_config.task.task_group == TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
|
|
147
|
+
sample_candidate_labels = [
|
|
148
|
+
extract_multiple_choice_labels(
|
|
149
|
+
prompt=prompt, candidate_labels=default_labels
|
|
150
|
+
)
|
|
151
|
+
for prompt in input_batch["prompt"]
|
|
152
|
+
]
|
|
153
|
+
else:
|
|
154
|
+
sample_candidate_labels = [default_labels] * len(input_batch["prompt"])
|
|
155
|
+
|
|
110
156
|
if model_output.scores is not None:
|
|
111
157
|
if first_label_token_mapping is False:
|
|
112
158
|
raise InvalidBenchmark(
|
|
@@ -115,39 +161,93 @@ def extract_labels_from_generation(
|
|
|
115
161
|
)
|
|
116
162
|
labels = get_closest_logprobs_labels(
|
|
117
163
|
generation_logprobs=model_output.scores,
|
|
118
|
-
dataset_config=dataset_config,
|
|
119
164
|
first_label_token_mapping=first_label_token_mapping,
|
|
165
|
+
candidate_labels=sample_candidate_labels,
|
|
120
166
|
)
|
|
121
167
|
if labels is not None:
|
|
122
168
|
return labels
|
|
169
|
+
elif dataset_config.task.requires_logprobs:
|
|
170
|
+
raise InvalidBenchmark(
|
|
171
|
+
"This task requires the model to output logprobs, and this model "
|
|
172
|
+
"does not seem to be able to do that. Skipping the evaluation."
|
|
173
|
+
)
|
|
123
174
|
|
|
124
|
-
candidate_labels = [
|
|
125
|
-
dataset_config.prompt_label_mapping[lbl]
|
|
126
|
-
for lbl in dataset_config.id2label.values()
|
|
127
|
-
]
|
|
128
175
|
new_predicted_labels: list[str] = list()
|
|
129
|
-
|
|
176
|
+
num_predictions_being_very_off = 0
|
|
177
|
+
for idx, predicted_label in enumerate(model_output.sequences):
|
|
130
178
|
# If the prediction includes a boxed answer, use that instead of the full
|
|
131
179
|
# generation
|
|
132
180
|
if (m := re.search(r"boxed\{(.*?)\}", predicted_label)) is not None:
|
|
133
181
|
predicted_label = m.group(1)
|
|
134
182
|
|
|
135
|
-
#
|
|
183
|
+
# We set the word edit distance weights such that we heavily penalise insertions
|
|
184
|
+
# and substitutions, so that we don't just insert the correct label, but that we
|
|
185
|
+
# want the model to have included the correct label in its output.
|
|
186
|
+
insertion_weight = 1000
|
|
187
|
+
deletion_weight = 1
|
|
188
|
+
substitution_weight = 1000
|
|
189
|
+
|
|
190
|
+
# Compute the word edit distances between the predicted label and all candidate
|
|
191
|
+
# labels
|
|
136
192
|
edit_distances = [
|
|
137
|
-
Levenshtein.distance(
|
|
138
|
-
|
|
193
|
+
Levenshtein.distance(
|
|
194
|
+
s1=predicted_label.lower(),
|
|
195
|
+
s2=candidate_label.lower(),
|
|
196
|
+
weights=(insertion_weight, deletion_weight, substitution_weight),
|
|
197
|
+
)
|
|
198
|
+
for candidate_label in sample_candidate_labels[idx]
|
|
139
199
|
]
|
|
140
|
-
|
|
141
|
-
|
|
200
|
+
|
|
201
|
+
best_candidate_label = sample_candidate_labels[idx][
|
|
202
|
+
np.argmin(edit_distances).item()
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
# If no candidate labels were found, we either pick the label with the smallest
|
|
206
|
+
# word edit distance to the predicted label (if invalid model outputs are
|
|
207
|
+
# allowed), or we raise an error
|
|
208
|
+
if min(edit_distances) >= 1000:
|
|
209
|
+
num_predictions_being_very_off += 1
|
|
210
|
+
|
|
211
|
+
new_predicted_labels.append(best_candidate_label)
|
|
212
|
+
|
|
213
|
+
if num_predictions_being_very_off > 0:
|
|
214
|
+
if dataset_config.allow_invalid_model_outputs:
|
|
215
|
+
log_msg = (
|
|
216
|
+
"No candidate labels found for the predicted label in "
|
|
217
|
+
f"{num_predictions_being_very_off:,}/{len(model_output.sequences):,} "
|
|
218
|
+
f"of the samples with the model {model_config.model_id!r}. This "
|
|
219
|
+
"likely means that the model were completely off in these cases, "
|
|
220
|
+
"but since invalid model outputs are allowed for this task, we used "
|
|
221
|
+
"the closest candidate labels as the output labels."
|
|
222
|
+
)
|
|
223
|
+
level = logging.DEBUG
|
|
224
|
+
if num_predictions_being_very_off / len(model_output.sequences) > 0.5:
|
|
225
|
+
log_msg += (
|
|
226
|
+
" Since this happened for most of the model's predictions, please "
|
|
227
|
+
"report this issue to the EuroEval team at "
|
|
228
|
+
"github.com/EuroEval/EuroEval/issues."
|
|
229
|
+
)
|
|
230
|
+
level = logging.WARNING
|
|
231
|
+
log_once(log_msg, level=level)
|
|
232
|
+
else:
|
|
233
|
+
raise InvalidBenchmark(
|
|
234
|
+
"No candidate labels found for the predicted label in "
|
|
235
|
+
f"{num_predictions_being_very_off:,}/{len(model_output.sequences):,} "
|
|
236
|
+
"of the samples. This likely means that the model were completely "
|
|
237
|
+
"off in these cases. Since this task does not allow invalid model "
|
|
238
|
+
"outputs, we have to abort the evaluation. Please re-run the "
|
|
239
|
+
"evaluation with the `--debug` flag (or `debug=True` if you're using "
|
|
240
|
+
"the `Benchmarker` API) to see the precise model outputs."
|
|
241
|
+
)
|
|
142
242
|
|
|
143
243
|
return new_predicted_labels
|
|
144
244
|
|
|
145
245
|
|
|
146
246
|
def get_closest_logprobs_labels(
|
|
147
|
-
generation_logprobs:
|
|
148
|
-
dataset_config: "DatasetConfig",
|
|
247
|
+
generation_logprobs: c.Sequence[c.Sequence[c.Sequence[tuple[str, float]]]],
|
|
149
248
|
first_label_token_mapping: dict[str, str] | t.Literal[True],
|
|
150
|
-
|
|
249
|
+
candidate_labels: c.Sequence[c.Sequence[str]],
|
|
250
|
+
) -> c.Sequence[str] | None:
|
|
151
251
|
"""Get the labels with the highest predicted logprob value.
|
|
152
252
|
|
|
153
253
|
In case a candidate label is split into multiple tokens, we only use the first
|
|
@@ -159,11 +259,11 @@ def get_closest_logprobs_labels(
|
|
|
159
259
|
generation_logprobs:
|
|
160
260
|
The logprobs of the generated tokens, for all samples in the batch. Of shape
|
|
161
261
|
(batch_size, num_tokens, num_logprobs).
|
|
162
|
-
dataset_config:
|
|
163
|
-
The configuration of the dataset.
|
|
164
262
|
first_label_token_mapping:
|
|
165
263
|
A mapping from labels to the first token in each label, or alternatively a
|
|
166
264
|
`True` value indicating that the model should output logprobs.
|
|
265
|
+
candidate_labels:
|
|
266
|
+
The candidate labels for each sample in the batch.
|
|
167
267
|
|
|
168
268
|
Returns:
|
|
169
269
|
The predicted labels, or None if labels could not be extracted.
|
|
@@ -172,19 +272,11 @@ def get_closest_logprobs_labels(
|
|
|
172
272
|
InvalidBenchmark:
|
|
173
273
|
If no candidate label can be found for any of the generated labels.
|
|
174
274
|
"""
|
|
175
|
-
english_labels = list(dataset_config.id2label.values())
|
|
176
|
-
english2local = dataset_config.prompt_label_mapping
|
|
177
|
-
candidate_labels = [english2local[lbl].lower() for lbl in english_labels]
|
|
178
|
-
|
|
179
275
|
output_labels: list[str] = list()
|
|
180
|
-
for sample in generation_logprobs:
|
|
276
|
+
for idx, sample in enumerate(generation_logprobs):
|
|
181
277
|
for logprob_list in sample:
|
|
182
278
|
generated_labels = [
|
|
183
|
-
re.sub(
|
|
184
|
-
pattern=r"^[^a-zæøåüöä]+|[^a-zæøåüöä]+$",
|
|
185
|
-
repl="",
|
|
186
|
-
string=label.lower(),
|
|
187
|
-
)
|
|
279
|
+
re.sub(pattern=r"^[^a-zæøåüöä0-9]+$", repl="", string=label.lower())
|
|
188
280
|
for label, _ in logprob_list
|
|
189
281
|
]
|
|
190
282
|
generated_labels = [label for label in generated_labels if label != ""]
|
|
@@ -199,7 +291,7 @@ def get_closest_logprobs_labels(
|
|
|
199
291
|
if isinstance(first_label_token_mapping, dict):
|
|
200
292
|
if any(
|
|
201
293
|
candidate_label not in first_label_token_mapping
|
|
202
|
-
for candidate_label in candidate_labels
|
|
294
|
+
for candidate_label in candidate_labels[idx]
|
|
203
295
|
):
|
|
204
296
|
raise InvalidBenchmark(
|
|
205
297
|
"There is a label not present in the first label token "
|
|
@@ -210,14 +302,14 @@ def get_closest_logprobs_labels(
|
|
|
210
302
|
|
|
211
303
|
candidate_output_labels = {
|
|
212
304
|
candidate_label
|
|
213
|
-
for candidate_label in candidate_labels
|
|
305
|
+
for candidate_label in candidate_labels[idx]
|
|
214
306
|
if generated_label == first_label_token_mapping[candidate_label]
|
|
215
307
|
}
|
|
216
308
|
else:
|
|
217
309
|
candidate_output_labels = {
|
|
218
310
|
candidate_label
|
|
219
|
-
for candidate_label in candidate_labels
|
|
220
|
-
if candidate_label.startswith(generated_label)
|
|
311
|
+
for candidate_label in candidate_labels[idx]
|
|
312
|
+
if candidate_label.startswith(generated_label.strip())
|
|
221
313
|
}
|
|
222
314
|
|
|
223
315
|
# If we can uniquely determine the output label, we break the loop.
|
|
@@ -250,33 +342,22 @@ def get_closest_logprobs_labels(
|
|
|
250
342
|
elif len(candidate_output_labels) == 0:
|
|
251
343
|
candidate_output_labels_starting_with_generated_label = [
|
|
252
344
|
candidate_label
|
|
253
|
-
for candidate_label in candidate_labels
|
|
345
|
+
for candidate_label in candidate_labels[idx]
|
|
254
346
|
if candidate_label.startswith(generated_label)
|
|
255
347
|
]
|
|
256
348
|
if candidate_output_labels_starting_with_generated_label:
|
|
257
349
|
log_once(
|
|
258
350
|
f"No candidate label found for the generated label "
|
|
259
|
-
f"{generated_label!r}
|
|
260
|
-
"
|
|
261
|
-
"
|
|
262
|
-
"
|
|
351
|
+
f"{generated_label!r}, but there are candidate labels "
|
|
352
|
+
f"starting with it: "
|
|
353
|
+
f"{candidate_output_labels_starting_with_generated_label}. "
|
|
354
|
+
"This means that the first label token mapping is not "
|
|
355
|
+
"reliable, and we will instead fall back to extracting "
|
|
356
|
+
"the labels using word edit distance.",
|
|
263
357
|
level=logging.DEBUG,
|
|
264
358
|
)
|
|
265
359
|
return None
|
|
266
360
|
|
|
267
|
-
# If we did not find any candidate label for any of the generated labels, we
|
|
268
|
-
# assume that something is wrong with the model output, and we fall back to
|
|
269
|
-
# using word edit distance to extract the labels
|
|
270
|
-
else:
|
|
271
|
-
log_once(
|
|
272
|
-
f"No candidate label found for any of the generated labels "
|
|
273
|
-
f"{generated_labels}. This means that using logprobs to extract "
|
|
274
|
-
"the labels is not reliable, and we will instead fall back to "
|
|
275
|
-
"extracting the labels using word edit distance.",
|
|
276
|
-
level=logging.DEBUG,
|
|
277
|
-
)
|
|
278
|
-
return None
|
|
279
|
-
|
|
280
361
|
if output_label is not None:
|
|
281
362
|
output_labels.append(output_label)
|
|
282
363
|
break
|
|
@@ -284,18 +365,20 @@ def get_closest_logprobs_labels(
|
|
|
284
365
|
if len(sample) == 0:
|
|
285
366
|
log_once(
|
|
286
367
|
"The model outputted an empty string, so no candidate labels could "
|
|
287
|
-
|
|
288
|
-
"
|
|
368
|
+
"be determined. This means that using logprobs to extract the "
|
|
369
|
+
"labels is not reliable, and we will instead fall back to "
|
|
370
|
+
"extracting the labels using word edit distance.",
|
|
289
371
|
level=logging.DEBUG,
|
|
290
372
|
)
|
|
291
373
|
else:
|
|
292
374
|
log_once(
|
|
293
|
-
"
|
|
294
|
-
|
|
295
|
-
"
|
|
375
|
+
"No candidate label found for any of the generated labels, which "
|
|
376
|
+
"means that using logprobs to extract the labels is not reliable, "
|
|
377
|
+
"and we will instead fall back to extracting the labels using "
|
|
378
|
+
"word edit distance.",
|
|
296
379
|
level=logging.DEBUG,
|
|
297
380
|
)
|
|
298
|
-
|
|
381
|
+
return None
|
|
299
382
|
|
|
300
383
|
assert len(output_labels) == len(generation_logprobs)
|
|
301
384
|
return output_labels
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Utility functions related to the text-to-text task group."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import logging
|
|
4
5
|
import typing as t
|
|
5
6
|
|
|
@@ -7,23 +8,23 @@ import numpy as np
|
|
|
7
8
|
|
|
8
9
|
from ..constants import METRIC_ATTRIBUTES_TAKING_UP_MEMORY
|
|
9
10
|
from ..exceptions import InvalidBenchmark
|
|
11
|
+
from ..logging_utils import log
|
|
10
12
|
from ..metrics import HuggingFaceMetric
|
|
11
13
|
from ..utils import raise_if_model_output_contains_nan_values
|
|
12
14
|
|
|
13
15
|
if t.TYPE_CHECKING:
|
|
16
|
+
from datasets.arrow_dataset import Dataset
|
|
14
17
|
from transformers.trainer_utils import EvalPrediction
|
|
15
18
|
|
|
16
19
|
from ..data_models import BenchmarkConfig, DatasetConfig, GenerativeModelOutput
|
|
17
20
|
from ..types import Labels, Predictions
|
|
18
21
|
|
|
19
22
|
|
|
20
|
-
logger = logging.getLogger("euroeval")
|
|
21
|
-
|
|
22
|
-
|
|
23
23
|
def compute_metrics(
|
|
24
24
|
model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
|
|
25
25
|
dataset_config: "DatasetConfig",
|
|
26
26
|
benchmark_config: "BenchmarkConfig",
|
|
27
|
+
dataset: "Dataset",
|
|
27
28
|
) -> dict[str, float]:
|
|
28
29
|
"""Compute the metrics needed for evaluation.
|
|
29
30
|
|
|
@@ -35,10 +36,17 @@ def compute_metrics(
|
|
|
35
36
|
The configuration of the dataset.
|
|
36
37
|
benchmark_config:
|
|
37
38
|
The configuration of the benchmark.
|
|
39
|
+
dataset:
|
|
40
|
+
The dataset used for evaluation. This is only used in case any additional
|
|
41
|
+
metadata is used to compute the metrics.
|
|
38
42
|
|
|
39
43
|
Returns:
|
|
40
44
|
A dictionary with the names of the metrics as keys and the metric values as
|
|
41
45
|
values.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
InvalidBenchmark:
|
|
49
|
+
If the metric computation fails.
|
|
42
50
|
"""
|
|
43
51
|
model_outputs, labels = model_outputs_and_labels
|
|
44
52
|
|
|
@@ -67,9 +75,15 @@ def compute_metrics(
|
|
|
67
75
|
):
|
|
68
76
|
metric.compute_kwargs["device"] = benchmark_config.device.type
|
|
69
77
|
|
|
70
|
-
|
|
78
|
+
for _ in range(num_attempts := 5):
|
|
71
79
|
try:
|
|
72
|
-
score: float | None = metric(
|
|
80
|
+
score: float | None = metric(
|
|
81
|
+
predictions=predictions,
|
|
82
|
+
references=labels,
|
|
83
|
+
dataset=dataset,
|
|
84
|
+
dataset_config=dataset_config,
|
|
85
|
+
benchmark_config=benchmark_config,
|
|
86
|
+
)
|
|
73
87
|
break
|
|
74
88
|
except Exception as e:
|
|
75
89
|
oom_error = [
|
|
@@ -78,28 +92,35 @@ def compute_metrics(
|
|
|
78
92
|
"MPS backend out of memory",
|
|
79
93
|
]
|
|
80
94
|
if not any(error in str(e) for error in oom_error):
|
|
81
|
-
raise InvalidBenchmark(str(e))
|
|
95
|
+
raise InvalidBenchmark(str(e)) from e
|
|
82
96
|
|
|
83
97
|
if (
|
|
84
98
|
isinstance(metric, HuggingFaceMetric)
|
|
85
99
|
and metric.compute_kwargs.get("device", "cpu") != "cpu"
|
|
86
100
|
):
|
|
87
101
|
metric.compute_kwargs["device"] = "cpu"
|
|
88
|
-
|
|
102
|
+
log(
|
|
89
103
|
"Out of memory error occurred during the computation of "
|
|
90
104
|
f"the metric {metric.pretty_name}. Moving the computation to "
|
|
91
|
-
"the CPU."
|
|
105
|
+
"the CPU.",
|
|
106
|
+
level=logging.DEBUG,
|
|
92
107
|
)
|
|
93
108
|
else:
|
|
94
|
-
raise InvalidBenchmark(str(e))
|
|
109
|
+
raise InvalidBenchmark(str(e)) from e
|
|
95
110
|
finally:
|
|
96
111
|
for attribute in METRIC_ATTRIBUTES_TAKING_UP_MEMORY:
|
|
97
112
|
if hasattr(metric, attribute):
|
|
98
|
-
|
|
113
|
+
log(
|
|
99
114
|
f"Deleting the {attribute!r} attribute of the metric "
|
|
100
|
-
f"{metric.pretty_name} to free up memory."
|
|
115
|
+
f"{metric.pretty_name} to free up memory.",
|
|
116
|
+
level=logging.DEBUG,
|
|
101
117
|
)
|
|
102
118
|
delattr(metric, attribute)
|
|
119
|
+
else:
|
|
120
|
+
raise InvalidBenchmark(
|
|
121
|
+
f"Could not compute the metric {metric.pretty_name} after "
|
|
122
|
+
f"{num_attempts} attempts due to out of memory errors."
|
|
123
|
+
)
|
|
103
124
|
|
|
104
125
|
# The metric returns None if we are running on multi-GPU and the current
|
|
105
126
|
# process is not the main process
|
|
@@ -111,7 +132,7 @@ def compute_metrics(
|
|
|
111
132
|
|
|
112
133
|
def extract_labels_from_generation(
|
|
113
134
|
input_batch: dict[str, list], model_output: "GenerativeModelOutput"
|
|
114
|
-
) ->
|
|
135
|
+
) -> c.Sequence[t.Any]:
|
|
115
136
|
"""Extract the predicted labels from the generated output.
|
|
116
137
|
|
|
117
138
|
Args:
|