EuroEval 16.3.0__py3-none-any.whl → 16.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +3 -2
- euroeval/benchmark_config_factory.py +0 -4
- euroeval/benchmark_modules/base.py +3 -16
- euroeval/benchmark_modules/fresh.py +2 -1
- euroeval/benchmark_modules/hf.py +99 -62
- euroeval/benchmark_modules/litellm.py +101 -41
- euroeval/benchmark_modules/vllm.py +91 -83
- euroeval/benchmarker.py +84 -78
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/constants.py +6 -0
- euroeval/data_loading.py +14 -11
- euroeval/data_models.py +12 -4
- euroeval/dataset_configs/__init__.py +2 -0
- euroeval/dataset_configs/czech.py +79 -0
- euroeval/dataset_configs/danish.py +10 -11
- euroeval/dataset_configs/dutch.py +0 -1
- euroeval/dataset_configs/english.py +0 -1
- euroeval/dataset_configs/estonian.py +11 -1
- euroeval/dataset_configs/finnish.py +0 -1
- euroeval/dataset_configs/french.py +0 -1
- euroeval/dataset_configs/german.py +0 -1
- euroeval/dataset_configs/italian.py +0 -1
- euroeval/dataset_configs/latvian.py +0 -1
- euroeval/dataset_configs/lithuanian.py +9 -3
- euroeval/dataset_configs/norwegian.py +0 -1
- euroeval/dataset_configs/polish.py +0 -1
- euroeval/dataset_configs/portuguese.py +0 -1
- euroeval/dataset_configs/slovak.py +60 -0
- euroeval/dataset_configs/spanish.py +0 -1
- euroeval/dataset_configs/swedish.py +10 -12
- euroeval/finetuning.py +21 -15
- euroeval/generation.py +10 -10
- euroeval/generation_utils.py +2 -3
- euroeval/logging_utils.py +250 -0
- euroeval/metrics/base.py +0 -3
- euroeval/metrics/huggingface.py +9 -5
- euroeval/metrics/llm_as_a_judge.py +5 -3
- euroeval/metrics/pipeline.py +17 -9
- euroeval/metrics/speed.py +0 -3
- euroeval/model_cache.py +11 -14
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/linguistic_acceptability.py +21 -3
- euroeval/prompt_templates/multiple_choice.py +25 -1
- euroeval/prompt_templates/named_entity_recognition.py +51 -11
- euroeval/prompt_templates/reading_comprehension.py +31 -3
- euroeval/prompt_templates/sentiment_classification.py +23 -1
- euroeval/prompt_templates/summarization.py +26 -6
- euroeval/scores.py +7 -7
- euroeval/speed_benchmark.py +3 -5
- euroeval/task_group_utils/multiple_choice_classification.py +0 -3
- euroeval/task_group_utils/question_answering.py +0 -3
- euroeval/task_group_utils/sequence_classification.py +43 -31
- euroeval/task_group_utils/text_to_text.py +17 -8
- euroeval/task_group_utils/token_classification.py +10 -9
- euroeval/tokenisation_utils.py +14 -12
- euroeval/utils.py +29 -146
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/METADATA +4 -4
- euroeval-16.4.0.dist-info/RECORD +75 -0
- euroeval-16.3.0.dist-info/RECORD +0 -71
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Utility functions related to logging."""
|
|
2
|
+
|
|
3
|
+
import datetime as dt
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import warnings
|
|
8
|
+
from io import TextIOWrapper
|
|
9
|
+
|
|
10
|
+
import litellm
|
|
11
|
+
from datasets.utils import disable_progress_bars as disable_datasets_progress_bars
|
|
12
|
+
from evaluate import disable_progress_bar as disable_evaluate_progress_bar
|
|
13
|
+
from huggingface_hub.utils.tqdm import (
|
|
14
|
+
disable_progress_bars as disable_hf_hub_progress_bars,
|
|
15
|
+
)
|
|
16
|
+
from termcolor import colored
|
|
17
|
+
from tqdm.auto import tqdm
|
|
18
|
+
from transformers import logging as tf_logging
|
|
19
|
+
|
|
20
|
+
from .caching_utils import cache_arguments
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger("euroeval")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_pbar(*tqdm_args, **tqdm_kwargs) -> tqdm:
|
|
26
|
+
"""Get a progress bar for vLLM with custom hard-coded arguments.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
*tqdm_args:
|
|
30
|
+
Positional arguments to pass to tqdm.
|
|
31
|
+
**tqdm_kwargs:
|
|
32
|
+
Additional keyword arguments to pass to tqdm.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A tqdm progress bar.
|
|
36
|
+
"""
|
|
37
|
+
tqdm_kwargs = dict(colour="yellow", ascii="—▰", leave=False) | tqdm_kwargs
|
|
38
|
+
tqdm_kwargs["desc"] = colored(
|
|
39
|
+
text=tqdm_kwargs.get("desc", "Processing"), color="light_yellow"
|
|
40
|
+
)
|
|
41
|
+
return tqdm(*tqdm_args, **tqdm_kwargs)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def log(message: str, level: int, colour: str | None = None) -> None:
|
|
45
|
+
"""Log a message.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
message:
|
|
49
|
+
The message to log.
|
|
50
|
+
level:
|
|
51
|
+
The logging level. Defaults to logging.INFO.
|
|
52
|
+
colour:
|
|
53
|
+
The colour to use for the message. If None, a default colour will be used
|
|
54
|
+
based on the logging level.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError:
|
|
58
|
+
If the logging level is invalid.
|
|
59
|
+
"""
|
|
60
|
+
match level:
|
|
61
|
+
case logging.DEBUG:
|
|
62
|
+
message = colored(
|
|
63
|
+
text=(
|
|
64
|
+
"[DEBUG] "
|
|
65
|
+
+ dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
66
|
+
+ f" · {message}"
|
|
67
|
+
),
|
|
68
|
+
color=colour or "light_blue",
|
|
69
|
+
)
|
|
70
|
+
logger.debug(message)
|
|
71
|
+
case logging.INFO:
|
|
72
|
+
if colour is not None:
|
|
73
|
+
message = colored(text=message, color=colour)
|
|
74
|
+
logger.info(message)
|
|
75
|
+
case logging.WARNING:
|
|
76
|
+
message = colored(text=message, color=colour or "light_red")
|
|
77
|
+
logger.warning(message)
|
|
78
|
+
case logging.ERROR:
|
|
79
|
+
message = colored(text=message, color=colour or "red")
|
|
80
|
+
logger.error(message)
|
|
81
|
+
case logging.CRITICAL:
|
|
82
|
+
message = colored(text=message, color=colour or "red")
|
|
83
|
+
logger.critical(message)
|
|
84
|
+
case _:
|
|
85
|
+
raise ValueError(f"Invalid logging level: {level}")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@cache_arguments("message")
|
|
89
|
+
def log_once(message: str, level: int = logging.INFO, prefix: str = "") -> None:
|
|
90
|
+
"""Log a message once.
|
|
91
|
+
|
|
92
|
+
This is ensured by caching the "message" argument and only logging it the first time
|
|
93
|
+
this function is called with that message.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
message:
|
|
97
|
+
The message to log.
|
|
98
|
+
level:
|
|
99
|
+
The logging level. Defaults to logging.INFO.
|
|
100
|
+
prefix:
|
|
101
|
+
A prefix to add to the message, which is not considered when determining if
|
|
102
|
+
the message has been logged before.
|
|
103
|
+
"""
|
|
104
|
+
log(message=prefix + message, level=level)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def block_terminal_output() -> None:
|
|
108
|
+
"""Blocks libraries from writing output to the terminal.
|
|
109
|
+
|
|
110
|
+
This filters warnings from some libraries, sets the logging level to ERROR for some
|
|
111
|
+
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
|
|
112
|
+
disables most of the logging from the `transformers` library.
|
|
113
|
+
"""
|
|
114
|
+
if os.getenv("FULL_LOG") == "1":
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
# Ignore miscellaneous warnings
|
|
118
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
119
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
120
|
+
logging.getLogger("absl").setLevel(logging.CRITICAL)
|
|
121
|
+
|
|
122
|
+
# Disable matplotlib logging
|
|
123
|
+
logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
|
|
124
|
+
|
|
125
|
+
# Disable PyTorch logging
|
|
126
|
+
logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
|
|
127
|
+
warnings.filterwarnings(action="ignore", module="torch*")
|
|
128
|
+
os.environ["TORCH_LOGS"] = "-all"
|
|
129
|
+
|
|
130
|
+
# Disable huggingface_hub logging
|
|
131
|
+
logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
|
|
132
|
+
disable_hf_hub_progress_bars()
|
|
133
|
+
|
|
134
|
+
# Disable LiteLLM logging
|
|
135
|
+
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
|
136
|
+
logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
|
|
137
|
+
logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
|
|
138
|
+
logging.getLogger("openai").setLevel(logging.CRITICAL)
|
|
139
|
+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
140
|
+
litellm.suppress_debug_info = True
|
|
141
|
+
|
|
142
|
+
# Disable vLLM logging
|
|
143
|
+
logging.getLogger("vllm").setLevel(logging.CRITICAL)
|
|
144
|
+
logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
|
|
145
|
+
logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
|
|
146
|
+
logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
|
|
147
|
+
logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
|
|
148
|
+
logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
|
|
149
|
+
logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
|
|
150
|
+
logging.CRITICAL
|
|
151
|
+
)
|
|
152
|
+
os.environ["LOG_LEVEL"] = "CRITICAL"
|
|
153
|
+
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
154
|
+
|
|
155
|
+
# Disable flashinfer logging
|
|
156
|
+
os.environ["FLASHINFER_LOGGING_LEVEL"] = "CRITICAL"
|
|
157
|
+
|
|
158
|
+
# Disable datasets logging
|
|
159
|
+
logging.getLogger("datasets").setLevel(logging.CRITICAL)
|
|
160
|
+
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
|
161
|
+
disable_datasets_progress_bars()
|
|
162
|
+
|
|
163
|
+
# Disable evaluate logging
|
|
164
|
+
warnings.filterwarnings("ignore", module="seqeval*")
|
|
165
|
+
disable_evaluate_progress_bar()
|
|
166
|
+
|
|
167
|
+
# Disable most of the `transformers` logging
|
|
168
|
+
tf_logging._default_log_level = logging.CRITICAL
|
|
169
|
+
tf_logging.set_verbosity(logging.CRITICAL)
|
|
170
|
+
logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
|
|
171
|
+
logging.getLogger("accelerate").setLevel(logging.CRITICAL)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class no_terminal_output:
|
|
175
|
+
"""Context manager that suppresses all terminal output."""
|
|
176
|
+
|
|
177
|
+
def __init__(self, disable: bool = False) -> None:
|
|
178
|
+
"""Initialise the context manager.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
disable:
|
|
182
|
+
If True, this context manager does nothing.
|
|
183
|
+
"""
|
|
184
|
+
self.disable = disable
|
|
185
|
+
self.nothing_file: TextIOWrapper | None = None
|
|
186
|
+
self._cpp_stdout_file: int | None = None
|
|
187
|
+
self._cpp_stderr_file: int | None = None
|
|
188
|
+
try:
|
|
189
|
+
self._cpp_stdout_file = os.dup(sys.stdout.fileno())
|
|
190
|
+
self._cpp_stderr_file = os.dup(sys.stderr.fileno())
|
|
191
|
+
except OSError:
|
|
192
|
+
self._log_windows_warning()
|
|
193
|
+
|
|
194
|
+
def _log_windows_warning(self) -> None:
|
|
195
|
+
"""Log a warning about Windows not supporting blocking terminal output."""
|
|
196
|
+
log_once(
|
|
197
|
+
"Your operating system (probably Windows) does not support blocking "
|
|
198
|
+
"terminal output, so expect more messy output - sorry!",
|
|
199
|
+
level=logging.WARNING,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
def __enter__(self) -> None:
|
|
203
|
+
"""Suppress all terminal output."""
|
|
204
|
+
if not self.disable:
|
|
205
|
+
self.nothing_file = open(os.devnull, "w")
|
|
206
|
+
try:
|
|
207
|
+
os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stdout.fileno())
|
|
208
|
+
os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stderr.fileno())
|
|
209
|
+
except OSError:
|
|
210
|
+
self._log_windows_warning()
|
|
211
|
+
|
|
212
|
+
def __exit__(
|
|
213
|
+
self,
|
|
214
|
+
exc_type: type[BaseException] | None,
|
|
215
|
+
exc_val: BaseException | None,
|
|
216
|
+
exc_tb: type[BaseException] | None,
|
|
217
|
+
) -> None:
|
|
218
|
+
"""Re-enable terminal output."""
|
|
219
|
+
if not self.disable:
|
|
220
|
+
if self.nothing_file is not None:
|
|
221
|
+
self.nothing_file.close()
|
|
222
|
+
try:
|
|
223
|
+
if self._cpp_stdout_file is not None:
|
|
224
|
+
os.dup2(fd=self._cpp_stdout_file, fd2=sys.stdout.fileno())
|
|
225
|
+
if self._cpp_stderr_file is not None:
|
|
226
|
+
os.dup2(fd=self._cpp_stderr_file, fd2=sys.stderr.fileno())
|
|
227
|
+
except OSError:
|
|
228
|
+
self._log_windows_warning()
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def adjust_logging_level(verbose: bool, ignore_testing: bool = False) -> int:
|
|
232
|
+
"""Adjust the logging level based on verbosity.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
verbose:
|
|
236
|
+
Whether to output additional output.
|
|
237
|
+
ignore_testing:
|
|
238
|
+
Whether to ignore the testing flag.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
The logging level that was set.
|
|
242
|
+
"""
|
|
243
|
+
if hasattr(sys, "_called_from_test") and not ignore_testing:
|
|
244
|
+
logging_level = logging.CRITICAL
|
|
245
|
+
elif verbose:
|
|
246
|
+
logging_level = logging.DEBUG
|
|
247
|
+
else:
|
|
248
|
+
logging_level = logging.INFO
|
|
249
|
+
logger.setLevel(logging_level)
|
|
250
|
+
return logging_level
|
euroeval/metrics/base.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import abc
|
|
4
4
|
import collections.abc as c
|
|
5
|
-
import logging
|
|
6
5
|
import typing as t
|
|
7
6
|
|
|
8
7
|
if t.TYPE_CHECKING:
|
|
@@ -10,8 +9,6 @@ if t.TYPE_CHECKING:
|
|
|
10
9
|
|
|
11
10
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
12
11
|
|
|
13
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
14
|
-
|
|
15
12
|
|
|
16
13
|
class Metric(abc.ABC):
|
|
17
14
|
"""Abstract base class for all metrics."""
|
euroeval/metrics/huggingface.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""All the Hugging Face metrics used in EuroEval."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
-
import logging
|
|
5
4
|
import typing as t
|
|
6
5
|
from pathlib import Path
|
|
7
6
|
|
|
@@ -9,7 +8,7 @@ import evaluate
|
|
|
9
8
|
import numpy as np
|
|
10
9
|
from datasets import DownloadConfig
|
|
11
10
|
|
|
12
|
-
from ..
|
|
11
|
+
from ..logging_utils import no_terminal_output
|
|
13
12
|
from .base import Metric
|
|
14
13
|
|
|
15
14
|
if t.TYPE_CHECKING:
|
|
@@ -18,8 +17,6 @@ if t.TYPE_CHECKING:
|
|
|
18
17
|
|
|
19
18
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
20
19
|
|
|
21
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
22
|
-
|
|
23
20
|
|
|
24
21
|
class HuggingFaceMetric(Metric):
|
|
25
22
|
"""A metric which is implemented in the `evaluate` package.
|
|
@@ -126,7 +123,7 @@ class HuggingFaceMetric(Metric):
|
|
|
126
123
|
|
|
127
124
|
assert self.metric is not None
|
|
128
125
|
|
|
129
|
-
with
|
|
126
|
+
with no_terminal_output(disable=benchmark_config.verbose):
|
|
130
127
|
results = self.metric.compute(
|
|
131
128
|
predictions=predictions, references=references, **self.compute_kwargs
|
|
132
129
|
)
|
|
@@ -145,6 +142,13 @@ class HuggingFaceMetric(Metric):
|
|
|
145
142
|
|
|
146
143
|
return score
|
|
147
144
|
|
|
145
|
+
def __del__(self) -> None:
|
|
146
|
+
"""Clean up the metric from memory."""
|
|
147
|
+
if self.metric is not None:
|
|
148
|
+
if self.metric.writer is not None:
|
|
149
|
+
self.metric.writer.close()
|
|
150
|
+
del self.metric
|
|
151
|
+
|
|
148
152
|
|
|
149
153
|
mcc_metric = HuggingFaceMetric(
|
|
150
154
|
name="mcc",
|
|
@@ -8,6 +8,7 @@ from pathlib import Path
|
|
|
8
8
|
from pydantic import BaseModel, Field
|
|
9
9
|
|
|
10
10
|
from ..exceptions import InvalidBenchmark
|
|
11
|
+
from ..logging_utils import log
|
|
11
12
|
from ..model_cache import ModelCache
|
|
12
13
|
from ..utils import extract_json_dict_from_string
|
|
13
14
|
from .base import Metric
|
|
@@ -17,8 +18,6 @@ if t.TYPE_CHECKING:
|
|
|
17
18
|
|
|
18
19
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
19
20
|
|
|
20
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
21
|
-
|
|
22
21
|
|
|
23
22
|
class LLMAsAJudgeMetric(Metric):
|
|
24
23
|
"""Use an LLM to judge the quality of the predictions."""
|
|
@@ -190,7 +189,10 @@ class LLMAsAJudgeMetric(Metric):
|
|
|
190
189
|
# Calculate the scores using the scoring function
|
|
191
190
|
scores = [self.scoring_fn(output) for output in outputs]
|
|
192
191
|
if not scores:
|
|
193
|
-
|
|
192
|
+
log(
|
|
193
|
+
f"No scores were calculated for {self.pretty_name}.",
|
|
194
|
+
level=logging.WARNING,
|
|
195
|
+
)
|
|
194
196
|
return None
|
|
195
197
|
return sum(scores) / len(scores)
|
|
196
198
|
|
euroeval/metrics/pipeline.py
CHANGED
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
|
11
11
|
from scipy.special import expit as sigmoid
|
|
12
12
|
|
|
13
13
|
from ..exceptions import InvalidBenchmark
|
|
14
|
+
from ..logging_utils import log, no_terminal_output
|
|
14
15
|
from ..utils import unscramble
|
|
15
16
|
from .base import Metric
|
|
16
17
|
|
|
@@ -20,8 +21,6 @@ if t.TYPE_CHECKING:
|
|
|
20
21
|
|
|
21
22
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
22
23
|
|
|
23
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
24
|
-
|
|
25
24
|
|
|
26
25
|
T = t.TypeVar("T", bound=int | float | str | bool)
|
|
27
26
|
|
|
@@ -121,16 +120,22 @@ class PipelineMetric(Metric):
|
|
|
121
120
|
The calculated metric score, or None if the score should be ignored.
|
|
122
121
|
"""
|
|
123
122
|
if self.pipeline is None:
|
|
124
|
-
self.pipeline = self._download_pipeline(
|
|
123
|
+
self.pipeline = self._download_pipeline(
|
|
124
|
+
cache_dir=benchmark_config.cache_dir
|
|
125
|
+
)
|
|
125
126
|
if self.preprocessing_fn is not None:
|
|
126
127
|
predictions = self.preprocessing_fn(
|
|
127
128
|
predictions=predictions, dataset=dataset
|
|
128
129
|
)
|
|
129
130
|
return self.pipeline_scoring_function(self.pipeline, predictions)
|
|
130
131
|
|
|
131
|
-
def _download_pipeline(self) -> "Pipeline":
|
|
132
|
+
def _download_pipeline(self, cache_dir: str) -> "Pipeline":
|
|
132
133
|
"""Download the scikit-learn pipeline from the given URL.
|
|
133
134
|
|
|
135
|
+
Args:
|
|
136
|
+
cache_dir:
|
|
137
|
+
The directory to use for caching the downloaded pipeline.
|
|
138
|
+
|
|
134
139
|
Returns:
|
|
135
140
|
The downloaded scikit-learn pipeline.
|
|
136
141
|
|
|
@@ -138,10 +143,13 @@ class PipelineMetric(Metric):
|
|
|
138
143
|
InvalidBenchmark:
|
|
139
144
|
If the loading of the pipeline fails for any reason.
|
|
140
145
|
"""
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
146
|
+
log(f"Loading pipeline from {self.pipeline_repo}...", level=logging.DEBUG)
|
|
147
|
+
with no_terminal_output():
|
|
148
|
+
folder_path = hf_hub.HfApi(
|
|
149
|
+
token=unscramble("XbjeOLhwebEaSaDUMqqaPaPIhgOcyOfDpGnX_")
|
|
150
|
+
).snapshot_download(
|
|
151
|
+
repo_id=self.pipeline_repo, repo_type="model", cache_dir=cache_dir
|
|
152
|
+
)
|
|
145
153
|
model_path = Path(folder_path, self.pipeline_file_name)
|
|
146
154
|
try:
|
|
147
155
|
with model_path.open(mode="rb") as f:
|
|
@@ -150,7 +158,7 @@ class PipelineMetric(Metric):
|
|
|
150
158
|
raise InvalidBenchmark(
|
|
151
159
|
f"Failed to load pipeline from {self.pipeline_repo!r}: {e}"
|
|
152
160
|
) from e
|
|
153
|
-
|
|
161
|
+
log(f"Successfully loaded pipeline: {pipeline}", level=logging.DEBUG)
|
|
154
162
|
return pipeline
|
|
155
163
|
|
|
156
164
|
|
euroeval/metrics/speed.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Inference speed metric."""
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
|
-
import logging
|
|
5
4
|
import typing as t
|
|
6
5
|
|
|
7
6
|
from .base import Metric
|
|
@@ -11,8 +10,6 @@ if t.TYPE_CHECKING:
|
|
|
11
10
|
|
|
12
11
|
from ..data_models import BenchmarkConfig, DatasetConfig
|
|
13
12
|
|
|
14
|
-
logger: logging.Logger = logging.getLogger("euroeval")
|
|
15
|
-
|
|
16
13
|
|
|
17
14
|
class SpeedMetric(Metric):
|
|
18
15
|
"""Speed metric."""
|
euroeval/model_cache.py
CHANGED
|
@@ -8,11 +8,9 @@ import typing as t
|
|
|
8
8
|
from collections import defaultdict
|
|
9
9
|
from dataclasses import asdict
|
|
10
10
|
|
|
11
|
-
from tqdm.auto import tqdm
|
|
12
|
-
|
|
13
11
|
from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
|
|
14
12
|
from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
|
|
15
|
-
from .
|
|
13
|
+
from .logging_utils import get_pbar, log, log_once
|
|
16
14
|
|
|
17
15
|
if t.TYPE_CHECKING:
|
|
18
16
|
from pathlib import Path
|
|
@@ -20,9 +18,6 @@ if t.TYPE_CHECKING:
|
|
|
20
18
|
from datasets import Dataset
|
|
21
19
|
|
|
22
20
|
|
|
23
|
-
logger = logging.getLogger("euroeval")
|
|
24
|
-
|
|
25
|
-
|
|
26
21
|
class ModelCache:
|
|
27
22
|
"""A cache for model outputs.
|
|
28
23
|
|
|
@@ -65,9 +60,10 @@ class ModelCache:
|
|
|
65
60
|
with self.cache_path.open() as f:
|
|
66
61
|
json_cache = json.load(f)
|
|
67
62
|
except json.JSONDecodeError:
|
|
68
|
-
|
|
63
|
+
log(
|
|
69
64
|
f"Failed to load the cache from {self.cache_path}. The cache will be "
|
|
70
|
-
f"re-initialised."
|
|
65
|
+
f"re-initialised.",
|
|
66
|
+
level=logging.WARNING,
|
|
71
67
|
)
|
|
72
68
|
json_cache = dict()
|
|
73
69
|
with self.cache_path.open("w") as f:
|
|
@@ -89,9 +85,10 @@ class ModelCache:
|
|
|
89
85
|
with self.cache_path.open("w") as f:
|
|
90
86
|
json.dump(dumpable_cache, f)
|
|
91
87
|
except KeyError:
|
|
92
|
-
|
|
88
|
+
log(
|
|
93
89
|
f"Failed to load the cache from {self.cache_path}. The cache will be "
|
|
94
|
-
f"re-initialised."
|
|
90
|
+
f"re-initialised.",
|
|
91
|
+
level=logging.WARNING,
|
|
95
92
|
)
|
|
96
93
|
self.cache = dict()
|
|
97
94
|
with self.cache_path.open("w") as f:
|
|
@@ -172,18 +169,18 @@ class ModelCache:
|
|
|
172
169
|
|
|
173
170
|
# Double check that the number of inputs and outputs match
|
|
174
171
|
if not len(model_inputs) == len(model_output.sequences):
|
|
175
|
-
|
|
172
|
+
log(
|
|
176
173
|
f"Number of model inputs ({len(model_inputs)}) does not match the "
|
|
177
174
|
f"number of model outputs ({len(model_output.sequences)}). We will not "
|
|
178
|
-
f"cache the model outputs."
|
|
175
|
+
f"cache the model outputs.",
|
|
176
|
+
level=logging.WARNING,
|
|
179
177
|
)
|
|
180
178
|
return
|
|
181
179
|
|
|
182
180
|
# Store the generated sequences in the cache, one by one
|
|
183
|
-
with
|
|
181
|
+
with get_pbar(
|
|
184
182
|
iterable=model_inputs,
|
|
185
183
|
desc="Caching model outputs",
|
|
186
|
-
leave=False,
|
|
187
184
|
disable=hasattr(sys, "_called_from_test"),
|
|
188
185
|
) as pbar:
|
|
189
186
|
for sample_idx, model_input in enumerate(pbar):
|
euroeval/model_config.py
CHANGED
|
@@ -5,14 +5,12 @@ import typing as t
|
|
|
5
5
|
|
|
6
6
|
from . import benchmark_modules
|
|
7
7
|
from .exceptions import InvalidModel, NeedsEnvironmentVariable, NeedsExtraInstalled
|
|
8
|
+
from .logging_utils import log
|
|
8
9
|
|
|
9
10
|
if t.TYPE_CHECKING:
|
|
10
11
|
from .data_models import BenchmarkConfig, ModelConfig
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
logger = logging.getLogger("euroeval")
|
|
14
|
-
|
|
15
|
-
|
|
16
14
|
def get_model_config(
|
|
17
15
|
model_id: str, benchmark_config: "BenchmarkConfig"
|
|
18
16
|
) -> "ModelConfig":
|
|
@@ -51,9 +49,10 @@ def get_model_config(
|
|
|
51
49
|
elif isinstance(exists_or_err, NeedsEnvironmentVariable):
|
|
52
50
|
needs_env_vars.append(exists_or_err.env_var)
|
|
53
51
|
elif exists_or_err is True:
|
|
54
|
-
|
|
52
|
+
log(
|
|
55
53
|
f"The model {model_id!r} was identified by the "
|
|
56
|
-
f"{benchmark_module.__name__} benchmark module."
|
|
54
|
+
f"{benchmark_module.__name__} benchmark module.",
|
|
55
|
+
logging.DEBUG,
|
|
57
56
|
)
|
|
58
57
|
model_config = benchmark_module.get_model_config(
|
|
59
58
|
model_id=model_id, benchmark_config=benchmark_config
|
euroeval/model_loading.py
CHANGED
|
@@ -10,6 +10,7 @@ from .benchmark_modules import (
|
|
|
10
10
|
)
|
|
11
11
|
from .enums import InferenceBackend, ModelType
|
|
12
12
|
from .exceptions import InvalidModel
|
|
13
|
+
from .logging_utils import log_once
|
|
13
14
|
|
|
14
15
|
if t.TYPE_CHECKING:
|
|
15
16
|
from .benchmark_modules import BenchmarkModule
|
|
@@ -34,6 +35,8 @@ def load_model(
|
|
|
34
35
|
Returns:
|
|
35
36
|
The model.
|
|
36
37
|
"""
|
|
38
|
+
log_once(f"Loading the model {model_config.model_id}...")
|
|
39
|
+
|
|
37
40
|
# The order matters; the first model type that matches will be used. For this
|
|
38
41
|
# reason, they have been ordered in terms of the most common model types.
|
|
39
42
|
model_class: t.Type[BenchmarkModule]
|
|
@@ -4,6 +4,7 @@ import typing as t
|
|
|
4
4
|
|
|
5
5
|
from ..data_models import PromptConfig
|
|
6
6
|
from ..languages import (
|
|
7
|
+
CS,
|
|
7
8
|
DA,
|
|
8
9
|
DE,
|
|
9
10
|
EN,
|
|
@@ -22,6 +23,7 @@ from ..languages import (
|
|
|
22
23
|
NO,
|
|
23
24
|
PL,
|
|
24
25
|
PT,
|
|
26
|
+
SK,
|
|
25
27
|
SV,
|
|
26
28
|
)
|
|
27
29
|
|
|
@@ -29,6 +31,13 @@ if t.TYPE_CHECKING:
|
|
|
29
31
|
from ..data_models import Language
|
|
30
32
|
|
|
31
33
|
LA_TEMPLATES: dict["Language", PromptConfig] = {
|
|
34
|
+
CS: PromptConfig(
|
|
35
|
+
default_prompt_label_mapping=dict(correct="ano", incorrect="ne"),
|
|
36
|
+
default_prompt_prefix="Následující jsou věty a zda jsou gramaticky správné.",
|
|
37
|
+
default_prompt_template="Věta: {text}\nGramaticky správná: {label}",
|
|
38
|
+
default_instruction_prompt="Věta: {text}\n\nUrčete, zda je věta gramaticky "
|
|
39
|
+
"správná nebo ne. Odpovězte {labels_str}, a nic jiné.",
|
|
40
|
+
),
|
|
32
41
|
DA: PromptConfig(
|
|
33
42
|
default_prompt_label_mapping=dict(correct="ja", incorrect="nej"),
|
|
34
43
|
default_prompt_prefix="Følgende er sætninger og om de er grammatisk korrekte.",
|
|
@@ -71,11 +80,11 @@ LA_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
71
80
|
),
|
|
72
81
|
PL: PromptConfig(
|
|
73
82
|
default_prompt_label_mapping=dict(correct="tak", incorrect="nie"),
|
|
74
|
-
default_prompt_prefix="Poniżej znajdują się teksty i czy są "
|
|
83
|
+
default_prompt_prefix="Poniżej znajdują się teksty i informacja, czy są "
|
|
75
84
|
"gramatycznie poprawne.",
|
|
76
85
|
default_prompt_template="Tekst: {text}\nGramatycznie poprawny: {label}",
|
|
77
|
-
default_instruction_prompt="Tekst: {text}\n\nOkreśl czy tekst jest "
|
|
78
|
-
"gramatycznie poprawny
|
|
86
|
+
default_instruction_prompt="Tekst: {text}\n\nOkreśl, czy tekst jest "
|
|
87
|
+
"gramatycznie poprawny. Odpowiedz używając wyłącznie {labels_str}.",
|
|
79
88
|
),
|
|
80
89
|
PT: PromptConfig(
|
|
81
90
|
default_prompt_label_mapping=dict(correct="sim", incorrect="não"),
|
|
@@ -174,6 +183,15 @@ LA_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
174
183
|
default_instruction_prompt="Setning: {text}\n\nBestem om setningen er "
|
|
175
184
|
"grammatisk korrekt eller ikke. Svar med {labels_str}, og ikke noe annet.",
|
|
176
185
|
),
|
|
186
|
+
SK: PromptConfig(
|
|
187
|
+
default_prompt_label_mapping=dict(correct="áno", incorrect="nie"),
|
|
188
|
+
default_prompt_prefix="Nasledujú vety a či sú gramaticky správne.",
|
|
189
|
+
default_prompt_template="Veta: {text}\nGramaticky správna: {label}",
|
|
190
|
+
default_instruction_prompt=(
|
|
191
|
+
"Veta: {text}\n\nUrčite, či je veta gramaticky správna alebo nie. "
|
|
192
|
+
"Odpovedzte so {labels_str}, a nič iné."
|
|
193
|
+
),
|
|
194
|
+
),
|
|
177
195
|
SV: PromptConfig(
|
|
178
196
|
default_prompt_label_mapping=dict(correct="ja", incorrect="nej"),
|
|
179
197
|
default_prompt_prefix="Följande är meningar och huruvida de är grammatiskt "
|
|
@@ -4,6 +4,7 @@ import typing as t
|
|
|
4
4
|
|
|
5
5
|
from ..data_models import PromptConfig
|
|
6
6
|
from ..languages import (
|
|
7
|
+
CS,
|
|
7
8
|
DA,
|
|
8
9
|
DE,
|
|
9
10
|
EN,
|
|
@@ -21,6 +22,7 @@ from ..languages import (
|
|
|
21
22
|
NO,
|
|
22
23
|
PL,
|
|
23
24
|
PT,
|
|
25
|
+
SK,
|
|
24
26
|
SV,
|
|
25
27
|
)
|
|
26
28
|
|
|
@@ -29,6 +31,17 @@ if t.TYPE_CHECKING:
|
|
|
29
31
|
|
|
30
32
|
# TODO: Missing Faroese
|
|
31
33
|
MULTIPLE_CHOICE_TEMPLATES: dict["Language", PromptConfig] = {
|
|
34
|
+
CS: PromptConfig(
|
|
35
|
+
default_prompt_prefix=(
|
|
36
|
+
"Následující jsou otázky s výběrem z více možností (s odpověďmi)."
|
|
37
|
+
),
|
|
38
|
+
default_prompt_template="Otázka: {text}\nOdpověď: {label}",
|
|
39
|
+
default_instruction_prompt=(
|
|
40
|
+
"Otázka: {text}\n\nOdpovězte na výše uvedenou otázku "
|
|
41
|
+
"pomocí {labels_str}, a nic jiného."
|
|
42
|
+
),
|
|
43
|
+
default_prompt_label_mapping="auto",
|
|
44
|
+
),
|
|
32
45
|
DA: PromptConfig(
|
|
33
46
|
default_prompt_prefix="Følgende er multiple choice spørgsmål (med svar).",
|
|
34
47
|
default_prompt_template="Spørgsmål: {text}\nSvar: {label}",
|
|
@@ -155,7 +168,18 @@ MULTIPLE_CHOICE_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
155
168
|
"(z odpowiedziami).",
|
|
156
169
|
default_prompt_template="Pytanie: {text}\nOdpowiedź: {label}",
|
|
157
170
|
default_instruction_prompt="Pytanie: {text}\n\nOdpowiedz na powyższe pytanie, "
|
|
158
|
-
"
|
|
171
|
+
"używając {labels_str} i niczego więcej.",
|
|
172
|
+
default_prompt_label_mapping="auto",
|
|
173
|
+
),
|
|
174
|
+
SK: PromptConfig(
|
|
175
|
+
default_prompt_prefix=(
|
|
176
|
+
"Nasledujú otázky s viacerými možnosťami (s odpoveďami)."
|
|
177
|
+
),
|
|
178
|
+
default_prompt_template="Otázka: {text}\nOdpoveď: {label}",
|
|
179
|
+
default_instruction_prompt=(
|
|
180
|
+
"Otázka: {text}\n\n"
|
|
181
|
+
"Odpovedzte na nasledujúcu otázku použitím {labels_str}, a nič iné."
|
|
182
|
+
),
|
|
159
183
|
default_prompt_label_mapping="auto",
|
|
160
184
|
),
|
|
161
185
|
SV: PromptConfig(
|