ScandEval 16.11.0__py3-none-any.whl → 16.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. scandeval/__init__.py +0 -9
  2. scandeval/async_utils.py +46 -0
  3. scandeval/benchmark_config_factory.py +31 -2
  4. scandeval/benchmark_modules/fresh.py +2 -1
  5. scandeval/benchmark_modules/hf.py +76 -23
  6. scandeval/benchmark_modules/litellm.py +33 -15
  7. scandeval/benchmark_modules/vllm.py +97 -44
  8. scandeval/benchmarker.py +29 -33
  9. scandeval/cli.py +11 -0
  10. scandeval/constants.py +36 -2
  11. scandeval/custom_dataset_configs.py +152 -0
  12. scandeval/data_loading.py +87 -31
  13. scandeval/data_models.py +405 -224
  14. scandeval/dataset_configs/__init__.py +51 -25
  15. scandeval/dataset_configs/albanian.py +1 -1
  16. scandeval/dataset_configs/belarusian.py +47 -0
  17. scandeval/dataset_configs/bulgarian.py +1 -1
  18. scandeval/dataset_configs/catalan.py +1 -1
  19. scandeval/dataset_configs/croatian.py +1 -1
  20. scandeval/dataset_configs/danish.py +3 -2
  21. scandeval/dataset_configs/dutch.py +16 -5
  22. scandeval/dataset_configs/english.py +4 -3
  23. scandeval/dataset_configs/estonian.py +8 -7
  24. scandeval/dataset_configs/faroese.py +1 -1
  25. scandeval/dataset_configs/finnish.py +5 -4
  26. scandeval/dataset_configs/french.py +6 -5
  27. scandeval/dataset_configs/german.py +4 -3
  28. scandeval/dataset_configs/greek.py +1 -1
  29. scandeval/dataset_configs/hungarian.py +1 -1
  30. scandeval/dataset_configs/icelandic.py +4 -3
  31. scandeval/dataset_configs/italian.py +4 -3
  32. scandeval/dataset_configs/latvian.py +2 -2
  33. scandeval/dataset_configs/lithuanian.py +1 -1
  34. scandeval/dataset_configs/norwegian.py +6 -5
  35. scandeval/dataset_configs/polish.py +4 -3
  36. scandeval/dataset_configs/portuguese.py +5 -4
  37. scandeval/dataset_configs/romanian.py +2 -2
  38. scandeval/dataset_configs/serbian.py +1 -1
  39. scandeval/dataset_configs/slovene.py +1 -1
  40. scandeval/dataset_configs/spanish.py +4 -3
  41. scandeval/dataset_configs/swedish.py +4 -3
  42. scandeval/dataset_configs/ukrainian.py +1 -1
  43. scandeval/generation_utils.py +6 -6
  44. scandeval/metrics/__init__.py +1 -0
  45. scandeval/metrics/bias.py +237 -0
  46. scandeval/metrics/huggingface.py +2 -1
  47. scandeval/metrics/llm_as_a_judge.py +1 -1
  48. scandeval/metrics/pipeline.py +1 -1
  49. scandeval/model_cache.py +34 -4
  50. scandeval/prompt_templates/linguistic_acceptability.py +9 -0
  51. scandeval/prompt_templates/multiple_choice.py +9 -0
  52. scandeval/prompt_templates/named_entity_recognition.py +21 -0
  53. scandeval/prompt_templates/reading_comprehension.py +10 -0
  54. scandeval/prompt_templates/sentiment_classification.py +11 -0
  55. scandeval/string_utils.py +157 -0
  56. scandeval/task_group_utils/sequence_classification.py +2 -5
  57. scandeval/task_group_utils/token_classification.py +2 -4
  58. scandeval/tasks.py +22 -0
  59. scandeval/tokenisation_utils.py +12 -1
  60. scandeval/utils.py +13 -383
  61. scandeval-16.13.0.dist-info/METADATA +334 -0
  62. scandeval-16.13.0.dist-info/RECORD +94 -0
  63. scandeval-16.11.0.dist-info/METADATA +0 -649
  64. scandeval-16.11.0.dist-info/RECORD +0 -89
  65. {scandeval-16.11.0.dist-info → scandeval-16.13.0.dist-info}/WHEEL +0 -0
  66. {scandeval-16.11.0.dist-info → scandeval-16.13.0.dist-info}/entry_points.txt +0 -0
  67. {scandeval-16.11.0.dist-info → scandeval-16.13.0.dist-info}/licenses/LICENSE +0 -0
