EuroEval 16.2.2__py3-none-any.whl → 16.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (65) hide show
  1. euroeval/__init__.py +7 -4
  2. euroeval/benchmark_config_factory.py +0 -4
  3. euroeval/benchmark_modules/base.py +3 -16
  4. euroeval/benchmark_modules/fresh.py +5 -2
  5. euroeval/benchmark_modules/hf.py +107 -66
  6. euroeval/benchmark_modules/litellm.py +103 -55
  7. euroeval/benchmark_modules/vllm.py +155 -82
  8. euroeval/benchmarker.py +184 -129
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +1 -1
  12. euroeval/constants.py +9 -0
  13. euroeval/data_loading.py +14 -11
  14. euroeval/data_models.py +12 -4
  15. euroeval/dataset_configs/__init__.py +3 -0
  16. euroeval/dataset_configs/czech.py +79 -0
  17. euroeval/dataset_configs/danish.py +10 -13
  18. euroeval/dataset_configs/dutch.py +0 -3
  19. euroeval/dataset_configs/english.py +0 -3
  20. euroeval/dataset_configs/estonian.py +11 -1
  21. euroeval/dataset_configs/finnish.py +0 -3
  22. euroeval/dataset_configs/french.py +0 -3
  23. euroeval/dataset_configs/german.py +0 -3
  24. euroeval/dataset_configs/italian.py +0 -3
  25. euroeval/dataset_configs/latvian.py +2 -4
  26. euroeval/dataset_configs/lithuanian.py +68 -0
  27. euroeval/dataset_configs/norwegian.py +0 -3
  28. euroeval/dataset_configs/polish.py +0 -3
  29. euroeval/dataset_configs/portuguese.py +0 -3
  30. euroeval/dataset_configs/slovak.py +60 -0
  31. euroeval/dataset_configs/spanish.py +0 -3
  32. euroeval/dataset_configs/swedish.py +10 -15
  33. euroeval/finetuning.py +21 -15
  34. euroeval/generation.py +10 -10
  35. euroeval/generation_utils.py +2 -3
  36. euroeval/logging_utils.py +250 -0
  37. euroeval/metrics/base.py +0 -3
  38. euroeval/metrics/huggingface.py +10 -6
  39. euroeval/metrics/llm_as_a_judge.py +5 -3
  40. euroeval/metrics/pipeline.py +22 -9
  41. euroeval/metrics/speed.py +0 -3
  42. euroeval/model_cache.py +11 -14
  43. euroeval/model_config.py +4 -5
  44. euroeval/model_loading.py +3 -0
  45. euroeval/prompt_templates/linguistic_acceptability.py +30 -3
  46. euroeval/prompt_templates/multiple_choice.py +34 -1
  47. euroeval/prompt_templates/named_entity_recognition.py +71 -11
  48. euroeval/prompt_templates/reading_comprehension.py +41 -3
  49. euroeval/prompt_templates/sentiment_classification.py +34 -1
  50. euroeval/prompt_templates/summarization.py +26 -6
  51. euroeval/scores.py +7 -7
  52. euroeval/speed_benchmark.py +3 -5
  53. euroeval/task_group_utils/multiple_choice_classification.py +0 -3
  54. euroeval/task_group_utils/question_answering.py +0 -3
  55. euroeval/task_group_utils/sequence_classification.py +43 -31
  56. euroeval/task_group_utils/text_to_text.py +17 -8
  57. euroeval/task_group_utils/token_classification.py +10 -9
  58. euroeval/tokenisation_utils.py +22 -20
  59. euroeval/utils.py +30 -147
  60. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/METADATA +182 -61
  61. euroeval-16.4.0.dist-info/RECORD +75 -0
  62. euroeval-16.2.2.dist-info/RECORD +0 -70
  63. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
  64. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
  65. {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
euroeval/__init__.py CHANGED
@@ -21,7 +21,8 @@ if os.getenv("FULL_LOG") != "1":
21
21
  os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
22
22
 
23
23
  # Set up logging
24
- fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
24
+ # fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
25
+ fmt = colored("%(message)s", "light_yellow")
25
26
  logging.basicConfig(
26
27
  level=logging.CRITICAL if hasattr(sys, "_called_from_test") else logging.INFO,
27
28
  format=fmt,
@@ -50,7 +51,7 @@ import importlib.metadata # noqa: E402
50
51
  from dotenv import load_dotenv # noqa: E402
51
52
 
52
53
  from .benchmarker import Benchmarker # noqa: E402
53
- from .utils import block_terminal_output # noqa: E402
54
+ from .logging_utils import block_terminal_output # noqa: E402
54
55
 
55
56
  # Block unwanted terminal outputs. This blocks way more than the above, but since it
56
57
  # relies on importing from the `utils` module, external modules are already imported
@@ -103,8 +104,10 @@ os.environ["DISABLE_AIOHTTP_TRANSPORT"] = "True"
103
104
  os.environ["VLLM_USE_V1"] = "1"
104
105
 
105
106
 
106
- # Use the FlashInfer flash-attention backend for vLLM
107
- os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
107
+ # Use the FlashInfer flash-attention backend for vLLM, unless the user has already
108
+ # specified a different backend.
109
+ if os.getenv("VLLM_ATTENTION_BACKEND") is None:
110
+ os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
108
111
 
109
112
 
110
113
  # Set the HF_TOKEN env var to copy the HUGGINGFACE_API_KEY env var, as vLLM uses the
@@ -1,6 +1,5 @@
1
1
  """Factory class for creating dataset configurations."""
2
2
 
3
- import logging
4
3
  import sys
5
4
  import typing as t
6
5
 
@@ -17,9 +16,6 @@ if t.TYPE_CHECKING:
17
16
  from .data_models import Language, Task
18
17
 
19
18
 
20
- logger = logging.getLogger("euroeval")
21
-
22
-
23
19
  def build_benchmark_config(
24
20
  benchmark_config_params: BenchmarkConfigParams,
25
21
  ) -> BenchmarkConfig:
@@ -3,24 +3,22 @@
3
3
  import collections.abc as c
4
4
  import logging
5
5
  import re
6
- import sys
7
6
  import typing as t
8
7
  from abc import ABC, abstractmethod
9
8
  from functools import cached_property, partial
10
9
 
11
10
  from datasets import Dataset, DatasetDict
12
11
  from torch import nn
13
- from tqdm.auto import tqdm
14
12
 
15
13
  from ..enums import TaskGroup
16
14
  from ..exceptions import InvalidBenchmark, NeedsEnvironmentVariable, NeedsExtraInstalled
15
+ from ..logging_utils import get_pbar, log_once
17
16
  from ..task_group_utils import (
18
17
  question_answering,
19
18
  sequence_classification,
20
19
  text_to_text,
21
20
  token_classification,
22
21
  )
23
- from ..utils import log_once
24
22
 
25
23
  if t.TYPE_CHECKING:
26
24
  from transformers.tokenization_utils import PreTrainedTokenizer
@@ -36,8 +34,6 @@ if t.TYPE_CHECKING:
36
34
  from ..enums import BatchingPreference, GenerativeType
37
35
  from ..types import ComputeMetricsFunction, ExtractLabelsFunction
38
36
 
39
- logger = logging.getLogger("euroeval")
40
-
41
37
 
42
38
  class BenchmarkModule(ABC):
43
39
  """Abstract class for a benchmark module.
@@ -87,16 +83,7 @@ class BenchmarkModule(ABC):
87
83
 
88
84
  def _log_metadata(self) -> None:
89
85
  """Log the metadata of the model."""
90
- # Set logging level based on verbosity
91
- if hasattr(sys, "_called_from_test"):
92
- logging_level = logging.CRITICAL
93
- elif self.benchmark_config.verbose:
94
- logging_level = logging.DEBUG
95
- else:
96
- logging_level = logging.INFO
97
- logger.setLevel(logging_level)
98
-
99
- logging_msg: str = ""
86
+ logging_msg: str = " ↳ "
100
87
  if self.num_params < 0:
101
88
  logging_msg += "The model has an unknown number of parameters, "
102
89
  else:
@@ -273,7 +260,7 @@ class BenchmarkModule(ABC):
273
260
  tasks.
274
261
  """
275
262
  for idx, dataset in enumerate(
276
- tqdm(iterable=datasets, desc="Preparing datasets")
263
+ get_pbar(iterable=datasets, desc="Preparing datasets")
277
264
  ):
278
265
  prepared_dataset = self.prepare_dataset(
279
266
  dataset=dataset, task=task, itr_idx=idx
@@ -1,5 +1,6 @@
1
1
  """Freshly initialised encoder models."""
2
2
 
3
+ import re
3
4
  import typing as t
4
5
  from functools import cached_property
5
6
  from json import JSONDecodeError
@@ -26,7 +27,8 @@ from ..exceptions import (
26
27
  NeedsExtraInstalled,
27
28
  )
28
29
  from ..generation_utils import raise_if_wrong_params
29
- from ..utils import block_terminal_output, create_model_cache_dir, get_hf_token
30
+ from ..logging_utils import block_terminal_output
31
+ from ..utils import create_model_cache_dir, get_hf_token
30
32
  from .hf import (
31
33
  HuggingFaceEncoderModel,
32
34
  align_model_and_tokeniser,
@@ -45,6 +47,7 @@ class FreshEncoderModel(HuggingFaceEncoderModel):
45
47
  """A freshly initialised encoder model."""
46
48
 
47
49
  fresh_model = True
50
+ allowed_params = {re.compile(r".*"): ["slow-tokenizer"]}
48
51
 
49
52
  def __init__(
50
53
  self,
@@ -294,7 +297,7 @@ def load_model_and_tokeniser(
294
297
  token=get_hf_token(api_key=benchmark_config.api_key),
295
298
  add_prefix_space=prefix,
296
299
  cache_dir=model_config.model_cache_dir,
297
- use_fast=True,
300
+ use_fast=False if model_config.param == "slow-tokenizer" else True,
298
301
  verbose=False,
299
302
  trust_remote_code=benchmark_config.trust_remote_code,
300
303
  )
@@ -2,6 +2,7 @@
2
2
 
3
3
  import collections.abc as c
4
4
  import logging
5
+ import re
5
6
  import typing as t
6
7
  from functools import cached_property, partial
7
8
  from json import JSONDecodeError
@@ -35,6 +36,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
35
36
  from transformers.trainer import Trainer
36
37
  from urllib3.exceptions import RequestError
37
38
 
39
+ from ..caching_utils import cache_arguments
38
40
  from ..constants import (
39
41
  DUMMY_FILL_VALUE,
40
42
  GENERATIVE_PIPELINE_TAGS,
@@ -42,7 +44,7 @@ from ..constants import (
42
44
  MAX_CONTEXT_LENGTH,
43
45
  MERGE_TAGS,
44
46
  )
45
- from ..data_models import HFModelInfo, ModelConfig
47
+ from ..data_models import HashableDict, HFModelInfo, ModelConfig
46
48
  from ..enums import (
47
49
  BatchingPreference,
48
50
  GenerativeType,
@@ -59,6 +61,7 @@ from ..exceptions import (
59
61
  )
60
62
  from ..generation_utils import raise_if_wrong_params
61
63
  from ..languages import get_all_languages
64
+ from ..logging_utils import block_terminal_output, log, log_once
62
65
  from ..task_group_utils import (
63
66
  multiple_choice_classification,
64
67
  question_answering,
@@ -66,12 +69,10 @@ from ..task_group_utils import (
66
69
  )
67
70
  from ..tokenisation_utils import get_bos_token, get_eos_token
68
71
  from ..utils import (
69
- block_terminal_output,
70
72
  create_model_cache_dir,
71
73
  get_class_by_name,
72
74
  get_hf_token,
73
75
  internet_connection_available,
74
- log_once,
75
76
  split_model_id,
76
77
  )
77
78
  from .base import BenchmarkModule
@@ -84,8 +85,6 @@ if t.TYPE_CHECKING:
84
85
  from ..data_models import BenchmarkConfig, DatasetConfig, Task
85
86
  from ..types import ExtractLabelsFunction
86
87
 
87
- logger = logging.getLogger("euroeval")
88
-
89
88
 
90
89
  class HuggingFaceEncoderModel(BenchmarkModule):
91
90
  """An encoder model from the Hugging Face Hub."""
@@ -93,6 +92,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
93
92
  fresh_model = False
94
93
  batching_preference = BatchingPreference.NO_PREFERENCE
95
94
  high_priority = True
95
+ allowed_params = {re.compile(r".*"): ["slow-tokenizer"]}
96
96
 
97
97
  def __init__(
98
98
  self,
@@ -181,12 +181,13 @@ class HuggingFaceEncoderModel(BenchmarkModule):
181
181
  elif hasattr(self._model, "parameters"):
182
182
  num_params = sum(p.numel() for p in self._model.parameters())
183
183
  else:
184
- logger.warning(
184
+ log(
185
185
  "The number of parameters could not be determined for the model, since "
186
186
  "the model is not stored in the safetensors format. If this is your "
187
187
  "own model, then you can use this Hugging Face Space to convert your "
188
188
  "model to the safetensors format: "
189
- "https://huggingface.co/spaces/safetensors/convert."
189
+ "https://huggingface.co/spaces/safetensors/convert.",
190
+ level=logging.WARNING,
190
191
  )
191
192
  num_params = -1
192
193
  return num_params
@@ -489,7 +490,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
489
490
  model_info = get_model_repo_info(
490
491
  model_id=model_id_components.model_id,
491
492
  revision=model_id_components.revision,
492
- benchmark_config=benchmark_config,
493
+ api_key=benchmark_config.api_key,
494
+ cache_dir=benchmark_config.cache_dir,
495
+ trust_remote_code=benchmark_config.trust_remote_code,
496
+ requires_safetensors=benchmark_config.requires_safetensors,
497
+ run_with_cli=benchmark_config.run_with_cli,
493
498
  )
494
499
  return (
495
500
  model_info is not None
@@ -515,7 +520,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
515
520
  model_info = get_model_repo_info(
516
521
  model_id=model_id_components.model_id,
517
522
  revision=model_id_components.revision,
518
- benchmark_config=benchmark_config,
523
+ api_key=benchmark_config.api_key,
524
+ cache_dir=benchmark_config.cache_dir,
525
+ trust_remote_code=benchmark_config.trust_remote_code,
526
+ requires_safetensors=benchmark_config.requires_safetensors,
527
+ run_with_cli=benchmark_config.run_with_cli,
519
528
  )
520
529
  if model_info is None:
521
530
  raise InvalidModel(f"The model {model_id!r} could not be found.")
@@ -581,8 +590,8 @@ def load_model_and_tokeniser(
581
590
  config = load_hf_model_config(
582
591
  model_id=model_id,
583
592
  num_labels=len(id2label),
584
- id2label=id2label,
585
- label2id={label: idx for idx, label in id2label.items()},
593
+ id2label=HashableDict(id2label),
594
+ label2id=HashableDict({label: idx for idx, label in id2label.items()}),
586
595
  revision=model_config.revision,
587
596
  model_cache_dir=model_config.model_cache_dir,
588
597
  api_key=benchmark_config.api_key,
@@ -606,11 +615,8 @@ def load_model_and_tokeniser(
606
615
  ),
607
616
  )
608
617
 
609
- # These are used when a timeout occurs
610
- attempts_left = 5
611
-
612
618
  model: "PreTrainedModel | None" = None
613
- while True:
619
+ for _ in range(num_attempts := 5):
614
620
  # Get the model class associated with the task group
615
621
  model_cls_or_none: t.Type["PreTrainedModel"] | None = get_class_by_name(
616
622
  class_name=task_group_to_class_name(task_group=task_group),
@@ -637,22 +643,21 @@ def load_model_and_tokeniser(
637
643
  break
638
644
  except (KeyError, RuntimeError) as e:
639
645
  if not model_kwargs["ignore_mismatched_sizes"]:
640
- logger.debug(
646
+ log(
641
647
  f"{type(e).__name__} occurred during the loading "
642
648
  f"of the {model_id!r} model. Retrying with "
643
- "`ignore_mismatched_sizes` set to True."
649
+ "`ignore_mismatched_sizes` set to True.",
650
+ level=logging.DEBUG,
644
651
  )
645
652
  model_kwargs["ignore_mismatched_sizes"] = True
646
653
  continue
647
654
  else:
648
655
  raise InvalidModel(str(e)) from e
649
- except (TimeoutError, RequestError) as e:
650
- attempts_left -= 1
651
- if attempts_left == 0:
652
- raise InvalidModel(
653
- "The model could not be loaded after 5 attempts."
654
- ) from e
655
- logger.info(f"Couldn't load the model {model_id!r}. Retrying.")
656
+ except (TimeoutError, RequestError):
657
+ log(
658
+ f"Couldn't load the model {model_id!r}. Retrying.",
659
+ level=logging.WARNING,
660
+ )
656
661
  sleep(5)
657
662
  continue
658
663
  except (OSError, ValueError) as e:
@@ -669,6 +674,10 @@ def load_model_and_tokeniser(
669
674
  raise InvalidModel(
670
675
  f"The model {model_id!r} could not be loaded. The error was {e!r}."
671
676
  ) from e
677
+ else:
678
+ raise InvalidModel(
679
+ f"Could not load the model {model_id!r} after {num_attempts} attempts."
680
+ )
672
681
 
673
682
  if isinstance(model_or_tuple, tuple):
674
683
  model = model_or_tuple[0]
@@ -690,14 +699,21 @@ def load_model_and_tokeniser(
690
699
  model=model,
691
700
  model_id=model_id,
692
701
  trust_remote_code=benchmark_config.trust_remote_code,
693
- model_cache_dir=model_config.model_cache_dir,
702
+ model_config=model_config,
694
703
  )
695
704
 
696
705
  return model, tokeniser
697
706
 
698
707
 
708
+ @cache_arguments("model_id", "revision")
699
709
  def get_model_repo_info(
700
- model_id: str, revision: str, benchmark_config: "BenchmarkConfig"
710
+ model_id: str,
711
+ revision: str,
712
+ api_key: str | None,
713
+ cache_dir: str,
714
+ trust_remote_code: bool,
715
+ requires_safetensors: bool,
716
+ run_with_cli: bool,
701
717
  ) -> "HFModelInfo | None":
702
718
  """Get the information about the model from the HF Hub or a local directory.
703
719
 
@@ -706,13 +722,11 @@ def get_model_repo_info(
706
722
  The model ID.
707
723
  revision:
708
724
  The revision of the model.
709
- benchmark_config:
710
- The benchmark configuration.
711
725
 
712
726
  Returns:
713
727
  The information about the model, or None if the model could not be found.
714
728
  """
715
- token = get_hf_token(api_key=benchmark_config.api_key)
729
+ token = get_hf_token(api_key=api_key)
716
730
  hf_api = HfApi(token=token)
717
731
 
718
732
  # Get information on the model.
@@ -720,7 +734,7 @@ def get_model_repo_info(
720
734
  # model info object.
721
735
  model_info: HfApiModelInfo | None = None
722
736
  if Path(model_id).is_dir():
723
- logger.debug(f"Checking for local model in {model_id}.")
737
+ log(f"Checking for local model in {model_id}.", level=logging.DEBUG)
724
738
  if all(
725
739
  (Path(model_id) / required_file).exists()
726
740
  for required_file in LOCAL_MODELS_REQUIRED_FILES
@@ -746,17 +760,19 @@ def get_model_repo_info(
746
760
  except (GatedRepoError, LocalTokenNotFoundError) as e:
747
761
  try:
748
762
  hf_whoami(token=token)
749
- logger.debug(
763
+ log(
750
764
  f"Could not access the model {model_id} with the revision "
751
- f"{revision}. The error was {str(e)!r}."
765
+ f"{revision}. The error was {str(e)!r}.",
766
+ level=logging.DEBUG,
752
767
  )
753
768
  return None
754
769
  except LocalTokenNotFoundError:
755
- logger.debug(
770
+ log(
756
771
  f"Could not access the model {model_id} with the revision "
757
772
  f"{revision}. The error was {str(e)!r}. Please set the "
758
773
  "`HUGGINGFACE_API_KEY` environment variable or use the "
759
- "`--api-key` argument."
774
+ "`--api-key` argument.",
775
+ level=logging.DEBUG,
760
776
  )
761
777
  return None
762
778
  except (RepositoryNotFoundError, HFValidationError):
@@ -772,16 +788,18 @@ def get_model_repo_info(
772
788
  if internet_connection_available():
773
789
  errors.append(e)
774
790
  continue
775
- logger.debug(
791
+ log(
776
792
  "Could not access the Hugging Face Hub. Please check your internet "
777
- "connection."
793
+ "connection.",
794
+ level=logging.DEBUG,
778
795
  )
779
796
  return None
780
797
  else:
781
- logger.debug(
798
+ log(
782
799
  f"Could not access model info for the model {model_id!r} from the "
783
800
  f"Hugging Face Hub, after {num_attempts} attempts. The errors "
784
- f"encountered were {errors!r}."
801
+ f"encountered were {errors!r}.",
802
+ level=logging.DEBUG,
785
803
  )
786
804
  return None
787
805
 
@@ -812,15 +830,15 @@ def get_model_repo_info(
812
830
  hf_config = load_hf_model_config(
813
831
  model_id=base_model_id or model_id,
814
832
  num_labels=0,
815
- id2label=dict(),
816
- label2id=dict(),
833
+ id2label=HashableDict(),
834
+ label2id=HashableDict(),
817
835
  revision=revision,
818
836
  model_cache_dir=create_model_cache_dir(
819
- cache_dir=benchmark_config.cache_dir, model_id=model_id
837
+ cache_dir=cache_dir, model_id=model_id
820
838
  ),
821
- api_key=benchmark_config.api_key,
822
- trust_remote_code=benchmark_config.trust_remote_code,
823
- run_with_cli=benchmark_config.run_with_cli,
839
+ api_key=api_key,
840
+ trust_remote_code=trust_remote_code,
841
+ run_with_cli=run_with_cli,
824
842
  )
825
843
  class_names = hf_config.architectures
826
844
  generative_class_names = [
@@ -835,19 +853,19 @@ def get_model_repo_info(
835
853
  else:
836
854
  pipeline_tag = "fill-mask"
837
855
 
838
- if benchmark_config.requires_safetensors:
856
+ if requires_safetensors:
839
857
  repo_files = hf_api.list_repo_files(repo_id=model_id, revision=revision)
840
858
  has_safetensors = any(f.endswith(".safetensors") for f in repo_files)
841
859
  if not has_safetensors:
842
860
  msg = f"Model {model_id} does not have safetensors weights available. "
843
- if benchmark_config.run_with_cli:
861
+ if run_with_cli:
844
862
  msg += "Skipping since the `--only-allow-safetensors` flag is set."
845
863
  else:
846
864
  msg += (
847
865
  "Skipping since the `requires_safetensors` argument is set "
848
866
  "to `True`."
849
867
  )
850
- logger.warning(msg)
868
+ log(msg, level=logging.WARNING)
851
869
  return None
852
870
 
853
871
  # Also check base model if we are evaluating an adapter
@@ -861,7 +879,7 @@ def get_model_repo_info(
861
879
  f"Base model {base_model_id} does not have safetensors weights "
862
880
  "available."
863
881
  )
864
- if benchmark_config.run_with_cli:
882
+ if run_with_cli:
865
883
  msg += " Skipping since the `--only-allow-safetensors` flag is set."
866
884
  else:
867
885
  msg += (
@@ -880,7 +898,7 @@ def load_tokeniser(
880
898
  model: "PreTrainedModel | None",
881
899
  model_id: str,
882
900
  trust_remote_code: bool,
883
- model_cache_dir: str,
901
+ model_config: "ModelConfig",
884
902
  ) -> "PreTrainedTokenizer":
885
903
  """Load the tokeniser.
886
904
 
@@ -892,17 +910,19 @@ def load_tokeniser(
892
910
  The model identifier. Used for logging.
893
911
  trust_remote_code:
894
912
  Whether to trust remote code.
913
+ model_config:
914
+ The model configuration.
895
915
 
896
916
  Returns:
897
917
  The loaded tokeniser.
898
918
  """
899
919
  loading_kwargs: dict[str, bool | str] = dict(
900
- use_fast=True,
920
+ use_fast=False if model_config.param == "slow-tokenizer" else True,
901
921
  verbose=False,
902
922
  trust_remote_code=trust_remote_code,
903
923
  padding_side="right",
904
924
  truncation_side="right",
905
- cache_dir=model_cache_dir,
925
+ cache_dir=model_config.model_cache_dir,
906
926
  )
907
927
 
908
928
  # If the model is a subclass of a certain model types then we have to add a prefix
@@ -925,7 +945,10 @@ def load_tokeniser(
925
945
  f"Could not load tokeniser for model {model_id!r}."
926
946
  ) from e
927
947
  except (TimeoutError, RequestError):
928
- logger.info(f"Couldn't load tokeniser for {model_id!r}. Retrying.")
948
+ log(
949
+ f"Couldn't load tokeniser for {model_id!r}. Retrying.",
950
+ level=logging.WARNING,
951
+ )
929
952
  sleep(5)
930
953
  continue
931
954
  else:
@@ -941,6 +964,7 @@ def load_tokeniser(
941
964
  return tokeniser
942
965
 
943
966
 
967
+ @cache_arguments()
944
968
  def get_dtype(
945
969
  device: torch.device, dtype_is_set: bool, bf16_available: bool
946
970
  ) -> str | torch.dtype:
@@ -949,6 +973,7 @@ def get_dtype(
949
973
  Args:
950
974
  device:
951
975
  The device to use.
976
+ dtype_is_set:
952
977
  Whether the data type is set in the model configuration.
953
978
  bf16_available:
954
979
  Whether bfloat16 is available.
@@ -966,6 +991,7 @@ def get_dtype(
966
991
  return torch.float32
967
992
 
968
993
 
994
+ @cache_arguments("model_id", "revision", "num_labels", "id2label", "label2id")
969
995
  def load_hf_model_config(
970
996
  model_id: str,
971
997
  num_labels: int,
@@ -1002,7 +1028,7 @@ def load_hf_model_config(
1002
1028
  Returns:
1003
1029
  The Hugging Face model configuration.
1004
1030
  """
1005
- while True:
1031
+ for _ in range(num_attempts := 5):
1006
1032
  try:
1007
1033
  config = AutoConfig.from_pretrained(
1008
1034
  model_id,
@@ -1015,12 +1041,7 @@ def load_hf_model_config(
1015
1041
  cache_dir=model_cache_dir,
1016
1042
  local_files_only=not internet_connection_available(),
1017
1043
  )
1018
- if config.eos_token_id is not None and config.pad_token_id is None:
1019
- if isinstance(config.eos_token_id, list):
1020
- config.pad_token_id = config.eos_token_id[0]
1021
- else:
1022
- config.pad_token_id = config.eos_token_id
1023
- return config
1044
+ break
1024
1045
  except KeyError as e:
1025
1046
  key = e.args[0]
1026
1047
  raise InvalidModel(
@@ -1028,18 +1049,23 @@ def load_hf_model_config(
1028
1049
  f"loaded, as the key {key!r} was not found in the config."
1029
1050
  ) from e
1030
1051
  except (OSError, GatedRepoError) as e:
1031
- # TEMP: When the model is gated then we cannot set cache dir, for some
1032
- # reason (since transformers v4.38.2, still a problem in v4.48.0). This
1033
- # should be included back in when this is fixed.
1034
- if "gated repo" in str(e):
1035
- model_cache_dir = None
1036
- continue
1052
+ if isinstance(e, GatedRepoError) or "gated repo" in str(e).lower():
1053
+ raise InvalidModel(
1054
+ f"The model {model_id!r} is a gated repository. Please ensure "
1055
+ "that you are logged in with `hf auth login` or have provided a "
1056
+ "valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
1057
+ "environment variable or the `--api-key` argument. Also check that "
1058
+ "your account has access to this model."
1059
+ ) from e
1037
1060
  raise InvalidModel(
1038
1061
  f"Couldn't load model config for {model_id!r}. The error was "
1039
1062
  f"{e!r}. Skipping"
1040
1063
  ) from e
1041
1064
  except (TimeoutError, RequestError):
1042
- logger.info(f"Couldn't load model config for {model_id!r}. Retrying.")
1065
+ log(
1066
+ f"Couldn't load model config for {model_id!r}. Retrying.",
1067
+ level=logging.WARNING,
1068
+ )
1043
1069
  sleep(5)
1044
1070
  continue
1045
1071
  except ValueError as e:
@@ -1058,6 +1084,20 @@ def load_hf_model_config(
1058
1084
  f"The config for the model {model_id!r} could not be loaded. The "
1059
1085
  f"error was {e!r}."
1060
1086
  ) from e
1087
+ else:
1088
+ raise InvalidModel(
1089
+ f"Couldn't load model config for {model_id!r} after {num_attempts} "
1090
+ "attempts."
1091
+ )
1092
+
1093
+ # Ensure that the PAD token ID is set
1094
+ if config.eos_token_id is not None and config.pad_token_id is None:
1095
+ if isinstance(config.eos_token_id, list):
1096
+ config.pad_token_id = config.eos_token_id[0]
1097
+ else:
1098
+ config.pad_token_id = config.eos_token_id
1099
+
1100
+ return config
1061
1101
 
1062
1102
 
1063
1103
  def setup_model_for_question_answering(model: "PreTrainedModel") -> "PreTrainedModel":
@@ -1226,6 +1266,7 @@ def align_model_and_tokeniser(
1226
1266
  return model, tokeniser
1227
1267
 
1228
1268
 
1269
+ @cache_arguments()
1229
1270
  def task_group_to_class_name(task_group: TaskGroup) -> str:
1230
1271
  """Convert a task group to a class name.
1231
1272