EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,79 @@
1
+ """Caching utility functions."""
2
+
3
+ import typing as t
4
+ from functools import wraps
5
+
6
+ from .constants import T
7
+
8
+
9
+ def cache_arguments(
10
+ *arguments: str, disable_condition: t.Callable[[], bool] = lambda: False
11
+ ) -> t.Callable[[t.Callable[..., T]], t.Callable[..., T]]:
12
+ """Cache specified arguments of a function.
13
+
14
+ Args:
15
+ arguments:
16
+ The list of argument names to cache. If empty, all arguments are cached.
17
+ disable_condition:
18
+ A function that checks if cache should be disabled.
19
+
20
+ Returns:
21
+ A decorator that caches the specified arguments of a function.
22
+ """
23
+
24
+ def caching_decorator(func: t.Callable[..., T]) -> t.Callable[..., T]:
25
+ """Decorator that caches the specified arguments of a function.
26
+
27
+ Args:
28
+ func:
29
+ The function to decorate.
30
+
31
+ Returns:
32
+ The decorated function.
33
+ """
34
+ cache: dict[tuple, T] = dict()
35
+
36
+ @wraps(func)
37
+ def wrapper(*args, **kwargs) -> T:
38
+ """Wrapper function that caches the specified arguments.
39
+
40
+ Args:
41
+ *args:
42
+ The positional arguments to the function.
43
+ **kwargs:
44
+ The keyword arguments to the function.
45
+
46
+ Returns:
47
+ The result of the function.
48
+
49
+ Raises:
50
+ ValueError:
51
+ If an argument name is not found in the function parameters.
52
+ """
53
+ if not arguments:
54
+ key = args + tuple(kwargs[k] for k in sorted(kwargs.keys()))
55
+ else:
56
+ func_params = func.__code__.co_varnames
57
+ key_items: list[t.Any] = list()
58
+ for arg_name in arguments:
59
+ if arg_name in kwargs:
60
+ key_items.append(kwargs[arg_name])
61
+ else:
62
+ try:
63
+ arg_index = func_params.index(arg_name)
64
+ key_items.append(args[arg_index])
65
+ except (ValueError, IndexError):
66
+ raise ValueError(
67
+ f"Argument {arg_name} not found in function "
68
+ f"{func.__name__} parameters."
69
+ )
70
+ key = tuple(key_items)
71
+
72
+ # Do not cache if the condition is met
73
+ if key not in cache or disable_condition():
74
+ cache[key] = func(*args, **kwargs)
75
+ return cache[key]
76
+
77
+ return wrapper
78
+
79
+ return caching_decorator
euroeval/callbacks.py CHANGED
@@ -7,6 +7,8 @@ from collections.abc import Sized
7
7
  from tqdm.auto import tqdm
8
8
  from transformers.trainer_callback import ProgressCallback
9
9
 
10
+ from .logging_utils import get_pbar
11
+
10
12
  if t.TYPE_CHECKING:
11
13
  from torch.utils.data import DataLoader
12
14
  from transformers.trainer_callback import TrainerControl, TrainerState
@@ -32,11 +34,8 @@ class NeverLeaveProgressCallback(ProgressCallback):
32
34
  """Callback actions when training begins."""
33
35
  if state.is_local_process_zero:
34
36
  desc = "Finetuning model"
