EuroEval 15.10.1__py3-none-any.whl → 15.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

euroeval/metrics.py ADDED
@@ -0,0 +1,452 @@
1
+ """All the metrics used in EuroEval."""
2
+
3
+ import abc
4
+ import logging
5
+ import typing as t
6
+
7
+ import evaluate
8
+ import litellm
9
+ from litellm.types.utils import Choices, ModelResponse
10
+ from pydantic import BaseModel, Field
11
+ from tqdm.auto import tqdm
12
+
13
+ from .exceptions import InvalidBenchmark
14
+ from .utils import HiddenPrints
15
+
16
+ if t.TYPE_CHECKING:
17
+ from evaluate import EvaluationModule
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class Metric(abc.ABC):
23
+ """Abstract base class for all metrics."""
24
+
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ pretty_name: str,
29
+ postprocessing_fn: t.Callable[[float], tuple[float, str]] | None = None,
30
+ ) -> None:
31
+ """Initialise the metric.
32
+
33
+ Args:
34
+ name:
35
+ The name of the metric in snake_case.
36
+ pretty_name:
37
+ The pretty name of the metric, used for display purposes.
38
+ postprocessing_fn:
39
+ A function to apply to the metric scores after they are computed,
40
+ taking the score to the postprocessed score along with its string
41
+ representation. Defaults to x -> (100 * x, f"{x:.2%}").
42
+ """
43
+ self.name = name
44
+ self.pretty_name = pretty_name
45
+ self.postprocessing_fn = (
46
+ postprocessing_fn
47
+ if postprocessing_fn is not None
48
+ else lambda x: (100 * x, f"{x:.2%}")
49
+ )
50
+
51
+ @abc.abstractmethod
52
+ def __call__(self, predictions: t.Sequence, references: t.Sequence) -> float | None:
53
+ """Calculate the metric score.
54
+
55
+ Args:
56
+ predictions:
57
+ The model predictions.
58
+ references:
59
+ The ground truth references.
60
+
61
+ Returns:
62
+ The calculated metric score, or None if the score should be ignored.
63
+ """
64
+ ...
65
+
66
+ def __hash__(self) -> int:
67
+ """Return a hash of the metric configuration."""
68
+ return hash(self.name)
69
+
70
+
71
+ class HuggingFaceMetric(Metric):
72
+ """A metric which is implemented in the `evaluate` package.
73
+
74
+ Attributes:
75
+ name:
76
+ The name of the metric in snake_case.
77
+ pretty_name:
78
+ The pretty name of the metric, used for display purposes.
79
+ huggingface_id:
80
+ The Hugging Face ID of the metric.
81
+ results_key:
82
+ The name of the key used to extract the metric scores from the results
83
+ dictionary.
84
+ compute_kwargs:
85
+ Keyword arguments to pass to the metric's compute function. Defaults to
86
+ an empty dictionary.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ name: str,
92
+ pretty_name: str,
93
+ huggingface_id: str,
94
+ results_key: str,
95
+ compute_kwargs: dict[str, t.Any] | None = None,
96
+ postprocessing_fn: t.Callable[[float], tuple[float, str]] | None = None,
97
+ ) -> None:
98
+ """Initialise the Hugging Face metric.
99
+
100
+ Args:
101
+ name:
102
+ The name of the metric in snake_case.
103
+ pretty_name:
104
+ The pretty name of the metric, used for display purposes.
105
+ huggingface_id:
106
+ The Hugging Face ID of the metric.
107
+ results_key:
108
+ The name of the key used to extract the metric scores from the results
109
+ dictionary.
110
+ compute_kwargs:
111
+ Keyword arguments to pass to the metric's compute function. Defaults to
112
+ an empty dictionary.
113
+ postprocessing_fn:
114
+ A function to apply to the metric scores after they are computed, taking
115
+ the score to the postprocessed score along with its string
116
+ representation. Defaults to x -> (100 * x, f"{x:.2%}").
117
+ """
118
+ super().__init__(
119
+ name=name, pretty_name=pretty_name, postprocessing_fn=postprocessing_fn
120
+ )
121
+ self.huggingface_id = huggingface_id
122
+ self.results_key = results_key
123
+ self.compute_kwargs: dict[str, t.Any] = (
124
+ dict() if compute_kwargs is None else compute_kwargs
125
+ )
126
+ self.metric: "EvaluationModule | None" = None
127
+
128
+ def __call__(self, predictions: t.Sequence, references: t.Sequence) -> float | None:
129
+ """Calculate the metric score.
130
+
131
+ Args:
132
+ predictions:
133
+ The model predictions.
134
+ references:
135
+ The ground truth references.
136
+
137
+ Returns:
138
+ The calculated metric score, or None if the score should be ignored.
139
+ """
140
+ if self.metric is None:
141
+ self.metric = evaluate.load(path=self.huggingface_id)
142
+
143
+ with HiddenPrints():
144
+ results = self.metric.compute(
145
+ predictions=predictions, references=references, **self.compute_kwargs
146
+ )
147
+
148
+ # The metric returns None if we are running on multi-GPU and the current
149
+ # process is not the main process
150
+ if results is None:
151
+ return None
152
+
153
+ score = results[self.results_key]
154
+ if isinstance(score, list):
155
+ score = sum(score) / len(score)
156
+
157
+ return score
158
+
159
+
160
+ class LLMAsAJudgeMetric(Metric):
161
+ """Use an LLM to judge the quality of the predictions."""
162
+
163
+ def __init__(
164
+ self,
165
+ name: str,
166
+ pretty_name: str,
167
+ judge_id: str,
168
+ judge_kwargs: dict[str, t.Any],
169
+ user_prompt: str,
170
+ response_format: t.Type[BaseModel],
171
+ scoring_fn: t.Callable[[BaseModel], float],
172
+ condition_formatting_fn: t.Callable[[str], str] = lambda x: x,
173
+ system_prompt: str | None = None,
174
+ ) -> None:
175
+ """Initialise the LLM as a judge metric.
176
+
177
+ Args:
178
+ name:
179
+ The name of the metric in snake_case.
180
+ pretty_name:
181
+ The pretty name of the metric, used for display purposes.
182
+ judge_id:
183
+ The model ID of the LLM to use as a judge.
184
+ judge_kwargs:
185
+ Generation parameters for the judge model, such as temperature.
186
+ user_prompt:
187
+ The user prompt to use for the judge model. The prompt should be
188
+ formatted with the variables `prediction` and `condition`, to
189
+ include the model predictions and a description of what the prediction
190
+ should be judged on, respectively. If the condition is not needed,
191
+ it can be omitted from the prompt, but the `prediction` variable must
192
+ still be present.
193
+ response_format:
194
+ The response format to use for the judge model. This should be a
195
+ Pydantic model that defines the expected structure of the judge's
196
+ response.
197
+ scoring_fn:
198
+ A function that takes the judge's response and returns a score.
199
+ condition_formatting_fn (optional):
200
+ A function to format the condition string before it is included in the
201
+ user prompt. Defaults to a no-op function that returns the input
202
+ unchanged.
203
+ system_prompt (optional):
204
+ The system prompt to use for the judge model. If not provided, no system
205
+ prompt will be used.
206
+ """
207
+ super().__init__(name=name, pretty_name=pretty_name)
208
+ self.judge_id = judge_id
209
+ self.judge_kwargs = judge_kwargs
210
+ self.user_prompt = user_prompt
211
+ self.response_format = response_format
212
+ self.scoring_fn = scoring_fn
213
+ self.condition_formatting_fn = condition_formatting_fn
214
+ self.system_prompt = system_prompt
215
+
216
+ def __call__(self, predictions: t.Sequence, references: t.Sequence) -> float | None:
217
+ """Calculate the metric score using the judge model.
218
+
219
+ Args:
220
+ predictions:
221
+ The model predictions.
222
+ references:
223
+ The ground truth references.
224
+
225
+ Returns:
226
+ The calculated metric score, or None if the score should be ignored.
227
+
228
+ Raises:
229
+ InvalidBenchmark:
230
+ If the number of predictions does not match the number of references,
231
+ or if the user prompt requires a condition but none is provided.
232
+ """
233
+ if not predictions or not references:
234
+ return None
235
+ elif len(predictions) != len(references):
236
+ raise InvalidBenchmark(
237
+ f"The number of predictions ({len(predictions):,}) does not match the "
238
+ f"number of references ({len(references):,})."
239
+ )
240
+
241
+ # Prepare the messages for the LLM
242
+ conversations: list[list[dict[str, str]]] = [
243
+ [
244
+ dict(
245
+ role="user",
246
+ content=self._apply_user_prompt(
247
+ prediction=prediction, condition=condition
248
+ ),
249
+ )
250
+ ]
251
+ for prediction, condition in zip(predictions, references)
252
+ ]
253
+ if self.system_prompt:
254
+ conversations = [
255
+ [dict(role="system", content=self.system_prompt), *conversation]
256
+ for conversation in conversations
257
+ ]
258
+
259
+ # Get the judge generations
260
+ generations = [
261
+ litellm.completion(
262
+ model=self.judge_id,
263
+ messages=conversation,
264
+ response_format=self.response_format,
265
+ **self.judge_kwargs,
266
+ )
267
+ for conversation in tqdm(
268
+ iterable=conversations,
269
+ desc=f"Computing {self.pretty_name} scores",
270
+ unit="sample",
271
+ )
272
+ ]
273
+
274
+ # Extract the outputs from the generations
275
+ outputs: list[BaseModel] = list()
276
+ for generation in generations:
277
+ assert isinstance(generation, ModelResponse), (
278
+ f"The judge model did not return a valid response: {generation!r}"
279
+ )
280
+ choice = generation.choices[0]
281
+ assert isinstance(choice, Choices), (
282
+ f"The judge model did not return a valid choice: {choice!r}"
283
+ )
284
+ json_content = choice.message.content
285
+ assert json_content is not None, (
286
+ "The judge model returned a None content in the response message."
287
+ )
288
+ output = self.response_format.model_validate_json(json_data=json_content)
289
+ outputs.append(output)
290
+
291
+ # Calculate the scores using the scoring function
292
+ scores = [self.scoring_fn(output) for output in outputs]
293
+ if not scores:
294
+ logger.warning(f"No scores were calculated for {self.pretty_name}.")
295
+ return None
296
+ return sum(scores) / len(scores)
297
+
298
+ def _apply_user_prompt(self, prediction: str, condition: str | None = None) -> str:
299
+ """Apply the user prompt to the prediction and condition.
300
+
301
+ Args:
302
+ prediction:
303
+ The model prediction.
304
+ condition (optional):
305
+ A description of what the prediction should be judged on. If not
306
+ provided, it will be omitted from the prompt.
307
+
308
+ Returns:
309
+ The formatted user prompt with the prediction and reference.
310
+
311
+ Raises:
312
+ InvalidBenchmark:
313
+ If the user prompt requires a reference but none is provided.
314
+ """
315
+ condition_required = "{condition}" in self.user_prompt
316
+ if condition_required and condition is None:
317
+ raise InvalidBenchmark(
318
+ f"The user prompt for the {self.pretty_name!r} metric requires a "
319
+ "condition, but none was provided."
320
+ )
321
+ if condition is not None:
322
+ return self.user_prompt.format(
323
+ prediction=prediction, condition=self.condition_formatting_fn(condition)
324
+ )
325
+ return self.user_prompt.format(prediction=prediction)
326
+
327
+
328
+ class SpeedMetric(Metric):
329
+ """Speed metric."""
330
+
331
+ def __init__(self, name: str, pretty_name: str) -> None:
332
+ """Initialise the speed metric.
333
+
334
+ Args:
335
+ name:
336
+ The name of the metric in snake_case.
337
+ pretty_name:
338
+ The pretty name of the metric, used for display purposes.
339
+ """
340
+ super().__init__(
341
+ name=name,
342
+ pretty_name=pretty_name,
343
+ postprocessing_fn=lambda raw_score: (raw_score, f"{raw_score:,.0f}"),
344
+ )
345
+
346
+ def __call__(self, _: t.Sequence, __: t.Sequence) -> float | None:
347
+ """Not used with the speed metric, but required for consistency."""
348
+ raise NotImplementedError
349
+
350
+
351
+ mcc_metric = HuggingFaceMetric(
352
+ name="mcc",
353
+ pretty_name="Matthew's Correlation Coefficient",
354
+ huggingface_id="matthews_correlation",
355
+ results_key="matthews_correlation",
356
+ )
357
+
358
+ macro_f1_metric = HuggingFaceMetric(
359
+ name="macro_f1",
360
+ pretty_name="Macro-average F1-score",
361
+ huggingface_id="f1",
362
+ results_key="f1",
363
+ compute_kwargs=dict(average="macro"),
364
+ )
365
+
366
+ micro_f1_metric = HuggingFaceMetric(
367
+ name="micro_f1",
368
+ pretty_name="Micro-average F1-score with MISC tags",
369
+ huggingface_id="seqeval",
370
+ results_key="overall_f1",
371
+ )
372
+
373
+ micro_f1_no_misc_metric = HuggingFaceMetric(
374
+ name="micro_f1_no_misc",
375
+ pretty_name="Micro-average F1-score without MISC tags",
376
+ huggingface_id="seqeval",
377
+ results_key="overall_f1",
378
+ )
379
+
380
+ f1_metric = HuggingFaceMetric(
381
+ name="f1",
382
+ pretty_name="F1-score",
383
+ huggingface_id="squad_v2",
384
+ results_key="f1",
385
+ postprocessing_fn=lambda x: (x, f"{x:.2f}%"),
386
+ )
387
+
388
+ em_metric = HuggingFaceMetric(
389
+ name="em",
390
+ pretty_name="Exact Match",
391
+ huggingface_id="squad_v2",
392
+ results_key="exact",
393
+ postprocessing_fn=lambda x: (x, f"{x:.2f}%"),
394
+ )
395
+
396
+ bert_score_metric = HuggingFaceMetric(
397
+ name="bertscore",
398
+ pretty_name="BERTScore",
399
+ huggingface_id="bertscore",
400
+ results_key="f1",
401
+ compute_kwargs=dict(
402
+ model_type="microsoft/mdeberta-v3-base", device="auto", batch_size=1
403
+ ),
404
+ )
405
+
406
+ rouge_l_metric = HuggingFaceMetric(
407
+ name="rouge_l", pretty_name="ROUGE-L", huggingface_id="rouge", results_key="rougeL"
408
+ )
409
+
410
+ accuracy_metric = HuggingFaceMetric(
411
+ name="accuracy",
412
+ pretty_name="Accuracy",
413
+ huggingface_id="accuracy",
414
+ results_key="accuracy",
415
+ )
416
+
417
+
418
+ class Fluency(BaseModel):
419
+ """Response format for the fluency metric.
420
+
421
+ Attributes:
422
+ fluency:
423
+ The fluency rating, an integer between 1 and 5.
424
+ """
425
+
426
+ fluency: t.Annotated[int, Field(ge=1, le=5)]
427
+
428
+
429
+ # Example LLM-as-a-judge metric, to measure the fluency of the LLM output
430
+ fluency_metric = LLMAsAJudgeMetric(
431
+ name="fluency",
432
+ pretty_name="Fluency",
433
+ judge_id="gpt-4o-mini",
434
+ judge_kwargs=dict(temperature=0.0),
435
+ user_prompt="Please rate the fluency of the following text on a scale from 1 to 5, "
436
+ "with the following definitions:\n"
437
+ "- 1: Very poor fluency, many grammatical errors\n"
438
+ "- 2: Poor fluency, several grammatical errors\n"
439
+ "- 3: Average fluency, a few grammatical errors\n"
440
+ "- 4: Good fluency, no grammatical errors but sounds a bit off\n"
441
+ "- 5: Excellent fluency, no grammatical errors and sounds natural\n\n"
442
+ "Text: {prediction!r}\n\n"
443
+ "Output your rating as a JSON object with a single key 'fluency'.",
444
+ response_format=Fluency,
445
+ scoring_fn=lambda output: (output.fluency - 1) / 4.0,
446
+ )
447
+
448
+ speed_metric = SpeedMetric(name="speed", pretty_name="Tokens per second")
449
+
450
+ speed_short_metric = SpeedMetric(
451
+ name="speed_short", pretty_name="Tokens per second on short documents"
452
+ )
euroeval/scores.py CHANGED
@@ -7,7 +7,7 @@ import warnings
7
7
  import numpy as np
8
8
 
9
9
  if t.TYPE_CHECKING:
10
- from .data_models import MetricConfig
10
+ from .metrics import Metric
11
11
  from .types import ScoreDict
12
12
 
13
13
  logger = logging.getLogger("euroeval")
@@ -15,7 +15,7 @@ logger = logging.getLogger("euroeval")
15
15
 
16
16
  def log_scores(
17
17
  dataset_name: str,
18
- metric_configs: list["MetricConfig"],
18
+ metrics: list["Metric"],
19
19
  scores: list[dict[str, float]],
20
20
  model_id: str,
21
21
  model_revision: str,
@@ -25,7 +25,7 @@ def log_scores(
25
25
  Args:
26
26
  dataset_name:
27
27
  Name of the dataset.
28
- metric_configs:
28
+ metrics:
29
29
  List of metrics to log.
30
30
  scores:
31
31
  The scores that are to be logged. This is a list of dictionaries full of
@@ -46,19 +46,19 @@ def log_scores(
46
46
  logger.info(f"Finished evaluation of {model_id} on {dataset_name}.")
47
47
 
48
48
  total_dict: dict[str, float] = dict()
49
- for metric_cfg in metric_configs:
50
- test_score, test_se = aggregate_scores(scores=scores, metric_config=metric_cfg)
51
- test_score, test_score_str = metric_cfg.postprocessing_fn(test_score)
52
- test_se, test_se_str = metric_cfg.postprocessing_fn(test_se)
53
- total_dict[f"test_{metric_cfg.name}"] = test_score
54
- total_dict[f"test_{metric_cfg.name}_se"] = test_se
55
- logger.info(f"{metric_cfg.pretty_name}: {test_score_str} ± {test_se_str}")
49
+ for metric in metrics:
50
+ test_score, test_se = aggregate_scores(scores=scores, metric=metric)
51
+ test_score, test_score_str = metric.postprocessing_fn(test_score)
52
+ test_se, test_se_str = metric.postprocessing_fn(test_se)
53
+ total_dict[f"test_{metric.name}"] = test_score
54
+ total_dict[f"test_{metric.name}_se"] = test_se
55
+ logger.info(f"{metric.pretty_name}: {test_score_str} ± {test_se_str}")
56
56
 
57
57
  return dict(raw=scores, total=total_dict)
58
58
 
59
59
 
60
60
  def aggregate_scores(
61
- scores: list[dict[str, float]], metric_config: "MetricConfig"
61
+ scores: list[dict[str, float]], metric: "Metric"
62
62
  ) -> tuple[float, float]:
63
63
  """Helper function to compute the mean with confidence intervals.
