EuroEval 16.2.2__py3-none-any.whl → 16.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (65) hide show
  1. euroeval/__init__.py +7 -4
  2. euroeval/benchmark_config_factory.py +0 -4
  3. euroeval/benchmark_modules/base.py +3 -16
  4. euroeval/benchmark_modules/fresh.py +5 -2
  5. euroeval/benchmark_modules/hf.py +107 -66
  6. euroeval/benchmark_modules/litellm.py +103 -55
  7. euroeval/benchmark_modules/vllm.py +155 -82
  8. euroeval/benchmarker.py +184 -129
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +1 -1
  12. euroeval/constants.py +9 -0
  13. euroeval/data_loading.py +14 -11
  14. euroeval/data_models.py +12 -4
  15. euroeval/dataset_configs/__init__.py +3 -0
  16. euroeval/dataset_configs/czech.py +79 -0
  17. euroeval/dataset_configs/danish.py +10 -13
  18. euroeval/dataset_configs/dutch.py +0 -3
  19. euroeval/dataset_configs/english.py +0 -3
  20. euroeval/dataset_configs/estonian.py +11 -1
  21. euroeval/dataset_configs/finnish.py +0 -3
  22. euroeval/dataset_configs/french.py +0 -3
  23. euroeval/dataset_configs/german.py +0 -3
  24. euroeval/dataset_configs/italian.py +0 -3
  25. euroeval/dataset_configs/latvian.py +2 -4
  26. euroeval/dataset_configs/lithuanian.py +68 -0
  27. euroeval/dataset_configs/norwegian.py +0 -3
  28. euroeval/dataset_configs/polish.py +0 -3
  29. euroeval/dataset_configs/portuguese.py +0 -3
  30. euroeval/dataset_configs/slovak.py +60 -0
  31. euroeval/dataset_configs/spanish.py +0 -3
  32. euroeval/dataset_configs/swedish.py +10 -15
  33. euroeval/finetuning.py +21 -15
  34. euroeval/generation.py +10 -10
  35. euroeval/generation_utils.py +2 -3
  36. euroeval/logging_utils.py +250 -0
  37. euroeval/metrics/base.py +0 -3
  38. euroeval/metrics/huggingface.py +10 -6
  39. euroeval/metrics/llm_as_a_judge.py +5 -3
  40. euroeval/metrics/pipeline.py +22 -9
  41. euroeval/metrics/speed.py +0 -3
  42. euroeval/model_cache.py +11 -14
  43. euroeval/model_config.py +4 -5
  44. euroeval/model_loading.py +3 -0
  45. euroeval/prompt_templates/linguistic_acceptability.py +30 -3
  46. euroeval/prompt_templates/multiple_choice.py +34 -1
  47. euroeval/prompt_templates/named_entity_recognition.py +71 -11
  48. euroeval/prompt_templates/reading_comprehension.py +41 -3
  49. euroeval/prompt_templates/sentiment_classification.py +34 -1
  50. euroeval/prompt_templates/summarization.py +26 -6
  51. euroeval/scores.py +7 -7
  52. euroeval/speed_benchmark.py +3 -5
  53. euroeval/task_group_utils/multiple_choice_classification.py +0 -3
  54. euroeval/task_group_utils/question_answering.py +0 -3
  55. euroeval/task_group_utils/sequence_classification.py +43 -31
  56. euroeval/task_group_utils/text_to_text.py +17 -8
  57. euroeval/task_group_utils/token_classification.py +10 -9
  58. euroeval/tokenisation_utils.py +22 -20
  59. euroeval/utils.py +30 -147
  60. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/METADATA +182 -61
  61. euroeval-16.4.0.dist-info/RECORD +75 -0
  62. euroeval-16.2.2.dist-info/RECORD +0 -70
  63. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
  64. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
  65. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,10 +14,9 @@ from time import sleep
14
14
  import torch
15
15
  from huggingface_hub import snapshot_download
16
16
  from pydantic import conlist, create_model
