EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
euroeval/__init__.py CHANGED
@@ -12,14 +12,17 @@ import warnings
12
12
  from termcolor import colored
13
13
 
14
14
  # Block specific warnings before importing anything else, as they can be noisy
15
- warnings.filterwarnings("ignore", category=UserWarning)
16
- logging.getLogger("httpx").setLevel(logging.CRITICAL)
17
- logging.getLogger("datasets").setLevel(logging.CRITICAL)
18
- logging.getLogger("vllm").setLevel(logging.CRITICAL)
19
- os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
15
+ if os.getenv("FULL_LOG") != "1":
16
+ warnings.filterwarnings("ignore", category=UserWarning)
17
+ warnings.filterwarnings("ignore", category=FutureWarning)
18
+ logging.getLogger("httpx").setLevel(logging.CRITICAL)
19
+ logging.getLogger("datasets").setLevel(logging.CRITICAL)
20
+ logging.getLogger("vllm").setLevel(logging.CRITICAL)
21
+ os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
20
22
 
21
23
  # Set up logging
22
- fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
24
+ # fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
25
+ fmt = colored("%(message)s", "light_yellow")
23
26
  logging.basicConfig(
24
27
  level=logging.CRITICAL if hasattr(sys, "_called_from_test") else logging.INFO,
25
28
  format=fmt,
@@ -48,7 +51,13 @@ import importlib.metadata # noqa: E402
48
51
  from dotenv import load_dotenv # noqa: E402
49
52
 
50
53
  from .benchmarker import Benchmarker # noqa: E402
51
- from .utils import block_terminal_output # noqa: E402
54
+ from .data_models import DatasetConfig # noqa: E402
55
+ from .logging_utils import block_terminal_output # noqa: E402
56
+ from .tasks import ( # noqa: E402
57
+ MULTIPLE_CHOICE,
58
+ TEXT_CLASSIFICATION,
59
+ TOKEN_CLASSIFICATION,
60
+ )
52
61
 
53
62
  # Block unwanted terminal outputs. This blocks way more than the above, but since it
54
63
  # relies on importing from the `utils` module, external modules are already imported
@@ -77,15 +86,18 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
77
86
  os.environ["OMP_NUM_THREADS"] = "1"
78
87
 
79
88
 
80
- # Disable a warning from Ray regarding the detection of the number of CPUs
81
- os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
82
-
83
-
84
89
  # Avoid the "Cannot re-initialize CUDA in forked subprocess" error - see
85
90
  # https://github.com/vllm-project/vllm/issues/6152 for more
86
91
  os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
87
92
 
88
93
 
94
+ # Allow long max model length in vLLM. This happens when vLLM registers that the model
95
+ # has a shorter context length than the value we are inserting. But since we do a
96
+ # thorough check of the model's config before setting the context length, we trust our
97
+ # own checks and ignore the internal vLLM check.
98
+ os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
99
+
100
+
89
101
  # Avoid the "Unclosed client session" error when evaluating Ollama models with LiteLLM.
90
102
  # The error comes from the `aiohttp` package, and this environment variable forces the
91
103
  # use of `httpx` instead.
@@ -93,9 +105,15 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
93
105
  os.environ["DISABLE_AIOHTTP_TRANSPORT"] = "True"
94
106
 
95
107
 
96
- # Use older version v0 of vLLM, as the newer one requires XGrammar as decoding backend,
97
- # but XGrammar does not support having a maximal amount of elements in lists
98
- os.environ["VLLM_USE_V1"] = "0"
108
+ # Enable the newer vLLM V1 engine, which is faster and offers more compatibility with
109
+ # newer models
110
+ os.environ["VLLM_USE_V1"] = "1"
111
+
112
+
113
+ # Use the FlashInfer flash-attention backend for vLLM, unless the user has already
114
+ # specified a different backend.
115
+ if os.getenv("VLLM_ATTENTION_BACKEND") is None:
116
+ os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
99
117
 
100
118
 
101
119
  # Set the HF_TOKEN env var to copy the HUGGINGFACE_API_KEY env var, as vLLM uses the
@@ -1,173 +1,82 @@
1
1
  """Factory class for creating dataset configurations."""
2
2
 
3
- import logging
3
+ import collections.abc as c
4
4
  import sys
5
5
  import typing as t
6
6
 
7
7
  import torch
8
8
 
9
- from .data_models import BenchmarkConfig
9
+ from .data_models import BenchmarkConfig, BenchmarkConfigParams, DatasetConfig, Task
10
10
  from .dataset_configs import get_all_dataset_configs
11
11
  from .enums import Device
12
12
  from .exceptions import InvalidBenchmark
13
13
  from .languages import get_all_languages
14
- from .tasks import SPEED, get_all_tasks
15
14
 
16
15
  if t.TYPE_CHECKING:
17
- from .data_models import Language, Task
18
-
19
-
20
- logger = logging.getLogger("euroeval")
16
+ from .data_models import Language
21
17
 
22
18
 
23
19
  def build_benchmark_config(
24
- progress_bar: bool,
25
- save_results: bool,
26
- task: str | list[str] | None,
27
- dataset: str | list[str] | None,
28
- language: str | list[str],
29
- model_language: str | list[str] | None,
30
- dataset_language: str | list[str] | None,
31
- device: Device | None,
32
- batch_size: int,
33
- raise_errors: bool,
34
- cache_dir: str,
35
- api_key: str | None,
36
- force: bool,
37
- verbose: bool,
38
- trust_remote_code: bool,
39
- clear_model_cache: bool,
40
- evaluate_test_split: bool,
41
- few_shot: bool,
42
- num_iterations: int,
43
- api_base: str | None,
44
- api_version: str | None,
45
- gpu_memory_utilization: float,
46
- debug: bool,
47
- run_with_cli: bool,
48
- only_allow_safetensors: bool,
49
- first_time: bool = False,
20
+ benchmark_config_params: BenchmarkConfigParams,
50
21
  ) -> BenchmarkConfig:
51
22
  """Create a benchmark configuration.
52
23
 
53
24
  Args:
54
- progress_bar:
55
- Whether to show a progress bar when running the benchmark.
56
- save_results:
57
- Whether to save the benchmark results to a file.
58
- task:
59
- The tasks to include for dataset. If None then datasets will not be
60
- filtered based on their task.
61
- dataset:
62
- The datasets to include for task. If None then all datasets will be
63
- included, limited by the `task` parameter.
64
- language:
65
- The language codes of the languages to include, both for models and
66
- datasets. Here 'no' means both Bokmål (nb) and Nynorsk (nn). Set this
67
- to 'all' if all languages should be considered.
68
- model_language:
69
- The language codes of the languages to include for models. If None then
70
- the `language` parameter will be used.
71
- dataset_language:
72
- The language codes of the languages to include for datasets. If None then
73
- the `language` parameter will be used.
74
- device:
75
- The device to use for running the models. If None then the device will be
76
- set automatically.
77
- batch_size:
78
- The batch size to use for running the models.
79
- raise_errors:
80
- Whether to raise errors when running the benchmark.
81
- cache_dir:
82
- The directory to use for caching the models.
83
- api_key:
84
- The API key to use for a given inference server.
85
- force:
86
- Whether to force the benchmark to run even if the results are already
87
- cached.
88
- verbose:
89
- Whether to print verbose output when running the benchmark. This is
90
- automatically set if `debug` is True.
91
- trust_remote_code:
92
- Whether to trust remote code when running the benchmark.
93
- clear_model_cache:
94
- Whether to clear the model cache before running the benchmark.
95
- evaluate_test_split:
96
- Whether to use the test split for the datasets.
97
- few_shot:
98
- Whether to use few-shot learning for the models.
99
- num_iterations:
100
- The number of iterations each model should be evaluated for.
101
- api_base:
102
- The base URL for a given inference API. Only relevant if `model` refers to a
103
- model on an inference API.
104
- api_version:
105
- The version of the API to use for a given inference API.
106
- gpu_memory_utilization:
107
- The GPU memory utilization to use for vLLM. A larger value will result in
108
- faster evaluation, but at the risk of running out of GPU memory. Only reduce
109
- this if you are running out of GPU memory. Only relevant if the model is
110
- generative.
111
- debug:
112
- Whether to run the benchmark in debug mode.
113
- run_with_cli:
114
- Whether the benchmark is being run with the CLI.
115
- only_allow_safetensors:
116
- Whether to only allow evaluations of models stored as safetensors.
117
- first_time:
118
- Whether this is the first time the benchmark configuration is being created.
119
- Defaults to False.
25
+ benchmark_config_params:
26
+ The parameters for creating the benchmark configuration.
120
27
 
121
28
  Returns:
122
29
  The benchmark configuration.
123
30
  """
124
- language_codes = get_correct_language_codes(language_codes=language)
125
- model_languages = prepare_languages(
126
- language_codes=model_language, default_language_codes=language_codes
31
+ language_codes = get_correct_language_codes(
32
+ language_codes=benchmark_config_params.language
127
33
  )
128
- dataset_languages = prepare_languages(
129
- language_codes=dataset_language, default_language_codes=language_codes
34
+ languages = prepare_languages(
35
+ language_codes=benchmark_config_params.language,
36
+ default_language_codes=language_codes,
130
37
  )
131
38
 
132
- tasks, datasets = prepare_tasks_and_datasets(
133
- task=task, dataset=dataset, dataset_languages=dataset_languages
39
+ dataset_configs = prepare_dataset_configs(
40
+ task=benchmark_config_params.task,
41
+ dataset=benchmark_config_params.dataset,
42
+ languages=languages,
134
43
  )
135
44
 
136
- torch_device = prepare_device(device=device)
137
-
138
- # Set variable with number of iterations
139
- if hasattr(sys, "_called_from_test"):
140
- num_iterations = 1
141
-
142
45
  return BenchmarkConfig(
143
- model_languages=model_languages,
144
- dataset_languages=dataset_languages,
145
- tasks=tasks,
146
- datasets=datasets,
147
- batch_size=batch_size,
148
- raise_errors=raise_errors,
149
- cache_dir=cache_dir,
150
- api_key=api_key,
151
- force=force,
152
- progress_bar=progress_bar,
153
- save_results=save_results,
154
- verbose=verbose or debug,
155
- device=torch_device,
156
- trust_remote_code=trust_remote_code,
157
- clear_model_cache=clear_model_cache,
158
- evaluate_test_split=evaluate_test_split,
159
- few_shot=few_shot,
160
- num_iterations=num_iterations,
161
- api_base=api_base,
162
- api_version=api_version,
163
- gpu_memory_utilization=gpu_memory_utilization,
164
- debug=debug,
165
- run_with_cli=run_with_cli,
166
- only_allow_safetensors=only_allow_safetensors,
46
+ datasets=dataset_configs,
47
+ languages=languages,
48
+ finetuning_batch_size=benchmark_config_params.finetuning_batch_size,
49
+ raise_errors=benchmark_config_params.raise_errors,
50
+ cache_dir=benchmark_config_params.cache_dir,
51
+ api_key=benchmark_config_params.api_key,
52
+ force=benchmark_config_params.force,
53
+ progress_bar=benchmark_config_params.progress_bar,
54
+ save_results=benchmark_config_params.save_results,
55
+ verbose=benchmark_config_params.verbose or benchmark_config_params.debug,
56
+ device=prepare_device(device=benchmark_config_params.device),
57
+ trust_remote_code=benchmark_config_params.trust_remote_code,
58
+ clear_model_cache=benchmark_config_params.clear_model_cache,
59
+ evaluate_test_split=benchmark_config_params.evaluate_test_split,
60
+ few_shot=benchmark_config_params.few_shot,
61
+ num_iterations=(
62
+ 1
63
+ if hasattr(sys, "_called_from_test")
64
+ else benchmark_config_params.num_iterations
65
+ ),
66
+ api_base=benchmark_config_params.api_base,
67
+ api_version=benchmark_config_params.api_version,
68
+ gpu_memory_utilization=benchmark_config_params.gpu_memory_utilization,
69
+ generative_type=benchmark_config_params.generative_type,
70
+ debug=benchmark_config_params.debug,
71
+ run_with_cli=benchmark_config_params.run_with_cli,
72
+ requires_safetensors=benchmark_config_params.requires_safetensors,
73
+ download_only=benchmark_config_params.download_only,
167
74
  )
168
75
 
169
76
 
170
- def get_correct_language_codes(language_codes: str | list[str]) -> list[str]:
77
+ def get_correct_language_codes(
78
+ language_codes: str | c.Sequence[str],
79
+ ) -> c.Sequence[str]:
171
80
  """Get correct language code(s).
172
81
 
173
82
  Args:
@@ -188,7 +97,7 @@ def get_correct_language_codes(language_codes: str | list[str]) -> list[str]:
188
97
  elif isinstance(language_codes, str):
189
98
  languages = [language_codes]
190
99
  else:
191
- languages = language_codes
100
+ languages = list(language_codes)
192
101
 
193
102
  # If `languages` contains 'no' then also include 'nb' and 'nn'. Conversely, if
194
103
  # either 'nb' or 'nn' are specified then also include 'no'.
@@ -201,8 +110,9 @@ def get_correct_language_codes(language_codes: str | list[str]) -> list[str]:
201
110
 
202
111
 
203
112
  def prepare_languages(
204
- language_codes: str | list[str] | None, default_language_codes: list[str]
205
- ) -> list["Language"]:
113
+ language_codes: str | c.Sequence[str] | None,
114
+ default_language_codes: c.Sequence[str],
115
+ ) -> c.Sequence["Language"]:
206
116
  """Prepare language(s) for benchmarking.
207
117
 
208
118
  Args:
@@ -220,7 +130,7 @@ def prepare_languages(
220
130
  language_mapping = get_all_languages()
221
131
 
222
132
  # Create the list `languages_str` of language codes to use for models or datasets
223
- languages_str: list[str]
133
+ languages_str: c.Sequence[str]
224
134
  if language_codes is None:
225
135
  languages_str = default_language_codes
226
136
  elif isinstance(language_codes, str):
@@ -237,74 +147,76 @@ def prepare_languages(
237
147
  return prepared_languages
238
148
 
239
149
 
240
- def prepare_tasks_and_datasets(
241
- task: str | list[str] | None,
242
- dataset_languages: list["Language"],
243
- dataset: str | list[str] | None,
244
- ) -> tuple[list["Task"], list[str]]:
245
- """Prepare task(s) and dataset(s) for benchmarking.
150
+ def prepare_dataset_configs(
151
+ task: "str | Task | c.Sequence[str | Task] | None",
152
+ languages: c.Sequence["Language"],
153
+ dataset: "str | DatasetConfig | c.Sequence[str | DatasetConfig] | None",
154
+ ) -> c.Sequence["DatasetConfig"]:
155
+ """Prepare dataset config(s) for benchmarking.
246
156
 
247
157
  Args:
248
158
  task:
249
159
  The tasks to include for dataset. If None then datasets will not be
250
160
  filtered based on their task.
251
- dataset_languages:
161
+ languages:
252
162
  The languages of the datasets in the benchmark.
253
163
  dataset:
254
164
  The datasets to include for task. If None then all datasets will be
255
- included, limited by the `task` and `dataset_languages` parameters.
165
+ included, limited by the `task` and `languages` parameters.
256
166
 
257
167
  Returns:
258
- The prepared tasks and datasets.
168
+ The prepared dataset configs.
259
169
 
260
170
  Raises:
261
171
  InvalidBenchmark:
262
172
  If the task or dataset is not found in the benchmark tasks or datasets.
263
173
  """
264
- # Create a dictionary that maps benchmark tasks to their associated benchmark
265
- # task objects, and a dictionary that maps dataset names to their associated
266
- # dataset configuration objects
267
- task_mapping = get_all_tasks()
174
+ # Create the list of dataset configs
268
175
  all_dataset_configs = get_all_dataset_configs()
176
+ all_official_dataset_configs: c.Sequence[DatasetConfig] = [
177
+ dataset_config
178
+ for dataset_config in all_dataset_configs.values()
179
+ if not dataset_config.unofficial
180
+ ]
181
+ try:
182
+ if dataset is None:
183
+ datasets = all_official_dataset_configs
184
+ elif isinstance(dataset, str):
185
+ datasets = [all_dataset_configs[dataset]]
186
+ elif isinstance(dataset, DatasetConfig):
187
+ datasets = [dataset]
188
+ else:
189
+ datasets = [
190
+ all_dataset_configs[d] if isinstance(d, str) else d for d in dataset
191
+ ]
192
+ except KeyError as e:
193
+ raise InvalidBenchmark(
194
+ f"Dataset {e} not found in the benchmark datasets."
195
+ ) from e
269
196
 
270
197
  # Create the list of dataset tasks
198
+ task_mapping = {cfg.task.name: cfg.task for cfg in all_dataset_configs.values()}
271
199
  try:
272
200
  if task is None:
273
- tasks = [t for t in task_mapping.values() if t != SPEED]
201
+ tasks = None
274
202
  elif isinstance(task, str):
275
203
  tasks = [task_mapping[task]]
204
+ elif isinstance(task, Task):
205
+ tasks = [task]
276
206
  else:
277
- tasks = [task_mapping[t] for t in task]
207
+ tasks = [task_mapping[t] if isinstance(t, str) else t for t in task]
278
208
  except KeyError as e:
279
209
  raise InvalidBenchmark(f"Task {e} not found in the benchmark tasks.") from e
280
210
 
281
- all_official_datasets = [
282
- dataset_name
283
- for dataset_name, dataset_config in all_dataset_configs.items()
284
- if not dataset_config.unofficial
285
- ]
286
- if dataset is None:
287
- dataset = all_official_datasets
288
- elif isinstance(dataset, str):
289
- dataset = [dataset]
290
-
291
- all_datasets = list(all_dataset_configs.keys())
292
- invalid_datasets = set(dataset) - set(all_datasets)
293
- if invalid_datasets:
294
- raise InvalidBenchmark(
295
- f"Dataset(s) {', '.join(invalid_datasets)} not found in the benchmark "
296
- "datasets."
297
- )
298
-
211
+ # Filter the dataset configs based on the specified tasks and languages
299
212
  datasets = [
300
- dataset_name
301
- for dataset_name, dataset_config in all_dataset_configs.items()
302
- if dataset_name in dataset
303
- and dataset_config.task in tasks
304
- and set(dataset_config.languages).intersection(dataset_languages)
213
+ ds
214
+ for ds in datasets
215
+ if (tasks is None or ds.task in tasks)
216
+ and any(lang in languages for lang in ds.languages)
305
217
  ]
306
218
 
307
- return tasks, datasets
219
+ return datasets
308
220
 
309
221
 
310
222
  def prepare_device(device: Device | None) -> torch.device:
@@ -2,24 +2,23 @@
2
2
 
3
3
  import collections.abc as c
4
4
  import logging
5
- import sys
5
+ import re
6
6
  import typing as t
7
7
  from abc import ABC, abstractmethod
8
8
  from functools import cached_property, partial
9
9
 
10
- from datasets import DatasetDict
10
+ from datasets import Dataset, DatasetDict
11
11
  from torch import nn
12
- from tqdm.auto import tqdm
13
12
 
14
13
  from ..enums import TaskGroup
15
- from ..exceptions import NeedsEnvironmentVariable, NeedsExtraInstalled
14
+ from ..exceptions import InvalidBenchmark, NeedsEnvironmentVariable, NeedsExtraInstalled
15
+ from ..logging_utils import get_pbar, log_once
16
16
  from ..task_group_utils import (
17
17
  question_answering,
18
18
  sequence_classification,
19
19
  text_to_text,
20
20
  token_classification,
21
21
  )
22
- from ..utils import log_once
23
22
 
24
23
  if t.TYPE_CHECKING:
25
24
  from transformers.tokenization_utils import PreTrainedTokenizer
@@ -35,8 +34,6 @@ if t.TYPE_CHECKING:
35
34
  from ..enums import BatchingPreference, GenerativeType
36
35
  from ..types import ComputeMetricsFunction, ExtractLabelsFunction
37
36
 
38
- logger = logging.getLogger("euroeval")
39
-
40
37
 
41
38
  class BenchmarkModule(ABC):
42
39
  """Abstract class for a benchmark module.
@@ -55,12 +52,14 @@ class BenchmarkModule(ABC):
55
52
  fresh_model: bool
56
53
  batching_preference: "BatchingPreference"
57
54
  high_priority: bool
55
+ allowed_params: dict[re.Pattern, c.Sequence[str]] = {re.compile(r".*"): []}
58
56
 
59
57
  def __init__(
60
58
  self,
61
59
  model_config: "ModelConfig",
62
60
  dataset_config: "DatasetConfig",
63
61
  benchmark_config: "BenchmarkConfig",
62
+ log_metadata: bool = True,
64
63
  ) -> None:
65
64
  """Initialise the benchmark module.
66
65
 
@@ -71,29 +70,25 @@ class BenchmarkModule(ABC):
71
70
  The dataset configuration.
72
71
  benchmark_config:
73
72
  The benchmark configuration.
73
+ log_metadata:
74
+ Whether to log the metadata of the model.
74
75
  """
75
76
  self.model_config = model_config
76
77
  self.dataset_config = dataset_config
77
78
  self.benchmark_config = benchmark_config
79
+ self.log_metadata = log_metadata
78
80
  self.buffer: dict[str, t.Any] = dict()
79
- self._log_metadata()
81
+ if self.log_metadata:
82
+ self._log_metadata()
80
83
 
81
84
  def _log_metadata(self) -> None:
82
85
  """Log the metadata of the model."""
83
- # Set logging level based on verbosity
84
- if hasattr(sys, "_called_from_test"):
85
- logging_level = logging.CRITICAL
86
- elif self.benchmark_config.verbose:
87
- logging_level = logging.DEBUG
88
- else:
89
- logging_level = logging.INFO
90
- logger.setLevel(logging_level)
91
-
92
- logging_msg: str = ""
86
+ model_id = self.model_config.model_id
87
+ logging_msg: str = ""
93
88
  if self.num_params < 0:
94
- logging_msg += "The model has an unknown number of parameters, "
89
+ logging_msg += f"The model {model_id} has an unknown number of parameters, "
95
90
  else:
96
- logging_msg += f"The model has {self.num_params:,} parameters, "
91
+ logging_msg += f"The model {model_id} has {self.num_params:,} parameters, "
97
92
  if self.vocab_size < 0:
98
93
  logging_msg += "an unknown vocabulary size, "
99
94
  else:
@@ -117,16 +112,16 @@ class BenchmarkModule(ABC):
117
112
  f"{self.__class__.__name__}."
118
113
  )
119
114
 
120
- def get_tokenizer(self) -> "PreTrainedTokenizer":
121
- """Get the underlying tokenizer.
115
+ def get_tokeniser(self) -> "PreTrainedTokenizer":
116
+ """Get the underlying tokeniser.
122
117
 
123
118
  Returns:
124
- The tokenizer.
119
+ The tokeniser.
125
120
  """
126
- if hasattr(self, "_tokenizer"):
127
- return self._tokenizer
121
+ if hasattr(self, "_tokeniser"):
122
+ return self._tokeniser
128
123
  raise NotImplementedError(
129
- "The `get_tokenizer` method has not been implemented for "
124
+ "The `get_tokeniser` method has not been implemented for "
130
125
  f"{self.__class__.__name__}."
131
126
  )
132
127
 
@@ -172,7 +167,7 @@ class BenchmarkModule(ABC):
172
167
 
173
168
  @property
174
169
  @abstractmethod
175
- def data_collator(self) -> c.Callable[[list[t.Any]], dict[str, t.Any]]:
170
+ def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
176
171
  """The data collator used to prepare samples during finetuning.
177
172
 
178
173
  Returns:
@@ -192,11 +187,13 @@ class BenchmarkModule(ABC):
192
187
  return partial(
193
188
  sequence_classification.compute_metrics,
194
189
  dataset_config=self.dataset_config,
190
+ benchmark_config=self.benchmark_config,
195
191
  )
196
192
  case TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION:
197
193
  return partial(
198
194
  sequence_classification.compute_metrics,
199
195
  dataset_config=self.dataset_config,
196
+ benchmark_config=self.benchmark_config,
200
197
  )
201
198
  case TaskGroup.TEXT_TO_TEXT:
202
199
  return partial(
@@ -209,11 +206,13 @@ class BenchmarkModule(ABC):
209
206
  token_classification.compute_metrics,
210
207
  has_misc_tags=self.buffer.get("has_misc_tags", True),
211
208
  dataset_config=self.dataset_config,
209
+ benchmark_config=self.benchmark_config,
212
210
  )
213
211
  case TaskGroup.QUESTION_ANSWERING:
214
212
  return partial(
215
213
  question_answering.compute_metrics,
216
214
  dataset_config=self.dataset_config,
215
+ benchmark_config=self.benchmark_config,
217
216
  )
218
217
  case _:
219
218
  raise NotImplementedError(
@@ -242,7 +241,7 @@ class BenchmarkModule(ABC):
242
241
 
243
242
  def prepare_datasets(
244
243
  self, datasets: list[DatasetDict], task: "Task"
245
- ) -> list[DatasetDict]:
244
+ ) -> c.Sequence[DatasetDict]:
246
245
  """Prepare the datasets for the model.
247
246
 
248
247
  This includes things like tokenisation.
@@ -255,30 +254,41 @@ class BenchmarkModule(ABC):
255
254
 
256
255
  Returns:
257
256
  The prepared datasets.
257
+
258
+ Raises:
259
+ InvalidBenchmark:
260
+ If the dataset does not have a 'train' split for token classification
261
+ tasks.
258
262
  """
259
263
  for idx, dataset in enumerate(
260
- tqdm(iterable=datasets, desc="Preparing datasets")
264
+ get_pbar(
265
+ iterable=datasets,
266
+ desc="Preparing datasets",
267
+ disable=not self.benchmark_config.progress_bar,
268
+ )
261
269
  ):
262
270
  prepared_dataset = self.prepare_dataset(
263
271
  dataset=dataset, task=task, itr_idx=idx
264
272
  )
265
273
  if self.dataset_config.task.task_group == TaskGroup.TOKEN_CLASSIFICATION:
274
+ if "train" not in dataset:
275
+ raise InvalidBenchmark(
276
+ "The dataset does not have a 'train' split, which is required "
277
+ "for token classification tasks."
278
+ )
266
279
  labels_in_train: set[str] = {
267
280
  tag for tag_list in dataset["train"]["labels"] for tag in tag_list
268
281
  }
269
282
  self.buffer["has_misc_tags"] = (
270
283
  "B-MISC" in labels_in_train or "I-MISC" in labels_in_train
271
284
  )
272
- datasets[idx] = DatasetDict(
273
- dict(
274
- train=prepared_dataset["train"],
275
- val=prepared_dataset["val"],
276
- test=prepared_dataset["test"],
277
- original_train=dataset["train"],
278
- original_val=dataset["val"],
279
- original_test=dataset["test"],
280
- )
281
- )
285
+
286
+ datasets_dict: dict[str, Dataset] = dict()
287
+ for split_name, split in prepared_dataset.items():
288
+ datasets_dict[split_name] = split
289
+ for split_name, split in dataset.items():
290
+ datasets_dict[f"original_{split_name}"] = split
291
+ datasets[idx] = DatasetDict(datasets_dict)
282
292
  return datasets
283
293
 
284
294
  @abstractmethod