EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""Utility functions related to logging."""
|
|
2
|
+
|
|
3
|
+
import datetime as dt
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import warnings
|
|
8
|
+
from io import TextIOWrapper
|
|
9
|
+
|
|
10
|
+
import litellm
|
|
11
|
+
from datasets.utils import disable_progress_bars as disable_datasets_progress_bars
|
|
12
|
+
from evaluate import disable_progress_bar as disable_evaluate_progress_bar
|
|
13
|
+
from huggingface_hub.utils.tqdm import (
|
|
14
|
+
disable_progress_bars as disable_hf_hub_progress_bars,
|
|
15
|
+
)
|
|
16
|
+
from termcolor import colored
|
|
17
|
+
from tqdm.auto import tqdm
|
|
18
|
+
from transformers import logging as tf_logging
|
|
19
|
+
|
|
20
|
+
from .caching_utils import cache_arguments
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger("euroeval")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_pbar(*tqdm_args, **tqdm_kwargs) -> tqdm:
|
|
26
|
+
"""Get a progress bar for vLLM with custom hard-coded arguments.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
*tqdm_args:
|
|
30
|
+
Positional arguments to pass to tqdm.
|
|
31
|
+
**tqdm_kwargs:
|
|
32
|
+
Additional keyword arguments to pass to tqdm.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A tqdm progress bar.
|
|
36
|
+
"""
|
|
37
|
+
tqdm_kwargs = dict(colour="yellow", ascii="—▰", leave=False) | tqdm_kwargs
|
|
38
|
+
tqdm_kwargs["desc"] = colored(
|
|
39
|
+
text=tqdm_kwargs.get("desc", "Processing"), color="light_yellow"
|
|
40
|
+
)
|
|
41
|
+
return tqdm(*tqdm_args, **tqdm_kwargs)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def log(message: str, level: int, colour: str | None = None) -> None:
|
|
45
|
+
"""Log a message.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
message:
|
|
49
|
+
The message to log.
|
|
50
|
+
level:
|
|
51
|
+
The logging level. Defaults to logging.INFO.
|
|
52
|
+
colour:
|
|
53
|
+
The colour to use for the message. If None, a default colour will be used
|
|
54
|
+
based on the logging level.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError:
|
|
58
|
+
If the logging level is invalid.
|
|
59
|
+
"""
|
|
60
|
+
match level:
|
|
61
|
+
case logging.DEBUG:
|
|
62
|
+
message = colored(
|
|
63
|
+
text=(
|
|
64
|
+
"[DEBUG] "
|
|
65
|
+
+ dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
66
|
+
+ f" · {message}"
|
|
67
|
+
),
|
|
68
|
+
color=colour or "light_blue",
|
|
69
|
+
)
|
|
70
|
+
logger.debug(message)
|
|
71
|
+
case logging.INFO:
|
|
72
|
+
if colour is not None:
|
|
73
|
+
message = colored(text=message, color=colour)
|
|
74
|
+
logger.info(message)
|
|
75
|
+
case logging.WARNING:
|
|
76
|
+
message = colored(text=message, color=colour or "light_red")
|
|
77
|
+
logger.warning(message)
|
|
78
|
+
case logging.ERROR:
|
|
79
|
+
message = colored(text=message, color=colour or "red")
|
|
80
|
+
logger.error(message)
|
|
81
|
+
case logging.CRITICAL:
|
|
82
|
+
message = colored(text=message, color=colour or "red")
|
|
83
|
+
logger.critical(message)
|
|
84
|
+
case _:
|
|
85
|
+
raise ValueError(f"Invalid logging level: {level}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@cache_arguments("message")
|
|
89
|
+
def log_once(message: str, level: int = logging.INFO, prefix: str = "") -> None:
|
|
90
|
+
"""Log a message once.
|
|
91
|
+
|
|
92
|
+
This is ensured by caching the "message" argument and only logging it the first time
|
|
93
|
+
this function is called with that message.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
message:
|
|
97
|
+
The message to log.
|
|
98
|
+
level:
|
|
99
|
+
The logging level. Defaults to logging.INFO.
|
|
100
|
+
prefix:
|
|
101
|
+
A prefix to add to the message, which is not considered when determining if
|
|
102
|
+
the message has been logged before.
|
|
103
|
+
"""
|
|
104
|
+
log(message=prefix + message, level=level)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def block_terminal_output() -> None:
|
|
108
|
+
"""Blocks libraries from writing output to the terminal.
|
|
109
|
+
|
|
110
|
+
This filters warnings from some libraries, sets the logging level to ERROR for some
|
|
111
|
+
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
|
|
112
|
+
disables most of the logging from the `transformers` library.
|
|
113
|
+
"""
|
|
114
|
+
if os.getenv("FULL_LOG") == "1":
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# Ignore miscellaneous warnings
|
|
118
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
119
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
120
|
+
logging.getLogger("absl").setLevel(logging.CRITICAL)
|
|
121
|
+
|
|
122
|
+
# Disable matplotlib logging
|
|
123
|
+
logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
|
|
124
|
+
|
|
125
|
+
# Disable PyTorch logging
|
|
126
|
+
logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
|
|
127
|
+
warnings.filterwarnings(action="ignore", module="torch*")
|
|
128
|
+
os.environ["TORCH_LOGS"] = "-all"
|
|
129
|
+
|
|
130
|
+
# Disable huggingface_hub logging
|
|
131
|
+
logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
|
|
132
|
+
disable_hf_hub_progress_bars()
|
|
133
|
+
|
|
134
|
+
# Disable LiteLLM logging
|
|
135
|
+
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
|
136
|
+
logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
|
|
137
|
+
logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
|
|
138
|
+
logging.getLogger("openai").setLevel(logging.CRITICAL)
|
|
139
|
+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
140
|
+
litellm.suppress_debug_info = True
|
|
141
|
+
|
|
142
|
+
# Disable vLLM logging
|
|
143
|
+
logging.getLogger("vllm").setLevel(logging.CRITICAL)
|
|
144
|
+
logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
|
|
145
|
+
logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
|
|
146
|
+
logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
|
|
147
|
+
logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
|
|
148
|
+
logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
|
|
149
|
+
logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
|
|
150
|
+
logging.CRITICAL
|
|
151
|
+
)
|
|
152
|
+
os.environ["LOG_LEVEL"] = "CRITICAL"
|
|
153
|
+
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
154
|
+
|
|
155
|
+
# Disable flashinfer logging
|
|
156
|
+
os.environ["FLASHINFER_LOGGING_LEVEL"] = "CRITICAL"
|
|
157
|
+
|
|
158
|
+
# Disable datasets logging
|
|
159
|
+
logging.getLogger("datasets").setLevel(logging.CRITICAL)
|
|
160
|
+
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
|
161
|
+
disable_datasets_progress_bars()
|
|
162
|
+
|
|
163
|
+
# Disable evaluate logging
|
|
164
|
+
warnings.filterwarnings("ignore", module="seqeval*")
|
|
165
|
+
disable_evaluate_progress_bar()
|
|
166
|
+
|
|
167
|
+
# Disable most of the `transformers` logging
|
|
168
|
+
tf_logging._default_log_level = logging.CRITICAL
|
|
169
|
+
tf_logging.set_verbosity(logging.CRITICAL)
|
|
170
|
+
logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
|
|
171
|
+
logging.getLogger("accelerate").setLevel(logging.CRITICAL)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class no_terminal_output:
|
|
175
|
+
"""Context manager that suppresses all terminal output."""
|
|
176
|
+
|
|
177
|
+
def __init__(self, disable: bool = False) -> None:
|
|
178
|
+
"""Initialise the context manager.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
disable:
|
|
182
|
+
If True, this context manager does nothing.
|
|
183
|
+
"""
|
|
184
|
+
self.disable = disable
|
|
185
|
+
self.devnull_file: TextIOWrapper | None = None
|
|
186
|
+
self._original_stdout_fd: int | None = None
|
|
187
|
+
self._original_stderr_fd: int | None = None
|
|
188
|
+
|
|
189
|
+
def _log_windows_warning(self) -> None:
|
|
190
|
+
"""Log a warning about Windows not supporting blocking terminal output."""
|
|
191
|
+
log_once(
|
|
192
|
+
"Your operating system (probably Windows) does not support blocking "
|
|
193
|
+
"terminal output, so expect more messy output - sorry!",
|
|
194
|
+
level=logging.WARNING,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def __enter__(self) -> None:
|
|
198
|
+
"""Suppress all terminal output."""
|
|
199
|
+
if self.disable:
|
|
200
|
+
return
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
# Save original FDs by duplicating them
|
|
204
|
+
self._original_stdout_fd = os.dup(sys.stdout.fileno())
|
|
205
|
+
self._original_stderr_fd = os.dup(sys.stderr.fileno())
|
|
206
|
+
|
|
207
|
+
# Open /dev/null
|
|
208
|
+
self.devnull_file = open(os.devnull, "w")
|
|
209
|
+
|
|
210
|
+
# Redirect stdout/stderr to /dev/null
|
|
211
|
+
os.dup2(self.devnull_file.fileno(), sys.stdout.fileno())
|
|
212
|
+
os.dup2(self.devnull_file.fileno(), sys.stderr.fileno())
|
|
213
|
+
|
|
214
|
+
except OSError:
|
|
215
|
+
self._log_windows_warning()
|
|
216
|
+
# If setup fails, clean up any resources we might have acquired
|
|
217
|
+
self.__exit__(None, None, None)
|
|
218
|
+
|
|
219
|
+
def __exit__(
|
|
220
|
+
self,
|
|
221
|
+
exc_type: type[BaseException] | None,
|
|
222
|
+
exc_val: BaseException | None,
|
|
223
|
+
exc_tb: type[BaseException] | None,
|
|
224
|
+
) -> None:
|
|
225
|
+
"""Re-enable terminal output."""
|
|
226
|
+
if self.disable:
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
# Restore stdout/stderr from our saved FDs
|
|
230
|
+
try:
|
|
231
|
+
if self._original_stdout_fd is not None:
|
|
232
|
+
os.dup2(self._original_stdout_fd, sys.stdout.fileno())
|
|
233
|
+
if self._original_stderr_fd is not None:
|
|
234
|
+
os.dup2(self._original_stderr_fd, sys.stderr.fileno())
|
|
235
|
+
except OSError:
|
|
236
|
+
self._log_windows_warning()
|
|
237
|
+
finally:
|
|
238
|
+
# Close the duplicated FDs we created
|
|
239
|
+
if self._original_stdout_fd is not None:
|
|
240
|
+
os.close(self._original_stdout_fd)
|
|
241
|
+
if self._original_stderr_fd is not None:
|
|
242
|
+
os.close(self._original_stderr_fd)
|
|
243
|
+
|
|
244
|
+
# Close the /dev/null file
|
|
245
|
+
if self.devnull_file is not None:
|
|
246
|
+
self.devnull_file.close()
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def adjust_logging_level(verbose: bool, ignore_testing: bool = False) -> int:
|
|
250
|
+
"""Adjust the logging level based on verbosity.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
verbose:
|
|
254
|
+
Whether to output additional output.
|
|
255
|
+
ignore_testing:
|
|
256
|
+
Whether to ignore the testing flag.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
The logging level that was set.
|
|
260
|
+
"""
|
|
261
|
+
if hasattr(sys, "_called_from_test") and not ignore_testing:
|
|
262
|
+
logging_level = logging.CRITICAL
|
|
263
|
+
elif verbose:
|
|
264
|
+
logging_level = logging.DEBUG
|
|
265
|
+
else:
|
|
266
|
+
logging_level = logging.INFO
|
|
267
|
+
logger.setLevel(logging_level)
|
|
268
|
+
return logging_level
|
euroeval/metrics/base.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""The abstract base class for all metrics."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import collections.abc as c
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
if t.TYPE_CHECKING:
|
|
8
|
+
from datasets.arrow_dataset import Dataset
|
|
9
|
+
|
|
10
|
+
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Metric(abc.ABC):
|
|
14
|
+
"""Abstract base class for all metrics."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
name: str,
|
|
19
|
+
pretty_name: str,
|
|
20
|
+
postprocessing_fn: t.Callable[[float], tuple[float, str]] | None = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initialise the metric.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
name:
|
|
26
|
+
The name of the metric in snake_case.
|
|
27
|
+
pretty_name:
|
|
28
|
+
The pretty name of the metric, used for display purposes.
|
|
29
|
+
postprocessing_fn:
|
|
30
|
+
A function to apply to the metric scores after they are computed,
|
|
31
|
+
taking the score to the postprocessed score along with its string
|
|
32
|
+
representation. Defaults to x -> (100 * x, f"{x:.2%}").
|
|
33
|
+
"""
|
|
34
|
+
self.name = name
|
|
35
|
+
self.pretty_name = pretty_name
|
|
36
|
+
self.postprocessing_fn = (
|
|
37
|
+
postprocessing_fn
|
|
38
|
+
if postprocessing_fn is not None
|
|
39
|
+
else lambda x: (100 * x, f"{x:.2%}")
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def download(self, cache_dir: str) -> "Metric":
|
|
43
|
+
"""Initiates the download of the metric if needed.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
cache_dir:
|
|
47
|
+
The directory where the metric will be downloaded to.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The metric object itself.
|
|
51
|
+
"""
|
|
52
|
+
return self
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def __call__(
|
|
56
|
+
self,
|
|
57
|
+
predictions: c.Sequence,
|
|
58
|
+
references: c.Sequence,
|
|
59
|
+
dataset: "Dataset",
|
|
60
|
+
dataset_config: "DatasetConfig",
|
|
61
|
+
benchmark_config: "BenchmarkConfig",
|
|
62
|
+
) -> float | None:
|
|
63
|
+
"""Calculate the metric score.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
predictions:
|
|
67
|
+
The model predictions.
|
|
68
|
+
references:
|
|
69
|
+
The ground truth references.
|
|
70
|
+
dataset:
|
|
71
|
+
The dataset used for evaluation. This is only used in case any
|
|
72
|
+
additional metadata is used to compute the metrics.
|
|
73
|
+
dataset_config:
|
|
74
|
+
The dataset configuration.
|
|
75
|
+
benchmark_config:
|
|
76
|
+
The benchmark configuration.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The calculated metric score, or None if the score should be ignored.
|
|
80
|
+
"""
|
|
81
|
+
...
|
|
82
|
+
|
|
83
|
+
def __hash__(self) -> int:
|
|
84
|
+
"""Return a hash of the metric configuration."""
|
|
85
|
+
return hash(self.name)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""All the Hugging Face metrics used in EuroEval."""
|
|
2
|
+
|
|
3
|
+
import collections.abc as c
|
|
4
|
+
import typing as t
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import evaluate
|
|
8
|
+
import numpy as np
|
|
9
|
+
from datasets import DownloadConfig, DownloadMode
|
|
10
|
+
|
|
11
|
+
from ..logging_utils import no_terminal_output
|
|
12
|
+
from .base import Metric
|
|
13
|
+
|
|
14
|
+
if t.TYPE_CHECKING:
|
|
15
|
+
from datasets.arrow_dataset import Dataset
|
|
16
|
+
from evaluate import EvaluationModule
|
|
17
|
+
|
|
18
|
+
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HuggingFaceMetric(Metric):
|
|
22
|
+
"""A metric which is implemented in the `evaluate` package.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
name:
|
|
26
|
+
The name of the metric in snake_case.
|
|
27
|
+
pretty_name:
|
|
28
|
+
The pretty name of the metric, used for display purposes.
|
|
29
|
+
huggingface_id:
|
|
30
|
+
The Hugging Face ID of the metric.
|
|
31
|
+
results_key:
|
|
32
|
+
The name of the key used to extract the metric scores from the results
|
|
33
|
+
dictionary.
|
|
34
|
+
compute_kwargs:
|
|
35
|
+
Keyword arguments to pass to the metric's compute function. Defaults to
|
|
36
|
+
an empty dictionary.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
name: str,
|
|
42
|
+
pretty_name: str,
|
|
43
|
+
huggingface_id: str,
|
|
44
|
+
results_key: str,
|
|
45
|
+
compute_kwargs: dict[str, t.Any] | None = None,
|
|
46
|
+
postprocessing_fn: t.Callable[[float], tuple[float, str]] | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialise the Hugging Face metric.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
name:
|
|
52
|
+
The name of the metric in snake_case.
|
|
53
|
+
pretty_name:
|
|
54
|
+
The pretty name of the metric, used for display purposes.
|
|
55
|
+
huggingface_id:
|
|
56
|
+
The Hugging Face ID of the metric.
|
|
57
|
+
results_key:
|
|
58
|
+
The name of the key used to extract the metric scores from the results
|
|
59
|
+
dictionary.
|
|
60
|
+
compute_kwargs:
|
|
61
|
+
Keyword arguments to pass to the metric's compute function. Defaults to
|
|
62
|
+
an empty dictionary.
|
|
63
|
+
postprocessing_fn:
|
|
64
|
+
A function to apply to the metric scores after they are computed, taking
|
|
65
|
+
the score to the postprocessed score along with its string
|
|
66
|
+
representation. Defaults to x -> (100 * x, f"{x:.2%}").
|
|
67
|
+
"""
|
|
68
|
+
super().__init__(
|
|
69
|
+
name=name, pretty_name=pretty_name, postprocessing_fn=postprocessing_fn
|
|
70
|
+
)
|
|
71
|
+
self.huggingface_id = huggingface_id
|
|
72
|
+
self.results_key = results_key
|
|
73
|
+
self.compute_kwargs: dict[str, t.Any] = (
|
|
74
|
+
dict() if compute_kwargs is None else compute_kwargs
|
|
75
|
+
)
|
|
76
|
+
self.metric: "EvaluationModule | None" = None
|
|
77
|
+
|
|
78
|
+
def download(self, cache_dir: str) -> "HuggingFaceMetric":
|
|
79
|
+
"""Initiates the download of the metric if needed.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
cache_dir:
|
|
83
|
+
The directory where the metric will be downloaded to.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The metric object itself.
|
|
87
|
+
"""
|
|
88
|
+
metric_cache_dir = Path(cache_dir) / "metrics"
|
|
89
|
+
download_config = DownloadConfig(cache_dir=metric_cache_dir)
|
|
90
|
+
self.metric = evaluate.load(
|
|
91
|
+
path=self.huggingface_id,
|
|
92
|
+
download_config=download_config,
|
|
93
|
+
download_mode=DownloadMode.REUSE_CACHE_IF_EXISTS,
|
|
94
|
+
cache_dir=metric_cache_dir.as_posix(),
|
|
95
|
+
)
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
def __call__(
|
|
99
|
+
self,
|
|
100
|
+
predictions: c.Sequence,
|
|
101
|
+
references: c.Sequence,
|
|
102
|
+
dataset: "Dataset",
|
|
103
|
+
dataset_config: "DatasetConfig",
|
|
104
|
+
benchmark_config: "BenchmarkConfig",
|
|
105
|
+
) -> float | None:
|
|
106
|
+
"""Calculate the metric score.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
predictions:
|
|
110
|
+
The model predictions.
|
|
111
|
+
references:
|
|
112
|
+
The ground truth references.
|
|
113
|
+
dataset:
|
|
114
|
+
The dataset used for evaluation. This is only used in case any
|
|
115
|
+
additional metadata is used to compute the metrics.
|
|
116
|
+
dataset_config:
|
|
117
|
+
The dataset configuration.
|
|
118
|
+
benchmark_config:
|
|
119
|
+
The benchmark configuration.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
The calculated metric score, or None if the score should be ignored.
|
|
123
|
+
"""
|
|
124
|
+
if self.metric is None:
|
|
125
|
+
self.download(cache_dir=benchmark_config.cache_dir)
|
|
126
|
+
|
|
127
|
+
assert self.metric is not None, (
|
|
128
|
+
"Metric has not been downloaded. Please call download() before using the "
|
|
129
|
+
"__call__ method."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
with no_terminal_output(disable=benchmark_config.verbose):
|
|
133
|
+
results = self.metric.compute(
|
|
134
|
+
predictions=predictions, references=references, **self.compute_kwargs
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# The metric returns None if we are running on multi-GPU and the current
|
|
138
|
+
# process is not the main process
|
|
139
|
+
if results is None:
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
# Convert the results to a float score
|
|
143
|
+
score = results[self.results_key]
|
|
144
|
+
if isinstance(score, list):
|
|
145
|
+
score = sum(score) / len(score)
|
|
146
|
+
if isinstance(score, np.floating):
|
|
147
|
+
score = float(score)
|
|
148
|
+
|
|
149
|
+
return score
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
mcc_metric = HuggingFaceMetric(
|
|
153
|
+
name="mcc",
|
|
154
|
+
pretty_name="Matthew's Correlation Coefficient",
|
|
155
|
+
huggingface_id="matthews_correlation",
|
|
156
|
+
results_key="matthews_correlation",
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
macro_f1_metric = HuggingFaceMetric(
|
|
160
|
+
name="macro_f1",
|
|
161
|
+
pretty_name="Macro-average F1-score",
|
|
162
|
+
huggingface_id="f1",
|
|
163
|
+
results_key="f1",
|
|
164
|
+
compute_kwargs=dict(average="macro"),
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
micro_f1_metric = HuggingFaceMetric(
|
|
168
|
+
name="micro_f1",
|
|
169
|
+
pretty_name="Micro-average F1-score with MISC tags",
|
|
170
|
+
huggingface_id="seqeval",
|
|
171
|
+
results_key="overall_f1",
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
micro_f1_no_misc_metric = HuggingFaceMetric(
|
|
175
|
+
name="micro_f1_no_misc",
|
|
176
|
+
pretty_name="Micro-average F1-score without MISC tags",
|
|
177
|
+
huggingface_id="seqeval",
|
|
178
|
+
results_key="overall_f1",
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
f1_metric = HuggingFaceMetric(
|
|
182
|
+
name="f1",
|
|
183
|
+
pretty_name="F1-score",
|
|
184
|
+
huggingface_id="squad_v2",
|
|
185
|
+
results_key="f1",
|
|
186
|
+
postprocessing_fn=lambda x: (x, f"{x:.2f}%"),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
em_metric = HuggingFaceMetric(
|
|
190
|
+
name="em",
|
|
191
|
+
pretty_name="Exact Match",
|
|
192
|
+
huggingface_id="squad_v2",
|
|
193
|
+
results_key="exact",
|
|
194
|
+
postprocessing_fn=lambda x: (x, f"{x:.2f}%"),
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
bert_score_metric = HuggingFaceMetric(
|
|
198
|
+
name="bertscore",
|
|
199
|
+
pretty_name="BERTScore",
|
|
200
|
+
huggingface_id="bertscore",
|
|
201
|
+
results_key="f1",
|
|
202
|
+
compute_kwargs=dict(
|
|
203
|
+
model_type="microsoft/mdeberta-v3-base", device="auto", batch_size=1
|
|
204
|
+
),
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
rouge_l_metric = HuggingFaceMetric(
|
|
208
|
+
name="rouge_l", pretty_name="ROUGE-L", huggingface_id="rouge", results_key="rougeL"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
accuracy_metric = HuggingFaceMetric(
|
|
212
|
+
name="accuracy",
|
|
213
|
+
pretty_name="Accuracy",
|
|
214
|
+
huggingface_id="accuracy",
|
|
215
|
+
results_key="accuracy",
|
|
216
|
+
)
|