EuroEval 16.3.0__py3-none-any.whl → 16.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (64) hide show
  1. euroeval/__init__.py +3 -2
  2. euroeval/benchmark_config_factory.py +0 -4
  3. euroeval/benchmark_modules/base.py +3 -16
  4. euroeval/benchmark_modules/fresh.py +2 -1
  5. euroeval/benchmark_modules/hf.py +99 -62
  6. euroeval/benchmark_modules/litellm.py +101 -41
  7. euroeval/benchmark_modules/vllm.py +91 -83
  8. euroeval/benchmarker.py +84 -78
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/constants.py +6 -0
  12. euroeval/data_loading.py +14 -11
  13. euroeval/data_models.py +12 -4
  14. euroeval/dataset_configs/__init__.py +2 -0
  15. euroeval/dataset_configs/czech.py +79 -0
  16. euroeval/dataset_configs/danish.py +10 -11
  17. euroeval/dataset_configs/dutch.py +0 -1
  18. euroeval/dataset_configs/english.py +0 -1
  19. euroeval/dataset_configs/estonian.py +11 -1
  20. euroeval/dataset_configs/finnish.py +0 -1
  21. euroeval/dataset_configs/french.py +0 -1
  22. euroeval/dataset_configs/german.py +0 -1
  23. euroeval/dataset_configs/italian.py +0 -1
  24. euroeval/dataset_configs/latvian.py +0 -1
  25. euroeval/dataset_configs/lithuanian.py +9 -3
  26. euroeval/dataset_configs/norwegian.py +0 -1
  27. euroeval/dataset_configs/polish.py +0 -1
  28. euroeval/dataset_configs/portuguese.py +0 -1
  29. euroeval/dataset_configs/slovak.py +60 -0
  30. euroeval/dataset_configs/spanish.py +0 -1
  31. euroeval/dataset_configs/swedish.py +10 -12
  32. euroeval/finetuning.py +21 -15
  33. euroeval/generation.py +10 -10
  34. euroeval/generation_utils.py +2 -3
  35. euroeval/logging_utils.py +250 -0
  36. euroeval/metrics/base.py +0 -3
  37. euroeval/metrics/huggingface.py +9 -5
  38. euroeval/metrics/llm_as_a_judge.py +5 -3
  39. euroeval/metrics/pipeline.py +17 -9
  40. euroeval/metrics/speed.py +0 -3
  41. euroeval/model_cache.py +11 -14
  42. euroeval/model_config.py +4 -5
  43. euroeval/model_loading.py +3 -0
  44. euroeval/prompt_templates/linguistic_acceptability.py +21 -3
  45. euroeval/prompt_templates/multiple_choice.py +25 -1
  46. euroeval/prompt_templates/named_entity_recognition.py +51 -11
  47. euroeval/prompt_templates/reading_comprehension.py +31 -3
  48. euroeval/prompt_templates/sentiment_classification.py +23 -1
  49. euroeval/prompt_templates/summarization.py +26 -6
  50. euroeval/scores.py +7 -7
  51. euroeval/speed_benchmark.py +3 -5
  52. euroeval/task_group_utils/multiple_choice_classification.py +0 -3
  53. euroeval/task_group_utils/question_answering.py +0 -3
  54. euroeval/task_group_utils/sequence_classification.py +43 -31
  55. euroeval/task_group_utils/text_to_text.py +17 -8
  56. euroeval/task_group_utils/token_classification.py +10 -9
  57. euroeval/tokenisation_utils.py +14 -12
  58. euroeval/utils.py +29 -146
  59. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/METADATA +4 -4
  60. euroeval-16.4.0.dist-info/RECORD +75 -0
  61. euroeval-16.3.0.dist-info/RECORD +0 -71
  62. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
  63. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
  64. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
euroeval/__init__.py CHANGED
@@ -21,7 +21,8 @@ if os.getenv("FULL_LOG") != "1":
21
21
  os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
22
22
 
23
23
  # Set up logging
