ScandEval 16.10.0__py3-none-any.whl → 16.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -758,12 +758,25 @@ def get_model_repo_info(
758
758
  # model info object.
759
759
  model_info: HfApiModelInfo | None = None
760
760
  if Path(model_id).is_dir():
761
- log(f"Checking for local model in {model_id}.", level=logging.DEBUG)
762
761
  if all(
763
762
  (Path(model_id) / required_file).exists()
764
763
  for required_file in LOCAL_MODELS_REQUIRED_FILES
765
764
  ):
765
+ log_once(
766
+ f"The local model directory {model_id!r} has all the required model "
767
+ f"files ({LOCAL_MODELS_REQUIRED_FILES}), so we're skipping looking up "
768
+ "model information from the Hugging Face Hub.",
769
+ level=logging.DEBUG,
770
+ )
766
771
  model_info = HfApiModelInfo(id=model_id, tags=None, pipeline_tag=None)
772
+ else:
773
+ log_once(
774
+ f"The local model directory {model_id} does not contain all the "
775
+ f"required files: {LOCAL_MODELS_REQUIRED_FILES}. Skipping this "
776
+ f"model.",
777
+ level=logging.WARNING,
778
+ )
779
+ return None
767
780
 
768
781
  # If we have not internet, and the model_id is not a directory for a local model
769
782
  # we also just create a dummy model info object.
@@ -4,6 +4,7 @@ import asyncio
4
4
  import collections.abc as c
5
5
  import json
6
6
  import logging
7
+ import os
7
8
  import re
8
9
  import typing as t
9
10
  from functools import cached_property, partial
@@ -32,9 +33,10 @@ from litellm.exceptions import (
32
33
  )
33
34
  from litellm.llms.vertex_ai.common_utils import VertexAIError
34
35
  from litellm.router import Router
36
+ from litellm.types.router import RouterRateLimitError
35
37
  from litellm.types.utils import ChoiceLogprobs, Logprobs
36
38
  from litellm.utils import supports_reasoning, supports_response_schema
37
- from pydantic import conlist, create_model
39
+ from pydantic import ValidationError, conlist, create_model
38
40
  from requests.exceptions import RequestException
39
41
  from tqdm.asyncio import tqdm as tqdm_async
40
42
 
@@ -99,12 +101,13 @@ if t.TYPE_CHECKING:
99
101
 
100
102
  VOCAB_SIZE_MAPPING = {
101
103
  # OpenAI models
104
+ r"gpt-5\.2.*": -1,
102
105
  r"gpt-5-.*": 100_256,
103
106
  r"gpt-4-(32k)?(-[0-9]{4})?": 100_256,
104
107
  r"gpt-4-[0-9]{4}-preview": 100_256,
105
108
  r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 100_256,
106
109
  r"gpt-4-(vision|turbo)(-preview)?": 100_256,
107
- r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 100_256,
110
+ r"gpt-3\.5-turbo-instruct(-[0-9]{4})?": 100_256,
108
111
  r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_019,
109
112
  r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
110
113
  # Anthropic models
@@ -113,23 +116,27 @@ VOCAB_SIZE_MAPPING = {
113
116
  r"(gemini/)?gemini-[1-9](\.[0-9])?-(flash|pro).*": 256_128,
114
117
  # xAI models
115
118
  r"(xai/)?grok.*": -1,
119
+ # Chat.dk models
120
+ r"(ordbogen/)?odin-medium.*": -1,
121
+ r"(ordbogen/)?odin-large.*": -1,
116
122
  }
117
123
 
118
124
 
119
125
  MODEL_MAX_LENGTH_MAPPING = {
120
126
  # OpenAI models
127
+ r"gpt-5\.2.*": 400_000,
121
128
  r"gpt-5-.*": 272_000,
122
129
  r"gpt-4(-[0-9]{4})?": 8_191,
123
130
  r"gpt-4-32k(-[0-9]{4})?": 32_767,
124
131
  r"gpt-4-[0-9]{4}-preview": 128_000,
125
132
  r"gpt-4-turbo(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
126
133
  r"gpt-4-(vision|turbo)(-preview)?": 128_000,
127
- r"gpt-3.5-turbo-instruct(-[0-9]{4})?": 4_095,
134
+ r"gpt-3\.5-turbo-instruct(-[0-9]{4})?": 4_095,
128
135
  r"gpt-4o(-mini)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
129
136
  r"o1-(mini|preview)(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 128_000,
130
137
  r"o1(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
131
138
  r"o[2-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": 200_000,
132
- r"gpt-4.1.*": 1_047_576,
139
+ r"gpt-4\.1.*": 1_047_576,
133
140
  # Anthropic models
134
141
  r"(anthropic/)?claude-[1-9](-[1-9])?-(opus|sonnet|haiku)-[0-9]{8}": 200_000,
135
142
  r"(anthropic/)?claude-(opus|sonnet|haiku)-[1-9](-[1-9])?-[0-9]{8}": 200_000,
@@ -139,12 +146,15 @@ MODEL_MAX_LENGTH_MAPPING = {
139
146
  r"(gemini/)?gemini-[23](\.[05])?.*": 1_048_576,
140
147
  # xAI models
141
148
  r"(xai/)?grok.*": 131_072,
149
+ # Chat.dk models
150
+ r"(ordbogen/)?odin-medium.*": 131_072,
151
+ r"(ordbogen/)?odin-large.*": 202_752,
142
152
  }
143
153
 
144
154
 
145
155
  NUM_PARAMS_MAPPING = {
146
156
  # OpenAI models
147
- r"gpt-5-.*": -1,
157
+ r"gpt-5.*": -1,
148
158
  r"gpt-4.*": -1,
149
159
  r"o[1-9](-mini|-preview)?(-[0-9]{4}-[0-9]{2}-[0-9]{2})?": -1,
150
160
  # Anthropic models
@@ -155,6 +165,9 @@ NUM_PARAMS_MAPPING = {
155
165
  r"(gemini/)?gemini-[23](.[05])?.*": -1,
156
166
  # xAI models
157
167
  r"(xai/)?grok.*": -1,
168
+ # Chat.dk models
169
+ r"(ordbogen/)?odin-medium.*": -1,
170
+ r"(ordbogen/)?odin-large.*": -1,
158
171
  }
159
172
 
160
173
 
@@ -164,6 +177,7 @@ REASONING_MODELS = [
164
177
  r"(gemini/)?gemini-2.5.*",
165
178
  r"(xai/)?grok-3-mini.*",
166
179
  r".*gpt-oss.*",
180
+ r"(ordbogen/)?odin-.*",
167
181
  ]
168
182
 
169
183
  BASE_DECODER_MODELS = [
@@ -186,6 +200,8 @@ CUSTOM_INFERENCE_API_PREFIXES = [
186
200
  "openai/",
187
201
  ]
188
202
 
203
+ UNOFFICIAL_INFERENCE_API_PREFIXES = ["ordbogen/"]
204
+
189
205
 
190
206
  class LiteLLMModel(BenchmarkModule):
191
207
  """A generative model from LiteLLM."""
@@ -220,7 +236,7 @@ class LiteLLMModel(BenchmarkModule):
220
236
  dataset_config: DatasetConfig,
221
237
  benchmark_config: BenchmarkConfig,
222
238
  log_metadata: bool = True,
223
- **generation_kwargs: dict[str, t.Any],
239
+ **generation_kwargs,
224
240
  ) -> None:
225
241
  """Initialise the model.
226
242
 
@@ -241,6 +257,10 @@ class LiteLLMModel(BenchmarkModule):
241
257
  model_config=model_config, allowed_params=self.allowed_params
242
258
  )
243
259
 
260
+ set_up_benchmark_config_for_model(
261
+ benchmark_config=benchmark_config, model_id=model_config.model_id
262
+ )
263
+
244
264
  # Detect whether the model is an Ollama model, as we need to extract metadata
245
265
  # differently for these models
246
266
  self.is_ollama = model_config.model_id.startswith(
@@ -401,7 +421,7 @@ class LiteLLMModel(BenchmarkModule):
401
421
  http_429_errors = [
402
422
  idx
403
423
  for idx, (_, error) in enumerate(failures)
404
- if isinstance(error, RateLimitError) and "Error code: 429" in str(error)
424
+ if isinstance(error, RateLimitError)
405
425
  ]
406
426
  if http_429_errors and self.buffer["max_concurrent_calls"] > 1:
407
427
  failures = [
@@ -417,7 +437,6 @@ class LiteLLMModel(BenchmarkModule):
417
437
  f"{self.buffer['max_concurrent_calls']:,} due to rate limiting.",
418
438
  level=logging.DEBUG,
419
439
  )
420
- continue
421
440
 
422
441
  # Attempt to handle the exceptions, to improve the chance of getting
423
442
  # successful generations next time around
@@ -483,11 +502,13 @@ class LiteLLMModel(BenchmarkModule):
483
502
  "you've reached the maximum number of requests with logprobs",
484
503
  "logprobs is not supported",
485
504
  "logprobs is not enabled",
486
- "Invalid value at 'generation_config.response_logprobs' (TYPE_BOOL)",
487
505
  ]
488
506
  logprobs_pattern = re.compile(
489
507
  r"does not support parameters: \[.*'logprobs'.*\]"
490
508
  )
509
+ logprobs_argument_should_be_bool_messages = [
510
+ "Invalid value at 'generation_config.response_logprobs' (TYPE_BOOL)"
511
+ ]
491
512
  top_logprobs_messages = ["got an unexpected keyword argument 'top_logprobs'"]
492
513
  top_logprobs_pattern = re.compile(
493
514
  r"does not support parameters: \[.*'top_logprobs'.*\]"
@@ -548,6 +569,17 @@ class LiteLLMModel(BenchmarkModule):
548
569
  generation_kwargs.pop("top_logprobs", None)
549
570
  generation_kwargs.pop("response_format", None)
550
571
  return generation_kwargs, 0
572
+ elif any(
573
+ msg.lower() in error_msg
574
+ for msg in logprobs_argument_should_be_bool_messages
575
+ ):
576
+ log_once(
577
+ f"The model {model_id!r} requires the `logprobs` argument to be a "
578
+ "Boolean, so setting it to True.",
579
+ level=logging.DEBUG,
580
+ )
581
+ generation_kwargs["logprobs"] = True
582
+ return generation_kwargs, 0
551
583
  elif (
552
584
  any(msg.lower() in error_msg for msg in top_logprobs_messages)
553
585
  or top_logprobs_pattern.search(string=error_msg) is not None
@@ -700,23 +732,25 @@ class LiteLLMModel(BenchmarkModule):
700
732
  ) from error
701
733
 
702
734
  if (
703
- isinstance(error, (RateLimitError, BadRequestError))
735
+ isinstance(error, (RateLimitError, RouterRateLimitError, BadRequestError))
704
736
  and (
705
737
  retry_match := re.search(
706
- pattern=r"\bretry in ([0-9]+(.[0-9]+)?) ?(s|seconds)\b",
738
+ pattern=(
739
+ r"\b(try( again)?|retry) in ([0-9]+(\.[0-9]+)?) ?(s|seconds?)\b"
740
+ ),
707
741
  string=error_msg,
708
742
  flags=re.IGNORECASE,
709
743
  )
710
744
  )
711
745
  is not None
712
746
  ):
713
- retry_seconds = float(retry_match.group(1))
747
+ retry_seconds = float(retry_match.group(3))
714
748
  log_once(
715
749
  f"You have encountered your rate limit for model {model_id!r}.",
716
750
  level=logging.DEBUG,
717
751
  )
718
752
  return generation_kwargs, int(retry_seconds)
719
- elif isinstance(error, RateLimitError):
753
+ elif isinstance(error, (RateLimitError, RouterRateLimitError)):
720
754
  log_once(
721
755
  f"You have encountered your rate limit for model {model_id!r}.",
722
756
  level=logging.DEBUG,
@@ -919,12 +953,37 @@ class LiteLLMModel(BenchmarkModule):
919
953
  logprobs_obj = model_response_choices.logprobs
920
954
 
921
955
  if not isinstance(logprobs_obj, (Logprobs, ChoiceLogprobs)):
922
- log_once(
923
- "The logprobs object is malformed, so we won't use logprobs to "
924
- "determine the labels.",
925
- level=logging.WARNING,
956
+ error_msg = (
957
+ "The logprobs object is malformed, so we won't use logprobs "
958
+ "to determine the labels."
959
+ )
960
+ if not isinstance(logprobs_obj, list):
961
+ log_once(error_msg, level=logging.WARNING)
962
+ continue
963
+
964
+ # Some APIs have implemented the logprobs differently, being a list
965
+ # of ChoiceLogprobs dictionaries rather than having that list being
966
+ # under the 'content' key, so we deal with that here.
967
+ # TODO: Maybe remove this in future if all APIs standardise this
968
+ try:
969
+ choice_logprobs_list = [
970
+ ChoiceLogprobs.model_validate(item) for item in logprobs_obj
971
+ ]
972
+ except ValidationError:
973
+ log_once(error_msg, level=logging.WARNING)
974
+ continue
975
+ if not all(
976
+ len(item.content or []) == 1 for item in choice_logprobs_list
977
+ ):
978
+ log_once(error_msg, level=logging.WARNING)
979
+ continue
980
+ logprobs_obj = ChoiceLogprobs(
981
+ content=[
982
+ item.content[0]
983
+ for item in choice_logprobs_list
984
+ if item.content
985
+ ]
926
986
  )
927
- continue
928
987
 
929
988
  logprobs_list: c.Sequence[c.Sequence[tuple[str, float]]]
930
989
  if isinstance(logprobs_obj, ChoiceLogprobs):
@@ -964,10 +1023,9 @@ class LiteLLMModel(BenchmarkModule):
964
1023
 
965
1024
  if not sequences:
966
1025
  log(
967
- "No sequences were generated by the model "
968
- f"{model_id!r}. This may be due to the "
969
- "model running out of tokens or an issue with the input data. "
970
- "Returning an empty GenerativeModelOutput.",
1026
+ f"No sequences were generated by the model {model_id!r}. This may be "
1027
+ "due to the model running out of tokens or an issue with the input "
1028
+ "data. Returning an empty GenerativeModelOutput.",
971
1029
  level=logging.WARNING,
972
1030
  )
973
1031
  return GenerativeModelOutput(sequences=[], scores=None)
@@ -1295,6 +1353,10 @@ class LiteLLMModel(BenchmarkModule):
1295
1353
  if model_id in litellm.model_list:
1296
1354
  return True
1297
1355
 
1356
+ set_up_benchmark_config_for_model(
1357
+ benchmark_config=benchmark_config, model_id=model_id
1358
+ )
1359
+
1298
1360
  # Separate check for Ollama models
1299
1361
  if model_id.startswith("ollama/") or model_id.startswith("ollama_chat/"):
1300
1362
  ollama_model_exists = try_download_ollama_model(
@@ -1596,6 +1658,11 @@ class LiteLLMModel(BenchmarkModule):
1596
1658
  level=logging.DEBUG,
1597
1659
  )
1598
1660
 
1661
+ # If the model is a Chat.dk model, we make sure reasoning traces are not
1662
+ # included in the output
1663
+ if self.model_config.model_id.startswith("ordbogen/"):
1664
+ generation_kwargs["include_reasoning"] = False
1665
+
1599
1666
  # Handle manually set parameters
1600
1667
  if self.buffer["first_label_token_mapping"]:
1601
1668
  generation_kwargs["logprobs"] = True
@@ -1784,6 +1851,12 @@ def clean_model_id(model_id: str, benchmark_config: BenchmarkConfig) -> str:
1784
1851
  Returns:
1785
1852
  The cleaned model ID.
1786
1853
  """
1854
+ # Remove unofficial prefixes
1855
+ for unofficial_prefix in UNOFFICIAL_INFERENCE_API_PREFIXES:
1856
+ model_id = re.sub(
1857
+ pattern=rf"^{re.escape(unofficial_prefix)}", repl="", string=model_id
1858
+ )
1859
+
1787
1860
  if benchmark_config.api_base is not None and not any(
1788
1861
  model_id.startswith(prefix) for prefix in CUSTOM_INFERENCE_API_PREFIXES
1789
1862
  ):
@@ -1793,3 +1866,19 @@ def clean_model_id(model_id: str, benchmark_config: BenchmarkConfig) -> str:
1793
1866
  prefix = "openai/"
1794
1867
  model_id = prefix + model_id
1795
1868
  return model_id
1869
+
1870
+
1871
+ def set_up_benchmark_config_for_model(
1872
+ benchmark_config: BenchmarkConfig, model_id: str
1873
+ ) -> None:
1874
+ """Set up the benchmark configuration for the model.
1875
+
1876
+ Args:
1877
+ benchmark_config:
1878
+ The benchmark configuration to set up.
1879
+ model_id:
1880
+ The model ID.
1881
+ """
1882
+ if model_id.startswith("ordbogen/"):
1883
+ benchmark_config.api_key = os.getenv("ORDBOGEN_API_KEY")
1884
+ benchmark_config.api_base = "https://api.ordbogen.ai/v1"
@@ -500,7 +500,8 @@ class VLLMModel(HuggingFaceEncoderModel):
500
500
  log_once(
501
501
  f"Using temperature={temperature} with the model "
502
502
  f"{self.model_config.model_id!r} as specified in its "
503
- "generation configuration."
503
+ "generation configuration.",
504
+ level=logging.DEBUG,
504
505
  )
505
506
  if "top_p" in changed_params:
506
507
  top_p = changed_params["top_p"]
@@ -508,7 +509,8 @@ class VLLMModel(HuggingFaceEncoderModel):
508
509
  log_once(
509
510
  f"Using top_p={top_p} with the model "
510
511
  f"{self.model_config.model_id!r} as specified in its "
511
- "generation configuration."
512
+ "generation configuration.",
513
+ level=logging.DEBUG,
512
514
  )
513
515
  if "top_k" in changed_params:
514
516
  top_k = changed_params["top_k"]
@@ -516,7 +518,8 @@ class VLLMModel(HuggingFaceEncoderModel):
516
518
  log_once(
517
519
  f"Using top_k={top_k} with the model "
518
520
  f"{self.model_config.model_id!r} as specified in its "
519
- "generation configuration."
521
+ "generation configuration.",
522
+ level=logging.DEBUG,
520
523
  )
521
524
  if "repetition_penalty" in changed_params:
522
525
  repetition_penalty = changed_params["repetition_penalty"]
@@ -524,8 +527,10 @@ class VLLMModel(HuggingFaceEncoderModel):
524
527
  log_once(
525
528
  f"Using repetition_penalty={repetition_penalty} with the model "
526
529
  f"{self.model_config.model_id!r} as specified in its "
527
- "generation configuration."
530
+ "generation configuration.",
531
+ level=logging.DEBUG,
528
532
  )
533
+
529
534
  max_tokens: int = (
530
535
  REASONING_MAX_TOKENS
531
536
  if self.generative_type == GenerativeType.REASONING
@@ -567,16 +572,76 @@ class VLLMModel(HuggingFaceEncoderModel):
567
572
  )
568
573
  prompts = [prompt.strip() for prompt in prompts]
569
574
 
570
- # Truncate the prompts if needed, but only if it's not a reasoning model
571
- if self.generative_type != GenerativeType.REASONING:
572
- max_tokens_per_prompt = (
573
- min(self._tokeniser.model_max_length, MAX_CONTEXT_LENGTH) - max_tokens
574
- )
575
- tokenized_prompts = self._tokeniser(
576
- text=list(prompts), truncation=True, max_length=max_tokens_per_prompt
575
+ # Truncate the prompts if needed
576
+ max_tokens_per_prompt = min(
577
+ self._tokeniser.model_max_length, MAX_CONTEXT_LENGTH
578
+ )
579
+ max_tokens_per_prompt -= min(
580
+ self.dataset_config.max_generated_tokens, max_tokens_per_prompt - 1
581
+ )
582
+ tokenized_prompts = self._tokeniser(
583
+ text=prompts, max_length=max_tokens_per_prompt
584
+ )
585
+ if any(
586
+ len(input_ids) > max_tokens_per_prompt
587
+ for input_ids in tokenized_prompts.input_ids
588
+ ):
589
+ log(
590
+ f"Truncating prompts for the model {self.model_config.model_id!r} "
591
+ f"to a maximum of {max_tokens_per_prompt:,} tokens.",
592
+ level=logging.DEBUG,
577
593
  )
578
- prompts = self._tokeniser.batch_decode(
579
- sequences=tokenized_prompts.input_ids, skip_special_tokens=True
594
+ match self.generative_type:
595
+ case GenerativeType.BASE:
596
+ truncated_tokenized_prompts = self._tokeniser(
597
+ text=prompts, max_length=max_tokens_per_prompt, truncation=True
598
+ )
599
+ prompts = self._tokeniser.batch_decode(
600
+ sequences=truncated_tokenized_prompts.input_ids,
601
+ skip_special_tokens=True,
602
+ )
603
+ case GenerativeType.INSTRUCTION_TUNED | GenerativeType.REASONING:
604
+ assert self.end_of_chat_token_ids is not None, (
605
+ "The end-of-chat token IDs should be set for instruction-tuned "
606
+ "and reasoning models."
607
+ )
608
+ end_of_chat_token = self._tokeniser.decode(
609
+ list(self.end_of_chat_token_ids)
610
+ )
611
+ prompt_segments: list[list[str]] = [
612
+ prompt.replace(self._tokeniser.bos_token, "").split(
613
+ end_of_chat_token
614
+ )
615
+ for prompt in prompts
616
+ ]
617
+ for num_few_shots_to_remove in range(
618
+ 0, self.dataset_config.num_few_shot_examples + 1
619
+ ):
620
+ new_prompts = [
621
+ end_of_chat_token.join(
622
+ prompt_segment[2 * num_few_shots_to_remove :]
623
+ )
624
+ for prompt_segment in prompt_segments
625
+ ]
626
+ tokenized_prompts = self._tokeniser(
627
+ text=new_prompts, max_length=max_tokens_per_prompt
628
+ )
629
+ if all(
630
+ len(input_ids) <= max_tokens_per_prompt
631
+ for input_ids in tokenized_prompts.input_ids
632
+ ):
633
+ prompts = new_prompts
634
+ break
635
+ else:
636
+ raise InvalidBenchmark(
637
+ "Truncation of prompts failed, some prompts are still too "
638
+ "long."
639
+ )
640
+ else:
641
+ log(
642
+ f"Truncation of prompts for model {self.model_config.model_id!r} is "
643
+ "not needed, so skipping truncation.",
644
+ level=logging.DEBUG,
580
645
  )
581
646
 
582
647
  # Generate sequences using vLLM
@@ -598,10 +663,11 @@ class VLLMModel(HuggingFaceEncoderModel):
598
663
  level=logging.DEBUG,
599
664
  )
600
665
  sleep(1)
601
- except ValueError as e:
666
+ except (ValueError, RuntimeError) as e:
602
667
  # Truncate the prompts if they are too long for the model
603
668
  truncate_error_messages = [
604
- r"prompt \(length [0-9]+\) is longer than the maximum model length"
669
+ r"prompt \(length [0-9]+\) is longer than the maximum model length",
670
+ "Sampled token IDs exceed the max model length",
605
671
  ]
606
672
  if any(
607
673
  re.search(pattern, str(e), flags=re.IGNORECASE) is not None
@@ -905,19 +971,6 @@ def load_model_and_tokeniser(
905
971
  run_with_cli=benchmark_config.run_with_cli,
906
972
  )
907
973
 
908
- quantization = None
909
- if hasattr(hf_model_config, "quantization_config"):
910
- quantization = hf_model_config.quantization_config.get("quant_method")
911
-
912
- # The quantised models require extra dependencies
913
- if quantization == "gptq" and (
914
- importlib.util.find_spec("auto_gptq") is None
915
- or importlib.util.find_spec("optimum") is None
916
- ):
917
- raise NeedsExtraInstalled(extra="quantization")
918
- if quantization == "awq" and importlib.util.find_spec("awq") is None:
919
- raise NeedsExtraInstalled(extra="quantization")
920
-
921
974
  # Start with dtype being the "auto" vLLM dtype
922
975
  dtype: str | torch.dtype = "auto"
923
976
 
@@ -940,23 +993,6 @@ def load_model_and_tokeniser(
940
993
  )
941
994
  dtype = torch.float16
942
995
 
943
- # If the model is a quantized model, we might need to change the dtype
944
- if quantization == "mxfp4" and hf_model_config.dtype is None:
945
- dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
946
- log(
947
- "You are loading a quantized model where `dtype` has not been set. "
948
- f"Setting dtype to {dtype!r}.",
949
- level=logging.DEBUG,
950
- )
951
- elif quantization is not None and hf_model_config.dtype != torch.float16:
952
- log(
953
- "You are loading a quantized model with dtype "
954
- f"{hf_model_config.dtype}, which vLLM does not support. Setting "
955
- "dtype to float16 instead.",
956
- level=logging.WARNING,
957
- )
958
- dtype = torch.float16
959
-
960
996
  # If the model is a bf16 model, we need to check the CUDA compute capability
961
997
  if hf_model_config.dtype == torch.bfloat16:
962
998
  min_cuda_compute_capability = get_min_cuda_compute_capability()
@@ -974,6 +1010,28 @@ def load_model_and_tokeniser(
974
1010
  )
975
1011
  dtype = torch.float16
976
1012
 
1013
+ quantization = None
1014
+ if hasattr(hf_model_config, "quantization_config"):
1015
+ quantization = hf_model_config.quantization_config.get("quant_method")
1016
+
1017
+ # The quantised models require extra dependencies
1018
+ if quantization == "gptq" and (
1019
+ importlib.util.find_spec("auto_gptq") is None
1020
+ or importlib.util.find_spec("optimum") is None
1021
+ ):
1022
+ raise NeedsExtraInstalled(extra="quantization")
1023
+ if quantization == "awq" and importlib.util.find_spec("awq") is None:
1024
+ raise NeedsExtraInstalled(extra="quantization")
1025
+
1026
+ # If the model is a quantized model, let vLLM decide the dtype
1027
+ if quantization is not None:
1028
+ log(
1029
+ f"You are loading a quantized model with quantization {quantization}. "
1030
+ "Forcing the vLLM dtype to 'auto'",
1031
+ level=logging.WARNING,
1032
+ )
1033
+ dtype = "auto"
1034
+
977
1035
  if model_config.adapter_base_model_id is not None:
978
1036
  download_dir = str(Path(model_config.model_cache_dir) / "base_model")
979
1037
  else:
@@ -1017,17 +1075,14 @@ def load_model_and_tokeniser(
1017
1075
  )
1018
1076
 
1019
1077
  try:
1078
+ model_location = (
1079
+ model_id
1080
+ if internet_connection_available() or Path(model_id).is_dir()
1081
+ else resolve_model_path(download_dir=download_dir)
1082
+ )
1020
1083
  model = LLM(
1021
- model=(
1022
- model_id
1023
- if internet_connection_available()
1024
- else resolve_model_path(download_dir=download_dir)
1025
- ),
1026
- tokenizer=(
1027
- model_id
1028
- if internet_connection_available()
1029
- else resolve_model_path(download_dir=download_dir)
1030
- ),
1084
+ model=model_location,
1085
+ tokenizer=model_location,
1031
1086
  gpu_memory_utilization=benchmark_config.gpu_memory_utilization,
1032
1087
  max_model_len=min(true_max_model_len, MAX_CONTEXT_LENGTH),
1033
1088
  download_dir=download_dir,
@@ -1454,10 +1509,11 @@ def select_backend_and_parallelism() -> tuple[str, int, int]:
1454
1509
  try:
1455
1510
  ray.init(address="auto", ignore_reinit_error=True)
1456
1511
  except Exception as e:
1457
- log_once(
1458
- f"Ray initialisation failed with a {type(e)} exception: {e}",
1459
- level=logging.DEBUG,
1460
- )
1512
+ if "could not find any running ray instance" not in str(e).lower():
1513
+ log_once(
1514
+ f"Ray initialisation failed with a {type(e)} exception: {e}",
1515
+ level=logging.DEBUG,
1516
+ )
1461
1517
 
1462
1518
  is_ray = ray.is_initialized()
1463
1519
  local_gpu_count = torch.cuda.device_count()
@@ -1475,7 +1531,7 @@ def select_backend_and_parallelism() -> tuple[str, int, int]:
1475
1531
  pipeline_parallel_size = max(1, total_gpus // tensor_parallel_size)
1476
1532
  log_once(
1477
1533
  f"Detected a multi-node setup with {pipeline_parallel_size:,} nodes, each "
1478
- "with {tensor_parallel_size:,} GPUs, so using `ray` as the "
1534
+ f"with {tensor_parallel_size:,} GPUs, so using `ray` as the "
1479
1535
  "distributed backend.",
1480
1536
  level=logging.DEBUG,
1481
1537
  )
scandeval/benchmarker.py CHANGED
@@ -1045,8 +1045,16 @@ class Benchmarker:
1045
1045
  if model.generative_type is not None
1046
1046
  else None
1047
1047
  ),
1048
- few_shot=benchmark_config.few_shot,
1049
- validation_split=not benchmark_config.evaluate_test_split,
1048
+ few_shot=(
1049
+ None
1050
+ if dataset_config.task.requires_zero_shot
1051
+ else benchmark_config.few_shot
1052
+ ),
1053
+ validation_split=(
1054
+ None
1055
+ if "val" not in dataset_config.splits
1056
+ else not benchmark_config.evaluate_test_split
1057
+ ),
1050
1058
  )
1051
1059
  log(f"Results:\n{results}", level=logging.DEBUG)
1052
1060
  return record
@@ -1122,12 +1130,10 @@ def get_record(
1122
1130
  same_revision = model_id_components.revision == model_config.revision
1123
1131
  same_param = model_id_components.param == model_config.param
1124
1132
  same_dataset = record.dataset == dataset_config.name
1125
- same_split = (
1126
- record.validation_split != benchmark_config.evaluate_test_split
1127
- or "val" not in dataset_config.splits
1128
- )
1133
+ same_split = record.validation_split != benchmark_config.evaluate_test_split
1129
1134
  same_num_shots = (
1130
1135
  record.few_shot == benchmark_config.few_shot
1136
+ or record.few_shot is None
1131
1137
  or not record.generative
1132
1138
  or dataset_config.task.requires_zero_shot
1133
1139
  )
@@ -1225,6 +1231,7 @@ def initial_logging(
1225
1231
  f"{dataset_config.logging_string} ({num_finished_benchmarks + 1}/"
1226
1232
  f"{num_total_benchmarks} benchmarks)...",
1227
1233
  prefix=f"\n[{dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]",
1234
+ level=logging.INFO,
1228
1235
  )
1229
1236
 
1230
1237
  if dataset_config.unofficial:
scandeval/data_models.py CHANGED
@@ -623,8 +623,8 @@ class BenchmarkResult(pydantic.BaseModel):
623
623
  merge: bool
624
624
  generative: bool
625
625
  generative_type: str | None
626
- few_shot: bool
627
- validation_split: bool
626
+ few_shot: bool | None
627
+ validation_split: bool | None
628
628
  euroeval_version: str | None = get_package_version("euroeval")
629
629
  transformers_version: str | None = get_package_version("transformers")
630
630
  torch_version: str | None = get_package_version("torch")