ScandEval 16.12.0__py3-none-any.whl → 16.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. scandeval/async_utils.py +46 -0
  2. scandeval/benchmark_config_factory.py +26 -2
  3. scandeval/benchmark_modules/fresh.py +2 -1
  4. scandeval/benchmark_modules/hf.py +50 -12
  5. scandeval/benchmark_modules/litellm.py +25 -15
  6. scandeval/benchmark_modules/vllm.py +3 -3
  7. scandeval/benchmarker.py +15 -33
  8. scandeval/cli.py +2 -4
  9. scandeval/constants.py +5 -0
  10. scandeval/custom_dataset_configs.py +152 -0
  11. scandeval/data_loading.py +87 -31
  12. scandeval/data_models.py +396 -225
  13. scandeval/dataset_configs/__init__.py +51 -25
  14. scandeval/dataset_configs/albanian.py +1 -1
  15. scandeval/dataset_configs/belarusian.py +47 -0
  16. scandeval/dataset_configs/bulgarian.py +1 -1
  17. scandeval/dataset_configs/catalan.py +1 -1
  18. scandeval/dataset_configs/croatian.py +1 -1
  19. scandeval/dataset_configs/danish.py +3 -2
  20. scandeval/dataset_configs/dutch.py +7 -6
  21. scandeval/dataset_configs/english.py +4 -3
  22. scandeval/dataset_configs/estonian.py +8 -7
  23. scandeval/dataset_configs/faroese.py +1 -1
  24. scandeval/dataset_configs/finnish.py +5 -4
  25. scandeval/dataset_configs/french.py +6 -5
  26. scandeval/dataset_configs/german.py +4 -3
  27. scandeval/dataset_configs/greek.py +1 -1
  28. scandeval/dataset_configs/hungarian.py +1 -1
  29. scandeval/dataset_configs/icelandic.py +4 -3
  30. scandeval/dataset_configs/italian.py +4 -3
  31. scandeval/dataset_configs/latvian.py +2 -2
  32. scandeval/dataset_configs/lithuanian.py +1 -1
  33. scandeval/dataset_configs/norwegian.py +6 -5
  34. scandeval/dataset_configs/polish.py +4 -3
  35. scandeval/dataset_configs/portuguese.py +5 -4
  36. scandeval/dataset_configs/romanian.py +2 -2
  37. scandeval/dataset_configs/serbian.py +1 -1
  38. scandeval/dataset_configs/slovene.py +1 -1
  39. scandeval/dataset_configs/spanish.py +4 -3
  40. scandeval/dataset_configs/swedish.py +4 -3
  41. scandeval/dataset_configs/ukrainian.py +1 -1
  42. scandeval/generation_utils.py +6 -6
  43. scandeval/metrics/llm_as_a_judge.py +1 -1
  44. scandeval/metrics/pipeline.py +1 -1
  45. scandeval/model_cache.py +34 -4
  46. scandeval/prompt_templates/linguistic_acceptability.py +9 -0
  47. scandeval/prompt_templates/multiple_choice.py +9 -0
  48. scandeval/prompt_templates/named_entity_recognition.py +21 -0
  49. scandeval/prompt_templates/reading_comprehension.py +10 -0
  50. scandeval/prompt_templates/sentiment_classification.py +11 -0
  51. scandeval/string_utils.py +157 -0
  52. scandeval/task_group_utils/sequence_classification.py +2 -5
  53. scandeval/task_group_utils/token_classification.py +2 -4
  54. scandeval/utils.py +6 -323
  55. scandeval-16.13.0.dist-info/METADATA +334 -0
  56. scandeval-16.13.0.dist-info/RECORD +94 -0
  57. scandeval-16.12.0.dist-info/METADATA +0 -667
  58. scandeval-16.12.0.dist-info/RECORD +0 -90
  59. {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/WHEEL +0 -0
  60. {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/entry_points.txt +0 -0
  61. {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,46 @@
1
+ """Utility functions for asyncronous tasks."""
2
+
3
+ import asyncio
4
+ import typing as t
5
+
6
+ from .constants import T
7
+
8
+
9
+ def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
10
+ """Run a coroutine, ensuring that the event loop is always closed when we're done.
11
+
12
+ Args:
13
+ coroutine:
14
+ The coroutine to run.
15
+
16
+ Returns:
17
+ The result of the coroutine.
18
+ """
19
+ try:
20
+ loop = asyncio.get_event_loop()
21
+ except RuntimeError: # If the current event loop is closed
22
+ loop = asyncio.new_event_loop()
23
+ asyncio.set_event_loop(loop)
24
+ response = loop.run_until_complete(coroutine)
25
+ return response
26
+
27
+
28
+ async def add_semaphore_and_catch_exception(
29
+ coroutine: t.Coroutine[t.Any, t.Any, T], semaphore: asyncio.Semaphore
30
+ ) -> T | Exception:
31
+ """Run a coroutine with a semaphore.
32
+
33
+ Args:
34
+ coroutine:
35
+ The coroutine to run.
36
+ semaphore:
37
+ The semaphore to use.
38
+
39
+ Returns:
40
+ The result of the coroutine.
41
+ """
42
+ async with semaphore:
43
+ try:
44
+ return await coroutine
45
+ except Exception as exc:
46
+ return exc
@@ -46,6 +46,8 @@ def build_benchmark_config(
46
46
  dataset=benchmark_config_params.dataset,
47
47
  languages=languages,
48
48
  custom_datasets_file=benchmark_config_params.custom_datasets_file,
49
+ api_key=benchmark_config_params.api_key,
50
+ cache_dir=Path(benchmark_config_params.cache_dir),
49
51
  )
50
52
 
51
53
  return BenchmarkConfig(
@@ -159,7 +161,9 @@ def prepare_dataset_configs(
159
161
  languages: c.Sequence["Language"],
160
162
  dataset: "str | DatasetConfig | c.Sequence[str | DatasetConfig] | None",
161
163
  custom_datasets_file: Path,
162
- ) -> c.Sequence["DatasetConfig"]:
164
+ api_key: str | None,
165
+ cache_dir: Path,
166
+ ) -> list["DatasetConfig"]:
163
167
  """Prepare dataset config(s) for benchmarking.
164
168
 
165
169
  Args:
@@ -173,6 +177,10 @@ def prepare_dataset_configs(
173
177
  included, limited by the `task` and `languages` parameters.
174
178
  custom_datasets_file:
175
179
  A path to a Python file containing custom dataset configurations.
180
+ api_key:
181
+ The API key to use for accessing the Hugging Face Hub.
182
+ cache_dir:
183
+ The directory to store the cache in.
176
184
 
177
185
  Returns:
178
186
  The prepared dataset configs.
@@ -181,9 +189,25 @@ def prepare_dataset_configs(
181
189
  InvalidBenchmark:
182
190
  If the task or dataset is not found in the benchmark tasks or datasets.
183
191
  """
192
+ # Extract the dataset IDs from the `dataset` argument
193
+ dataset_ids: list[str] = list()
194
+ if isinstance(dataset, str):
195
+ dataset_ids.append(dataset)
196
+ elif isinstance(dataset, DatasetConfig):
197
+ dataset_ids.append(dataset.name)
198
+ elif isinstance(dataset, list):
199
+ for d in dataset:
200
+ if isinstance(d, str):
201
+ dataset_ids.append(d)
202
+ elif isinstance(d, DatasetConfig):
203
+ dataset_ids.append(d.name)
204
+
184
205
  # Create the list of dataset configs
185
206
  all_dataset_configs = get_all_dataset_configs(
186
- custom_datasets_file=custom_datasets_file
207
+ custom_datasets_file=custom_datasets_file,
208
+ dataset_ids=dataset_ids,
209
+ api_key=api_key,
210
+ cache_dir=cache_dir,
187
211
  )
188
212
  all_official_dataset_configs: c.Sequence[DatasetConfig] = [
189
213
  dataset_config
@@ -28,8 +28,9 @@ from ..exceptions import (
28
28
  )
29
29
  from ..generation_utils import raise_if_wrong_params
30
30
  from ..logging_utils import block_terminal_output
31
+ from ..model_cache import create_model_cache_dir
31
32
  from ..types import Tokeniser
32
- from ..utils import create_model_cache_dir, get_hf_token
33
+ from ..utils import get_hf_token
33
34
  from .hf import (
34
35
  HuggingFaceEncoderModel,
35
36
  align_model_and_tokeniser,
@@ -1,6 +1,7 @@
1
1
  """Encoder models from the Hugging Face Hub."""
2
2
 
3
3
  import collections.abc as c
4
+ import importlib
4
5
  import logging
5
6
  import re
6
7
  import typing as t
@@ -63,6 +64,8 @@ from ..exceptions import (
63
64
  from ..generation_utils import raise_if_wrong_params
64
65
  from ..languages import get_all_languages
65
66
  from ..logging_utils import block_terminal_output, log, log_once
67
+ from ..model_cache import create_model_cache_dir
68
+ from ..string_utils import split_model_id
66
69
  from ..task_group_utils import (
67
70
  multiple_choice_classification,
68
71
  question_answering,
@@ -70,13 +73,7 @@ from ..task_group_utils import (
70
73
  )
71
74
  from ..tokenisation_utils import get_bos_token, get_eos_token
72
75
  from ..types import Tokeniser
73
- from ..utils import (
74
- create_model_cache_dir,
75
- get_class_by_name,
76
- get_hf_token,
77
- internet_connection_available,
78
- split_model_id,
79
- )
76
+ from ..utils import get_hf_token, internet_connection_available
80
77
  from .base import BenchmarkModule
81
78
 
82
79
  try:
@@ -381,7 +378,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
381
378
  if "label" in examples:
382
379
  try:
383
380
  examples["label"] = [
384
- self._model.config.label2id[lbl.lower()]
381
+ self._model.config.label2id[str(lbl).lower()]
385
382
  if self._model.config.label2id is not None
386
383
  else lbl
387
384
  for lbl in examples["label"]
@@ -817,8 +814,8 @@ def get_model_repo_info(
817
814
  log(
818
815
  f"Could not access the model {model_id} with the revision "
819
816
  f"{revision}. The error was {str(e)!r}. Please set the "
820
- "`HUGGINGFACE_API_KEY` environment variable or use the "
821
- "`--api-key` argument.",
817
+ "`HUGGINGFACE_API_KEY` or `HF_TOKEN` environment variable or "
818
+ "use the `--api-key` argument.",
822
819
  level=logging.DEBUG,
823
820
  )
824
821
  return None
@@ -1095,8 +1092,8 @@ def load_hf_model_config(
1095
1092
  f"The model {model_id!r} is a gated repository. Please ensure "
1096
1093
  "that you are logged in with `hf auth login` or have provided a "
1097
1094
  "valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
1098
- "environment variable or the `--api-key` argument. Also check that "
1099
- "your account has access to this model."
1095
+ "or `HF_TOKEN` environment variable or the `--api-key` argument. "
1096
+ "Also check that your account has access to this model."
1100
1097
  ) from e
1101
1098
  raise InvalidModel(
1102
1099
  f"Couldn't load model config for {model_id!r}. The error was "
@@ -1334,3 +1331,44 @@ def task_group_to_class_name(task_group: TaskGroup) -> str:
1334
1331
  )
1335
1332
  pascal_case = special_case_mapping.get(pascal_case, pascal_case)
1336
1333
  return f"AutoModelFor{pascal_case}"
1334
+
1335
+
1336
+ def get_class_by_name(
1337
+ class_name: str | c.Sequence[str], module_name: str
1338
+ ) -> t.Type | None:
1339
+ """Get a class by its name.
1340
+
1341
+ Args:
1342
+ class_name:
1343
+ The name of the class, written in kebab-case. The corresponding class name
1344
+ must be the same, but written in PascalCase, and lying in a module with the
1345
+ same name, but written in snake_case. If a list of strings is passed, the
1346
+ first class that is found is returned.
1347
+ module_name:
1348
+ The name of the module where the class is located.
1349
+
1350
+ Returns:
1351
+ The class. If the class is not found, None is returned.
1352
+ """
1353
+ if isinstance(class_name, str):
1354
+ class_name = [class_name]
1355
+
1356
+ error_messages = list()
1357
+ for name in class_name:
1358
+ try:
1359
+ module = importlib.import_module(name=module_name)
1360
+ class_: t.Type = getattr(module, name)
1361
+ return class_
1362
+ except (ModuleNotFoundError, AttributeError) as e:
1363
+ error_messages.append(str(e))
1364
+
1365
+ if error_messages:
1366
+ errors = "\n- " + "\n- ".join(error_messages)
1367
+ log(
1368
+ f"Could not find the class with the name(s) {', '.join(class_name)}. The "
1369
+ f"following error messages were raised: {errors}",
1370
+ level=logging.DEBUG,
1371
+ )
1372
+
1373
+ # If the class could not be found, return None
1374
+ return None
@@ -40,7 +40,7 @@ from pydantic import ValidationError, conlist, create_model
40
40
  from requests.exceptions import RequestException
41
41
  from tqdm.asyncio import tqdm as tqdm_async
42
42
 
43
- from ..caching_utils import cache_arguments
43
+ from ..async_utils import add_semaphore_and_catch_exception, safe_run
44
44
  from ..constants import (
45
45
  JSON_STRIP_CHARACTERS,
46
46
  LITELLM_CLASSIFICATION_OUTPUT_KEY,
@@ -74,6 +74,8 @@ from ..generation_utils import (
74
74
  raise_if_wrong_params,
75
75
  )
76
76
  from ..logging_utils import get_pbar, log, log_once
77
+ from ..model_cache import create_model_cache_dir
78
+ from ..string_utils import split_model_id
77
79
  from ..task_group_utils import (
78
80
  question_answering,
79
81
  sequence_classification,
@@ -83,13 +85,7 @@ from ..task_group_utils import (
83
85
  from ..tasks import NER
84
86
  from ..tokenisation_utils import get_first_label_token_mapping
85
87
  from ..types import ExtractLabelsFunction
86
- from ..utils import (
87
- add_semaphore_and_catch_exception,
88
- create_model_cache_dir,
89
- get_hf_token,
90
- safe_run,
91
- split_model_id,
92
- )
88
+ from ..utils import get_hf_token
93
89
  from .base import BenchmarkModule
94
90
  from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokeniser
95
91
 
@@ -700,10 +696,10 @@ class LiteLLMModel(BenchmarkModule):
700
696
  elif isinstance(
701
697
  error, (Timeout, ServiceUnavailableError, InternalServerError, SystemError)
702
698
  ):
703
- log_once(
704
- f"Service temporarily unavailable. The error message was: {error}. "
705
- "Retrying in 10 seconds...",
706
- level=logging.DEBUG,
699
+ log(
700
+ "Service temporarily unavailable during generation. The error "
701
+ f"message was: {error}. Retrying in 10 seconds...",
702
+ level=logging.INFO,
707
703
  )
708
704
  return generation_kwargs, 10
709
705
  elif isinstance(error, UnsupportedParamsError):
@@ -764,6 +760,20 @@ class LiteLLMModel(BenchmarkModule):
764
760
  run_with_cli=self.benchmark_config.run_with_cli,
765
761
  ) from error
766
762
 
763
+ if (
764
+ isinstance(error, (BadRequestError, NotFoundError))
765
+ and self.benchmark_config.api_base is not None
766
+ and not self.benchmark_config.api_base.endswith("/v1")
767
+ ):
768
+ log_once(
769
+ f"The API base {self.benchmark_config.api_base!r} is not valid. We "
770
+ "will try appending '/v1' to it and try again.",
771
+ level=logging.DEBUG,
772
+ )
773
+ self.benchmark_config.api_base += "/v1"
774
+ generation_kwargs["api_base"] = self.benchmark_config.api_base
775
+ return generation_kwargs, 0
776
+
767
777
  raise InvalidBenchmark(
768
778
  f"Failed to generate text. The error message was: {error}"
769
779
  ) from error
@@ -1390,9 +1400,10 @@ class LiteLLMModel(BenchmarkModule):
1390
1400
  InternalServerError,
1391
1401
  ) as e:
1392
1402
  log(
1393
- f"Service temporarily unavailable. The error message was: {e}. "
1403
+ "Service temporarily unavailable while checking for model "
1404
+ f"existence of the model {model_id!r}. The error message was: {e}. "
1394
1405
  "Retrying in 10 seconds...",
1395
- level=logging.DEBUG,
1406
+ level=logging.INFO,
1396
1407
  )
1397
1408
  sleep(10)
1398
1409
  except APIError as e:
@@ -1567,7 +1578,6 @@ class LiteLLMModel(BenchmarkModule):
1567
1578
 
1568
1579
  return dataset
1569
1580
 
1570
- @cache_arguments()
1571
1581
  def get_generation_kwargs(self, dataset_config: DatasetConfig) -> dict[str, t.Any]:
1572
1582
  """Get the generation arguments for the model.
1573
1583
 
@@ -54,6 +54,8 @@ from ..generation_utils import (
54
54
  )
55
55
  from ..languages import get_all_languages
56
56
  from ..logging_utils import get_pbar, log, log_once, no_terminal_output
57
+ from ..model_cache import create_model_cache_dir
58
+ from ..string_utils import split_model_id
57
59
  from ..task_group_utils import (
58
60
  question_answering,
59
61
  sequence_classification,
@@ -73,12 +75,10 @@ from ..tokenisation_utils import (
73
75
  from ..types import ExtractLabelsFunction, Tokeniser
74
76
  from ..utils import (
75
77
  clear_memory,
76
- create_model_cache_dir,
77
78
  get_hf_token,
78
79
  get_min_cuda_compute_capability,
79
80
  internet_connection_available,
80
81
  resolve_model_path,
81
- split_model_id,
82
82
  )
83
83
  from .hf import HuggingFaceEncoderModel, get_model_repo_info, load_hf_model_config
84
84
 
@@ -1144,7 +1144,7 @@ def load_model_and_tokeniser(
1144
1144
  pipeline_parallel_size=pipeline_parallel_size,
1145
1145
  disable_custom_all_reduce=True,
1146
1146
  quantization=quantization,
1147
- dtype=dtype,
1147
+ dtype=dtype, # pyrefly: ignore[bad-argument-type]
1148
1148
  enforce_eager=True,
1149
1149
  # TEMP: Prefix caching isn't supported with sliding window in vLLM yet,
1150
1150
  # so we disable it for now
scandeval/benchmarker.py CHANGED
@@ -18,7 +18,6 @@ from .benchmark_config_factory import build_benchmark_config
18
18
  from .constants import ATTENTION_BACKENDS, GENERATIVE_PIPELINE_TAGS
19
19
  from .data_loading import load_data, load_raw_data
20
20
  from .data_models import BenchmarkConfigParams, BenchmarkResult
21
- from .dataset_configs import get_all_dataset_configs
22
21
  from .enums import Device, GenerativeType, ModelType
23
22
  from .exceptions import HuggingFaceHubDown, InvalidBenchmark, InvalidModel
24
23
  from .finetuning import finetune
@@ -28,12 +27,9 @@ from .model_config import get_model_config
28
27
  from .model_loading import load_model
29
28
  from .scores import log_scores
30
29
  from .speed_benchmark import benchmark_speed
30
+ from .string_utils import split_model_id
31
31
  from .tasks import SPEED
32
- from .utils import (
33
- enforce_reproducibility,
34
- internet_connection_available,
35
- split_model_id,
36
- )
32
+ from .utils import enforce_reproducibility, internet_connection_available
37
33
 
38
34
  if t.TYPE_CHECKING:
39
35
  from .benchmark_modules import BenchmarkModule
@@ -79,7 +75,9 @@ class Benchmarker:
79
75
  api_base: str | None = None,
80
76
  api_version: str | None = None,
81
77
  gpu_memory_utilization: float = 0.8,
82
- attention_backend: str = "FLASHINFER",
78
+ attention_backend: t.Literal[
79
+ *ATTENTION_BACKENDS # pyrefly: ignore[invalid-literal]
80
+ ] = "FLASHINFER",
83
81
  generative_type: GenerativeType | None = None,
84
82
  custom_datasets_file: Path | str = Path("custom_datasets.py"),
85
83
  debug: bool = False,
@@ -346,7 +344,9 @@ class Benchmarker:
346
344
  f"Loading data for {dataset_config.logging_string}", level=logging.INFO
347
345
  )
348
346
  dataset = load_raw_data(
349
- dataset_config=dataset_config, cache_dir=benchmark_config.cache_dir
347
+ dataset_config=dataset_config,
348
+ cache_dir=benchmark_config.cache_dir,
349
+ api_key=benchmark_config.api_key,
350
350
  )
351
351
  del dataset
352
352
 
@@ -513,6 +513,11 @@ class Benchmarker:
513
513
  ValueError:
514
514
  If both `task` and `dataset` are specified.
515
515
  """
516
+ log(
517
+ "Started EuroEval run. Run with `--verbose` for more information.",
518
+ level=logging.INFO,
519
+ )
520
+
516
521
  if task is not None and dataset is not None:
517
522
  raise ValueError("Only one of `task` and `dataset` can be specified.")
518
523
 
@@ -790,7 +795,7 @@ class Benchmarker:
790
795
 
791
796
  # Update the benchmark config if the dataset requires it
792
797
  if (
793
- "val" not in dataset_config.splits
798
+ dataset_config.val_split is None
794
799
  and not benchmark_config.evaluate_test_split
795
800
  ):
796
801
  log(
@@ -1066,7 +1071,7 @@ class Benchmarker:
1066
1071
  ),
1067
1072
  validation_split=(
1068
1073
  None
1069
- if "val" not in dataset_config.splits
1074
+ if dataset_config.val_split is None
1070
1075
  else not benchmark_config.evaluate_test_split
1071
1076
  ),
1072
1077
  )
@@ -1181,29 +1186,6 @@ def clear_model_cache_fn(cache_dir: str) -> None:
1181
1186
  rmtree(sub_model_dir)
1182
1187
 
1183
1188
 
1184
- def prepare_dataset_configs(
1185
- dataset_names: c.Sequence[str], custom_datasets_file: Path
1186
- ) -> c.Sequence["DatasetConfig"]:
1187
- """Prepare the dataset configuration(s) to be benchmarked.
1188
-
1189
- Args:
1190
- dataset_names:
1191
- The dataset names to benchmark.
1192
- custom_datasets_file:
1193
- A path to a Python file containing custom dataset configurations.
1194
-
1195
- Returns:
1196
- The prepared list of model IDs.
1197
- """
1198
- return [
1199
- cfg
1200
- for cfg in get_all_dataset_configs(
1201
- custom_datasets_file=custom_datasets_file
1202
- ).values()
1203
- if cfg.name in dataset_names
1204
- ]
1205
-
1206
-
1207
1189
  def initial_logging(
1208
1190
  model_config: "ModelConfig",
1209
1191
  dataset_config: "DatasetConfig",
scandeval/cli.py CHANGED
@@ -5,6 +5,7 @@ from pathlib import Path
5
5
  import click
6
6
 
7
7
  from .benchmarker import Benchmarker
8
+ from .constants import ATTENTION_BACKENDS
8
9
  from .data_models import DatasetConfig
9
10
  from .enums import Device, GenerativeType
10
11
  from .languages import get_all_languages
@@ -174,10 +175,7 @@ from .languages import get_all_languages
174
175
  "--attention-backend",
175
176
  default="FLASHINFER",
176
177
  show_default=True,
177
- type=click.Choice(
178
- ["FLASHINFER", "FLASH_ATTN", "TRITON_ATTN", "FLEX_ATTENTION"],
179
- case_sensitive=True,
180
- ),
178
+ type=click.Choice(ATTENTION_BACKENDS, case_sensitive=True),
181
179
  help="The attention backend to use for vLLM. Only relevant if the model is "
182
180
  "generative.",
183
181
  )
scandeval/constants.py CHANGED
@@ -134,3 +134,8 @@ ATTENTION_BACKENDS: list[str] = [
134
134
  "CPU_ATTN",
135
135
  "CUSTOM",
136
136
  ]
137
+
138
+ # If a dataset configuration has more than this number of languages, we won't log any of
139
+ # the languages. This is for instance the case for the speed benchmark, which has all
140
+ # the languages. The threshold of 5 is somewhat arbitrary.
141
+ MAX_NUMBER_OF_LOGGING_LANGUAGES = 5
@@ -0,0 +1,152 @@
1
+ """Load custom dataset configs."""
2
+
3
+ import importlib.util
4
+ import logging
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ from huggingface_hub import HfApi
9
+
10
+ from .data_models import DatasetConfig
11
+ from .logging_utils import log_once
12
+ from .utils import get_hf_token
13
+
14
+
15
+ def load_custom_datasets_module(custom_datasets_file: Path) -> ModuleType | None:
16
+ """Load the custom datasets module if it exists.
17
+
18
+ Args:
19
+ custom_datasets_file:
20
+ The path to the custom datasets module.
21
+
22
+ Raises:
23
+ RuntimeError:
24
+ If the custom datasets module cannot be loaded.
25
+ """
26
+ if custom_datasets_file.exists():
27
+ spec = importlib.util.spec_from_file_location(
28
+ name="custom_datasets_module", location=str(custom_datasets_file.resolve())
29
+ )
30
+ if spec is None:
31
+ log_once(
32
+ "Could not load the spec for the custom datasets file from "
33
+ f"{custom_datasets_file.resolve()}.",
34
+ level=logging.ERROR,
35
+ )
36
+ return None
37
+ module = importlib.util.module_from_spec(spec=spec)
38
+ if spec.loader is None:
39
+ log_once(
40
+ "Could not load the module for the custom datasets file from "
41
+ f"{custom_datasets_file.resolve()}.",
42
+ level=logging.ERROR,
43
+ )
44
+ return None
45
+ spec.loader.exec_module(module)
46
+ return module
47
+ return None
48
+
49
+
50
+ def try_get_dataset_config_from_repo(
51
+ dataset_id: str, api_key: str | None, cache_dir: Path
52
+ ) -> DatasetConfig | None:
53
+ """Try to get a dataset config from a Hugging Face dataset repository.
54
+
55
+ Args:
56
+ dataset_id:
57
+ The ID of the dataset to get the config for.
58
+ api_key:
59
+ The Hugging Face API key to use to check if the repositories have custom
60
+ dataset configs.
61
+ cache_dir:
62
+ The directory to store the cache in.
63
+
64
+ Returns:
65
+ The dataset config if it exists, otherwise None.
66
+ """
67
+ # Check if the dataset ID is a Hugging Face dataset ID, abort if it isn't
68
+ token = get_hf_token(api_key=api_key)
69
+ hf_api = HfApi(token=token)
70
+ if not hf_api.repo_exists(repo_id=dataset_id, repo_type="dataset"):
71
+ return None
72
+
73
+ # Check if the repository has a euroeval_config.py file, abort if it doesn't
74
+ repo_files = hf_api.list_repo_files(
75
+ repo_id=dataset_id, repo_type="dataset", revision="main"
76
+ )
77
+ if "euroeval_config.py" not in repo_files:
78
+ log_once(
79
+ f"Dataset {dataset_id} does not have a euroeval_config.py file, so we "
80
+ "cannot load it. Skipping.",
81
+ level=logging.WARNING,
82
+ )
83
+ return None
84
+
85
+ # Fetch the euroeval_config.py file, abort if loading failed
86
+ external_config_path = cache_dir / "external_dataset_configs" / dataset_id
87
+ external_config_path.mkdir(parents=True, exist_ok=True)
88
+ hf_api.hf_hub_download(
89
+ repo_id=dataset_id,
90
+ repo_type="dataset",
91
+ filename="euroeval_config.py",
92
+ local_dir=external_config_path,
93
+ local_dir_use_symlinks=False,
94
+ )
95
+ module = load_custom_datasets_module(
96
+ custom_datasets_file=external_config_path / "euroeval_config.py"
97
+ )
98
+ if module is None:
99
+ return None
100
+
101
+ # Check that there is exactly one dataset config, abort if there isn't
102
+ repo_dataset_configs = [
103
+ cfg for cfg in vars(module).values() if isinstance(cfg, DatasetConfig)
104
+ ]
105
+ if not repo_dataset_configs:
106
+ return None # Already warned the user in this case, so we just skip
107
+ elif len(repo_dataset_configs) > 1:
108
+ log_once(
109
+ f"Dataset {dataset_id} has multiple dataset configurations. Please ensure "
110
+ "that only a single DatasetConfig is defined in the `euroeval_config.py` "
111
+ "file.",
112
+ level=logging.WARNING,
113
+ )
114
+ return None
115
+
116
+ # Get the dataset split names
117
+ splits = [
118
+ split["name"]
119
+ for split in hf_api.dataset_info(repo_id=dataset_id).card_data.dataset_info[
120
+ "splits"
121
+ ]
122
+ ]
123
+ train_split_candidates = sorted(
124
+ [split for split in splits if "train" in split.lower()], key=len
125
+ )
126
+ val_split_candidates = sorted(
127
+ [split for split in splits if "val" in split.lower()], key=len
128
+ )
129
+ test_split_candidates = sorted(
130
+ [split for split in splits if "test" in split.lower()], key=len
131
+ )
132
+ train_split = train_split_candidates[0] if train_split_candidates else None
133
+ val_split = val_split_candidates[0] if val_split_candidates else None
134
+ test_split = test_split_candidates[0] if test_split_candidates else None
135
+ if test_split is None:
136
+ log_once(
137
+ f"Dataset {dataset_id} does not have a test split, so we cannot load it. "
138
+ "Please ensure that the dataset has a test split.",
139
+ level=logging.ERROR,
140
+ )
141
+ return None
142
+
143
+ # Set up the config with the repo information
144
+ repo_dataset_config = repo_dataset_configs[0]
145
+ repo_dataset_config.name = dataset_id
146
+ repo_dataset_config.pretty_name = dataset_id
147
+ repo_dataset_config.source = dataset_id
148
+ repo_dataset_config.train_split = train_split
149
+ repo_dataset_config.val_split = val_split
150
+ repo_dataset_config.test_split = test_split
151
+
152
+ return repo_dataset_config