EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
euroeval/__init__.py
CHANGED
|
@@ -12,14 +12,17 @@ import warnings
|
|
|
12
12
|
from termcolor import colored
|
|
13
13
|
|
|
14
14
|
# Block specific warnings before importing anything else, as they can be noisy
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
logging.getLogger("
|
|
19
|
-
|
|
15
|
+
if os.getenv("FULL_LOG") != "1":
|
|
16
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
17
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
18
|
+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
19
|
+
logging.getLogger("datasets").setLevel(logging.CRITICAL)
|
|
20
|
+
logging.getLogger("vllm").setLevel(logging.CRITICAL)
|
|
21
|
+
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
20
22
|
|
|
21
23
|
# Set up logging
|
|
22
|
-
fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
|
|
24
|
+
# fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
|
|
25
|
+
fmt = colored("%(message)s", "light_yellow")
|
|
23
26
|
logging.basicConfig(
|
|
24
27
|
level=logging.CRITICAL if hasattr(sys, "_called_from_test") else logging.INFO,
|
|
25
28
|
format=fmt,
|
|
@@ -48,7 +51,13 @@ import importlib.metadata # noqa: E402
|
|
|
48
51
|
from dotenv import load_dotenv # noqa: E402
|
|
49
52
|
|
|
50
53
|
from .benchmarker import Benchmarker # noqa: E402
|
|
51
|
-
from .
|
|
54
|
+
from .data_models import DatasetConfig # noqa: E402
|
|
55
|
+
from .logging_utils import block_terminal_output # noqa: E402
|
|
56
|
+
from .tasks import ( # noqa: E402
|
|
57
|
+
MULTIPLE_CHOICE,
|
|
58
|
+
TEXT_CLASSIFICATION,
|
|
59
|
+
TOKEN_CLASSIFICATION,
|
|
60
|
+
)
|
|
52
61
|
|
|
53
62
|
# Block unwanted terminal outputs. This blocks way more than the above, but since it
|
|
54
63
|
# relies on importing from the `utils` module, external modules are already imported
|
|
@@ -77,15 +86,18 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
|
77
86
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
78
87
|
|
|
79
88
|
|
|
80
|
-
# Disable a warning from Ray regarding the detection of the number of CPUs
|
|
81
|
-
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
|
|
82
|
-
|
|
83
|
-
|
|
84
89
|
# Avoid the "Cannot re-initialize CUDA in forked subprocess" error - see
|
|
85
90
|
# https://github.com/vllm-project/vllm/issues/6152 for more
|
|
86
91
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
87
92
|
|
|
88
93
|
|
|
94
|
+
# Allow long max model length in vLLM. This happens when vLLM registers that the model
|
|
95
|
+
# has a shorter context length than the value we are inserting. But since we do a
|
|
96
|
+
# thorough check of the model's config before setting the context length, we trust our
|
|
97
|
+
# own checks and ignore the internal vLLM check.
|
|
98
|
+
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
|
99
|
+
|
|
100
|
+
|
|
89
101
|
# Avoid the "Unclosed client session" error when evaluating Ollama models with LiteLLM.
|
|
90
102
|
# The error comes from the `aiohttp` package, and this environment variable forces the
|
|
91
103
|
# use of `httpx` instead.
|
|
@@ -93,9 +105,15 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
|
93
105
|
os.environ["DISABLE_AIOHTTP_TRANSPORT"] = "True"
|
|
94
106
|
|
|
95
107
|
|
|
96
|
-
#
|
|
97
|
-
#
|
|
98
|
-
os.environ["VLLM_USE_V1"] = "
|
|
108
|
+
# Enable the newer vLLM V1 engine, which is faster and offers more compatibility with
|
|
109
|
+
# newer models
|
|
110
|
+
os.environ["VLLM_USE_V1"] = "1"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# Use the FlashInfer flash-attention backend for vLLM, unless the user has already
|
|
114
|
+
# specified a different backend.
|
|
115
|
+
if os.getenv("VLLM_ATTENTION_BACKEND") is None:
|
|
116
|
+
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
|
|
99
117
|
|
|
100
118
|
|
|
101
119
|
# Set the HF_TOKEN env var to copy the HUGGINGFACE_API_KEY env var, as vLLM uses the
|
|
@@ -1,173 +1,82 @@
|
|
|
1
1
|
"""Factory class for creating dataset configurations."""
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import collections.abc as c
|
|
4
4
|
import sys
|
|
5
5
|
import typing as t
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
|
|
9
|
-
from .data_models import BenchmarkConfig
|
|
9
|
+
from .data_models import BenchmarkConfig, BenchmarkConfigParams, DatasetConfig, Task
|
|
10
10
|
from .dataset_configs import get_all_dataset_configs
|
|
11
11
|
from .enums import Device
|
|
12
12
|
from .exceptions import InvalidBenchmark
|
|
13
13
|
from .languages import get_all_languages
|
|
14
|
-
from .tasks import SPEED, get_all_tasks
|
|
15
14
|
|
|
16
15
|
if t.TYPE_CHECKING:
|
|
17
|
-
from .data_models import Language
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger("euroeval")
|
|
16
|
+
from .data_models import Language
|
|
21
17
|
|
|
22
18
|
|
|
23
19
|
def build_benchmark_config(
|
|
24
|
-
|
|
25
|
-
save_results: bool,
|
|
26
|
-
task: str | list[str] | None,
|
|
27
|
-
dataset: str | list[str] | None,
|
|
28
|
-
language: str | list[str],
|
|
29
|
-
model_language: str | list[str] | None,
|
|
30
|
-
dataset_language: str | list[str] | None,
|
|
31
|
-
device: Device | None,
|
|
32
|
-
batch_size: int,
|
|
33
|
-
raise_errors: bool,
|
|
34
|
-
cache_dir: str,
|
|
35
|
-
api_key: str | None,
|
|
36
|
-
force: bool,
|
|
37
|
-
verbose: bool,
|
|
38
|
-
trust_remote_code: bool,
|
|
39
|
-
clear_model_cache: bool,
|
|
40
|
-
evaluate_test_split: bool,
|
|
41
|
-
few_shot: bool,
|
|
42
|
-
num_iterations: int,
|
|
43
|
-
api_base: str | None,
|
|
44
|
-
api_version: str | None,
|
|
45
|
-
gpu_memory_utilization: float,
|
|
46
|
-
debug: bool,
|
|
47
|
-
run_with_cli: bool,
|
|
48
|
-
only_allow_safetensors: bool,
|
|
49
|
-
first_time: bool = False,
|
|
20
|
+
benchmark_config_params: BenchmarkConfigParams,
|
|
50
21
|
) -> BenchmarkConfig:
|
|
51
22
|
"""Create a benchmark configuration.
|
|
52
23
|
|
|
53
24
|
Args:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
save_results:
|
|
57
|
-
Whether to save the benchmark results to a file.
|
|
58
|
-
task:
|
|
59
|
-
The tasks to include for dataset. If None then datasets will not be
|
|
60
|
-
filtered based on their task.
|
|
61
|
-
dataset:
|
|
62
|
-
The datasets to include for task. If None then all datasets will be
|
|
63
|
-
included, limited by the `task` parameter.
|
|
64
|
-
language:
|
|
65
|
-
The language codes of the languages to include, both for models and
|
|
66
|
-
datasets. Here 'no' means both Bokmål (nb) and Nynorsk (nn). Set this
|
|
67
|
-
to 'all' if all languages should be considered.
|
|
68
|
-
model_language:
|
|
69
|
-
The language codes of the languages to include for models. If None then
|
|
70
|
-
the `language` parameter will be used.
|
|
71
|
-
dataset_language:
|
|
72
|
-
The language codes of the languages to include for datasets. If None then
|
|
73
|
-
the `language` parameter will be used.
|
|
74
|
-
device:
|
|
75
|
-
The device to use for running the models. If None then the device will be
|
|
76
|
-
set automatically.
|
|
77
|
-
batch_size:
|
|
78
|
-
The batch size to use for running the models.
|
|
79
|
-
raise_errors:
|
|
80
|
-
Whether to raise errors when running the benchmark.
|
|
81
|
-
cache_dir:
|
|
82
|
-
The directory to use for caching the models.
|
|
83
|
-
api_key:
|
|
84
|
-
The API key to use for a given inference server.
|
|
85
|
-
force:
|
|
86
|
-
Whether to force the benchmark to run even if the results are already
|
|
87
|
-
cached.
|
|
88
|
-
verbose:
|
|
89
|
-
Whether to print verbose output when running the benchmark. This is
|
|
90
|
-
automatically set if `debug` is True.
|
|
91
|
-
trust_remote_code:
|
|
92
|
-
Whether to trust remote code when running the benchmark.
|
|
93
|
-
clear_model_cache:
|
|
94
|
-
Whether to clear the model cache before running the benchmark.
|
|
95
|
-
evaluate_test_split:
|
|
96
|
-
Whether to use the test split for the datasets.
|
|
97
|
-
few_shot:
|
|
98
|
-
Whether to use few-shot learning for the models.
|
|
99
|
-
num_iterations:
|
|
100
|
-
The number of iterations each model should be evaluated for.
|
|
101
|
-
api_base:
|
|
102
|
-
The base URL for a given inference API. Only relevant if `model` refers to a
|
|
103
|
-
model on an inference API.
|
|
104
|
-
api_version:
|
|
105
|
-
The version of the API to use for a given inference API.
|
|
106
|
-
gpu_memory_utilization:
|
|
107
|
-
The GPU memory utilization to use for vLLM. A larger value will result in
|
|
108
|
-
faster evaluation, but at the risk of running out of GPU memory. Only reduce
|
|
109
|
-
this if you are running out of GPU memory. Only relevant if the model is
|
|
110
|
-
generative.
|
|
111
|
-
debug:
|
|
112
|
-
Whether to run the benchmark in debug mode.
|
|
113
|
-
run_with_cli:
|
|
114
|
-
Whether the benchmark is being run with the CLI.
|
|
115
|
-
only_allow_safetensors:
|
|
116
|
-
Whether to only allow evaluations of models stored as safetensors.
|
|
117
|
-
first_time:
|
|
118
|
-
Whether this is the first time the benchmark configuration is being created.
|
|
119
|
-
Defaults to False.
|
|
25
|
+
benchmark_config_params:
|
|
26
|
+
The parameters for creating the benchmark configuration.
|
|
120
27
|
|
|
121
28
|
Returns:
|
|
122
29
|
The benchmark configuration.
|
|
123
30
|
"""
|
|
124
|
-
language_codes = get_correct_language_codes(
|
|
125
|
-
|
|
126
|
-
language_codes=model_language, default_language_codes=language_codes
|
|
31
|
+
language_codes = get_correct_language_codes(
|
|
32
|
+
language_codes=benchmark_config_params.language
|
|
127
33
|
)
|
|
128
|
-
|
|
129
|
-
language_codes=
|
|
34
|
+
languages = prepare_languages(
|
|
35
|
+
language_codes=benchmark_config_params.language,
|
|
36
|
+
default_language_codes=language_codes,
|
|
130
37
|
)
|
|
131
38
|
|
|
132
|
-
|
|
133
|
-
task=task,
|
|
39
|
+
dataset_configs = prepare_dataset_configs(
|
|
40
|
+
task=benchmark_config_params.task,
|
|
41
|
+
dataset=benchmark_config_params.dataset,
|
|
42
|
+
languages=languages,
|
|
134
43
|
)
|
|
135
44
|
|
|
136
|
-
torch_device = prepare_device(device=device)
|
|
137
|
-
|
|
138
|
-
# Set variable with number of iterations
|
|
139
|
-
if hasattr(sys, "_called_from_test"):
|
|
140
|
-
num_iterations = 1
|
|
141
|
-
|
|
142
45
|
return BenchmarkConfig(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
46
|
+
datasets=dataset_configs,
|
|
47
|
+
languages=languages,
|
|
48
|
+
finetuning_batch_size=benchmark_config_params.finetuning_batch_size,
|
|
49
|
+
raise_errors=benchmark_config_params.raise_errors,
|
|
50
|
+
cache_dir=benchmark_config_params.cache_dir,
|
|
51
|
+
api_key=benchmark_config_params.api_key,
|
|
52
|
+
force=benchmark_config_params.force,
|
|
53
|
+
progress_bar=benchmark_config_params.progress_bar,
|
|
54
|
+
save_results=benchmark_config_params.save_results,
|
|
55
|
+
verbose=benchmark_config_params.verbose or benchmark_config_params.debug,
|
|
56
|
+
device=prepare_device(device=benchmark_config_params.device),
|
|
57
|
+
trust_remote_code=benchmark_config_params.trust_remote_code,
|
|
58
|
+
clear_model_cache=benchmark_config_params.clear_model_cache,
|
|
59
|
+
evaluate_test_split=benchmark_config_params.evaluate_test_split,
|
|
60
|
+
few_shot=benchmark_config_params.few_shot,
|
|
61
|
+
num_iterations=(
|
|
62
|
+
1
|
|
63
|
+
if hasattr(sys, "_called_from_test")
|
|
64
|
+
else benchmark_config_params.num_iterations
|
|
65
|
+
),
|
|
66
|
+
api_base=benchmark_config_params.api_base,
|
|
67
|
+
api_version=benchmark_config_params.api_version,
|
|
68
|
+
gpu_memory_utilization=benchmark_config_params.gpu_memory_utilization,
|
|
69
|
+
generative_type=benchmark_config_params.generative_type,
|
|
70
|
+
debug=benchmark_config_params.debug,
|
|
71
|
+
run_with_cli=benchmark_config_params.run_with_cli,
|
|
72
|
+
requires_safetensors=benchmark_config_params.requires_safetensors,
|
|
73
|
+
download_only=benchmark_config_params.download_only,
|
|
167
74
|
)
|
|
168
75
|
|
|
169
76
|
|
|
170
|
-
def get_correct_language_codes(
|
|
77
|
+
def get_correct_language_codes(
|
|
78
|
+
language_codes: str | c.Sequence[str],
|
|
79
|
+
) -> c.Sequence[str]:
|
|
171
80
|
"""Get correct language code(s).
|
|
172
81
|
|
|
173
82
|
Args:
|
|
@@ -188,7 +97,7 @@ def get_correct_language_codes(language_codes: str | list[str]) -> list[str]:
|
|
|
188
97
|
elif isinstance(language_codes, str):
|
|
189
98
|
languages = [language_codes]
|
|
190
99
|
else:
|
|
191
|
-
languages = language_codes
|
|
100
|
+
languages = list(language_codes)
|
|
192
101
|
|
|
193
102
|
# If `languages` contains 'no' then also include 'nb' and 'nn'. Conversely, if
|
|
194
103
|
# either 'nb' or 'nn' are specified then also include 'no'.
|
|
@@ -201,8 +110,9 @@ def get_correct_language_codes(language_codes: str | list[str]) -> list[str]:
|
|
|
201
110
|
|
|
202
111
|
|
|
203
112
|
def prepare_languages(
|
|
204
|
-
language_codes: str |
|
|
205
|
-
|
|
113
|
+
language_codes: str | c.Sequence[str] | None,
|
|
114
|
+
default_language_codes: c.Sequence[str],
|
|
115
|
+
) -> c.Sequence["Language"]:
|
|
206
116
|
"""Prepare language(s) for benchmarking.
|
|
207
117
|
|
|
208
118
|
Args:
|
|
@@ -220,7 +130,7 @@ def prepare_languages(
|
|
|
220
130
|
language_mapping = get_all_languages()
|
|
221
131
|
|
|
222
132
|
# Create the list `languages_str` of language codes to use for models or datasets
|
|
223
|
-
languages_str:
|
|
133
|
+
languages_str: c.Sequence[str]
|
|
224
134
|
if language_codes is None:
|
|
225
135
|
languages_str = default_language_codes
|
|
226
136
|
elif isinstance(language_codes, str):
|
|
@@ -237,74 +147,76 @@ def prepare_languages(
|
|
|
237
147
|
return prepared_languages
|
|
238
148
|
|
|
239
149
|
|
|
240
|
-
def
|
|
241
|
-
task: str |
|
|
242
|
-
|
|
243
|
-
dataset: str |
|
|
244
|
-
) ->
|
|
245
|
-
"""Prepare
|
|
150
|
+
def prepare_dataset_configs(
|
|
151
|
+
task: "str | Task | c.Sequence[str | Task] | None",
|
|
152
|
+
languages: c.Sequence["Language"],
|
|
153
|
+
dataset: "str | DatasetConfig | c.Sequence[str | DatasetConfig] | None",
|
|
154
|
+
) -> c.Sequence["DatasetConfig"]:
|
|
155
|
+
"""Prepare dataset config(s) for benchmarking.
|
|
246
156
|
|
|
247
157
|
Args:
|
|
248
158
|
task:
|
|
249
159
|
The tasks to include for dataset. If None then datasets will not be
|
|
250
160
|
filtered based on their task.
|
|
251
|
-
|
|
161
|
+
languages:
|
|
252
162
|
The languages of the datasets in the benchmark.
|
|
253
163
|
dataset:
|
|
254
164
|
The datasets to include for task. If None then all datasets will be
|
|
255
|
-
included, limited by the `task` and `
|
|
165
|
+
included, limited by the `task` and `languages` parameters.
|
|
256
166
|
|
|
257
167
|
Returns:
|
|
258
|
-
The prepared
|
|
168
|
+
The prepared dataset configs.
|
|
259
169
|
|
|
260
170
|
Raises:
|
|
261
171
|
InvalidBenchmark:
|
|
262
172
|
If the task or dataset is not found in the benchmark tasks or datasets.
|
|
263
173
|
"""
|
|
264
|
-
# Create
|
|
265
|
-
# task objects, and a dictionary that maps dataset names to their associated
|
|
266
|
-
# dataset configuration objects
|
|
267
|
-
task_mapping = get_all_tasks()
|
|
174
|
+
# Create the list of dataset configs
|
|
268
175
|
all_dataset_configs = get_all_dataset_configs()
|
|
176
|
+
all_official_dataset_configs: c.Sequence[DatasetConfig] = [
|
|
177
|
+
dataset_config
|
|
178
|
+
for dataset_config in all_dataset_configs.values()
|
|
179
|
+
if not dataset_config.unofficial
|
|
180
|
+
]
|
|
181
|
+
try:
|
|
182
|
+
if dataset is None:
|
|
183
|
+
datasets = all_official_dataset_configs
|
|
184
|
+
elif isinstance(dataset, str):
|
|
185
|
+
datasets = [all_dataset_configs[dataset]]
|
|
186
|
+
elif isinstance(dataset, DatasetConfig):
|
|
187
|
+
datasets = [dataset]
|
|
188
|
+
else:
|
|
189
|
+
datasets = [
|
|
190
|
+
all_dataset_configs[d] if isinstance(d, str) else d for d in dataset
|
|
191
|
+
]
|
|
192
|
+
except KeyError as e:
|
|
193
|
+
raise InvalidBenchmark(
|
|
194
|
+
f"Dataset {e} not found in the benchmark datasets."
|
|
195
|
+
) from e
|
|
269
196
|
|
|
270
197
|
# Create the list of dataset tasks
|
|
198
|
+
task_mapping = {cfg.task.name: cfg.task for cfg in all_dataset_configs.values()}
|
|
271
199
|
try:
|
|
272
200
|
if task is None:
|
|
273
|
-
tasks =
|
|
201
|
+
tasks = None
|
|
274
202
|
elif isinstance(task, str):
|
|
275
203
|
tasks = [task_mapping[task]]
|
|
204
|
+
elif isinstance(task, Task):
|
|
205
|
+
tasks = [task]
|
|
276
206
|
else:
|
|
277
|
-
tasks = [task_mapping[t] for t in task]
|
|
207
|
+
tasks = [task_mapping[t] if isinstance(t, str) else t for t in task]
|
|
278
208
|
except KeyError as e:
|
|
279
209
|
raise InvalidBenchmark(f"Task {e} not found in the benchmark tasks.") from e
|
|
280
210
|
|
|
281
|
-
|
|
282
|
-
dataset_name
|
|
283
|
-
for dataset_name, dataset_config in all_dataset_configs.items()
|
|
284
|
-
if not dataset_config.unofficial
|
|
285
|
-
]
|
|
286
|
-
if dataset is None:
|
|
287
|
-
dataset = all_official_datasets
|
|
288
|
-
elif isinstance(dataset, str):
|
|
289
|
-
dataset = [dataset]
|
|
290
|
-
|
|
291
|
-
all_datasets = list(all_dataset_configs.keys())
|
|
292
|
-
invalid_datasets = set(dataset) - set(all_datasets)
|
|
293
|
-
if invalid_datasets:
|
|
294
|
-
raise InvalidBenchmark(
|
|
295
|
-
f"Dataset(s) {', '.join(invalid_datasets)} not found in the benchmark "
|
|
296
|
-
"datasets."
|
|
297
|
-
)
|
|
298
|
-
|
|
211
|
+
# Filter the dataset configs based on the specified tasks and languages
|
|
299
212
|
datasets = [
|
|
300
|
-
|
|
301
|
-
for
|
|
302
|
-
if
|
|
303
|
-
and
|
|
304
|
-
and set(dataset_config.languages).intersection(dataset_languages)
|
|
213
|
+
ds
|
|
214
|
+
for ds in datasets
|
|
215
|
+
if (tasks is None or ds.task in tasks)
|
|
216
|
+
and any(lang in languages for lang in ds.languages)
|
|
305
217
|
]
|
|
306
218
|
|
|
307
|
-
return
|
|
219
|
+
return datasets
|
|
308
220
|
|
|
309
221
|
|
|
310
222
|
def prepare_device(device: Device | None) -> torch.device:
|
|
@@ -2,24 +2,23 @@
|
|
|
2
2
|
|
|
3
3
|
import collections.abc as c
|
|
4
4
|
import logging
|
|
5
|
-
import
|
|
5
|
+
import re
|
|
6
6
|
import typing as t
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
from functools import cached_property, partial
|
|
9
9
|
|
|
10
|
-
from datasets import DatasetDict
|
|
10
|
+
from datasets import Dataset, DatasetDict
|
|
11
11
|
from torch import nn
|
|
12
|
-
from tqdm.auto import tqdm
|
|
13
12
|
|
|
14
13
|
from ..enums import TaskGroup
|
|
15
|
-
from ..exceptions import NeedsEnvironmentVariable, NeedsExtraInstalled
|
|
14
|
+
from ..exceptions import InvalidBenchmark, NeedsEnvironmentVariable, NeedsExtraInstalled
|
|
15
|
+
from ..logging_utils import get_pbar, log_once
|
|
16
16
|
from ..task_group_utils import (
|
|
17
17
|
question_answering,
|
|
18
18
|
sequence_classification,
|
|
19
19
|
text_to_text,
|
|
20
20
|
token_classification,
|
|
21
21
|
)
|
|
22
|
-
from ..utils import log_once
|
|
23
22
|
|
|
24
23
|
if t.TYPE_CHECKING:
|
|
25
24
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
@@ -35,8 +34,6 @@ if t.TYPE_CHECKING:
|
|
|
35
34
|
from ..enums import BatchingPreference, GenerativeType
|
|
36
35
|
from ..types import ComputeMetricsFunction, ExtractLabelsFunction
|
|
37
36
|
|
|
38
|
-
logger = logging.getLogger("euroeval")
|
|
39
|
-
|
|
40
37
|
|
|
41
38
|
class BenchmarkModule(ABC):
|
|
42
39
|
"""Abstract class for a benchmark module.
|
|
@@ -55,12 +52,14 @@ class BenchmarkModule(ABC):
|
|
|
55
52
|
fresh_model: bool
|
|
56
53
|
batching_preference: "BatchingPreference"
|
|
57
54
|
high_priority: bool
|
|
55
|
+
allowed_params: dict[re.Pattern, c.Sequence[str]] = {re.compile(r".*"): []}
|
|
58
56
|
|
|
59
57
|
def __init__(
|
|
60
58
|
self,
|
|
61
59
|
model_config: "ModelConfig",
|
|
62
60
|
dataset_config: "DatasetConfig",
|
|
63
61
|
benchmark_config: "BenchmarkConfig",
|
|
62
|
+
log_metadata: bool = True,
|
|
64
63
|
) -> None:
|
|
65
64
|
"""Initialise the benchmark module.
|
|
66
65
|
|
|
@@ -71,29 +70,25 @@ class BenchmarkModule(ABC):
|
|
|
71
70
|
The dataset configuration.
|
|
72
71
|
benchmark_config:
|
|
73
72
|
The benchmark configuration.
|
|
73
|
+
log_metadata:
|
|
74
|
+
Whether to log the metadata of the model.
|
|
74
75
|
"""
|
|
75
76
|
self.model_config = model_config
|
|
76
77
|
self.dataset_config = dataset_config
|
|
77
78
|
self.benchmark_config = benchmark_config
|
|
79
|
+
self.log_metadata = log_metadata
|
|
78
80
|
self.buffer: dict[str, t.Any] = dict()
|
|
79
|
-
self.
|
|
81
|
+
if self.log_metadata:
|
|
82
|
+
self._log_metadata()
|
|
80
83
|
|
|
81
84
|
def _log_metadata(self) -> None:
|
|
82
85
|
"""Log the metadata of the model."""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
logging_level = logging.CRITICAL
|
|
86
|
-
elif self.benchmark_config.verbose:
|
|
87
|
-
logging_level = logging.DEBUG
|
|
88
|
-
else:
|
|
89
|
-
logging_level = logging.INFO
|
|
90
|
-
logger.setLevel(logging_level)
|
|
91
|
-
|
|
92
|
-
logging_msg: str = ""
|
|
86
|
+
model_id = self.model_config.model_id
|
|
87
|
+
logging_msg: str = " ↳ "
|
|
93
88
|
if self.num_params < 0:
|
|
94
|
-
logging_msg += "The model has an unknown number of parameters, "
|
|
89
|
+
logging_msg += f"The model {model_id} has an unknown number of parameters, "
|
|
95
90
|
else:
|
|
96
|
-
logging_msg += f"The model has {self.num_params:,} parameters, "
|
|
91
|
+
logging_msg += f"The model {model_id} has {self.num_params:,} parameters, "
|
|
97
92
|
if self.vocab_size < 0:
|
|
98
93
|
logging_msg += "an unknown vocabulary size, "
|
|
99
94
|
else:
|
|
@@ -117,16 +112,16 @@ class BenchmarkModule(ABC):
|
|
|
117
112
|
f"{self.__class__.__name__}."
|
|
118
113
|
)
|
|
119
114
|
|
|
120
|
-
def
|
|
121
|
-
"""Get the underlying
|
|
115
|
+
def get_tokeniser(self) -> "PreTrainedTokenizer":
|
|
116
|
+
"""Get the underlying tokeniser.
|
|
122
117
|
|
|
123
118
|
Returns:
|
|
124
|
-
The
|
|
119
|
+
The tokeniser.
|
|
125
120
|
"""
|
|
126
|
-
if hasattr(self, "
|
|
127
|
-
return self.
|
|
121
|
+
if hasattr(self, "_tokeniser"):
|
|
122
|
+
return self._tokeniser
|
|
128
123
|
raise NotImplementedError(
|
|
129
|
-
"The `
|
|
124
|
+
"The `get_tokeniser` method has not been implemented for "
|
|
130
125
|
f"{self.__class__.__name__}."
|
|
131
126
|
)
|
|
132
127
|
|
|
@@ -172,7 +167,7 @@ class BenchmarkModule(ABC):
|
|
|
172
167
|
|
|
173
168
|
@property
|
|
174
169
|
@abstractmethod
|
|
175
|
-
def data_collator(self) -> c.Callable[[
|
|
170
|
+
def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
|
|
176
171
|
"""The data collator used to prepare samples during finetuning.
|
|
177
172
|
|
|
178
173
|
Returns:
|
|
@@ -192,11 +187,13 @@ class BenchmarkModule(ABC):
|
|
|
192
187
|
return partial(
|
|
193
188
|
sequence_classification.compute_metrics,
|
|
194
189
|
dataset_config=self.dataset_config,
|
|
190
|
+
benchmark_config=self.benchmark_config,
|
|
195
191
|
)
|
|
196
192
|
case TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
|
|
197
193
|
return partial(
|
|
198
194
|
sequence_classification.compute_metrics,
|
|
199
195
|
dataset_config=self.dataset_config,
|
|
196
|
+
benchmark_config=self.benchmark_config,
|
|
200
197
|
)
|
|
201
198
|
case TaskGroup.TEXT_TO_TEXT:
|
|
202
199
|
return partial(
|
|
@@ -209,11 +206,13 @@ class BenchmarkModule(ABC):
|
|
|
209
206
|
token_classification.compute_metrics,
|
|
210
207
|
has_misc_tags=self.buffer.get("has_misc_tags", True),
|
|
211
208
|
dataset_config=self.dataset_config,
|
|
209
|
+
benchmark_config=self.benchmark_config,
|
|
212
210
|
)
|
|
213
211
|
case TaskGroup.QUESTION_ANSWERING:
|
|
214
212
|
return partial(
|
|
215
213
|
question_answering.compute_metrics,
|
|
216
214
|
dataset_config=self.dataset_config,
|
|
215
|
+
benchmark_config=self.benchmark_config,
|
|
217
216
|
)
|
|
218
217
|
case _:
|
|
219
218
|
raise NotImplementedError(
|
|
@@ -242,7 +241,7 @@ class BenchmarkModule(ABC):
|
|
|
242
241
|
|
|
243
242
|
def prepare_datasets(
|
|
244
243
|
self, datasets: list[DatasetDict], task: "Task"
|
|
245
|
-
) ->
|
|
244
|
+
) -> c.Sequence[DatasetDict]:
|
|
246
245
|
"""Prepare the datasets for the model.
|
|
247
246
|
|
|
248
247
|
This includes things like tokenisation.
|
|
@@ -255,30 +254,41 @@ class BenchmarkModule(ABC):
|
|
|
255
254
|
|
|
256
255
|
Returns:
|
|
257
256
|
The prepared datasets.
|
|
257
|
+
|
|
258
|
+
Raises:
|
|
259
|
+
InvalidBenchmark:
|
|
260
|
+
If the dataset does not have a 'train' split for token classification
|
|
261
|
+
tasks.
|
|
258
262
|
"""
|
|
259
263
|
for idx, dataset in enumerate(
|
|
260
|
-
|
|
264
|
+
get_pbar(
|
|
265
|
+
iterable=datasets,
|
|
266
|
+
desc="Preparing datasets",
|
|
267
|
+
disable=not self.benchmark_config.progress_bar,
|
|
268
|
+
)
|
|
261
269
|
):
|
|
262
270
|
prepared_dataset = self.prepare_dataset(
|
|
263
271
|
dataset=dataset, task=task, itr_idx=idx
|
|
264
272
|
)
|
|
265
273
|
if self.dataset_config.task.task_group == TaskGroup.TOKEN_CLASSIFICATION:
|
|
274
|
+
if "train" not in dataset:
|
|
275
|
+
raise InvalidBenchmark(
|
|
276
|
+
"The dataset does not have a 'train' split, which is required "
|
|
277
|
+
"for token classification tasks."
|
|
278
|
+
)
|
|
266
279
|
labels_in_train: set[str] = {
|
|
267
280
|
tag for tag_list in dataset["train"]["labels"] for tag in tag_list
|
|
268
281
|
}
|
|
269
282
|
self.buffer["has_misc_tags"] = (
|
|
270
283
|
"B-MISC" in labels_in_train or "I-MISC" in labels_in_train
|
|
271
284
|
)
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
original_test=dataset["test"],
|
|
280
|
-
)
|
|
281
|
-
)
|
|
285
|
+
|
|
286
|
+
datasets_dict: dict[str, Dataset] = dict()
|
|
287
|
+
for split_name, split in prepared_dataset.items():
|
|
288
|
+
datasets_dict[split_name] = split
|
|
289
|
+
for split_name, split in dataset.items():
|
|
290
|
+
datasets_dict[f"original_{split_name}"] = split
|
|
291
|
+
datasets[idx] = DatasetDict(datasets_dict)
|
|
282
292
|
return datasets
|
|
283
293
|
|
|
284
294
|
@abstractmethod
|