64
64
 
@@ -66,9 +66,8 @@ def aggregate_scores(
66
66
  scores:
67
67
  Dictionary with the names of the metrics as keys, of the form
68
68
  "<split>_<metric_name>", such as "val_f1", and values the metric values.
69
- metric_config:
70
- The configuration of the metric, which is used to collect the correct
71
- metric from `scores`.
69
+ metric:
70
+ The metric, which is used to collect the correct metric from `scores`.
72
71
 
73
72
  Returns:
74
73
  A pair of floats, containing the score and the radius of its 95% confidence
@@ -78,11 +77,7 @@ def aggregate_scores(
78
77
  warnings.simplefilter("ignore")
79
78
 
80
79
  test_scores = [
81
- (
82
- dct[metric_config.name]
83
- if metric_config.name in dct
84
- else dct[f"test_{metric_config.name}"]
85
- )
80
+ dct[metric.name] if metric.name in dct else dct[f"test_{metric.name}"]
86
81
  for dct in scores
87
82
  ]
88
83
  test_score = np.mean(test_scores).item()
@@ -1,21 +1,20 @@
1
1
  """Benchmarking model inference speed."""
2
2
 
3
3
  import logging
4
+ import typing as t
4
5
 
5
6
  import pyinfer
6
7
  from tqdm.auto import tqdm
7
8
  from transformers.models.auto.tokenization_auto import AutoTokenizer
8
9
 
9
- from .benchmark_modules import (
10
- BenchmarkModule,
11
- HuggingFaceEncoderModel,
12
- LiteLLMModel,
13
- VLLMModel,
14
- )
15
- from .data_models import BenchmarkConfig
10
+ from .benchmark_modules import HuggingFaceEncoderModel, LiteLLMModel, VLLMModel
16
11
  from .exceptions import InvalidBenchmark
17
12
  from .utils import clear_memory
18
13
 
14
+ if t.TYPE_CHECKING:
15
+ from .benchmark_modules import BenchmarkModule
16
+ from .data_models import BenchmarkConfig
17
+
19
18
  logger = logging.getLogger("euroeval")
20
19
 
21
20
 
@@ -7,14 +7,15 @@ import typing as t
7
7
  from collections import defaultdict
8
8
 
9
9
  import numpy as np
10
- from datasets import Dataset
11
- from transformers.tokenization_utils import PreTrainedTokenizer
12
- from transformers.tokenization_utils_base import BatchEncoding
13
10
  from transformers.trainer import Trainer
14
11
 
15
12
  from ..exceptions import InvalidBenchmark
16
13
 
17
14
  if t.TYPE_CHECKING:
15
+ from datasets import Dataset
16
+ from transformers.tokenization_utils import PreTrainedTokenizer
17
+ from transformers.tokenization_utils_base import BatchEncoding
18
+
18
19
  from ..types import Labels, Predictions
19
20
 
20
21
  logger = logging.getLogger("euroeval")
@@ -147,7 +148,8 @@ def postprocess_predictions_and_labels(
147
148
 
148
149
  Args:
149
150
  predictions:
150
- The model predictions, of shape (num_examples, 2).
151
+ The model predictions, of shape (num_examples, 2), corresponding to the
152
+ False/True probabilities for each example.
151
153
  dataset:
152
154
  The dataset containing the examples.
153
155
 
@@ -5,13 +5,10 @@ import logging
5
5
  import typing as t
6
6
  from collections import defaultdict
7
7
 
8
- import evaluate
9
8
  import numpy as np
10
- from evaluate import EvaluationModule
11
9
  from transformers.tokenization_utils_base import PreTrainedTokenizerBase
12
10
  from transformers.trainer import Trainer
13
11
 
14
- from ..data_models import BenchmarkConfig, DatasetConfig, GenerativeModelOutput
15
12
  from ..exceptions import InvalidBenchmark
16
13
  from ..tokenization_utils import get_special_token_metadata
17
14
  from ..utils import raise_if_model_output_contains_nan_values
@@ -26,6 +23,7 @@ if t.TYPE_CHECKING:
26
23
  from transformers.trainer_utils import EvalPrediction
27
24
  from transformers.training_args import TrainingArguments
28
25
 
26
+ from ..data_models import DatasetConfig, GenerativeModelOutput
29
27
  from ..types import Labels, Predictions
30
28
 
31
29
  logger = logging.getLogger("euroeval")
@@ -151,7 +149,6 @@ class QuestionAnsweringTrainer(Trainer):
151
149
  def compute_metrics(
152
150
  model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
153
151
  dataset_config: "DatasetConfig",
154
- benchmark_config: "BenchmarkConfig",
155
152
  ) -> dict[str, float]:
156
153
  """Compute the metrics needed for evaluation.
157
154
 
@@ -161,8 +158,6 @@ def compute_metrics(
161
158
  contains the true labels.
162
159
  dataset_config:
163
160
  The configuration of the dataset.
164
- benchmark_config:
165
- The configuration of the benchmark.
166
161
 
167
162
  Returns:
168
163
  A dictionary with the names of the metrics as keys and the metric values as
@@ -178,17 +173,6 @@ def compute_metrics(
178
173
  assert not isinstance(model_outputs, tuple)
179
174
  raise_if_model_output_contains_nan_values(model_output=model_outputs)
180
175
 
181
- metrics = {
182
- metric_cfg.name: (
183
- evaluate.load(
184
- path=metric_cfg.huggingface_id, cache_dir=benchmark_config.cache_dir
185
- )
186
- if metric_cfg.huggingface_id != ""
187
- else None
188
- )
189
- for metric_cfg in dataset_config.task.metrics
190
- }
191
-
192
176
  model_output_dtype = np.asarray(model_outputs).dtype
193
177
  if model_output_dtype in [np.float16, np.float32, np.float64]:
194
178
  predictions = np.asarray(model_outputs).argmax(axis=-1)
@@ -196,20 +180,13 @@ def compute_metrics(
196
180
  predictions = model_outputs
197
181
 
198
182
  results: dict[str, float] = dict()
199
- for cfg in dataset_config.task.metrics:
200
- metric = metrics[cfg.name]
201
- assert isinstance(metric, EvaluationModule)
202
- score_dict: dict[str, float] | None = metric.compute(
203
- predictions=predictions, references=labels, **cfg.compute_kwargs
204
- )
183
+ for metric in dataset_config.task.metrics:
184
+ score: float | None = metric(predictions=predictions, references=labels)
205
185
 
206
186
  # The metric returns None if we are running on multi-GPU and the current
207
187
  # process is not the main process
208
- if score_dict is not None:
209
- scores = score_dict[cfg.results_key]
210
- if isinstance(scores, list):
211
- scores = sum(scores) / len(scores)
212
- results[cfg.name] = scores
188
+ if score is not None:
189
+ results[metric.name] = score
213
190
 
214
191
  return results
215
192