EuroEval 16.3.0__py3-none-any.whl → 16.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +9 -2
- euroeval/benchmark_config_factory.py +51 -50
- euroeval/benchmark_modules/base.py +9 -21
- euroeval/benchmark_modules/fresh.py +2 -1
- euroeval/benchmark_modules/hf.py +101 -71
- euroeval/benchmark_modules/litellm.py +115 -53
- euroeval/benchmark_modules/vllm.py +107 -92
- euroeval/benchmarker.py +144 -121
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +86 -8
- euroeval/constants.py +9 -0
- euroeval/data_loading.py +80 -29
- euroeval/data_models.py +338 -330
- euroeval/dataset_configs/__init__.py +12 -3
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +55 -93
- euroeval/dataset_configs/dutch.py +48 -87
- euroeval/dataset_configs/english.py +45 -77
- euroeval/dataset_configs/estonian.py +42 -34
- euroeval/dataset_configs/faroese.py +19 -60
- euroeval/dataset_configs/finnish.py +36 -69
- euroeval/dataset_configs/french.py +39 -75
- euroeval/dataset_configs/german.py +45 -82
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +54 -91
- euroeval/dataset_configs/italian.py +42 -79
- euroeval/dataset_configs/latvian.py +28 -35
- euroeval/dataset_configs/lithuanian.py +28 -26
- euroeval/dataset_configs/norwegian.py +72 -115
- euroeval/dataset_configs/polish.py +33 -61
- euroeval/dataset_configs/portuguese.py +33 -66
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/spanish.py +42 -77
- euroeval/dataset_configs/swedish.py +52 -90
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/exceptions.py +1 -1
- euroeval/finetuning.py +24 -17
- euroeval/generation.py +15 -14
- euroeval/generation_utils.py +8 -8
- euroeval/languages.py +395 -323
- euroeval/logging_utils.py +250 -0
- euroeval/metrics/base.py +0 -3
- euroeval/metrics/huggingface.py +21 -6
- euroeval/metrics/llm_as_a_judge.py +6 -4
- euroeval/metrics/pipeline.py +17 -9
- euroeval/metrics/speed.py +0 -3
- euroeval/model_cache.py +17 -19
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +99 -42
- euroeval/prompt_templates/multiple_choice.py +102 -38
- euroeval/prompt_templates/named_entity_recognition.py +172 -51
- euroeval/prompt_templates/reading_comprehension.py +119 -42
- euroeval/prompt_templates/sentiment_classification.py +110 -40
- euroeval/prompt_templates/summarization.py +85 -40
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +11 -10
- euroeval/speed_benchmark.py +5 -6
- euroeval/task_group_utils/multiple_choice_classification.py +2 -4
- euroeval/task_group_utils/question_answering.py +24 -16
- euroeval/task_group_utils/sequence_classification.py +48 -35
- euroeval/task_group_utils/text_to_text.py +19 -9
- euroeval/task_group_utils/token_classification.py +21 -17
- euroeval/tasks.py +44 -1
- euroeval/tokenisation_utils.py +33 -22
- euroeval/types.py +10 -9
- euroeval/utils.py +35 -149
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/METADATA +196 -39
- euroeval-16.5.0.dist-info/RECORD +81 -0
- euroeval-16.3.0.dist-info/RECORD +0 -71
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/WHEEL +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Utility functions related to logging."""
|
|
2
|
+
|
|
3
|
+
import datetime as dt
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import warnings
|
|
8
|
+
from io import TextIOWrapper
|
|
9
|
+
|
|
10
|
+
import litellm
|
|
11
|
+
from datasets.utils import disable_progress_bars as disable_datasets_progress_bars
|
|
12
|
+
from evaluate import disable_progress_bar as disable_evaluate_progress_bar
|
|
13
|
+
from huggingface_hub.utils.tqdm import (
|
|
14
|
+
disable_progress_bars as disable_hf_hub_progress_bars,
|
|
15
|
+
)
|
|
16
|
+
from termcolor import colored
|
|
17
|
+
from tqdm.auto import tqdm
|
|
18
|
+
from transformers import logging as tf_logging
|
|
19
|
+
|
|
20
|
+
from .caching_utils import cache_arguments
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger("euroeval")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_pbar(*tqdm_args, **tqdm_kwargs) -> tqdm:
|
|
26
|
+
"""Get a progress bar for vLLM with custom hard-coded arguments.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
*tqdm_args:
|
|
30
|
+
Positional arguments to pass to tqdm.
|
|
31
|
+
**tqdm_kwargs:
|
|
32
|
+
Additional keyword arguments to pass to tqdm.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A tqdm progress bar.
|
|
36
|
+
"""
|
|
37
|
+
tqdm_kwargs = dict(colour="yellow", ascii="—▰", leave=False) | tqdm_kwargs
|
|
38
|
+
tqdm_kwargs["desc"] = colored(
|
|
39
|
+
text=tqdm_kwargs.get("desc", "Processing"), color="light_yellow"
|
|
40
|
+
)
|
|
41
|
+
return tqdm(*tqdm_args, **tqdm_kwargs)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def log(message: str, level: int, colour: str | None = None) -> None:
|
|
45
|
+
"""Log a message.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
message:
|
|
49
|
+
The message to log.
|
|
50
|
+
level:
|
|
51
|
+
The logging level. Defaults to logging.INFO.
|
|
52
|
+
colour:
|
|
53
|
+
The colour to use for the message. If None, a default colour will be used
|
|
54
|
+
based on the logging level.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError:
|
|
58
|
+
If the logging level is invalid.
|
|
59
|
+
"""
|
|
60
|
+
match level:
|
|
61
|
+
case logging.DEBUG:
|
|
62
|
+
message = colored(
|
|
63
|
+
text=(
|
|
64
|
+
"[DEBUG] "
|
|
65
|
+
+ dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
66
|
+
+ f" · {message}"
|
|
67
|
+
),
|
|
68
|
+
color=colour or "light_blue",
|
|
69
|
+
)
|
|
70
|
+
logger.debug(message)
|
|
71
|
+
case logging.INFO:
|
|
72
|
+
if colour is not None:
|
|
73
|
+
message = colored(text=message, color=colour)
|
|
74
|
+
logger.info(message)
|
|
75
|
+
case logging.WARNING:
|
|
76
|
+
message = colored(text=message, color=colour or "light_red")
|
|
77
|
+
logger.warning(message)
|
|
78
|
+
case logging.ERROR:
|
|
79
|
+
message = colored(text=message, color=colour or "red")
|
|
80
|
+
logger.error(message)
|
|
81
|
+
case logging.CRITICAL:
|
|
82
|
+
message = colored(text=message, color=colour or "red")
|
|
83
|
+
logger.critical(message)
|
|
84
|
+
case _:
|
|
85
|
+
raise ValueError(f"Invalid logging level: {level}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@cache_arguments("message")
|
|
89
|
+
def log_once(message: str, level: int = logging.INFO, prefix: str = "") -> None:
|
|
90
|
+
"""Log a message once.
|
|
91
|
+
|
|
92
|
+
This is ensured by caching the "message" argument and only logging it the first time
|
|
93
|
+
this function is called with that message.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
message:
|
|
97
|
+
The message to log.
|
|
98
|
+
level:
|
|
99
|
+
The logging level. Defaults to logging.INFO.
|
|
100
|
+
prefix:
|
|
101
|
+
A prefix to add to the message, which is not considered when determining if
|
|
102
|
+
the message has been logged before.
|
|
103
|
+
"""
|
|
104
|
+
log(message=prefix + message, level=level)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def block_terminal_output() -> None:
|
|
108
|
+
"""Blocks libraries from writing output to the terminal.
|
|
109
|
+
|
|
110
|
+
This filters warnings from some libraries, sets the logging level to ERROR for some
|
|
111
|
+
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
|
|
112
|
+
disables most of the logging from the `transformers` library.
|
|
113
|
+
"""
|
|
114
|
+
if os.getenv("FULL_LOG") == "1":
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# Ignore miscellaneous warnings
|
|
118
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
119
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
120
|
+
logging.getLogger("absl").setLevel(logging.CRITICAL)
|
|
121
|
+
|
|
122
|
+
# Disable matplotlib logging
|
|
123
|
+
logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
|
|
124
|
+
|
|
125
|
+
# Disable PyTorch logging
|
|
126
|
+
logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
|
|
127
|
+
warnings.filterwarnings(action="ignore", module="torch*")
|
|
128
|
+
os.environ["TORCH_LOGS"] = "-all"
|
|
129
|
+
|
|
130
|
+
# Disable huggingface_hub logging
|
|
131
|
+
logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
|
|
132
|
+
disable_hf_hub_progress_bars()
|
|
133
|
+
|
|
134
|
+
# Disable LiteLLM logging
|
|
135
|
+
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
|
136
|
+
logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
|
|
137
|
+
logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
|
|
138
|
+
logging.getLogger("openai").setLevel(logging.CRITICAL)
|
|
139
|
+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
140
|
+
litellm.suppress_debug_info = True
|
|
141
|
+
|
|
142
|
+
# Disable vLLM logging
|
|
143
|
+
logging.getLogger("vllm").setLevel(logging.CRITICAL)
|
|
144
|
+
logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
|
|
145
|
+
logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
|
|
146
|
+
logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
|
|
147
|
+
logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
|
|
148
|
+
logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
|
|
149
|
+
logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
|
|
150
|
+
logging.CRITICAL
|
|
151
|
+
)
|
|
152
|
+
os.environ["LOG_LEVEL"] = "CRITICAL"
|
|
153
|
+
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
154
|
+
|
|
155
|
+
# Disable flashinfer logging
|
|
156
|
+
os.environ["FLASHINFER_LOGGING_LEVEL"] = "CRITICAL"
|
|
157
|
+
|
|
158
|
+
# Disable datasets logging
|
|
159
|
+
logging.getLogger("datasets").setLevel(logging.CRITICAL)
|
|
160
|
+
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
|
161
|
+
disable_datasets_progress_bars()
|
|
162
|
+
|
|
163
|
+
# Disable evaluate logging
|
|
164
|
+
warnings.filterwarnings("ignore", module="seqeval*")
|
|
165
|
+
disable_evaluate_progress_bar()
|
|
166
|
+
|
|
167
|
+
# Disable most of the `transformers` logging
|
|
168
|
+
tf_logging._default_log_level = logging.CRITICAL
|
|
169
|
+
tf_logging.set_verbosity(logging.CRITICAL)
|
|
170
|
+
logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
|
|
171
|
+
logging.getLogger("accelerate").setLevel(logging.CRITICAL)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class no_terminal_output:
|
|
175
|
+
"""Context manager that suppresses all terminal output."""
|
|
176
|
+
|
|
177
|
+
def __init__(self, disable: bool = False) -> None:
|
|
178
|
+
"""Initialise the context manager.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
disable:
|
|
182
|
+
If True, this context manager does nothing.
|
|
183
|
+
"""
|
|
184
|
+
self.disable = disable
|
|
185
|
+
self.nothing_file: TextIOWrapper | None = None
|
|
186
|
+
self._cpp_stdout_file: int | None = None
|
|
187
|
+
self._cpp_stderr_file: int | None = None
|
|
188
|
+
try:
|
|
189
|
+
self._cpp_stdout_file = os.dup(sys.stdout.fileno())
|
|
190
|
+
self._cpp_stderr_file = os.dup(sys.stderr.fileno())
|
|
191
|
+
except OSError:
|
|
192
|
+
self._log_windows_warning()
|
|
193
|
+
|
|
194
|
+
def _log_windows_warning(self) -> None:
|
|
195
|
+
"""Log a warning about Windows not supporting blocking terminal output."""
|
|
196
|
+
log_once(
|
|
197
|
+
"Your operating system (probably Windows) does not support blocking "
|
|
198
|
+
"terminal output, so expect more messy output - sorry!",
|
|
199
|
+
level=logging.WARNING,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def __enter__(self) -> None:
|
|
203
|
+
"""Suppress all terminal output."""
|
|
204
|
+
if not self.disable:
|
|
205
|
+
self.nothing_file = open(os.devnull, "w")
|
|
206
|
+
try:
|
|
207
|
+
os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stdout.fileno())
|
|
208
|
+
os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stderr.fileno())
|
|
209
|
+
except OSError:
|
|
210
|
+
self._log_windows_warning()
|
|
211
|
+
|
|
212
|
+
def __exit__(
|
|
213
|
+
self,
|
|
214
|
+
exc_type: type[BaseException] | None,
|
|
215
|
+
exc_val: BaseException | None,
|
|
216
|
+
exc_tb: type[BaseException] | None,
|
|
217
|
+
) -> None:
|
|
218
|
+
"""Re-enable terminal output."""
|
|
219
|
+
if not self.disable:
|
|
220
|
+
if self.nothing_file is not None:
|
|
221
|
+
self.nothing_file.close()
|
|
222
|
+
try:
|
|
223
|
+
if self._cpp_stdout_file is not None:
|
|
224
|
+
os.dup2(fd=self._cpp_stdout_file, fd2=sys.stdout.fileno())
|
|
225
|
+
if self._cpp_stderr_file is not None:
|
|
226
|
+
os.dup2(fd=self._cpp_stderr_file, fd2=sys.stderr.fileno())
|
|
227
|
+
except OSError:
|
|
228
|
+
self._log_windows_warning()
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def adjust_logging_level(verbose: bool, ignore_testing: bool = False) -> int:
|
|
232
|
+
"""Adjust the logging level based on verbosity.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
verbose:
|
|
236
|
+
Whether to output additional output.
|
|
237
|
+
ignore_testing:
|
|
238
|
+
Whether to ignore the testing flag.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
The logging level that was set.
|
|
242
|
+
"""
|
|
243
|
+
if hasattr(sys, "_called_from_test") and not ignore_testing:
|
|
244
|
+
logging_level = logging.CRITICAL
|
|
245
|
+
elif verbose:
|
|
246
|
+
logging_level = logging.DEBUG
|
|
247
|
+
else:
|
|
248
|
+
logging_level = logging.INFO
|
|
249
|
+
logger.setLevel(logging_level)
|
|
250
|
+
return logging_level
|
euroeval/metrics/base.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import collections.abc as c
|
|
5
|
-
import logging
|
|
6
5
|
import typing as t
|
|
7
6
|
|
|
8
7
|
if t.TYPE_CHECKING:
|
|
@@ -10,8 +9,6 @@ if t.TYPE_CHECKING:
|
|
|
10
9
|
|
|
11
10
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
12
11
|
|
|
13
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
14
|
-
|
|
15
12
|
|
|
16
13
|
class Metric(abc.ABC):
|
|
17
14
|
"""Abstract base class for all metrics."""
|
euroeval/metrics/huggingface.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""All the Hugging Face metrics used in EuroEval."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
-
import logging
|
|
5
4
|
import typing as t
|
|
6
5
|
from pathlib import Path
|
|
7
6
|
|
|
@@ -9,7 +8,7 @@ import evaluate
|
|
|
9
8
|
import numpy as np
|
|
10
9
|
from datasets import DownloadConfig
|
|
11
10
|
|
|
12
|
-
from ..
|
|
11
|
+
from ..logging_utils import no_terminal_output
|
|
13
12
|
from .base import Metric
|
|
14
13
|
|
|
15
14
|
if t.TYPE_CHECKING:
|
|
@@ -18,8 +17,6 @@ if t.TYPE_CHECKING:
|
|
|
18
17
|
|
|
19
18
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
20
19
|
|
|
21
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
22
|
-
|
|
23
20
|
|
|
24
21
|
class HuggingFaceMetric(Metric):
|
|
25
22
|
"""A metric which is implemented in the `evaluate` package.
|
|
@@ -124,9 +121,12 @@ class HuggingFaceMetric(Metric):
|
|
|
124
121
|
if self.metric is None:
|
|
125
122
|
self.download(cache_dir=benchmark_config.cache_dir)
|
|
126
123
|
|
|
127
|
-
assert self.metric is not None
|
|
124
|
+
assert self.metric is not None, (
|
|
125
|
+
"Metric has not been downloaded. Please call download() before using the "
|
|
126
|
+
"__call__ method."
|
|
127
|
+
)
|
|
128
128
|
|
|
129
|
-
with
|
|
129
|
+
with no_terminal_output(disable=benchmark_config.verbose):
|
|
130
130
|
results = self.metric.compute(
|
|
131
131
|
predictions=predictions, references=references, **self.compute_kwargs
|
|
132
132
|
)
|
|
@@ -143,8 +143,23 @@ class HuggingFaceMetric(Metric):
|
|
|
143
143
|
if isinstance(score, np.floating):
|
|
144
144
|
score = float(score)
|
|
145
145
|
|
|
146
|
+
self.close()
|
|
146
147
|
return score
|
|
147
148
|
|
|
149
|
+
def close(self) -> None:
|
|
150
|
+
"""Close any resources held by the metric."""
|
|
151
|
+
if self.metric is not None:
|
|
152
|
+
if self.metric.filelock is not None:
|
|
153
|
+
self.metric.filelock.release(force=True)
|
|
154
|
+
if self.metric.writer is not None:
|
|
155
|
+
self.metric.writer.finalize(close_stream=True)
|
|
156
|
+
|
|
157
|
+
def __del__(self) -> None:
|
|
158
|
+
"""Clean up the metric from memory."""
|
|
159
|
+
if self.metric is not None:
|
|
160
|
+
self.close()
|
|
161
|
+
del self.metric
|
|
162
|
+
|
|
148
163
|
|
|
149
164
|
mcc_metric = HuggingFaceMetric(
|
|
150
165
|
name="mcc",
|
|
@@ -8,7 +8,7 @@ from pathlib import Path
|
|
|
8
8
|
from pydantic import BaseModel, Field
|
|
9
9
|
|
|
10
10
|
from ..exceptions import InvalidBenchmark
|
|
11
|
-
from ..
|
|
11
|
+
from ..logging_utils import log
|
|
12
12
|
from ..utils import extract_json_dict_from_string
|
|
13
13
|
from .base import Metric
|
|
14
14
|
|
|
@@ -17,8 +17,6 @@ if t.TYPE_CHECKING:
|
|
|
17
17
|
|
|
18
18
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
19
19
|
|
|
20
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
21
|
-
|
|
22
20
|
|
|
23
21
|
class LLMAsAJudgeMetric(Metric):
|
|
24
22
|
"""Use an LLM to judge the quality of the predictions."""
|
|
@@ -112,6 +110,7 @@ class LLMAsAJudgeMetric(Metric):
|
|
|
112
110
|
"""
|
|
113
111
|
# Importing here to avoid circular imports
|
|
114
112
|
from ..benchmark_modules import LiteLLMModel
|
|
113
|
+
from ..model_cache import ModelCache
|
|
115
114
|
|
|
116
115
|
if not predictions or not references:
|
|
117
116
|
return None
|
|
@@ -190,7 +189,10 @@ class LLMAsAJudgeMetric(Metric):
|
|
|
190
189
|
# Calculate the scores using the scoring function
|
|
191
190
|
scores = [self.scoring_fn(output) for output in outputs]
|
|
192
191
|
if not scores:
|
|
193
|
-
|
|
192
|
+
log(
|
|
193
|
+
f"No scores were calculated for {self.pretty_name}.",
|
|
194
|
+
level=logging.WARNING,
|
|
195
|
+
)
|
|
194
196
|
return None
|
|
195
197
|
return sum(scores) / len(scores)
|
|
196
198
|
|
euroeval/metrics/pipeline.py
CHANGED
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
|
11
11
|
from scipy.special import expit as sigmoid
|
|
12
12
|
|
|
13
13
|
from ..exceptions import InvalidBenchmark
|
|
14
|
+
from ..logging_utils import log, no_terminal_output
|
|
14
15
|
from ..utils import unscramble
|
|
15
16
|
from .base import Metric
|
|
16
17
|
|
|
@@ -20,8 +21,6 @@ if t.TYPE_CHECKING:
|
|
|
20
21
|
|
|
21
22
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
22
23
|
|
|
23
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
24
|
-
|
|
25
24
|
|
|
26
25
|
T = t.TypeVar("T", bound=int | float | str | bool)
|
|
27
26
|
|
|
@@ -121,16 +120,22 @@ class PipelineMetric(Metric):
|
|
|
121
120
|
The calculated metric score, or None if the score should be ignored.
|
|
122
121
|
"""
|
|
123
122
|
if self.pipeline is None:
|
|
124
|
-
self.pipeline = self._download_pipeline(
|
|
123
|
+
self.pipeline = self._download_pipeline(
|
|
124
|
+
cache_dir=benchmark_config.cache_dir
|
|
125
|
+
)
|
|
125
126
|
if self.preprocessing_fn is not None:
|
|
126
127
|
predictions = self.preprocessing_fn(
|
|
127
128
|
predictions=predictions, dataset=dataset
|
|
128
129
|
)
|
|
129
130
|
return self.pipeline_scoring_function(self.pipeline, predictions)
|
|
130
131
|
|
|
131
|
-
def _download_pipeline(self) -> "Pipeline":
|
|
132
|
+
def _download_pipeline(self, cache_dir: str) -> "Pipeline":
|
|
132
133
|
"""Download the scikit-learn pipeline from the given URL.
|
|
133
134
|
|
|
135
|
+
Args:
|
|
136
|
+
cache_dir:
|
|
137
|
+
The directory to use for caching the downloaded pipeline.
|
|
138
|
+
|
|
134
139
|
Returns:
|
|
135
140
|
The downloaded scikit-learn pipeline.
|
|
136
141
|
|
|
@@ -138,10 +143,13 @@ class PipelineMetric(Metric):
|
|
|
138
143
|
InvalidBenchmark:
|
|
139
144
|
If the loading of the pipeline fails for any reason.
|
|
140
145
|
"""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
146
|
+
log(f"Loading pipeline from {self.pipeline_repo}...", level=logging.DEBUG)
|
|
147
|
+
with no_terminal_output():
|
|
148
|
+
folder_path = hf_hub.HfApi(
|
|
149
|
+
token=unscramble("XbjeOLhwebEaSaDUMqqaPaPIhgOcyOfDpGnX_")
|
|
150
|
+
).snapshot_download(
|
|
151
|
+
repo_id=self.pipeline_repo, repo_type="model", cache_dir=cache_dir
|
|
152
|
+
)
|
|
145
153
|
model_path = Path(folder_path, self.pipeline_file_name)
|
|
146
154
|
try:
|
|
147
155
|
with model_path.open(mode="rb") as f:
|
|
@@ -150,7 +158,7 @@ class PipelineMetric(Metric):
|
|
|
150
158
|
raise InvalidBenchmark(
|
|
151
159
|
f"Failed to load pipeline from {self.pipeline_repo!r}: {e}"
|
|
152
160
|
) from e
|
|
153
|
-
|
|
161
|
+
log(f"Successfully loaded pipeline: {pipeline}", level=logging.DEBUG)
|
|
154
162
|
return pipeline
|
|
155
163
|
|
|
156
164
|
|
euroeval/metrics/speed.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Inference speed metric."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
-
import logging
|
|
5
4
|
import typing as t
|
|
6
5
|
|
|
7
6
|
from .base import Metric
|
|
@@ -11,8 +10,6 @@ if t.TYPE_CHECKING:
|
|
|
11
10
|
|
|
12
11
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
13
12
|
|
|
14
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
15
|
-
|
|
16
13
|
|
|
17
14
|
class SpeedMetric(Metric):
|
|
18
15
|
"""Speed metric."""
|
euroeval/model_cache.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""ModelCache class for caching model outputs."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import hashlib
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
@@ -8,11 +9,9 @@ import typing as t
|
|
|
8
9
|
from collections import defaultdict
|
|
9
10
|
from dataclasses import asdict
|
|
10
11
|
|
|
11
|
-
from tqdm.auto import tqdm
|
|
12
|
-
|
|
13
12
|
from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
|
|
14
13
|
from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
|
|
15
|
-
from .
|
|
14
|
+
from .logging_utils import get_pbar, log, log_once
|
|
16
15
|
|
|
17
16
|
if t.TYPE_CHECKING:
|
|
18
17
|
from pathlib import Path
|
|
@@ -20,9 +19,6 @@ if t.TYPE_CHECKING:
|
|
|
20
19
|
from datasets import Dataset
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
logger = logging.getLogger("euroeval")
|
|
24
|
-
|
|
25
|
-
|
|
26
22
|
class ModelCache:
|
|
27
23
|
"""A cache for model outputs.
|
|
28
24
|
|
|
@@ -65,9 +61,10 @@ class ModelCache:
|
|
|
65
61
|
with self.cache_path.open() as f:
|
|
66
62
|
json_cache = json.load(f)
|
|
67
63
|
except json.JSONDecodeError:
|
|
68
|
-
|
|
64
|
+
log(
|
|
69
65
|
f"Failed to load the cache from {self.cache_path}. The cache will be "
|
|
70
|
-
f"re-initialised."
|
|
66
|
+
f"re-initialised.",
|
|
67
|
+
level=logging.WARNING,
|
|
71
68
|
)
|
|
72
69
|
json_cache = dict()
|
|
73
70
|
with self.cache_path.open("w") as f:
|
|
@@ -89,15 +86,16 @@ class ModelCache:
|
|
|
89
86
|
with self.cache_path.open("w") as f:
|
|
90
87
|
json.dump(dumpable_cache, f)
|
|
91
88
|
except KeyError:
|
|
92
|
-
|
|
89
|
+
log(
|
|
93
90
|
f"Failed to load the cache from {self.cache_path}. The cache will be "
|
|
94
|
-
f"re-initialised."
|
|
91
|
+
f"re-initialised.",
|
|
92
|
+
level=logging.WARNING,
|
|
95
93
|
)
|
|
96
94
|
self.cache = dict()
|
|
97
95
|
with self.cache_path.open("w") as f:
|
|
98
96
|
json.dump(dict(), f)
|
|
99
97
|
|
|
100
|
-
def _hash_key(self, key: str |
|
|
98
|
+
def _hash_key(self, key: str | c.Sequence[dict[str, str]]) -> str:
|
|
101
99
|
"""Hash the key to use as an index in the cache.
|
|
102
100
|
|
|
103
101
|
Args:
|
|
@@ -110,7 +108,7 @@ class ModelCache:
|
|
|
110
108
|
return hashlib.md5(string=str(key).encode()).hexdigest()
|
|
111
109
|
|
|
112
110
|
def __getitem__(
|
|
113
|
-
self, key: str |
|
|
111
|
+
self, key: str | c.Sequence[dict[str, str]]
|
|
114
112
|
) -> SingleGenerativeModelOutput:
|
|
115
113
|
"""Get an item from the cache.
|
|
116
114
|
|
|
@@ -125,7 +123,7 @@ class ModelCache:
|
|
|
125
123
|
return self.cache[hashed_key]
|
|
126
124
|
|
|
127
125
|
def __setitem__(
|
|
128
|
-
self, key: str |
|
|
126
|
+
self, key: str | c.Sequence[dict[str, str]], value: SingleGenerativeModelOutput
|
|
129
127
|
) -> None:
|
|
130
128
|
"""Set an item in the cache.
|
|
131
129
|
|
|
@@ -143,7 +141,7 @@ class ModelCache:
|
|
|
143
141
|
self.cache_path.unlink()
|
|
144
142
|
del self.cache
|
|
145
143
|
|
|
146
|
-
def __contains__(self, key: str |
|
|
144
|
+
def __contains__(self, key: str | c.Sequence[dict[str, str]]) -> bool:
|
|
147
145
|
"""Check if a key is in the cache.
|
|
148
146
|
|
|
149
147
|
Args:
|
|
@@ -172,18 +170,18 @@ class ModelCache:
|
|
|
172
170
|
|
|
173
171
|
# Double check that the number of inputs and outputs match
|
|
174
172
|
if not len(model_inputs) == len(model_output.sequences):
|
|
175
|
-
|
|
173
|
+
log(
|
|
176
174
|
f"Number of model inputs ({len(model_inputs)}) does not match the "
|
|
177
175
|
f"number of model outputs ({len(model_output.sequences)}). We will not "
|
|
178
|
-
f"cache the model outputs."
|
|
176
|
+
f"cache the model outputs.",
|
|
177
|
+
level=logging.WARNING,
|
|
179
178
|
)
|
|
180
179
|
return
|
|
181
180
|
|
|
182
181
|
# Store the generated sequences in the cache, one by one
|
|
183
|
-
with
|
|
182
|
+
with get_pbar(
|
|
184
183
|
iterable=model_inputs,
|
|
185
184
|
desc="Caching model outputs",
|
|
186
|
-
leave=False,
|
|
187
185
|
disable=hasattr(sys, "_called_from_test"),
|
|
188
186
|
) as pbar:
|
|
189
187
|
for sample_idx, model_input in enumerate(pbar):
|
|
@@ -261,7 +259,7 @@ def load_cached_model_outputs(
|
|
|
261
259
|
The model output containing the cached sequences.
|
|
262
260
|
"""
|
|
263
261
|
input_column = "messages" if "messages" in cached_dataset.column_names else "text"
|
|
264
|
-
cached_model_outputs:
|
|
262
|
+
cached_model_outputs: c.Sequence[SingleGenerativeModelOutput] = [
|
|
265
263
|
cache[prompt] for prompt in cached_dataset[input_column]
|
|
266
264
|
]
|
|
267
265
|
|
euroeval/model_config.py
CHANGED
|
@@ -5,14 +5,12 @@ import typing as t
|
|
|
5
5
|
|
|
6
6
|
from . import benchmark_modules
|
|
7
7
|
from .exceptions import InvalidModel, NeedsEnvironmentVariable, NeedsExtraInstalled
|
|
8
|
+
from .logging_utils import log
|
|
8
9
|
|
|
9
10
|
if t.TYPE_CHECKING:
|
|
10
11
|
from .data_models import BenchmarkConfig, ModelConfig
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
logger = logging.getLogger("euroeval")
|
|
14
|
-
|
|
15
|
-
|
|
16
14
|
def get_model_config(
|
|
17
15
|
model_id: str, benchmark_config: "BenchmarkConfig"
|
|
18
16
|
) -> "ModelConfig":
|
|
@@ -51,9 +49,10 @@ def get_model_config(
|
|
|
51
49
|
elif isinstance(exists_or_err, NeedsEnvironmentVariable):
|
|
52
50
|
needs_env_vars.append(exists_or_err.env_var)
|
|
53
51
|
elif exists_or_err is True:
|
|
54
|
-
|
|
52
|
+
log(
|
|
55
53
|
f"The model {model_id!r} was identified by the "
|
|
56
|
-
f"{benchmark_module.__name__} benchmark module."
|
|
54
|
+
f"{benchmark_module.__name__} benchmark module.",
|
|
55
|
+
logging.DEBUG,
|
|
57
56
|
)
|
|
58
57
|
model_config = benchmark_module.get_model_config(
|
|
59
58
|
model_id=model_id, benchmark_config=benchmark_config
|
euroeval/model_loading.py
CHANGED
|
@@ -10,6 +10,7 @@ from .benchmark_modules import (
|
|
|
10
10
|
)
|
|
11
11
|
from .enums import InferenceBackend, ModelType
|
|
12
12
|
from .exceptions import InvalidModel
|
|
13
|
+
from .logging_utils import log_once
|
|
13
14
|
|
|
14
15
|
if t.TYPE_CHECKING:
|
|
15
16
|
from .benchmark_modules import BenchmarkModule
|
|
@@ -34,6 +35,8 @@ def load_model(
|
|
|
34
35
|
Returns:
|
|
35
36
|
The model.
|
|
36
37
|
"""
|
|
38
|
+
log_once(f"\nLoading the model {model_config.model_id}...")
|
|
39
|
+
|
|
37
40
|
# The order matters; the first model type that matches will be used. For this
|
|
38
41
|
# reason, they have been ordered in terms of the most common model types.
|
|
39
42
|
model_class: t.Type[BenchmarkModule]
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
"""The different prompt templates used in EuroEval."""
|
|
2
2
|
|
|
3
|
+
from .classification import CLASSIFICATION_TEMPLATES
|
|
3
4
|
from .linguistic_acceptability import LA_TEMPLATES
|
|
4
5
|
from .multiple_choice import MULTIPLE_CHOICE_TEMPLATES
|
|
5
6
|
from .named_entity_recognition import NER_TEMPLATES
|
|
6
7
|
from .reading_comprehension import RC_TEMPLATES
|
|
7
8
|
from .sentiment_classification import SENT_TEMPLATES
|
|
8
9
|
from .summarization import SUMM_TEMPLATES
|
|
10
|
+
from .token_classification import TOKEN_CLASSIFICATION_TEMPLATES
|