scandeval/__init__.py CHANGED
@@ -110,15 +110,6 @@ os.environ["DISABLE_AIOHTTP_TRANSPORT"] = "True"
110
110
  os.environ["VLLM_USE_V1"] = "1"
111
111
 
112
112
 
113
- # Use the FlashInfer flash-attention backend for vLLM, unless the user has already
114
- # specified a different backend.
115
- if os.getenv("VLLM_ATTENTION_BACKEND") is None:
116
- os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
117
- os.environ["USER_HAS_SET_VLLM_ATTENTION_BACKEND"] = "0"
118
- else:
119
- os.environ["USER_HAS_SET_VLLM_ATTENTION_BACKEND"] = "1"
120
-
121
-
122
113
  # Set the HF_TOKEN env var to copy the HUGGINGFACE_API_KEY env var, as vLLM uses the
123
114
  # former and LiteLLM uses the latter
124
115
  if os.getenv("HUGGINGFACE_API_KEY"):
@@ -0,0 +1,46 @@
1
+ """Utility functions for asyncronous tasks."""
2
+
3
+ import asyncio
4
+ import typing as t
5
+
6
+ from .constants import T
7
+
8
+
9
+ def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
10
+ """Run a coroutine, ensuring that the event loop is always closed when we're done.
11
+
12
+ Args:
13
+ coroutine:
14
+ The coroutine to run.
15
+
16
+ Returns:
17
+ The result of the coroutine.
18
+ """
19
+ try:
20
+ loop = asyncio.get_event_loop()
21
+ except RuntimeError: # If the current event loop is closed
22
+ loop = asyncio.new_event_loop()
23
+ asyncio.set_event_loop(loop)
24
+ response = loop.run_until_complete(coroutine)
25
+ return response
26
+
27
+
28
+ async def add_semaphore_and_catch_exception(
29
+ coroutine: t.Coroutine[t.Any, t.Any, T], semaphore: asyncio.Semaphore
30
+ ) -> T | Exception:
31
+ """Run a coroutine with a semaphore.
32
+
33
+ Args:
34
+ coroutine:
35
+ The coroutine to run.
36
+ semaphore:
37
+ The semaphore to use.
38
+
39
+ Returns:
40
+ The result of the coroutine.
41
+ """
42
+ async with semaphore:
43
+ try:
44
+ return await coroutine
45
+ except Exception as exc:
46
+ return exc
@@ -1,6 +1,7 @@
1
1
  """Factory class for creating dataset configurations."""
2
2
 
3
3
  import collections.abc as c
4
+ import importlib.util
4
5
  import sys
5
6
  import typing as t
6
7
  from pathlib import Path
@@ -13,6 +14,9 @@ from .enums import Device
13
14
  from .exceptions import InvalidBenchmark
14
15
  from .languages import get_all_languages
15
16
 
17
+ if importlib.util.find_spec("vllm") is not None:
18
+ pass
19
+
16
20
  if t.TYPE_CHECKING:
17
21
  from .data_models import Language
18
22
 
@@ -42,6 +46,8 @@ def build_benchmark_config(
42
46
  dataset=benchmark_config_params.dataset,
43
47
  languages=languages,
44
48
  custom_datasets_file=benchmark_config_params.custom_datasets_file,
49
+ api_key=benchmark_config_params.api_key,
50
+ cache_dir=Path(benchmark_config_params.cache_dir),
45
51
  )
46
52
 