24
- fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
24
+ # fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
25
+ fmt = colored("%(message)s", "light_yellow")
25
26
  logging.basicConfig(
26
27
  level=logging.CRITICAL if hasattr(sys, "_called_from_test") else logging.INFO,
27
28
  format=fmt,
@@ -50,7 +51,7 @@ import importlib.metadata # noqa: E402
50
51
  from dotenv import load_dotenv # noqa: E402
51
52
 
52
53
  from .benchmarker import Benchmarker # noqa: E402
53
- from .utils import block_terminal_output # noqa: E402
54
+ from .logging_utils import block_terminal_output # noqa: E402
54
55
 
55
56
  # Block unwanted terminal outputs. This blocks way more than the above, but since it
56
57
  # relies on importing from the `utils` module, external modules are already imported
@@ -1,6 +1,5 @@
1
1
  """Factory class for creating dataset configurations."""
2
2
 
3
- import logging
4
3
  import sys
5
4
  import typing as t
6
5
 
@@ -17,9 +16,6 @@ if t.TYPE_CHECKING:
17
16
  from .data_models import Language, Task
18
17
 
19
18
 
20
- logger = logging.getLogger("euroeval")
21
-
22
-
23
19
  def build_benchmark_config(
24
20
  benchmark_config_params: BenchmarkConfigParams,
25
21
  ) -> BenchmarkConfig:
@@ -3,24 +3,22 @@
3
3
  import collections.abc as c
4
4
  import logging
5
5
  import re
6
- import sys
7
6
  import typing as t
8
7
  from abc import ABC, abstractmethod
9
8
  from functools import cached_property, partial
10
9
 
11
10
  from datasets import Dataset, DatasetDict
12
11
  from torch import nn
13
- from tqdm.auto import tqdm
14
12
 
15
13
  from ..enums import TaskGroup
16
14
  from ..exceptions import InvalidBenchmark, NeedsEnvironmentVariable, NeedsExtraInstalled
15
+ from ..logging_utils import get_pbar, log_once
17
16
  from ..task_group_utils import (
18
17
  question_answering,
19
18
  sequence_classification,
20
19
  text_to_text,
21
20
  token_classification,
22
21
  )
23
- from ..utils import log_once
24
22
 
25
23
  if t.TYPE_CHECKING:
26
24
  from transformers.tokenization_utils import PreTrainedTokenizer
@@ -36,8 +34,6 @@ if t.TYPE_CHECKING:
36
34
  from ..enums import BatchingPreference, GenerativeType
37
35
  from ..types import ComputeMetricsFunction, ExtractLabelsFunction
38
36
 
39
- logger = logging.getLogger("euroeval")
40
-
41
37
 
42
38
  class BenchmarkModule(ABC):
43
39
  """Abstract class for a benchmark module.
@@ -87,16 +83,7 @@ class BenchmarkModule(ABC):
87
83
 
88
84
  def _log_metadata(self) -> None:
89
85
  """Log the metadata of the model."""
90
- # Set logging level based on verbosity
91
- if hasattr(sys, "_called_from_test"):
92
- logging_level = logging.CRITICAL
93
- elif self.benchmark_config.verbose:
94
- logging_level = logging.DEBUG
95
- else:
96
- logging_level = logging.INFO
97
- logger.setLevel(logging_level)
98
-
99
- logging_msg: str = ""
86
+ logging_msg: str = " ↳ "
100
87
  if self.num_params < 0:
101
88
  logging_msg += "The model has an unknown number of parameters, "
102
89
  else:
@@ -273,7 +260,7 @@ class BenchmarkModule(ABC):
273
260
  tasks.
274
261
  """
275
262
  for idx, dataset in enumerate(
276
- tqdm(iterable=datasets, desc="Preparing datasets")
263
+ get_pbar(iterable=datasets, desc="Preparing datasets")
277
264
  ):
278
265
  prepared_dataset = self.prepare_dataset(
279
266
  dataset=dataset, task=task, itr_idx=idx
@@ -27,7 +27,8 @@ from ..exceptions import (
27
27
  NeedsExtraInstalled,
28
28
  )
29
29
  from ..generation_utils import raise_if_wrong_params
30
- from ..utils import block_terminal_output, create_model_cache_dir, get_hf_token
30
+ from ..logging_utils import block_terminal_output
31
+ from ..utils import create_model_cache_dir, get_hf_token
31
32
  from .hf import (
32
33
  HuggingFaceEncoderModel,
33
34
  align_model_and_tokeniser,
@@ -36,6 +36,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
36
36
  from transformers.trainer import Trainer
37
37
  from urllib3.exceptions import RequestError
38
38
 
39
+ from ..caching_utils import cache_arguments
39
40
  from ..constants import (
40
41
  DUMMY_FILL_VALUE,
41
42
  GENERATIVE_PIPELINE_TAGS,
@@ -43,7 +44,7 @@ from ..constants import (
43
44
  MAX_CONTEXT_LENGTH,
44
45
  MERGE_TAGS,
45
46
  )
46
- from ..data_models import HFModelInfo, ModelConfig
47
+ from ..data_models import HashableDict, HFModelInfo, ModelConfig
47
48
  from ..enums import (
48
49
  BatchingPreference,
49
50
  GenerativeType,
@@ -60,6 +61,7 @@ from ..exceptions import (
60
61
  )
61
62
  from ..generation_utils import raise_if_wrong_params
62
63
  from ..languages import get_all_languages
64
+ from ..logging_utils import block_terminal_output, log, log_once
63
65
  from ..task_group_utils import (
64
66
  multiple_choice_classification,
65
67
  question_answering,
@@ -67,12 +69,10 @@ from ..task_group_utils import (
67
69
  )
68
70
  from ..tokenisation_utils import get_bos_token, get_eos_token
69
71
  from ..utils import (
70
- block_terminal_output,
71
72
  create_model_cache_dir,
72
73
  get_class_by_name,
73
74
  get_hf_token,
74
75
  internet_connection_available,
75
- log_once,
76
76
  split_model_id,
77
77
  )
78
78
  from .base import BenchmarkModule
@@ -85,8 +85,6 @@ if t.TYPE_CHECKING:
85
85
  from ..data_models import BenchmarkConfig, DatasetConfig, Task
86
86
  from ..types import ExtractLabelsFunction
87
87
 
88
- logger = logging.getLogger("euroeval")
89
-
90
88
 
91
89
  class HuggingFaceEncoderModel(BenchmarkModule):
92
90
  """An encoder model from the Hugging Face Hub."""
@@ -183,12 +181,13 @@ class HuggingFaceEncoderModel(BenchmarkModule):
183
181
  elif hasattr(self._model, "parameters"):
184
182
  num_params = sum(p.numel() for p in self._model.parameters())
185
183
  else:
186
- logger.warning(
184
+ log(
187
185
  "The number of parameters could not be determined for the model, since "
188
186
  "the model is not stored in the safetensors format. If this is your "
189
187
  "own model, then you can use this Hugging Face Space to convert your "
190
188
  "model to the safetensors format: "
191
- "https://huggingface.co/spaces/safetensors/convert."
189
+ "https://huggingface.co/spaces/safetensors/convert.",
190
+ level=logging.WARNING,
192
191
  )
193
192
  num_params = -1
194
193
  return num_params
@@ -491,7 +490,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
491
490
  model_info = get_model_repo_info(
492
491
  model_id=model_id_components.model_id,
493
492
  revision=model_id_components.revision,
494
- benchmark_config=benchmark_config,
493
+ api_key=benchmark_config.api_key,
494
+ cache_dir=benchmark_config.cache_dir,
495
+ trust_remote_code=benchmark_config.trust_remote_code,
496
+ requires_safetensors=benchmark_config.requires_safetensors,
497
+ run_with_cli=benchmark_config.run_with_cli,
495
498
  )
496
499
  return (
497
500
  model_info is not None
@@ -517,7 +520,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
517
520
  model_info = get_model_repo_info(
518
521
  model_id=model_id_components.model_id,
519
522
  revision=model_id_components.revision,
520
- benchmark_config=benchmark_config,
523
+ api_key=benchmark_config.api_key,
524
+ cache_dir=benchmark_config.cache_dir,
525
+ trust_remote_code=benchmark_config.trust_remote_code,
526
+ requires_safetensors=benchmark_config.requires_safetensors,
527
+ run_with_cli=benchmark_config.run_with_cli,
521
528
  )
522
529
  if model_info is None:
523
530
  raise InvalidModel(f"The model {model_id!r} could not be found.")
@@ -583,8 +590,8 @@ def load_model_and_tokeniser(
583
590
  config = load_hf_model_config(
584
591
  model_id=model_id,
585
592
  num_labels=len(id2label),
586
- id2label=id2label,
587
- label2id={label: idx for idx, label in id2label.items()},
593
+ id2label=HashableDict(id2label),
594
+ label2id=HashableDict({label: idx for idx, label in id2label.items()}),
588
595
  revision=model_config.revision,
589
596
  model_cache_dir=model_config.model_cache_dir,
590
597
  api_key=benchmark_config.api_key,
@@ -608,11 +615,8 @@ def load_model_and_tokeniser(
608
615
  ),
609
616
  )
610
617
 
611
- # These are used when a timeout occurs
612
- attempts_left = 5
613
-
614
618
  model: "PreTrainedModel | None" = None
615
- while True:
619
+ for _ in range(num_attempts := 5):
616
620
  # Get the model class associated with the task group
617
621
  model_cls_or_none: t.Type["PreTrainedModel"] | None = get_class_by_name(
618
622
  class_name=task_group_to_class_name(task_group=task_group),
@@ -639,22 +643,21 @@ def load_model_and_tokeniser(
639
643
  break
640
644
  except (KeyError, RuntimeError) as e:
641
645
  if not model_kwargs["ignore_mismatched_sizes"]:
642
- logger.debug(
646
+ log(
643
647
  f"{type(e).__name__} occurred during the loading "
644
648
  f"of the {model_id!r} model. Retrying with "
645
- "`ignore_mismatched_sizes` set to True."
649
+ "`ignore_mismatched_sizes` set to True.",
650
+ level=logging.DEBUG,
646
651
  )
647
652
  model_kwargs["ignore_mismatched_sizes"] = True
648
653
  continue
649
654
  else:
650
655
  raise InvalidModel(str(e)) from e
651
- except (TimeoutError, RequestError) as e:
652
- attempts_left -= 1
653
- if attempts_left == 0:
654
- raise InvalidModel(
655
- "The model could not be loaded after 5 attempts."
656
- ) from e
657
- logger.info(f"Couldn't load the model {model_id!r}. Retrying.")
656
+ except (TimeoutError, RequestError):
657
+ log(
658
+ f"Couldn't load the model {model_id!r}. Retrying.",
659
+ level=logging.WARNING,
660
+ )
658
661
  sleep(5)
659
662
  continue
660
663
  except (OSError, ValueError) as e:
@@ -671,6 +674,10 @@ def load_model_and_tokeniser(
671
674
  raise InvalidModel(
672
675
  f"The model {model_id!r} could not be loaded. The error was {e!r}."
673
676
  ) from e
677
+ else:
678
+ raise InvalidModel(
679
+ f"Could not load the model {model_id!r} after {num_attempts} attempts."
680
+ )
674
681
 
675
682
  if isinstance(model_or_tuple, tuple):
676
683
  model = model_or_tuple[0]
@@ -698,8 +705,15 @@ def load_model_and_tokeniser(
698
705
  return model, tokeniser
699
706
 
700
707
 
708
+ @cache_arguments("model_id", "revision")
701
709
  def get_model_repo_info(
702
- model_id: str, revision: str, benchmark_config: "BenchmarkConfig"
710
+ model_id: str,
711
+ revision: str,
712
+ api_key: str | None,
713
+ cache_dir: str,
714
+ trust_remote_code: bool,
715
+ requires_safetensors: bool,
716
+ run_with_cli: bool,
703
717
  ) -> "HFModelInfo | None":
704
718
  """Get the information about the model from the HF Hub or a local directory.
705
719
 
@@ -708,13 +722,11 @@ def get_model_repo_info(
708
722
  The model ID.
709
723
  revision:
710
724
  The revision of the model.
711
- benchmark_config:
712
- The benchmark configuration.
713
725
 
714
726
  Returns:
715
727
  The information about the model, or None if the model could not be found.
716
728
  """
717
- token = get_hf_token(api_key=benchmark_config.api_key)
729
+ token = get_hf_token(api_key=api_key)
718
730
  hf_api = HfApi(token=token)
719
731
 
720
732
  # Get information on the model.
@@ -722,7 +734,7 @@ def get_model_repo_info(
722
734
  # model info object.
723
735
  model_info: HfApiModelInfo | None = None
724
736
  if Path(model_id).is_dir():
725
- logger.debug(f"Checking for local model in {model_id}.")
737
+ log(f"Checking for local model in {model_id}.", level=logging.DEBUG)
726
738
  if all(
727
739
  (Path(model_id) / required_file).exists()
728
740
  for required_file in LOCAL_MODELS_REQUIRED_FILES
@@ -748,17 +760,19 @@ def get_model_repo_info(
748
760
  except (GatedRepoError, LocalTokenNotFoundError) as e:
749
761
  try:
750
762
  hf_whoami(token=token)
751
- logger.debug(
763
+ log(
752
764
  f"Could not access the model {model_id} with the revision "
753
- f"{revision}. The error was {str(e)!r}."
765
+ f"{revision}. The error was {str(e)!r}.",
766
+ level=logging.DEBUG,
754
767
  )
755
768
  return None
756
769
  except LocalTokenNotFoundError:
757
- logger.debug(
770
+ log(
758
771
  f"Could not access the model {model_id} with the revision "
759
772
  f"{revision}. The error was {str(e)!r}. Please set the "
760
773
  "`HUGGINGFACE_API_KEY` environment variable or use the "
761
- "`--api-key` argument."
774
+ "`--api-key` argument.",
775
+ level=logging.DEBUG,
762
776
  )
763
777
  return None
764
778
  except (RepositoryNotFoundError, HFValidationError):
@@ -774,16 +788,18 @@ def get_model_repo_info(
774
788
  if internet_connection_available():
775
789
  errors.append(e)
776
790
  continue
777
- logger.debug(
791
+ log(
778
792
  "Could not access the Hugging Face Hub. Please check your internet "
779
- "connection."
793
+ "connection.",
794
+ level=logging.DEBUG,
780
795
  )
781
796
  return None
782
797
  else:
783
- logger.debug(
798
+ log(
784
799
  f"Could not access model info for the model {model_id!r} from the "
785
800
  f"Hugging Face Hub, after {num_attempts} attempts. The errors "
786
- f"encountered were {errors!r}."
801
+ f"encountered were {errors!r}.",
802
+ level=logging.DEBUG,
787
803
  )
788
804
  return None
789
805
 
@@ -814,15 +830,15 @@ def get_model_repo_info(
814
830
  hf_config = load_hf_model_config(
815
831
  model_id=base_model_id or model_id,
816
832
  num_labels=0,
817
- id2label=dict(),
818
- label2id=dict(),
833
+ id2label=HashableDict(),
834
+ label2id=HashableDict(),
819
835
  revision=revision,
820
836
  model_cache_dir=create_model_cache_dir(
821
- cache_dir=benchmark_config.cache_dir, model_id=model_id
837
+ cache_dir=cache_dir, model_id=model_id
822
838
  ),
823
- api_key=benchmark_config.api_key,
824
- trust_remote_code=benchmark_config.trust_remote_code,
825
- run_with_cli=benchmark_config.run_with_cli,
839
+ api_key=api_key,
840
+ trust_remote_code=trust_remote_code,
841
+ run_with_cli=run_with_cli,
826
842
  )
827
843
  class_names = hf_config.architectures
828
844
  generative_class_names = [
@@ -837,19 +853,19 @@ def get_model_repo_info(
837
853
  else:
838
854
  pipeline_tag = "fill-mask"
839
855
 
840
- if benchmark_config.requires_safetensors:
856
+ if requires_safetensors:
841
857
  repo_files = hf_api.list_repo_files(repo_id=model_id, revision=revision)
842
858
  has_safetensors = any(f.endswith(".safetensors") for f in repo_files)
843
859
  if not has_safetensors:
844
860
  msg = f"Model {model_id} does not have safetensors weights available. "
845
- if benchmark_config.run_with_cli:
861
+ if run_with_cli:
846
862
  msg += "Skipping since the `--only-allow-safetensors` flag is set."
847
863
  else:
848
864
  msg += (
849
865
  "Skipping since the `requires_safetensors` argument is set "
850
866
  "to `True`."
851
867
  )
852
- logger.warning(msg)
868
+ log(msg, level=logging.WARNING)
853
869
  return None
854
870
 
855
871
  # Also check base model if we are evaluating an adapter
@@ -863,7 +879,7 @@ def get_model_repo_info(
863
879
  f"Base model {base_model_id} does not have safetensors weights "
864
880
  "available."
865
881
  )
866
- if benchmark_config.run_with_cli:
882
+ if run_with_cli:
867
883
  msg += " Skipping since the `--only-allow-safetensors` flag is set."
868
884
  else:
869
885
  msg += (
@@ -929,7 +945,10 @@ def load_tokeniser(
929
945
  f"Could not load tokeniser for model {model_id!r}."
930
946
  ) from e
931
947
  except (TimeoutError, RequestError):
932
- logger.info(f"Couldn't load tokeniser for {model_id!r}. Retrying.")
948
+ log(
949
+ f"Couldn't load tokeniser for {model_id!r}. Retrying.",
950
+ level=logging.WARNING,
951
+ )
933
952
  sleep(5)
934
953
  continue
935
954
  else:
@@ -945,6 +964,7 @@ def load_tokeniser(
945
964
  return tokeniser
946
965
 
947
966
 
967
+ @cache_arguments()
948
968
  def get_dtype(
949
969
  device: torch.device, dtype_is_set: bool, bf16_available: bool
950
970
  ) -> str | torch.dtype:
@@ -953,6 +973,7 @@ def get_dtype(
953
973
  Args:
954
974
  device:
955
975
  The device to use.
976
+ dtype_is_set:
956
977
  Whether the data type is set in the model configuration.
957
978
  bf16_available:
958
979
  Whether bfloat16 is available.
@@ -970,6 +991,7 @@ def get_dtype(
970
991
  return torch.float32
971
992
 
972
993
 
994
+ @cache_arguments("model_id", "revision", "num_labels", "id2label", "label2id")
973
995
  def load_hf_model_config(
974
996
  model_id: str,
975
997
  num_labels: int,
@@ -1006,7 +1028,7 @@ def load_hf_model_config(
1006
1028
  Returns:
1007
1029
  The Hugging Face model configuration.
1008
1030
  """
1009
- while True:
1031
+ for _ in range(num_attempts := 5):
1010
1032
  try:
1011
1033
  config = AutoConfig.from_pretrained(
1012
1034
  model_id,
@@ -1019,12 +1041,7 @@ def load_hf_model_config(
1019
1041
  cache_dir=model_cache_dir,
1020
1042
  local_files_only=not internet_connection_available(),
1021
1043
  )
1022
- if config.eos_token_id is not None and config.pad_token_id is None:
1023
- if isinstance(config.eos_token_id, list):
1024
- config.pad_token_id = config.eos_token_id[0]
1025
- else:
1026
- config.pad_token_id = config.eos_token_id
1027
- return config
1044
+ break
1028
1045
  except KeyError as e:
1029
1046
  key = e.args[0]
1030
1047
  raise InvalidModel(
@@ -1032,18 +1049,23 @@ def load_hf_model_config(
1032
1049
  f"loaded, as the key {key!r} was not found in the config."
1033
1050
  ) from e
1034
1051
  except (OSError, GatedRepoError) as e:
1035
- # TEMP: When the model is gated then we cannot set cache dir, for some
1036
- # reason (since transformers v4.38.2, still a problem in v4.48.0). This
1037
- # should be included back in when this is fixed.
1038
- if "gated repo" in str(e):
1039
- model_cache_dir = None
1040
- continue
1052
+ if isinstance(e, GatedRepoError) or "gated repo" in str(e).lower():
1053
+ raise InvalidModel(
1054
+ f"The model {model_id!r} is a gated repository. Please ensure "
1055
+ "that you are logged in with `hf auth login` or have provided a "
1056
+ "valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
1057
+ "environment variable or the `--api-key` argument. Also check that "
1058
+ "your account has access to this model."
1059
+ ) from e
1041
1060
  raise InvalidModel(
1042
1061
  f"Couldn't load model config for {model_id!r}. The error was "
1043
1062
  f"{e!r}. Skipping"
1044
1063
  ) from e
1045
1064
  except (TimeoutError, RequestError):
1046
- logger.info(f"Couldn't load model config for {model_id!r}. Retrying.")
1065
+ log(
1066
+ f"Couldn't load model config for {model_id!r}. Retrying.",
1067
+ level=logging.WARNING,
1068
+ )
1047
1069
  sleep(5)
1048
1070
  continue
1049
1071
  except ValueError as e:
@@ -1062,6 +1084,20 @@ def load_hf_model_config(
1062
1084
  f"The config for the model {model_id!r} could not be loaded. The "
1063
1085
  f"error was {e!r}."
1064
1086
  ) from e
1087
+ else:
1088
+ raise InvalidModel(
1089
+ f"Couldn't load model config for {model_id!r} after {num_attempts} "
1090
+ "attempts."
1091
+ )
1092
+
1093
+ # Ensure that the PAD token ID is set
1094
+ if config.eos_token_id is not None and config.pad_token_id is None:
1095
+ if isinstance(config.eos_token_id, list):
1096
+ config.pad_token_id = config.eos_token_id[0]
1097
+ else:
1098
+ config.pad_token_id = config.eos_token_id
1099
+
1100
+ return config
1065
1101
 
1066
1102
 
1067
1103
  def setup_model_for_question_answering(model: "PreTrainedModel") -> "PreTrainedModel":
@@ -1230,6 +1266,7 @@ def align_model_and_tokeniser(
1230
1266
  return model, tokeniser
1231
1267
 
1232
1268
 
1269
+ @cache_arguments()
1233
1270
  def task_group_to_class_name(task_group: TaskGroup) -> str:
1234
1271
  """Convert a task group to a class name.
1235
1272