EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Caching utility functions."""
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
from functools import wraps
|
|
5
|
+
|
|
6
|
+
from .constants import T
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def cache_arguments(
|
|
10
|
+
*arguments: str, disable_condition: t.Callable[[], bool] = lambda: False
|
|
11
|
+
) -> t.Callable[[t.Callable[..., T]], t.Callable[..., T]]:
|
|
12
|
+
"""Cache specified arguments of a function.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
arguments:
|
|
16
|
+
The list of argument names to cache. If empty, all arguments are cached.
|
|
17
|
+
disable_condition:
|
|
18
|
+
A function that checks if cache should be disabled.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
A decorator that caches the specified arguments of a function.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def caching_decorator(func: t.Callable[..., T]) -> t.Callable[..., T]:
|
|
25
|
+
"""Decorator that caches the specified arguments of a function.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
func:
|
|
29
|
+
The function to decorate.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The decorated function.
|
|
33
|
+
"""
|
|
34
|
+
cache: dict[tuple, T] = dict()
|
|
35
|
+
|
|
36
|
+
@wraps(func)
|
|
37
|
+
def wrapper(*args, **kwargs) -> T:
|
|
38
|
+
"""Wrapper function that caches the specified arguments.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
*args:
|
|
42
|
+
The positional arguments to the function.
|
|
43
|
+
**kwargs:
|
|
44
|
+
The keyword arguments to the function.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The result of the function.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
ValueError:
|
|
51
|
+
If an argument name is not found in the function parameters.
|
|
52
|
+
"""
|
|
53
|
+
if not arguments:
|
|
54
|
+
key = args + tuple(kwargs[k] for k in sorted(kwargs.keys()))
|
|
55
|
+
else:
|
|
56
|
+
func_params = func.__code__.co_varnames
|
|
57
|
+
key_items: list[t.Any] = list()
|
|
58
|
+
for arg_name in arguments:
|
|
59
|
+
if arg_name in kwargs:
|
|
60
|
+
key_items.append(kwargs[arg_name])
|
|
61
|
+
else:
|
|
62
|
+
try:
|
|
63
|
+
arg_index = func_params.index(arg_name)
|
|
64
|
+
key_items.append(args[arg_index])
|
|
65
|
+
except (ValueError, IndexError):
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Argument {arg_name} not found in function "
|
|
68
|
+
f"{func.__name__} parameters."
|
|
69
|
+
)
|
|
70
|
+
key = tuple(key_items)
|
|
71
|
+
|
|
72
|
+
# Do not cache if the condition is met
|
|
73
|
+
if key not in cache or disable_condition():
|
|
74
|
+
cache[key] = func(*args, **kwargs)
|
|
75
|
+
return cache[key]
|
|
76
|
+
|
|
77
|
+
return wrapper
|
|
78
|
+
|
|
79
|
+
return caching_decorator
|
euroeval/callbacks.py
CHANGED
|
@@ -7,6 +7,8 @@ from collections.abc import Sized
|
|
|
7
7
|
from tqdm.auto import tqdm
|
|
8
8
|
from transformers.trainer_callback import ProgressCallback
|
|
9
9
|
|
|
10
|
+
from .logging_utils import get_pbar
|
|
11
|
+
|
|
10
12
|
if t.TYPE_CHECKING:
|
|
11
13
|
from torch.utils.data import DataLoader
|
|
12
14
|
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
@@ -32,11 +34,8 @@ class NeverLeaveProgressCallback(ProgressCallback):
|
|
|
32
34
|
"""Callback actions when training begins."""
|
|
33
35
|
if state.is_local_process_zero:
|
|
34
36
|
desc = "Finetuning model"
|
|
35
|
-
self.training_bar =
|
|
36
|
-
total=None,
|
|
37
|
-
leave=False,
|
|
38
|
-
desc=desc,
|
|
39
|
-
disable=hasattr(sys, "_called_from_test"),
|
|
37
|
+
self.training_bar = get_pbar(
|
|
38
|
+
total=None, desc=desc, disable=hasattr(sys, "_called_from_test")
|
|
40
39
|
)
|
|
41
40
|
self.current_step = 0
|
|
42
41
|
|
|
@@ -67,9 +66,8 @@ class NeverLeaveProgressCallback(ProgressCallback):
|
|
|
67
66
|
if state.is_local_process_zero and correct_dtype:
|
|
68
67
|
if self.prediction_bar is None:
|
|
69
68
|
desc = "Evaluating model"
|
|
70
|
-
self.prediction_bar =
|
|
69
|
+
self.prediction_bar = get_pbar(
|
|
71
70
|
total=len(eval_dataloader),
|
|
72
|
-
leave=False,
|
|
73
71
|
desc=desc,
|
|
74
72
|
disable=hasattr(sys, "_called_from_test"),
|
|
75
73
|
)
|
euroeval/cli.py
CHANGED
|
@@ -3,10 +3,9 @@
|
|
|
3
3
|
import click
|
|
4
4
|
|
|
5
5
|
from .benchmarker import Benchmarker
|
|
6
|
-
from .
|
|
7
|
-
from .enums import Device
|
|
6
|
+
from .data_models import DatasetConfig
|
|
7
|
+
from .enums import Device, GenerativeType
|
|
8
8
|
from .languages import get_all_languages
|
|
9
|
-
from .tasks import get_all_tasks
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
@click.command()
|
|
@@ -23,7 +22,6 @@ from .tasks import get_all_tasks
|
|
|
23
22
|
default=None,
|
|
24
23
|
show_default=True,
|
|
25
24
|
multiple=True,
|
|
26
|
-
type=click.Choice(list(get_all_tasks().keys())),
|
|
27
25
|
help="The dataset tasks to benchmark the model(s) on.",
|
|
28
26
|
)
|
|
29
27
|
@click.option(
|
|
@@ -45,8 +43,7 @@ from .tasks import get_all_tasks
|
|
|
45
43
|
multiple=True,
|
|
46
44
|
metavar="ISO 639-1 LANGUAGE CODE",
|
|
47
45
|
type=click.Choice(["all"] + list(get_all_languages().keys())),
|
|
48
|
-
help="""
|
|
49
|
-
`language` value.""",
|
|
46
|
+
help="""This option is deprecated - please use --language instead.""",
|
|
50
47
|
)
|
|
51
48
|
@click.option(
|
|
52
49
|
"--dataset-language",
|
|
@@ -56,24 +53,28 @@ from .tasks import get_all_tasks
|
|
|
56
53
|
multiple=True,
|
|
57
54
|
metavar="ISO 639-1 LANGUAGE CODE",
|
|
58
55
|
type=click.Choice(["all"] + list(get_all_languages().keys())),
|
|
59
|
-
help="""
|
|
60
|
-
benchmarked on all datasets. If not specified then this will use the `language`
|
|
61
|
-
value.""",
|
|
56
|
+
help="""This option is deprecated - please use --language instead.""",
|
|
62
57
|
)
|
|
63
58
|
@click.option(
|
|
64
59
|
"--dataset",
|
|
65
60
|
default=None,
|
|
66
61
|
show_default=True,
|
|
67
62
|
multiple=True,
|
|
68
|
-
type=click.Choice(list(get_all_dataset_configs().keys())),
|
|
69
63
|
help="""The name of the benchmark dataset. We recommend to use the `task` and
|
|
70
64
|
`language` options instead of this option.""",
|
|
71
65
|
)
|
|
72
66
|
@click.option(
|
|
73
67
|
"--batch-size",
|
|
68
|
+
default=None,
|
|
69
|
+
type=click.Choice(["1", "2", "4", "8", "16", "32"]),
|
|
70
|
+
help="This option is deprecated - please use --finetuning-batch-size instead.",
|
|
71
|
+
deprecated=True,
|
|
72
|
+
)
|
|
73
|
+
@click.option(
|
|
74
|
+
"--finetuning-batch-size",
|
|
74
75
|
default="32",
|
|
75
76
|
type=click.Choice(["1", "2", "4", "8", "16", "32"]),
|
|
76
|
-
help="The batch size to use.",
|
|
77
|
+
help="The batch size to use for finetuning.",
|
|
77
78
|
)
|
|
78
79
|
@click.option(
|
|
79
80
|
"--progress-bar/--no-progress-bar",
|
|
@@ -188,7 +189,7 @@ from .tasks import get_all_tasks
|
|
|
188
189
|
)
|
|
189
190
|
@click.option(
|
|
190
191
|
"--gpu-memory-utilization",
|
|
191
|
-
default=0.
|
|
192
|
+
default=0.8,
|
|
192
193
|
show_default=True,
|
|
193
194
|
help="The GPU memory utilization to use for vLLM. A larger value will result in "
|
|
194
195
|
"faster evaluation, but at the risk of running out of GPU memory. Only reduce this "
|
|
@@ -203,20 +204,35 @@ from .tasks import get_all_tasks
|
|
|
203
204
|
"relevant if the model is generative.",
|
|
204
205
|
)
|
|
205
206
|
@click.option(
|
|
206
|
-
"--
|
|
207
|
+
"--requires-safetensors",
|
|
207
208
|
is_flag=True,
|
|
208
209
|
help="Only allow loading models that have safetensors weights available",
|
|
209
210
|
default=False,
|
|
210
211
|
)
|
|
212
|
+
@click.option(
|
|
213
|
+
"--generative-type",
|
|
214
|
+
type=click.Choice(["base", "instruction_tuned", "reasoning"]),
|
|
215
|
+
default=None,
|
|
216
|
+
show_default=True,
|
|
217
|
+
help="The type of generative model. Only relevant if the model is generative. If "
|
|
218
|
+
"not specified, the type will be inferred automatically.",
|
|
219
|
+
)
|
|
220
|
+
@click.option(
|
|
221
|
+
"--download-only",
|
|
222
|
+
is_flag=True,
|
|
223
|
+
help="Only download the requested model weights and datasets, and exit.",
|
|
224
|
+
default=False,
|
|
225
|
+
)
|
|
211
226
|
def benchmark(
|
|
212
227
|
model: tuple[str],
|
|
213
|
-
dataset: tuple[str],
|
|
228
|
+
dataset: tuple[str | DatasetConfig],
|
|
214
229
|
language: tuple[str],
|
|
215
230
|
model_language: tuple[str],
|
|
216
231
|
dataset_language: tuple[str],
|
|
217
232
|
raise_errors: bool,
|
|
218
233
|
task: tuple[str],
|
|
219
|
-
batch_size: str,
|
|
234
|
+
batch_size: str | None,
|
|
235
|
+
finetuning_batch_size: str,
|
|
220
236
|
progress_bar: bool,
|
|
221
237
|
save_results: bool,
|
|
222
238
|
cache_dir: str,
|
|
@@ -233,25 +249,16 @@ def benchmark(
|
|
|
233
249
|
api_version: str | None,
|
|
234
250
|
gpu_memory_utilization: float,
|
|
235
251
|
debug: bool,
|
|
236
|
-
|
|
252
|
+
requires_safetensors: bool,
|
|
253
|
+
generative_type: str | None,
|
|
254
|
+
download_only: bool,
|
|
237
255
|
) -> None:
|
|
238
256
|
"""Benchmark pretrained language models on language tasks."""
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
tasks = None if len(task) == 0 else list(task)
|
|
245
|
-
batch_size_int = int(batch_size)
|
|
246
|
-
device = Device[device.upper()] if device is not None else None
|
|
247
|
-
|
|
248
|
-
benchmarker = Benchmarker(
|
|
249
|
-
language=languages,
|
|
250
|
-
model_language=model_languages,
|
|
251
|
-
dataset_language=dataset_languages,
|
|
252
|
-
task=tasks,
|
|
253
|
-
dataset=datasets,
|
|
254
|
-
batch_size=batch_size_int,
|
|
257
|
+
Benchmarker(
|
|
258
|
+
language=list(language),
|
|
259
|
+
task=None if len(task) == 0 else list(task),
|
|
260
|
+
dataset=None if len(dataset) == 0 else list(dataset),
|
|
261
|
+
finetuning_batch_size=int(finetuning_batch_size),
|
|
255
262
|
progress_bar=progress_bar,
|
|
256
263
|
save_results=save_results,
|
|
257
264
|
raise_errors=raise_errors,
|
|
@@ -259,7 +266,7 @@ def benchmark(
|
|
|
259
266
|
api_key=api_key,
|
|
260
267
|
force=force,
|
|
261
268
|
cache_dir=cache_dir,
|
|
262
|
-
device=device,
|
|
269
|
+
device=Device[device.upper()] if device is not None else None,
|
|
263
270
|
trust_remote_code=trust_remote_code,
|
|
264
271
|
clear_model_cache=clear_model_cache,
|
|
265
272
|
evaluate_test_split=evaluate_test_split,
|
|
@@ -268,13 +275,17 @@ def benchmark(
|
|
|
268
275
|
api_base=api_base,
|
|
269
276
|
api_version=api_version,
|
|
270
277
|
gpu_memory_utilization=gpu_memory_utilization,
|
|
278
|
+
generative_type=GenerativeType[generative_type.upper()]
|
|
279
|
+
if generative_type
|
|
280
|
+
else None,
|
|
271
281
|
debug=debug,
|
|
272
282
|
run_with_cli=True,
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
283
|
+
requires_safetensors=requires_safetensors,
|
|
284
|
+
download_only=download_only,
|
|
285
|
+
model_language=None if len(model_language) == 0 else list(model_language),
|
|
286
|
+
dataset_language=None if len(dataset_language) == 0 else list(dataset_language),
|
|
287
|
+
batch_size=int(batch_size) if batch_size is not None else None,
|
|
288
|
+
).benchmark(model=list(model))
|
|
278
289
|
|
|
279
290
|
|
|
280
291
|
if __name__ == "__main__":
|
euroeval/constants.py
CHANGED
|
@@ -1,23 +1,25 @@
|
|
|
1
1
|
"""Constants used throughout the project."""
|
|
2
2
|
|
|
3
|
+
import re
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
3
6
|
from .enums import TaskGroup
|
|
4
|
-
|
|
7
|
+
|
|
8
|
+
# Type variable used for generic typing
|
|
9
|
+
T = TypeVar("T", bound=object)
|
|
5
10
|
|
|
6
11
|
# This is used as input to generative models; it cannot be a special token
|
|
7
12
|
DUMMY_FILL_VALUE = 100
|
|
8
13
|
|
|
9
|
-
|
|
10
14
|
# This is the maximum allowed context length for models for the purpose of this
|
|
11
15
|
# benchmark. We will still report the models' true maximum context length in the
|
|
12
16
|
# metadata, but we won't use it for evaluation, as vLLM needs to allocate memory for
|
|
13
17
|
# all tokens in the context.
|
|
14
|
-
MAX_CONTEXT_LENGTH =
|
|
15
|
-
|
|
18
|
+
MAX_CONTEXT_LENGTH = 8_192
|
|
16
19
|
|
|
17
20
|
# We need to raise the amount of tokens generated for reasoning models, to give them
|
|
18
21
|
# time to think
|
|
19
|
-
REASONING_MAX_TOKENS =
|
|
20
|
-
|
|
22
|
+
REASONING_MAX_TOKENS = 8_192
|
|
21
23
|
|
|
22
24
|
# The Hugging Face Hub pipeline tags used to classify models as generative
|
|
23
25
|
GENERATIVE_PIPELINE_TAGS = [
|
|
@@ -28,48 +30,49 @@ GENERATIVE_PIPELINE_TAGS = [
|
|
|
28
30
|
"video-text-to-text",
|
|
29
31
|
]
|
|
30
32
|
|
|
31
|
-
|
|
32
33
|
# Used to disallow non-generative models to be evaluated on these task groups
|
|
33
34
|
GENERATIVE_DATASET_TASK_GROUPS = [TaskGroup.TEXT_TO_TEXT]
|
|
34
35
|
|
|
35
|
-
|
|
36
36
|
# Local models are required to have these files in their directory
|
|
37
37
|
LOCAL_MODELS_REQUIRED_FILES = ["config.json"]
|
|
38
38
|
|
|
39
|
-
|
|
40
|
-
# Tasks where we use structured generation for generative models
|
|
41
|
-
TASKS_USING_JSON = [NER]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# Tasks where we use log probabilities for generative models, rather than the raw
|
|
45
|
-
# completion
|
|
46
|
-
TASK_GROUPS_USING_LOGPROBS = [
|
|
47
|
-
TaskGroup.SEQUENCE_CLASSIFICATION,
|
|
48
|
-
TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION,
|
|
49
|
-
]
|
|
50
|
-
|
|
51
|
-
|
|
52
39
|
# The number of top log probabilities to return for generative models. For several APIs
|
|
53
40
|
# this is the maximum number of log probabilities that can be returned
|
|
54
|
-
|
|
55
|
-
|
|
41
|
+
MAX_VLLM_LOGPROBS = 20
|
|
42
|
+
MAX_LITELLM_LOGPROBS = 8
|
|
56
43
|
|
|
57
44
|
# We make sure to remove these metric attributes after each iteration, to avoid memory
|
|
58
45
|
# leaks
|
|
59
46
|
METRIC_ATTRIBUTES_TAKING_UP_MEMORY = ["cached_bertscorer"]
|
|
60
47
|
|
|
61
|
-
|
|
62
48
|
# Hugging Face Hub tags used to classify models as merge models
|
|
63
49
|
MERGE_TAGS = ["merge", "mergekit"]
|
|
64
50
|
|
|
65
51
|
# The minimum required CUDA compute capability for using bfloat16 in vLLM
|
|
66
52
|
VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY = 8.0
|
|
67
53
|
|
|
54
|
+
# The candidates for end-of-sequence, beginning-of-sequence and padding tokens
|
|
55
|
+
EOS_TOKENS = ["</s>", "<|end_of_text|>", "<|endoftext|>", "[SEP]", "<|return|>"]
|
|
56
|
+
BOS_TOKENS = ["<s>", "<|begin_of_text|>", "<|startoftext|>", "[CLS]"]
|
|
57
|
+
PAD_TOKENS = [
|
|
58
|
+
"<pad>",
|
|
59
|
+
"<PAD>",
|
|
60
|
+
"[pad]",
|
|
61
|
+
"[PAD]",
|
|
62
|
+
"<|endoftext|>",
|
|
63
|
+
"<|end▁of▁sentence|>",
|
|
64
|
+
"<|im_end|>",
|
|
65
|
+
]
|
|
66
|
+
|
|
68
67
|
# Used to detect whether a model is a reasoning model
|
|
69
|
-
REASONING_TOKENS = [
|
|
68
|
+
REASONING_TOKENS: list[tuple[str | re.Pattern, str | re.Pattern]] = [
|
|
70
69
|
("<think>", "</think>"),
|
|
71
70
|
("<reason>", "</reason>"),
|
|
72
71
|
("<reasoning>", "</reasoning>"),
|
|
72
|
+
(
|
|
73
|
+
re.compile(pattern=r"<\|channel\|>(analysis|commentary)<\|message\|>"),
|
|
74
|
+
"<|channel|>final<|message|>",
|
|
75
|
+
),
|
|
73
76
|
]
|
|
74
77
|
|
|
75
78
|
# These tokens are sometimes used by models to indicate the end of a generated
|
|
@@ -77,3 +80,19 @@ REASONING_TOKENS = [
|
|
|
77
80
|
# manually. We only use them as stop tokens if they actually appear in the model's
|
|
78
81
|
# output
|
|
79
82
|
CUSTOM_STOP_TOKENS = ["<sep>"]
|
|
83
|
+
|
|
84
|
+
# For classification tasks we force LiteLLM models to output a JSON dictionary with a
|
|
85
|
+
# single key and the values being restricted to the allowed labels. This is the key we
|
|
86
|
+
# use
|
|
87
|
+
LITELLM_CLASSIFICATION_OUTPUT_KEY = "label"
|
|
88
|
+
|
|
89
|
+
# These characters are stripped from JSON output when trying to identify the label
|
|
90
|
+
JSON_STRIP_CHARACTERS = ' {}\n\r":'
|
|
91
|
+
|
|
92
|
+
# The number of tokens we generate when evaluating generative models on classification
|
|
93
|
+
# tasks. We also use this to determine whether we should store logprobs in the model
|
|
94
|
+
# outputs (and cache).
|
|
95
|
+
NUM_GENERATION_TOKENS_FOR_CLASSIFICATION = 10
|
|
96
|
+
|
|
97
|
+
# We only allow loading local datasets in these file formats
|
|
98
|
+
SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS = ["csv"]
|
euroeval/data_loading.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Functions related to the loading of the data."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import logging
|
|
4
5
|
import sys
|
|
5
6
|
import time
|
|
@@ -11,7 +12,10 @@ from datasets.exceptions import DatasetsError
|
|
|
11
12
|
from huggingface_hub.errors import HfHubHTTPError
|
|
12
13
|
from numpy.random import Generator
|
|
13
14
|
|
|
15
|
+
from .constants import SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS
|
|
14
16
|
from .exceptions import HuggingFaceHubDown, InvalidBenchmark
|
|
17
|
+
from .logging_utils import log, no_terminal_output
|
|
18
|
+
from .tasks import EUROPEAN_VALUES
|
|
15
19
|
from .utils import unscramble
|
|
16
20
|
|
|
17
21
|
if t.TYPE_CHECKING:
|
|
@@ -19,8 +23,6 @@ if t.TYPE_CHECKING:
|
|
|
19
23
|
|
|
20
24
|
from .data_models import BenchmarkConfig, DatasetConfig
|
|
21
25
|
|
|
22
|
-
logger = logging.getLogger("euroeval")
|
|
23
|
-
|
|
24
26
|
|
|
25
27
|
def load_data(
|
|
26
28
|
rng: Generator, dataset_config: "DatasetConfig", benchmark_config: "BenchmarkConfig"
|
|
@@ -48,40 +50,45 @@ def load_data(
|
|
|
48
50
|
dataset_config=dataset_config, cache_dir=benchmark_config.cache_dir
|
|
49
51
|
)
|
|
50
52
|
|
|
51
|
-
if not benchmark_config.evaluate_test_split:
|
|
53
|
+
if not benchmark_config.evaluate_test_split and "val" in dataset:
|
|
52
54
|
dataset["test"] = dataset["val"]
|
|
53
55
|
|
|
54
56
|
# Remove empty examples from the datasets
|
|
55
57
|
for text_feature in ["tokens", "text"]:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
+
for split in dataset_config.splits:
|
|
59
|
+
if text_feature in dataset[split].features:
|
|
60
|
+
dataset = dataset.filter(lambda x: len(x[text_feature]) > 0)
|
|
58
61
|
|
|
59
|
-
# If we are testing then truncate the test set
|
|
60
|
-
|
|
62
|
+
# If we are testing then truncate the test set, unless we need the full set for
|
|
63
|
+
# evaluation
|
|
64
|
+
if hasattr(sys, "_called_from_test") and dataset_config.task != EUROPEAN_VALUES:
|
|
61
65
|
dataset["test"] = dataset["test"].select(range(1))
|
|
62
66
|
|
|
63
|
-
# Bootstrap the splits
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
67
|
+
# Bootstrap the splits, if applicable
|
|
68
|
+
if dataset_config.bootstrap_samples:
|
|
69
|
+
bootstrapped_splits: dict[str, c.Sequence["Dataset"]] = dict()
|
|
70
|
+
for split in dataset_config.splits:
|
|
71
|
+
bootstrap_indices = rng.integers(
|
|
72
|
+
0,
|
|
73
|
+
len(dataset[split]),
|
|
74
|
+
size=(benchmark_config.num_iterations, len(dataset[split])),
|
|
75
|
+
)
|
|
76
|
+
bootstrapped_splits[split] = [
|
|
77
|
+
dataset[split].select(bootstrap_indices[idx])
|
|
78
|
+
for idx in range(benchmark_config.num_iterations)
|
|
79
|
+
]
|
|
80
|
+
datasets = [
|
|
81
|
+
DatasetDict(
|
|
82
|
+
{
|
|
83
|
+
split: bootstrapped_splits[split][idx]
|
|
84
|
+
for split in dataset_config.splits
|
|
85
|
+
}
|
|
86
|
+
)
|
|
73
87
|
for idx in range(benchmark_config.num_iterations)
|
|
74
88
|
]
|
|
89
|
+
else:
|
|
90
|
+
datasets = [dataset] * benchmark_config.num_iterations
|
|
75
91
|
|
|
76
|
-
datasets = [
|
|
77
|
-
DatasetDict(
|
|
78
|
-
{
|
|
79
|
-
split: bootstrapped_splits[split][idx]
|
|
80
|
-
for split in ["train", "val", "test"]
|
|
81
|
-
}
|
|
82
|
-
)
|
|
83
|
-
for idx in range(benchmark_config.num_iterations)
|
|
84
|
-
]
|
|
85
92
|
return datasets
|
|
86
93
|
|
|
87
94
|
|
|
@@ -97,40 +104,89 @@ def load_raw_data(dataset_config: "DatasetConfig", cache_dir: str) -> "DatasetDi
|
|
|
97
104
|
Returns:
|
|
98
105
|
The dataset.
|
|
99
106
|
"""
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
+
# Case where the dataset source is a Hugging Face ID
|
|
108
|
+
if isinstance(dataset_config.source, str):
|
|
109
|
+
num_attempts = 5
|
|
110
|
+
for _ in range(num_attempts):
|
|
111
|
+
try:
|
|
112
|
+
with no_terminal_output():
|
|
113
|
+
dataset = load_dataset(
|
|
114
|
+
path=dataset_config.source.split("::")[0],
|
|
115
|
+
name=(
|
|
116
|
+
dataset_config.source.split("::")[1]
|
|
117
|
+
if "::" in dataset_config.source
|
|
118
|
+
else None
|
|
119
|
+
),
|
|
120
|
+
cache_dir=cache_dir,
|
|
121
|
+
token=unscramble("XbjeOLhwebEaSaDUMqqaPaPIhgOcyOfDpGnX_"),
|
|
122
|
+
)
|
|
123
|
+
break
|
|
124
|
+
except (
|
|
125
|
+
FileNotFoundError,
|
|
126
|
+
ConnectionError,
|
|
127
|
+
DatasetsError,
|
|
128
|
+
requests.ConnectionError,
|
|
129
|
+
requests.ReadTimeout,
|
|
130
|
+
) as e:
|
|
131
|
+
log(
|
|
132
|
+
f"Failed to load dataset {dataset_config.source!r}, due to "
|
|
133
|
+
f"the following error: {e}. Retrying...",
|
|
134
|
+
level=logging.DEBUG,
|
|
135
|
+
)
|
|
136
|
+
time.sleep(1)
|
|
137
|
+
continue
|
|
138
|
+
except HfHubHTTPError:
|
|
139
|
+
raise HuggingFaceHubDown()
|
|
140
|
+
else:
|
|
141
|
+
raise InvalidBenchmark(
|
|
142
|
+
f"Failed to load dataset {dataset_config.source!r} after "
|
|
143
|
+
f"{num_attempts} attempts. Run with verbose mode to see the individual "
|
|
144
|
+
"errors."
|
|
107
145
|
)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
ConnectionError,
|
|
112
|
-
DatasetsError,
|
|
113
|
-
requests.ConnectionError,
|
|
114
|
-
requests.ReadTimeout,
|
|
115
|
-
):
|
|
116
|
-
logger.warning(
|
|
117
|
-
f"Failed to load dataset {dataset_config.huggingface_id!r}. Retrying..."
|
|
118
|
-
)
|
|
119
|
-
time.sleep(1)
|
|
120
|
-
continue
|
|
121
|
-
except HfHubHTTPError:
|
|
122
|
-
raise HuggingFaceHubDown()
|
|
146
|
+
|
|
147
|
+
# Case where the dataset source is a dictionary with keys "train", "val" and "test",
|
|
148
|
+
# with the values pointing to local CSV files
|
|
123
149
|
else:
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
150
|
+
data_files = {
|
|
151
|
+
split: dataset_config.source[split]
|
|
152
|
+
for split in dataset_config.splits
|
|
153
|
+
if split in dataset_config.source
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
# Get the file extension and ensure that all files have the same extension
|
|
157
|
+
file_extensions = {
|
|
158
|
+
split: dataset_config.source[split].split(".")[-1]
|
|
159
|
+
for split in dataset_config.splits
|
|
160
|
+
if split in dataset_config.source
|
|
161
|
+
}
|
|
162
|
+
if len(set(file_extensions.values())) != 1:
|
|
163
|
+
raise InvalidBenchmark(
|
|
164
|
+
"All data files in a custom dataset must have the same file extension. "
|
|
165
|
+
f"Got the extensions {', '.join(file_extensions.values())} for the "
|
|
166
|
+
f"dataset {dataset_config.name!r}."
|
|
167
|
+
)
|
|
168
|
+
file_extension = list(file_extensions.values())[0]
|
|
169
|
+
|
|
170
|
+
# Check that the file extension is supported
|
|
171
|
+
if file_extension not in SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS:
|
|
172
|
+
raise InvalidBenchmark(
|
|
173
|
+
"Unsupported file extension for custom dataset. Supported file "
|
|
174
|
+
"extensions are "
|
|
175
|
+
f"{', '.join(SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS)}, but got "
|
|
176
|
+
f"{file_extension!r}."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Load the dataset
|
|
180
|
+
with no_terminal_output():
|
|
181
|
+
dataset = load_dataset(
|
|
182
|
+
path=file_extension, data_files=data_files, cache_dir=cache_dir
|
|
183
|
+
)
|
|
184
|
+
|
|
128
185
|
assert isinstance(dataset, DatasetDict) # type: ignore[used-before-def]
|
|
129
|
-
|
|
130
|
-
missing_keys = [key for key in required_keys if key not in dataset]
|
|
186
|
+
missing_keys = [key for key in dataset_config.splits if key not in dataset]
|
|
131
187
|
if missing_keys:
|
|
132
188
|
raise InvalidBenchmark(
|
|
133
189
|
"The dataset is missing the following required splits: "
|
|
134
190
|
f"{', '.join(missing_keys)}"
|
|
135
191
|
)
|
|
136
|
-
return DatasetDict({key: dataset[key] for key in
|
|
192
|
+
return DatasetDict({key: dataset[key] for key in dataset_config.splits})
|