47
53
  return BenchmarkConfig(
@@ -68,6 +74,7 @@ def build_benchmark_config(
68
74
  api_base=benchmark_config_params.api_base,
69
75
  api_version=benchmark_config_params.api_version,
70
76
  gpu_memory_utilization=benchmark_config_params.gpu_memory_utilization,
77
+ attention_backend=benchmark_config_params.attention_backend,
71
78
  generative_type=benchmark_config_params.generative_type,
72
79
  debug=benchmark_config_params.debug,
73
80
  run_with_cli=benchmark_config_params.run_with_cli,
@@ -154,7 +161,9 @@ def prepare_dataset_configs(
154
161
  languages: c.Sequence["Language"],
155
162
  dataset: "str | DatasetConfig | c.Sequence[str | DatasetConfig] | None",
156
163
  custom_datasets_file: Path,
157
- ) -> c.Sequence["DatasetConfig"]:
164
+ api_key: str | None,
165
+ cache_dir: Path,
166
+ ) -> list["DatasetConfig"]:
158
167
  """Prepare dataset config(s) for benchmarking.
159
168
 
160
169
  Args:
@@ -168,6 +177,10 @@ def prepare_dataset_configs(
168
177
  included, limited by the `task` and `languages` parameters.
169
178
  custom_datasets_file:
170
179
  A path to a Python file containing custom dataset configurations.
180
+ api_key:
181
+ The API key to use for accessing the Hugging Face Hub.
182
+ cache_dir:
183
+ The directory to store the cache in.
171
184
 
172
185
  Returns:
173
186
  The prepared dataset configs.
@@ -176,9 +189,25 @@ def prepare_dataset_configs(
176
189
  InvalidBenchmark:
177
190
  If the task or dataset is not found in the benchmark tasks or datasets.
178
191
  """
192
+ # Extract the dataset IDs from the `dataset` argument
193
+ dataset_ids: list[str] = list()
194
+ if isinstance(dataset, str):
195
+ dataset_ids.append(dataset)
196
+ elif isinstance(dataset, DatasetConfig):
197
+ dataset_ids.append(dataset.name)
198
+ elif isinstance(dataset, list):
199
+ for d in dataset:
200
+ if isinstance(d, str):
201
+ dataset_ids.append(d)
202
+ elif isinstance(d, DatasetConfig):
203
+ dataset_ids.append(d.name)
204
+
179
205
  # Create the list of dataset configs
180
206
  all_dataset_configs = get_all_dataset_configs(
181
- custom_datasets_file=custom_datasets_file
207
+ custom_datasets_file=custom_datasets_file,
208
+ dataset_ids=dataset_ids,
209
+ api_key=api_key,
210
+ cache_dir=cache_dir,
182
211
  )
183
212
  all_official_dataset_configs: c.Sequence[DatasetConfig] = [
184
213
  dataset_config
@@ -28,8 +28,9 @@ from ..exceptions import (
28
28
  )
29
29
  from ..generation_utils import raise_if_wrong_params
30
30
  from ..logging_utils import block_terminal_output
31
+ from ..model_cache import create_model_cache_dir
31
32
  from ..types import Tokeniser
32
- from ..utils import create_model_cache_dir, get_hf_token
33
+ from ..utils import get_hf_token
33
34
  from .hf import (
34
35
  HuggingFaceEncoderModel,
35
36
  align_model_and_tokeniser,
@@ -1,6 +1,7 @@
1
1
  """Encoder models from the Hugging Face Hub."""
2
2
 
3
3
  import collections.abc as c
4
+ import importlib
4
5
  import logging
5
6
  import re
6
7
  import typing as t
@@ -63,6 +64,8 @@ from ..exceptions import (
63
64
  from ..generation_utils import raise_if_wrong_params
64
65
  from ..languages import get_all_languages
65
66
  from ..logging_utils import block_terminal_output, log, log_once
67
+ from ..model_cache import create_model_cache_dir
68
+ from ..string_utils import split_model_id
66
69
  from ..task_group_utils import (
67
70
  multiple_choice_classification,
68
71
  question_answering,
@@ -70,13 +73,7 @@ from ..task_group_utils import (
70
73
  )
71
74
  from ..tokenisation_utils import get_bos_token, get_eos_token
72
75
  from ..types import Tokeniser
73
- from ..utils import (
74
- create_model_cache_dir,
75
- get_class_by_name,
76
- get_hf_token,
77
- internet_connection_available,
78
- split_model_id,
79
- )
76
+ from ..utils import get_hf_token, internet_connection_available
80
77
  from .base import BenchmarkModule
81
78
 
82
79
  try:
@@ -381,7 +378,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
381
378
  if "label" in examples:
382
379
  try:
383
380
  examples["label"] = [
384
- self._model.config.label2id[lbl.lower()]
381
+ self._model.config.label2id[str(lbl).lower()]
385
382
  if self._model.config.label2id is not None
386
383
  else lbl
387
384
  for lbl in examples["label"]
@@ -758,20 +755,30 @@ def get_model_repo_info(
758
755
  # model info object.
759
756
  model_info: HfApiModelInfo | None = None
760
757
  if Path(model_id).is_dir():
761
- if all(
762
- (Path(model_id) / required_file).exists()
763
- for required_file in LOCAL_MODELS_REQUIRED_FILES
764
- ):
758
+ if Path(model_id, "config.json").exists():
765
759
  log_once(
766
- f"The local model directory {model_id!r} has all the required model "
767
- f"files ({LOCAL_MODELS_REQUIRED_FILES}), so we're skipping looking up "
768
- "model information from the Hugging Face Hub.",
760
+ f"The local model directory {model_id!r} has a 'config.json' file, so "
761
+ "we're skipping looking up model information from the Hugging Face "
762
+ "Hub.",
769
763
  level=logging.DEBUG,
770
764
  )
771
765
  model_info = HfApiModelInfo(id=model_id, tags=None, pipeline_tag=None)
766
+ elif Path(model_id, "adapter_config.json").exists():
767
+ log_once(
768
+ f"The local model directory {model_id!r} has an 'adapter_config.json' "
769
+ "file, so we're skipping looking up model information from the Hugging "
770
+ "Face Hub.",
771
+ level=logging.DEBUG,
772
+ )
773
+ model_info = HfApiModelInfo(
774
+ id=model_id,
775
+ tags=None,
776
+ pipeline_tag=None,
777
+ siblings=[dict(rfilename="adapter_config.json")],
778
+ )
772
779
  else:
773
780
  log_once(
774
- f"The local model directory {model_id} does not contain all the "
781
+ f"The local model directory {model_id} does not contain any of the "
775
782
  f"required files: {LOCAL_MODELS_REQUIRED_FILES}. Skipping this "
776
783
  f"model.",
777
784
  level=logging.WARNING,
@@ -807,8 +814,8 @@ def get_model_repo_info(
807
814
  log(
808
815
  f"Could not access the model {model_id} with the revision "
809
816
  f"{revision}. The error was {str(e)!r}. Please set the "
810
- "`HUGGINGFACE_API_KEY` environment variable or use the "
811
- "`--api-key` argument.",
817
+ "`HUGGINGFACE_API_KEY` or `HF_TOKEN` environment variable or "
818
+ "use the `--api-key` argument.",
812
819
  level=logging.DEBUG,
813
820
  )
814
821
  return None
@@ -876,8 +883,9 @@ def get_model_repo_info(
876
883
  for tag in GENERATIVE_PIPELINE_TAGS
877
884
  for class_name in TASK_MAPPING.get(tag, dict()).values() # type: ignore[attr-defined]
878
885
  ]
879
- if class_names is not None and any(
880
- class_name in generative_class_names for class_name in class_names
886
+ if class_names is not None and (
887
+ any(class_name in generative_class_names for class_name in class_names)
888
+ or any("ForCausalLM" in class_name for class_name in class_names)
881
889
  ):
882
890
  pipeline_tag = "text-generation"
883
891
  else:
@@ -1084,8 +1092,8 @@ def load_hf_model_config(
1084
1092
  f"The model {model_id!r} is a gated repository. Please ensure "
1085
1093
  "that you are logged in with `hf auth login` or have provided a "
1086
1094
  "valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
1087
- "environment variable or the `--api-key` argument. Also check that "
1088
- "your account has access to this model."
1095
+ "or `HF_TOKEN` environment variable or the `--api-key` argument. "
1096
+ "Also check that your account has access to this model."
1089
1097
  ) from e
1090
1098
  raise InvalidModel(
1091
1099
  f"Couldn't load model config for {model_id!r}. The error was "
@@ -1121,7 +1129,11 @@ def load_hf_model_config(
1121
1129
  )
1122
1130
 
1123
1131
  # Ensure that the PAD token ID is set
1124
- if config.eos_token_id is not None and config.pad_token_id is None:
1132
+ if (
1133
+ hasattr(config, "eos_token_id")
1134
+ and config.eos_token_id is not None
1135
+ and (not hasattr(config, "pad_token_id") or config.pad_token_id is None)
1136
+ ):
1125
1137
  if isinstance(config.eos_token_id, list):
1126
1138
  config.pad_token_id = config.eos_token_id[0]
1127
1139
  else:
@@ -1319,3 +1331,44 @@ def task_group_to_class_name(task_group: TaskGroup) -> str:
1319
1331
  )
1320
1332
  pascal_case = special_case_mapping.get(pascal_case, pascal_case)
1321
1333
  return f"AutoModelFor{pascal_case}"
1334
+
1335
+
1336
+ def get_class_by_name(
1337
+ class_name: str | c.Sequence[str], module_name: str
1338
+ ) -> t.Type | None:
1339
+ """Get a class by its name.
1340
+
1341
+ Args:
1342
+ class_name:
1343
+ The name of the class, written in kebab-case. The corresponding class name
1344
+ must be the same, but written in PascalCase, and lying in a module with the
1345
+ same name, but written in snake_case. If a list of strings is passed, the
1346
+ first class that is found is returned.
1347
+ module_name:
1348
+ The name of the module where the class is located.
1349
+
1350
+ Returns:
1351
+ The class. If the class is not found, None is returned.
1352
+ """
1353
+ if isinstance(class_name, str):
1354
+ class_name = [class_name]
1355
+
1356
+ error_messages = list()
1357
+ for name in class_name:
1358
+ try:
1359
+ module = importlib.import_module(name=module_name)
1360
+ class_: t.Type = getattr(module, name)
1361
+ return class_
1362
+ except (ModuleNotFoundError, AttributeError) as e:
1363
+ error_messages.append(str(e))
1364
+
1365
+ if error_messages:
1366
+ errors = "\n- " + "\n- ".join(error_messages)
1367
+ log(
1368
+ f"Could not find the class with the name(s) {', '.join(class_name)}. The "
1369
+ f"following error messages were raised: {errors}",
1370
+ level=logging.DEBUG,
1371
+ )
1372
+
1373
+ # If the class could not be found, return None
1374
+ return None
@@ -40,7 +40,7 @@ from pydantic import ValidationError, conlist, create_model
40
40
  from requests.exceptions import RequestException
41
41
  from tqdm.asyncio import tqdm as tqdm_async
42
42
 
43
- from ..caching_utils import cache_arguments
43
+ from ..async_utils import add_semaphore_and_catch_exception, safe_run
44
44
  from ..constants import (
45
45
  JSON_STRIP_CHARACTERS,
46
46
  LITELLM_CLASSIFICATION_OUTPUT_KEY,
@@ -74,6 +74,8 @@ from ..generation_utils import (
74
74
  raise_if_wrong_params,
75
75
  )
76
76
  from ..logging_utils import get_pbar, log, log_once
77
+ from ..model_cache import create_model_cache_dir
78
+ from ..string_utils import split_model_id
77
79
  from ..task_group_utils import (
78
80
  question_answering,
79
81
  sequence_classification,
@@ -83,13 +85,7 @@ from ..task_group_utils import (
83
85
  from ..tasks import NER
84
86
  from ..tokenisation_utils import get_first_label_token_mapping
85
87
  from ..types import ExtractLabelsFunction
86
- from ..utils import (
87
- add_semaphore_and_catch_exception,
88
- create_model_cache_dir,
89
- get_hf_token,
90
- safe_run,
91
- split_model_id,
92
- )
88
+ from ..utils import get_hf_token
93
89
  from .base import BenchmarkModule
94
90
  from .hf import HuggingFaceEncoderModel, load_hf_model_config, load_tokeniser
95
91
 
@@ -700,10 +696,10 @@ class LiteLLMModel(BenchmarkModule):
700
696
  elif isinstance(
701
697
  error, (Timeout, ServiceUnavailableError, InternalServerError, SystemError)
702
698
  ):
703
- log_once(
704
- f"Service temporarily unavailable. The error message was: {error}. "
705
- "Retrying in 10 seconds...",
706
- level=logging.DEBUG,
699
+ log(
700
+ "Service temporarily unavailable during generation. The error "
701
+ f"message was: {error}. Retrying in 10 seconds...",
702
+ level=logging.INFO,
707
703
  )
708
704
  return generation_kwargs, 10
709
705
  elif isinstance(error, UnsupportedParamsError):
@@ -764,6 +760,20 @@ class LiteLLMModel(BenchmarkModule):
764
760
  run_with_cli=self.benchmark_config.run_with_cli,
765
761
  ) from error
766
762
 
763
+ if (
764
+ isinstance(error, (BadRequestError, NotFoundError))
765
+ and self.benchmark_config.api_base is not None
766
+ and not self.benchmark_config.api_base.endswith("/v1")
767
+ ):
768
+ log_once(
769
+ f"The API base {self.benchmark_config.api_base!r} is not valid. We "
770
+ "will try appending '/v1' to it and try again.",
771
+ level=logging.DEBUG,
772
+ )
773
+ self.benchmark_config.api_base += "/v1"
774
+ generation_kwargs["api_base"] = self.benchmark_config.api_base
775
+ return generation_kwargs, 0
776
+
767
777
  raise InvalidBenchmark(
768
778
  f"Failed to generate text. The error message was: {error}"
769
779
  ) from error
@@ -1390,9 +1400,10 @@ class LiteLLMModel(BenchmarkModule):
1390
1400
  InternalServerError,
1391
1401
  ) as e:
1392
1402
  log(
1393
- f"Service temporarily unavailable. The error message was: {e}. "
1403
+ "Service temporarily unavailable while checking for model "
1404
+ f"existence of the model {model_id!r}. The error message was: {e}. "
1394
1405
  "Retrying in 10 seconds...",
1395
- level=logging.DEBUG,
1406
+ level=logging.INFO,
1396
1407
  )
1397
1408
  sleep(10)
1398
1409
  except APIError as e:
@@ -1567,7 +1578,6 @@ class LiteLLMModel(BenchmarkModule):
1567
1578
 
1568
1579
  return dataset
1569
1580
 
1570
- @cache_arguments()
1571
1581
  def get_generation_kwargs(self, dataset_config: DatasetConfig) -> dict[str, t.Any]:
1572
1582
  """Get the generation arguments for the model.
1573
1583
 
@@ -1865,6 +1875,14 @@ def clean_model_id(model_id: str, benchmark_config: BenchmarkConfig) -> str:
1865
1875
  else:
1866
1876
  prefix = "openai/"
1867
1877
  model_id = prefix + model_id
1878
+
1879
+ # When we want to evaluate an OpenAI model on a custom inference server, such as HF
1880
+ # inference endpoints, LiteLLM gets confused since it's already using the `openai/`
1881
+ # prefix. We thus have to add it twice, and this hack here is to ensure that we
1882
+ # don't store the results with model ID `openai/openai/...`.
1883
+ elif benchmark_config.api_base is not None and model_id.startswith("openai/"):
1884
+ model_id = "openai/openai/" + re.sub(r"(openai/)*", "", model_id)
1885
+
1868
1886
  return model_id
1869
1887
 
1870
1888