17
- from tqdm.auto import tqdm
18
- from transformers import MistralCommonTokenizer
19
17
  from transformers.models.auto.configuration_auto import AutoConfig
20
18
  from transformers.models.auto.tokenization_auto import AutoTokenizer
19
+ from transformers.tokenization_mistral_common import MistralCommonTokenizer
21
20
  from urllib3.exceptions import RequestError
22
21
 
23
22
  from ..constants import (
@@ -30,7 +29,7 @@ from ..constants import (
30
29
  REASONING_TOKENS,
31
30
  VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY,
32
31
  )
33
- from ..data_models import GenerativeModelOutput, ModelConfig
32
+ from ..data_models import GenerativeModelOutput, HashableDict, ModelConfig
34
33
  from ..enums import (
35
34
  BatchingPreference,
36
35
  GenerativeType,
@@ -50,6 +49,7 @@ from ..generation_utils import (
50
49
  raise_if_wrong_params,
51
50
  )
52
51
  from ..languages import get_all_languages
52
+ from ..logging_utils import get_pbar, log, log_once, no_terminal_output
53
53
  from ..task_group_utils import (
54
54
  question_answering,
55
55
  sequence_classification,
@@ -73,7 +73,6 @@ from ..utils import (
73
73
  get_hf_token,
74
74
  get_min_cuda_compute_capability,
75
75
  internet_connection_available,
76
- log_once,
77
76
  resolve_model_path,
78
77
  split_model_id,
79
78
  )
@@ -86,7 +85,7 @@ if t.TYPE_CHECKING or importlib.util.find_spec("vllm") is not None:
86
85
  destroy_model_parallel,
87
86
  )
88
87
  from vllm.lora.request import LoRARequest
89
- from vllm.sampling_params import GuidedDecodingParams
88
+ from vllm.sampling_params import StructuredOutputsParams
90
89
 
91
90
  if t.TYPE_CHECKING:
92
91
  from datasets import DatasetDict
@@ -95,8 +94,6 @@ if t.TYPE_CHECKING:
95
94
 
96
95
  from ..data_models import BenchmarkConfig, DatasetConfig, Task
97
96
 
98
- logger = logging.getLogger("euroeval")
99
-
100
97
 
101
98
  class VLLMModel(HuggingFaceEncoderModel):
102
99
  """A generative model using the vLLM inference framework."""
@@ -104,7 +101,7 @@ class VLLMModel(HuggingFaceEncoderModel):
104
101
  fresh_model = False
105
102
  batching_preference = BatchingPreference.ALL_AT_ONCE
106
103
  high_priority = True
107
- allowed_params = {re.compile(r".*"): ["thinking", "no-thinking"]}
104
+ allowed_params = {re.compile(r".*"): ["thinking", "no-thinking", "slow-tokenizer"]}
108
105
 
109
106
  def __init__(
110
107
  self,
@@ -132,9 +129,10 @@ class VLLMModel(HuggingFaceEncoderModel):
132
129
  model_config=model_config, allowed_params=self.allowed_params
133
130
  )
134
131
 
135
- model, tokeniser = load_model_and_tokeniser(
136
- model_config=model_config, benchmark_config=benchmark_config
137
- )
132
+ with no_terminal_output(disable=benchmark_config.verbose):
133
+ model, tokeniser = load_model_and_tokeniser(
134
+ model_config=model_config, benchmark_config=benchmark_config
135
+ )
138
136
  self._model: "LLM" = model
139
137
  self._tokeniser: "PreTrainedTokenizer" = tokeniser
140
138
 
@@ -245,6 +243,7 @@ class VLLMModel(HuggingFaceEncoderModel):
245
243
  return partial(
246
244
  sequence_classification.extract_labels_from_generation,
247
245
  dataset_config=self.dataset_config,
246
+ model_config=self.model_config,
248
247
  first_label_token_mapping=self.buffer["first_label_token_mapping"],
249
248
  )
250
249
  case TaskGroup.TEXT_TO_TEXT:
@@ -394,10 +393,11 @@ class VLLMModel(HuggingFaceEncoderModel):
394
393
  self.dataset_config.task.uses_structured_output
395
394
  or (self.dataset_config.task.uses_logprobs and self.dataset_config.labels)
396
395
  ) and self.generative_type == GenerativeType.REASONING:
397
- guided_decoding = None
398
- logger.debug(
396
+ structured_outputs = None
397
+ log(
399
398
  "The dataset uses structured output, but we are not using it as the "
400
- "model is a reasoning model."
399
+ "model is a reasoning model.",
400
+ level=logging.DEBUG,
401
401
  )
402
402
  elif self.dataset_config.task.uses_structured_output:
403
403
  ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
@@ -412,9 +412,11 @@ class VLLMModel(HuggingFaceEncoderModel):
412
412
  f"{json.dumps(structured_generation_schema)}",
413
413
  level=logging.DEBUG,
414
414
  )
415
- guided_decoding = GuidedDecodingParams(json=structured_generation_schema)
415
+ structured_outputs = StructuredOutputsParams(
416
+ json=structured_generation_schema
417
+ )
416
418
  elif self.dataset_config.task.uses_logprobs and self.dataset_config.labels:
417
- guided_decoding = GuidedDecodingParams(
419
+ structured_outputs = StructuredOutputsParams(
418
420
  choice=[
419
421
  self.dataset_config.prompt_label_mapping[label]
420
422
  for label in self.dataset_config.labels
@@ -422,11 +424,11 @@ class VLLMModel(HuggingFaceEncoderModel):
422
424
  )
423
425
  log_once(
424
426
  "Using structured generation with the choices: "
425
- f"{guided_decoding.choice!r}.",
427
+ f"{structured_outputs.choice!r}.",
426
428
  level=logging.DEBUG,
427
429
  )
428
430
  else:
429
- guided_decoding = None
431
+ structured_outputs = None
430
432
  log_once(
431
433
  "Not using structured generation as the dataset does not require it.",
432
434
  level=logging.DEBUG,
@@ -445,14 +447,14 @@ class VLLMModel(HuggingFaceEncoderModel):
445
447
  else None,
446
448
  temperature=0.0,
447
449
  stop=[stop_token for stop_token in stop_tokens if stop_token],
448
- guided_decoding=guided_decoding,
450
+ structured_outputs=structured_outputs,
449
451
  )
450
452
 
451
453
  # If any of the prompts are empty then we need to replace them with a BOS token
452
454
  # so that the vLLM model can generate from them
453
455
  prompts: list[str] = inputs["text"]
454
456
  if any(len(prompt) == 0 for prompt in prompts):
455
- logger.debug("Found empty prompts, replacing with BOS token.")
457
+ log("Found empty prompts, replacing with BOS token.", level=logging.DEBUG)
456
458
  prompts = [
457
459
  prompt if len(prompt) > 0 else str(self._tokeniser.bos_token)
458
460
  for prompt in prompts
@@ -480,13 +482,14 @@ class VLLMModel(HuggingFaceEncoderModel):
480
482
  raw_outputs = self._model.generate(
481
483
  prompts=prompts,
482
484
  sampling_params=sampling_params,
483
- use_tqdm=False if input_is_a_test else get_pbar_without_leave,
485
+ use_tqdm=False if input_is_a_test else get_pbar,
484
486
  lora_request=self.buffer.get("lora_request"),
485
487
  )
486
488
  break
487
489
  except TypeError as e:
488
- logger.debug(
489
- f"Encountered error during vLLM generation: {str(e)}. Retrying..."
490
+ log(
491
+ f"Encountered error during vLLM generation: {str(e)}. Retrying...",
492
+ level=logging.DEBUG,
490
493
  )
491
494
  sleep(1)
492
495
  except ValueError as e:
@@ -498,10 +501,11 @@ class VLLMModel(HuggingFaceEncoderModel):
498
501
  re.search(pattern, str(e), flags=re.IGNORECASE) is not None
499
502
  for pattern in truncate_error_messages
500
503
  ):
501
- logger.info(
502
- "Prompts are too long, so truncating them and trying again..."
504
+ log(
505
+ "Prompts are too long, so truncating them and trying again...",
506
+ level=logging.WARNING,
503
507
  )
504
- logger.debug(f"The error message was: {str(e)}")
508
+ log(f"The error message was: {str(e)}", level=logging.DEBUG)
505
509
 
506
510
  # If we have already tried truncating the prompts a few times, then
507
511
  # we truncate a bit more aggressively
@@ -544,26 +548,49 @@ class VLLMModel(HuggingFaceEncoderModel):
544
548
  f"{num_extra_outputs!r} extra outputs."
545
549
  )
546
550
  else:
547
- logger.debug(
551
+ log(
548
552
  f"Filtered out {num_extra_outputs:,} extra outputs from the model, "
549
553
  "which occured as we interupted the generation when we truncated "
550
- "the prompts."
554
+ "the prompts.",
555
+ level=logging.DEBUG,
551
556
  )
552
557
 
553
558
  # Parse the raw model outputs
554
559
  completion_ids: list[list[int]] = [
555
- output.outputs[0].token_ids for output in raw_outputs
560
+ list(output.outputs[0].token_ids) for output in raw_outputs
556
561
  ]
557
562
  completions = self._tokeniser.batch_decode(
558
563
  sequences=[
559
564
  torch.LongTensor(completion_id) for completion_id in completion_ids
560
565
  ]
561
566
  )
562
- if self.end_of_reasoning_token is not None:
563
- completions = [
564
- completion.split(self.end_of_reasoning_token)[-1]
565
- for completion in completions
566
- ]
567
+ if (
568
+ self.end_of_reasoning_token is not None
569
+ and self.generative_type == GenerativeType.REASONING
570
+ ):
571
+ num_samples_without_eor_token = 0
572
+ for idx in range(len(completions)):
573
+ if self.end_of_reasoning_token in completions[idx]:
574
+ completions[idx] = completions[idx].split(
575
+ self.end_of_reasoning_token
576
+ )[-1]
577
+ else:
578
+ num_samples_without_eor_token += 1
579
+ completions[idx] = ""
580
+ if num_samples_without_eor_token > 0:
581
+ log_once(
582
+ f"The model {self.model_config.model_id!r} is a reasoning "
583
+ "model, but the generated output did not contain the end of "
584
+ f"reasoning token ({self.end_of_reasoning_token!r}) in "
585
+ f"{num_samples_without_eor_token:,}/{len(completions):,} of "
586
+ "the samples. Using an empty string for all these samples "
587
+ "instead.",
588
+ level=(
589
+ logging.WARNING
590
+ if num_samples_without_eor_token / len(completions) > 0.5
591
+ else logging.DEBUG
592
+ ),
593
+ )
567
594
  stop_token_pattern = re.compile(
568
595
  "|".join(re.escape(stop_token) for stop_token in stop_tokens)
569
596
  )
@@ -584,10 +611,10 @@ class VLLMModel(HuggingFaceEncoderModel):
584
611
  scores: list[list[list[tuple[str, float]]]] = [
585
612
  [
586
613
  [
587
- (obj.decoded_token, obj.logprob)
614
+ (obj.decoded_token or "", obj.logprob)
588
615
  for obj in token_logprobs_dict.values()
589
616
  ]
590
- for token_logprobs_dict in raw_output.outputs[0].logprobs
617
+ for token_logprobs_dict in raw_output.outputs[0].logprobs or list()
591
618
  ]
592
619
  for raw_output in raw_outputs
593
620
  ]
@@ -625,7 +652,13 @@ class VLLMModel(HuggingFaceEncoderModel):
625
652
  revision = model_id_components.revision
626
653
 
627
654
  model_info = get_model_repo_info(
628
- model_id=model_id, revision=revision, benchmark_config=benchmark_config
655
+ model_id=model_id,
656
+ revision=revision,
657
+ api_key=benchmark_config.api_key,
658
+ cache_dir=benchmark_config.cache_dir,
659
+ trust_remote_code=benchmark_config.trust_remote_code,
660
+ requires_safetensors=benchmark_config.requires_safetensors,
661
+ run_with_cli=benchmark_config.run_with_cli,
629
662
  )
630
663
  return (
631
664
  model_info is not None
@@ -651,7 +684,11 @@ class VLLMModel(HuggingFaceEncoderModel):
651
684
  model_info = get_model_repo_info(
652
685
  model_id=model_id_components.model_id,
653
686
  revision=model_id_components.revision,
654
- benchmark_config=benchmark_config,
687
+ api_key=benchmark_config.api_key,
688
+ cache_dir=benchmark_config.cache_dir,
689
+ trust_remote_code=benchmark_config.trust_remote_code,
690
+ requires_safetensors=benchmark_config.requires_safetensors,
691
+ run_with_cli=benchmark_config.run_with_cli,
655
692
  )
656
693
  if model_info is None:
657
694
  raise InvalidModel(f"The model {model_id!r} could not be found.")
@@ -728,8 +765,8 @@ def load_model_and_tokeniser(
728
765
  hf_model_config = load_hf_model_config(
729
766
  model_id=model_id,
730
767
  num_labels=0,
731
- id2label=dict(),
732
- label2id=dict(),
768
+ id2label=HashableDict(),
769
+ label2id=HashableDict(),
733
770
  revision=revision,
734
771
  model_cache_dir=model_config.model_cache_dir,
735
772
  api_key=benchmark_config.api_key,
@@ -756,32 +793,36 @@ def load_model_and_tokeniser(
756
793
  # Choose bf16 over fp16 if the model is a fp32 model and the GPU supports it
757
794
  if hf_model_config.dtype == torch.float32:
758
795
  if torch.cuda.is_bf16_supported():
759
- logger.info(
796
+ log(
760
797
  "You are loading a model with dtype FP32, which we will convert to "
761
798
  "BF16 as FP32 is not supported by vLLM and BF16 is supported by your "
762
- "GPU."
799
+ "GPU.",
800
+ level=logging.WARNING,
763
801
  )
764
802
  dtype = torch.bfloat16
765
803
  else:
766
- logger.info(
804
+ log(
767
805
  "You are loading a model with dtype FP32, which we will convert to "
768
806
  "FP16 as FP32 is not supported by vLLM and BF16 is not supported by "
769
- "your GPU."
807
+ "your GPU.",
808
+ level=logging.WARNING,
770
809
  )
771
810
  dtype = torch.float16
772
811
 
773
812
  # If the model is a quantized model, we might need to change the dtype
774
813
  if quantization == "mxfp4" and hf_model_config.dtype is None:
775
814
  dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
776
- logger.debug(
815
+ log(
777
816
  "You are loading a quantized model where `dtype` has not been set. "
778
- f"Setting dtype to {dtype!r}."
817
+ f"Setting dtype to {dtype!r}.",
818
+ level=logging.DEBUG,
779
819
  )
780
820
  elif quantization is not None and hf_model_config.dtype != torch.float16:
781
- logger.info(
821
+ log(
782
822
  "You are loading a quantized model with dtype "
783
823
  f"{hf_model_config.dtype}, which vLLM does not support. Setting "
784
- "dtype to float16 instead."
824
+ "dtype to float16 instead.",
825
+ level=logging.WARNING,
785
826
  )
786
827
  dtype = torch.float16
787
828
 
@@ -792,12 +833,13 @@ def load_model_and_tokeniser(
792
833
 
793
834
  if min_cuda_compute_capability is not None:
794
835
  if min_cuda_compute_capability < required_capability:
795
- logger.info(
836
+ log(
796
837
  f"You are loading a model with dtype {hf_model_config.dtype}, "
797
838
  "which vLLM only supports for CUDA devices with CUDA compute "
798
839
  f"capability >={required_capability}. You are using one or more "
799
840
  f"devices with compute capability {min_cuda_compute_capability}. "
800
- "Setting dtype to float16 instead."
841
+ "Setting dtype to float16 instead.",
842
+ level=logging.WARNING,
801
843
  )
802
844
  dtype = torch.float16
803
845
 
@@ -830,9 +872,12 @@ def load_model_and_tokeniser(
830
872
  adapter_base_model_id=model_config.adapter_base_model_id,
831
873
  trust_remote_code=benchmark_config.trust_remote_code,
832
874
  model_max_length=true_max_model_len,
833
- model_cache_dir=model_config.model_cache_dir,
875
+ model_config=model_config,
834
876
  token=get_hf_token(api_key=benchmark_config.api_key),
835
877
  )
878
+ vllm_tokenisation_params = get_vllm_tokenisation_params(
879
+ tokeniser=tokeniser, model_config=model_config
880
+ )
836
881
 
837
882
  clear_vllm()
838
883
 
@@ -865,16 +910,7 @@ def load_model_and_tokeniser(
865
910
  enable_prefix_caching=False,
866
911
  enable_lora=model_config.adapter_base_model_id is not None,
867
912
  max_lora_rank=256,
868
- # Special arguments in case we are dealing with a Mistral model
869
- tokenizer_mode="mistral"
870
- if isinstance(tokeniser, MistralCommonTokenizer)
871
- else "auto",
872
- config_format="mistral"
873
- if isinstance(tokeniser, MistralCommonTokenizer)
874
- else "auto",
875
- load_format="mistral"
876
- if isinstance(tokeniser, MistralCommonTokenizer)
877
- else "auto",
913
+ **vllm_tokenisation_params,
878
914
  )
879
915
  except (RuntimeError, ValueError, OSError) as e:
880
916
  if "awaiting a review from the repo authors" in str(e):
@@ -903,7 +939,7 @@ def load_tokeniser(
903
939
  adapter_base_model_id: str | None,
904
940
  trust_remote_code: bool,
905
941
  model_max_length: int,
906
- model_cache_dir: str,
942
+ model_config: "ModelConfig",
907
943
  token: str | bool,
908
944
  ) -> "PreTrainedTokenizer":
909
945
  """Load the tokeniser.
@@ -920,8 +956,8 @@ def load_tokeniser(
920
956
  Whether to trust remote code.
921
957
  model_max_length:
922
958
  The maximum length of the model.
923
- model_cache_dir:
924
- The cache directory for the model.
959
+ model_config:
960
+ The model configuration.
925
961
  token:
926
962
  The Hugging Face API token.
927
963
 
@@ -932,7 +968,7 @@ def load_tokeniser(
932
968
  config = AutoConfig.from_pretrained(
933
969
  adapter_base_model_id or model_id,
934
970
  revision=revision,
935
- cache_dir=model_cache_dir,
971
+ cache_dir=model_config.model_cache_dir,
936
972
  token=token,
937
973
  trust_remote_code=trust_remote_code,
938
974
  local_files_only=not internet_connection_available(),
@@ -940,15 +976,25 @@ def load_tokeniser(
940
976
  num_retries = 5
941
977
  for _ in range(num_retries):
942
978
  try:
979
+ # Mistral instruction-tuned models need a custom tokeniser
980
+ if model_id.startswith("mistralai/") and "base" not in model_id.lower():
981
+ tokeniser = MistralCommonTokenizer.from_pretrained(
982
+ model_id,
983
+ padding_side="left",
984
+ truncation_side="left",
985
+ model_max_length=model_max_length,
986
+ token=token,
987
+ )
988
+ break
943
989
  tokeniser = AutoTokenizer.from_pretrained(
944
990
  model_id,
945
- use_fast=True,
991
+ use_fast=False if model_config.param == "slow-tokenizer" else True,
946
992
  verbose=False,
947
993
  trust_remote_code=trust_remote_code,
948
994
  padding_side="left",
949
995
  truncation_side="left",
950
996
  model_max_length=model_max_length,
951
- cache_dir=model_cache_dir,
997
+ cache_dir=model_config.model_cache_dir,
952
998
  config=config,
953
999
  token=token,
954
1000
  local_files_only=not internet_connection_available(),
@@ -960,13 +1006,17 @@ def load_tokeniser(
960
1006
  f"Could not load tokeniser for model {model_id!r}. The error was "
961
1007
  f"{str(e)}."
962
1008
  ) from e
963
- logger.debug(
1009
+ log(
964
1010
  f"Could not load tokeniser for {model_id!r}. Falling back to "
965
- f"{adapter_base_model_id!r}."
1011
+ f"{adapter_base_model_id!r}.",
1012
+ level=logging.DEBUG,
966
1013
  )
967
1014
  model_id = adapter_base_model_id
968
1015
  except (TimeoutError, RequestError):
969
- logger.info(f"Couldn't load tokeniser for {model_id!r}. Retrying.")
1016
+ log(
1017
+ f"Couldn't load tokeniser for {model_id!r}. Retrying.",
1018
+ level=logging.WARNING,
1019
+ )
970
1020
  sleep(5)
971
1021
  continue
972
1022
  except (KeyError, ValueError) as e:
@@ -1165,27 +1215,50 @@ def get_custom_stop_tokens(
1165
1215
  if stop_token in prompt or stop_token in completion
1166
1216
  ]
1167
1217
  if stop_tokens:
1168
- logger.debug(
1218
+ log(
1169
1219
  f"Found the following custom stop tokens for model {model_id!r}: "
1170
- f"{stop_tokens}."
1220
+ f"{stop_tokens}.",
1221
+ level=logging.DEBUG,
1171
1222
  )
1172
1223
  else:
1173
- logger.debug(f"Found no custom stop tokens for model {model_id!r}.")
1224
+ log(f"Found no custom stop tokens for model {model_id!r}.", level=logging.DEBUG)
1174
1225
 
1175
1226
  return stop_tokens
1176
1227
 
1177
1228
 
1178
- def get_pbar_without_leave(*tqdm_args, **tqdm_kwargs) -> tqdm:
1179
- """Get a progress bar for vLLM which disappears after completion.
1229
+ def get_vllm_tokenisation_params(
1230
+ tokeniser: "PreTrainedTokenizer", model_config: "ModelConfig"
1231
+ ) -> dict[str, t.Any]:
1232
+ """Get the tokenisation parameters for vLLM.
1180
1233
 
1181
1234
  Args:
1182
- *tqdm_args:
1183
- Positional arguments to pass to tqdm.
1184
- **tqdm_kwargs:
1185
- Additional keyword arguments to pass to tqdm.
1235
+ tokeniser:
1236
+ The tokeniser.
1237
+ model_config:
1238
+ The model configuration.
1186
1239
 
1187
1240
  Returns:
1188
- A tqdm progress bar.
1241
+ A dictionary of tokenisation parameters to pass to vLLM.
1189
1242
  """
1190
- tqdm_kwargs.pop("leave", None) # Remove the 'leave' key if it exists
1191
- return tqdm(*tqdm_args, leave=False, **tqdm_kwargs)
1243
+ if isinstance(tokeniser, MistralCommonTokenizer):
1244
+ tokeniser_mode = "mistral"
1245
+ elif model_config.param == "slow-tokenizer":
1246
+ tokeniser_mode = "slow"
1247
+ else:
1248
+ tokeniser_mode = "auto"
1249
+
1250
+ if isinstance(tokeniser, MistralCommonTokenizer):
1251
+ config_format = "mistral"
1252
+ else:
1253
+ config_format = "auto"
1254
+
1255
+ if isinstance(tokeniser, MistralCommonTokenizer):
1256
+ load_format = "mistral"
1257
+ else:
1258
+ load_format = "auto"
1259
+
1260
+ return dict(
1261
+ tokenizer_mode=tokeniser_mode,
1262
+ config_format=config_format,
1263
+ load_format=load_format,
1264
+ )