35
- self.training_bar = tqdm(
36
- total=None,
37
- leave=False,
38
- desc=desc,
39
- disable=hasattr(sys, "_called_from_test"),
37
+ self.training_bar = get_pbar(
38
+ total=None, desc=desc, disable=hasattr(sys, "_called_from_test")
40
39
  )
41
40
  self.current_step = 0
42
41
 
@@ -67,9 +66,8 @@ class NeverLeaveProgressCallback(ProgressCallback):
67
66
  if state.is_local_process_zero and correct_dtype:
68
67
  if self.prediction_bar is None:
69
68
  desc = "Evaluating model"
70
- self.prediction_bar = tqdm(
69
+ self.prediction_bar = get_pbar(
71
70
  total=len(eval_dataloader),
72
- leave=False,
73
71
  desc=desc,
74
72
  disable=hasattr(sys, "_called_from_test"),
75
73
  )
euroeval/cli.py CHANGED
@@ -3,10 +3,9 @@
3
3
  import click
4
4
 
5
5
  from .benchmarker import Benchmarker
6
- from .dataset_configs import get_all_dataset_configs
7
- from .enums import Device
6
+ from .data_models import DatasetConfig
7
+ from .enums import Device, GenerativeType
8
8
  from .languages import get_all_languages
9
- from .tasks import get_all_tasks
10
9
 
11
10
 
12
11
  @click.command()
@@ -23,7 +22,6 @@ from .tasks import get_all_tasks
23
22
  default=None,
24
23
  show_default=True,
25
24
  multiple=True,
26
- type=click.Choice(list(get_all_tasks().keys())),
27
25
  help="The dataset tasks to benchmark the model(s) on.",
28
26
  )
29
27
  @click.option(
@@ -45,8 +43,7 @@ from .tasks import get_all_tasks
45
43
  multiple=True,
46
44
  metavar="ISO 639-1 LANGUAGE CODE",
47
45
  type=click.Choice(["all"] + list(get_all_languages().keys())),
48
- help="""The model languages to benchmark. If not specified then this will use the
49
- `language` value.""",
46
+ help="""This option is deprecated - please use --language instead.""",
50
47
  )
51
48
  @click.option(
52
49
  "--dataset-language",
@@ -56,24 +53,28 @@ from .tasks import get_all_tasks
56
53
  multiple=True,
57
54
  metavar="ISO 639-1 LANGUAGE CODE",
58
55
  type=click.Choice(["all"] + list(get_all_languages().keys())),
59
- help="""The dataset languages to benchmark. If "all" then the models will be
60
- benchmarked on all datasets. If not specified then this will use the `language`
61
- value.""",
56
+ help="""This option is deprecated - please use --language instead.""",
62
57
  )
63
58
  @click.option(
64
59
  "--dataset",
65
60
  default=None,
66
61
  show_default=True,
67
62
  multiple=True,
68
- type=click.Choice(list(get_all_dataset_configs().keys())),
69
63
  help="""The name of the benchmark dataset. We recommend to use the `task` and
70
64
  `language` options instead of this option.""",
71
65
  )
72
66
  @click.option(
73
67
  "--batch-size",
68
+ default=None,
69
+ type=click.Choice(["1", "2", "4", "8", "16", "32"]),
70
+ help="This option is deprecated - please use --finetuning-batch-size instead.",
71
+ deprecated=True,
72
+ )
73
+ @click.option(
74
+ "--finetuning-batch-size",
74
75
  default="32",
75
76
  type=click.Choice(["1", "2", "4", "8", "16", "32"]),
76
- help="The batch size to use.",
77
+ help="The batch size to use for finetuning.",
77
78
  )
78
79
  @click.option(
79
80
  "--progress-bar/--no-progress-bar",
@@ -188,7 +189,7 @@ from .tasks import get_all_tasks
188
189
  )
189
190
  @click.option(
190
191
  "--gpu-memory-utilization",
191
- default=0.9,
192
+ default=0.8,
192
193
  show_default=True,
193
194
  help="The GPU memory utilization to use for vLLM. A larger value will result in "
194
195
  "faster evaluation, but at the risk of running out of GPU memory. Only reduce this "
@@ -203,20 +204,35 @@ from .tasks import get_all_tasks
203
204
  "relevant if the model is generative.",
204
205
  )
205
206
  @click.option(
206
- "--only-allow-safetensors",
207
+ "--requires-safetensors",
207
208
  is_flag=True,
208
209
  help="Only allow loading models that have safetensors weights available",
209
210
  default=False,
210
211
  )
212
+ @click.option(
213
+ "--generative-type",
214
+ type=click.Choice(["base", "instruction_tuned", "reasoning"]),
215
+ default=None,
216
+ show_default=True,
217
+ help="The type of generative model. Only relevant if the model is generative. If "
218
+ "not specified, the type will be inferred automatically.",
219
+ )
220
+ @click.option(
221
+ "--download-only",
222
+ is_flag=True,
223
+ help="Only download the requested model weights and datasets, and exit.",
224
+ default=False,
225
+ )
211
226
  def benchmark(
212
227
  model: tuple[str],
213
- dataset: tuple[str],
228
+ dataset: tuple[str | DatasetConfig],
214
229
  language: tuple[str],
215
230
  model_language: tuple[str],
216
231
  dataset_language: tuple[str],
217
232
  raise_errors: bool,
218
233
  task: tuple[str],
219
- batch_size: str,
234
+ batch_size: str | None,
235
+ finetuning_batch_size: str,
220
236
  progress_bar: bool,
221
237
  save_results: bool,
222
238
  cache_dir: str,
@@ -233,25 +249,16 @@ def benchmark(
233
249
  api_version: str | None,
234
250
  gpu_memory_utilization: float,
235
251
  debug: bool,
236
- only_allow_safetensors: bool,
252
+ requires_safetensors: bool,
253
+ generative_type: str | None,
254
+ download_only: bool,
237
255
  ) -> None:
238
256
  """Benchmark pretrained language models on language tasks."""
239
- models = list(model)
240
- datasets = None if len(dataset) == 0 else list(dataset)
241
- languages: list[str] = list(language)
242
- model_languages = None if len(model_language) == 0 else list(model_language)
243
- dataset_languages = None if len(dataset_language) == 0 else list(dataset_language)
244
- tasks = None if len(task) == 0 else list(task)
245
- batch_size_int = int(batch_size)
246
- device = Device[device.upper()] if device is not None else None
247
-
248
- benchmarker = Benchmarker(
249
- language=languages,
250
- model_language=model_languages,
251
- dataset_language=dataset_languages,
252
- task=tasks,
253
- dataset=datasets,
254
- batch_size=batch_size_int,
257
+ Benchmarker(
258
+ language=list(language),
259
+ task=None if len(task) == 0 else list(task),
260
+ dataset=None if len(dataset) == 0 else list(dataset),
261
+ finetuning_batch_size=int(finetuning_batch_size),
255
262
  progress_bar=progress_bar,
256
263
  save_results=save_results,
257
264
  raise_errors=raise_errors,
@@ -259,7 +266,7 @@ def benchmark(
259
266
  api_key=api_key,
260
267
  force=force,
261
268
  cache_dir=cache_dir,
262
- device=device,
269
+ device=Device[device.upper()] if device is not None else None,
263
270
  trust_remote_code=trust_remote_code,
264
271
  clear_model_cache=clear_model_cache,
265
272
  evaluate_test_split=evaluate_test_split,
@@ -268,13 +275,17 @@ def benchmark(
268
275
  api_base=api_base,
269
276
  api_version=api_version,
270
277
  gpu_memory_utilization=gpu_memory_utilization,
278
+ generative_type=GenerativeType[generative_type.upper()]
279
+ if generative_type
280
+ else None,
271
281
  debug=debug,
272
282
  run_with_cli=True,
273
- only_allow_safetensors=only_allow_safetensors,
274
- )
275
-
276
- # Perform the benchmark evaluation
277
- benchmarker.benchmark(model=models)
283
+ requires_safetensors=requires_safetensors,
284
+ download_only=download_only,
285
+ model_language=None if len(model_language) == 0 else list(model_language),
286
+ dataset_language=None if len(dataset_language) == 0 else list(dataset_language),
287
+ batch_size=int(batch_size) if batch_size is not None else None,
288
+ ).benchmark(model=list(model))
278
289
 
279
290
 
280
291
  if __name__ == "__main__":
euroeval/constants.py CHANGED
@@ -1,23 +1,25 @@
1
1
  """Constants used throughout the project."""
2
2
 
3
+ import re
4
+ from typing import TypeVar
5
+
3
6
  from .enums import TaskGroup
4
- from .tasks import NER
7
+
8
+ # Type variable used for generic typing
9
+ T = TypeVar("T", bound=object)
5
10
 
6
11
  # This is used as input to generative models; it cannot be a special token
7
12
  DUMMY_FILL_VALUE = 100
8
13
 
9
-
10
14
  # This is the maximum allowed context length for models for the purpose of this
11
15
  # benchmark. We will still report the models' true maximum context length in the
12
16
  # metadata, but we won't use it for evaluation, as vLLM needs to allocate memory for
13
17
  # all tokens in the context.
14
- MAX_CONTEXT_LENGTH = 5_000
15
-
18
+ MAX_CONTEXT_LENGTH = 8_192
16
19
 
17
20
  # We need to raise the amount of tokens generated for reasoning models, to give them
18
21
  # time to think
19
- REASONING_MAX_TOKENS = 32_768
20
-
22
+ REASONING_MAX_TOKENS = 8_192
21
23
 
22
24
  # The Hugging Face Hub pipeline tags used to classify models as generative
23
25
  GENERATIVE_PIPELINE_TAGS = [
@@ -28,48 +30,49 @@ GENERATIVE_PIPELINE_TAGS = [
28
30
  "video-text-to-text",
29
31
  ]
30
32
 
31
-
32
33
  # Used to disallow non-generative models to be evaluated on these task groups
33
34
  GENERATIVE_DATASET_TASK_GROUPS = [TaskGroup.TEXT_TO_TEXT]
34
35
 
35
-
36
36
  # Local models are required to have these files in their directory
37
37
  LOCAL_MODELS_REQUIRED_FILES = ["config.json"]
38
38
 
39
-
40
- # Tasks where we use structured generation for generative models
41
- TASKS_USING_JSON = [NER]
42
-
43
-
44
- # Tasks where we use log probabilities for generative models, rather than the raw
45
- # completion
46
- TASK_GROUPS_USING_LOGPROBS = [
47
- TaskGroup.SEQUENCE_CLASSIFICATION,
48
- TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION,
49
- ]
50
-
51
-
52
39
  # The number of top log probabilities to return for generative models. For several APIs
53
40
  # this is the maximum number of log probabilities that can be returned
54
- MAX_LOGPROBS = 8
55
-
41
+ MAX_VLLM_LOGPROBS = 20
42
+ MAX_LITELLM_LOGPROBS = 8
56
43
 
57
44
  # We make sure to remove these metric attributes after each iteration, to avoid memory
58
45
  # leaks
59
46
  METRIC_ATTRIBUTES_TAKING_UP_MEMORY = ["cached_bertscorer"]
60
47
 
61
-
62
48
  # Hugging Face Hub tags used to classify models as merge models
63
49
  MERGE_TAGS = ["merge", "mergekit"]
64
50
 
65
51
  # The minimum required CUDA compute capability for using bfloat16 in vLLM
66
52
  VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY = 8.0
67
53
 
54
+ # The candidates for end-of-sequence, beginning-of-sequence and padding tokens
55
+ EOS_TOKENS = ["</s>", "<|end_of_text|>", "<|endoftext|>", "[SEP]", "<|return|>"]
56
+ BOS_TOKENS = ["<s>", "<|begin_of_text|>", "<|startoftext|>", "[CLS]"]
57
+ PAD_TOKENS = [
58
+ "<pad>",
59
+ "<PAD>",
60
+ "[pad]",
61
+ "[PAD]",
62
+ "<|endoftext|>",
63
+ "<|end▁of▁sentence|>",
64
+ "<|im_end|>",
65
+ ]
66
+
68
67
  # Used to detect whether a model is a reasoning model
69
- REASONING_TOKENS = [
68
+ REASONING_TOKENS: list[tuple[str | re.Pattern, str | re.Pattern]] = [
70
69
  ("<think>", "</think>"),
71
70
  ("<reason>", "</reason>"),
72
71
  ("<reasoning>", "</reasoning>"),
72
+ (
73
+ re.compile(pattern=r"<\|channel\|>(analysis|commentary)<\|message\|>"),
74
+ "<|channel|>final<|message|>",
75
+ ),
73
76
  ]
74
77
 
75
78
  # These tokens are sometimes used by models to indicate the end of a generated
@@ -77,3 +80,19 @@ REASONING_TOKENS = [
77
80
  # manually. We only use them as stop tokens if they actually appear in the model's
78
81
  # output
79
82
  CUSTOM_STOP_TOKENS = ["<sep>"]
83
+
84
+ # For classification tasks we force LiteLLM models to output a JSON dictionary with a
85
+ # single key and the values being restricted to the allowed labels. This is the key we
86
+ # use
87
+ LITELLM_CLASSIFICATION_OUTPUT_KEY = "label"
88
+
89
+ # These characters are stripped from JSON output when trying to identify the label
90
+ JSON_STRIP_CHARACTERS = ' {}\n\r":'
91
+
92
+ # The number of tokens we generate when evaluating generative models on classification
93
+ # tasks. We also use this to determine whether we should store logprobs in the model
94
+ # outputs (and cache).
95
+ NUM_GENERATION_TOKENS_FOR_CLASSIFICATION = 10
96
+
97
+ # We only allow loading local datasets in these file formats
98
+ SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS = ["csv"]
euroeval/data_loading.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Functions related to the loading of the data."""
2
2
 
3
+ import collections.abc as c
3
4
  import logging
4
5
  import sys
5
6
  import time
@@ -11,7 +12,10 @@ from datasets.exceptions import DatasetsError
11
12
  from huggingface_hub.errors import HfHubHTTPError
12
13
  from numpy.random import Generator
13
14
 
15
+ from .constants import SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS
14
16
  from .exceptions import HuggingFaceHubDown, InvalidBenchmark
17
+ from .logging_utils import log, no_terminal_output
18
+ from .tasks import EUROPEAN_VALUES
15
19
  from .utils import unscramble
16
20
 
17
21
  if t.TYPE_CHECKING:
@@ -19,8 +23,6 @@ if t.TYPE_CHECKING:
19
23
 
20
24
  from .data_models import BenchmarkConfig, DatasetConfig
21
25
 
22
- logger = logging.getLogger("euroeval")
23
-
24
26
 
25
27
  def load_data(
26
28
  rng: Generator, dataset_config: "DatasetConfig", benchmark_config: "BenchmarkConfig"
@@ -48,40 +50,45 @@ def load_data(
48
50
  dataset_config=dataset_config, cache_dir=benchmark_config.cache_dir
49
51
  )
50
52
 
51
- if not benchmark_config.evaluate_test_split:
53
+ if not benchmark_config.evaluate_test_split and "val" in dataset:
52
54
  dataset["test"] = dataset["val"]
53
55
 
54
56
  # Remove empty examples from the datasets
55
57
  for text_feature in ["tokens", "text"]:
56
- if text_feature in dataset["train"].features:
57
- dataset = dataset.filter(lambda x: len(x[text_feature]) > 0)
58
+ for split in dataset_config.splits:
59
+ if text_feature in dataset[split].features:
60
+ dataset = dataset.filter(lambda x: len(x[text_feature]) > 0)
58
61
 
59
- # If we are testing then truncate the test set
60
- if hasattr(sys, "_called_from_test"):
62
+ # If we are testing then truncate the test set, unless we need the full set for
63
+ # evaluation
64
+ if hasattr(sys, "_called_from_test") and dataset_config.task != EUROPEAN_VALUES:
61
65
  dataset["test"] = dataset["test"].select(range(1))
62
66
 
63
- # Bootstrap the splits
64
- bootstrapped_splits: dict[str, list["Dataset"]] = dict()
65
- for split in ["train", "val", "test"]:
66
- bootstrap_indices = rng.integers(
67
- 0,
68
- len(dataset[split]),
69
- size=(benchmark_config.num_iterations, len(dataset[split])),
70
- )
71
- bootstrapped_splits[split] = [
72
- dataset[split].select(bootstrap_indices[idx])
67
+ # Bootstrap the splits, if applicable
68
+ if dataset_config.bootstrap_samples:
69
+ bootstrapped_splits: dict[str, c.Sequence["Dataset"]] = dict()
70
+ for split in dataset_config.splits:
71
+ bootstrap_indices = rng.integers(
72
+ 0,
73
+ len(dataset[split]),
74
+ size=(benchmark_config.num_iterations, len(dataset[split])),
75
+ )
76
+ bootstrapped_splits[split] = [
77
+ dataset[split].select(bootstrap_indices[idx])
78
+ for idx in range(benchmark_config.num_iterations)
79
+ ]
80
+ datasets = [
81
+ DatasetDict(
82
+ {
83
+ split: bootstrapped_splits[split][idx]
84
+ for split in dataset_config.splits
85
+ }
86
+ )
73
87
  for idx in range(benchmark_config.num_iterations)
74
88
  ]
89
+ else:
90
+ datasets = [dataset] * benchmark_config.num_iterations
75
91
 
76
- datasets = [
77
- DatasetDict(
78
- {
79
- split: bootstrapped_splits[split][idx]
80
- for split in ["train", "val", "test"]
81
- }
82
- )
83
- for idx in range(benchmark_config.num_iterations)
84
- ]
85
92
  return datasets
86
93
 
87
94
 
@@ -97,40 +104,89 @@ def load_raw_data(dataset_config: "DatasetConfig", cache_dir: str) -> "DatasetDi
97
104
  Returns:
98
105
  The dataset.
99
106
  """
100
- num_attempts = 5
101
- for _ in range(num_attempts):
102
- try:
103
- dataset = load_dataset(
104
- path=dataset_config.huggingface_id,
105
- cache_dir=cache_dir,
106
- token=unscramble("HjccJFhIozVymqXDVqTUTXKvYhZMTbfIjMxG_"),
107
+ # Case where the dataset source is a Hugging Face ID
108
+ if isinstance(dataset_config.source, str):
109
+ num_attempts = 5
110
+ for _ in range(num_attempts):
111
+ try:
112
+ with no_terminal_output():
113
+ dataset = load_dataset(
114
+ path=dataset_config.source.split("::")[0],
115
+ name=(
116
+ dataset_config.source.split("::")[1]
117
+ if "::" in dataset_config.source
118
+ else None
119
+ ),
120
+ cache_dir=cache_dir,
121
+ token=unscramble("XbjeOLhwebEaSaDUMqqaPaPIhgOcyOfDpGnX_"),
122
+ )
123
+ break
124
+ except (
125
+ FileNotFoundError,
126
+ ConnectionError,
127
+ DatasetsError,
128
+ requests.ConnectionError,
129
+ requests.ReadTimeout,
130
+ ) as e:
131
+ log(
132
+ f"Failed to load dataset {dataset_config.source!r}, due to "
133
+ f"the following error: {e}. Retrying...",
134
+ level=logging.DEBUG,
135
+ )
136
+ time.sleep(1)
137
+ continue
138
+ except HfHubHTTPError:
139
+ raise HuggingFaceHubDown()
140
+ else:
141
+ raise InvalidBenchmark(
142
+ f"Failed to load dataset {dataset_config.source!r} after "
143
+ f"{num_attempts} attempts. Run with verbose mode to see the individual "
144
+ "errors."
107
145
  )
108
- break
109
- except (
110
- FileNotFoundError,
111
- ConnectionError,
112
- DatasetsError,
113
- requests.ConnectionError,
114
- requests.ReadTimeout,
115
- ):
116
- logger.warning(
117
- f"Failed to load dataset {dataset_config.huggingface_id!r}. Retrying..."
118
- )
119
- time.sleep(1)
120
- continue
121
- except HfHubHTTPError:
122
- raise HuggingFaceHubDown()
146
+
147
+ # Case where the dataset source is a dictionary with keys "train", "val" and "test",
148
+ # with the values pointing to local CSV files
123
149
  else:
124
- raise InvalidBenchmark(
125
- f"Failed to load dataset {dataset_config.huggingface_id!r} after "
126
- f"{num_attempts} attempts."
127
- )
150
+ data_files = {
151
+ split: dataset_config.source[split]
152
+ for split in dataset_config.splits
153
+ if split in dataset_config.source
154
+ }
155
+
156
+ # Get the file extension and ensure that all files have the same extension
157
+ file_extensions = {
158
+ split: dataset_config.source[split].split(".")[-1]
159
+ for split in dataset_config.splits
160
+ if split in dataset_config.source
161
+ }
162
+ if len(set(file_extensions.values())) != 1:
163
+ raise InvalidBenchmark(
164
+ "All data files in a custom dataset must have the same file extension. "
165
+ f"Got the extensions {', '.join(file_extensions.values())} for the "
166
+ f"dataset {dataset_config.name!r}."
167
+ )
168
+ file_extension = list(file_extensions.values())[0]
169
+
170
+ # Check that the file extension is supported
171
+ if file_extension not in SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS:
172
+ raise InvalidBenchmark(
173
+ "Unsupported file extension for custom dataset. Supported file "
174
+ "extensions are "
175
+ f"{', '.join(SUPPORTED_FILE_FORMATS_FOR_LOCAL_DATASETS)}, but got "
176
+ f"{file_extension!r}."
177
+ )
178
+
179
+ # Load the dataset
180
+ with no_terminal_output():
181
+ dataset = load_dataset(
182
+ path=file_extension, data_files=data_files, cache_dir=cache_dir
183
+ )
184
+
128
185
  assert isinstance(dataset, DatasetDict) # type: ignore[used-before-def]
129
- required_keys = ["train", "val", "test"]
130
- missing_keys = [key for key in required_keys if key not in dataset]
186
+ missing_keys = [key for key in dataset_config.splits if key not in dataset]
131
187
  if missing_keys:
132
188
  raise InvalidBenchmark(
133
189
  "The dataset is missing the following required splits: "
134
190
  f"{', '.join(missing_keys)}"
135
191
  )
136
- return DatasetDict({key: dataset[key] for key in required_keys})
192
+ return DatasetDict({key: dataset[key] for key in dataset_config.splits})