EuroEval 16.3.0__py3-none-any.whl → 16.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +9 -2
- euroeval/benchmark_config_factory.py +51 -50
- euroeval/benchmark_modules/base.py +9 -21
- euroeval/benchmark_modules/fresh.py +2 -1
- euroeval/benchmark_modules/hf.py +101 -71
- euroeval/benchmark_modules/litellm.py +115 -53
- euroeval/benchmark_modules/vllm.py +107 -92
- euroeval/benchmarker.py +144 -121
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +86 -8
- euroeval/constants.py +9 -0
- euroeval/data_loading.py +80 -29
- euroeval/data_models.py +338 -330
- euroeval/dataset_configs/__init__.py +12 -3
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +55 -93
- euroeval/dataset_configs/dutch.py +48 -87
- euroeval/dataset_configs/english.py +45 -77
- euroeval/dataset_configs/estonian.py +42 -34
- euroeval/dataset_configs/faroese.py +19 -60
- euroeval/dataset_configs/finnish.py +36 -69
- euroeval/dataset_configs/french.py +39 -75
- euroeval/dataset_configs/german.py +45 -82
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +54 -91
- euroeval/dataset_configs/italian.py +42 -79
- euroeval/dataset_configs/latvian.py +28 -35
- euroeval/dataset_configs/lithuanian.py +28 -26
- euroeval/dataset_configs/norwegian.py +72 -115
- euroeval/dataset_configs/polish.py +33 -61
- euroeval/dataset_configs/portuguese.py +33 -66
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/spanish.py +42 -77
- euroeval/dataset_configs/swedish.py +52 -90
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/exceptions.py +1 -1
- euroeval/finetuning.py +24 -17
- euroeval/generation.py +15 -14
- euroeval/generation_utils.py +8 -8
- euroeval/languages.py +395 -323
- euroeval/logging_utils.py +250 -0
- euroeval/metrics/base.py +0 -3
- euroeval/metrics/huggingface.py +21 -6
- euroeval/metrics/llm_as_a_judge.py +6 -4
- euroeval/metrics/pipeline.py +17 -9
- euroeval/metrics/speed.py +0 -3
- euroeval/model_cache.py +17 -19
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +99 -42
- euroeval/prompt_templates/multiple_choice.py +102 -38
- euroeval/prompt_templates/named_entity_recognition.py +172 -51
- euroeval/prompt_templates/reading_comprehension.py +119 -42
- euroeval/prompt_templates/sentiment_classification.py +110 -40
- euroeval/prompt_templates/summarization.py +85 -40
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +11 -10
- euroeval/speed_benchmark.py +5 -6
- euroeval/task_group_utils/multiple_choice_classification.py +2 -4
- euroeval/task_group_utils/question_answering.py +24 -16
- euroeval/task_group_utils/sequence_classification.py +48 -35
- euroeval/task_group_utils/text_to_text.py +19 -9
- euroeval/task_group_utils/token_classification.py +21 -17
- euroeval/tasks.py +44 -1
- euroeval/tokenisation_utils.py +33 -22
- euroeval/types.py +10 -9
- euroeval/utils.py +35 -149
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/METADATA +196 -39
- euroeval-16.5.0.dist-info/RECORD +81 -0
- euroeval-16.3.0.dist-info/RECORD +0 -71
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/WHEEL +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/licenses/LICENSE +0 -0
euroeval/benchmark_modules/hf.py
CHANGED
|
@@ -36,6 +36,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
|
36
36
|
from transformers.trainer import Trainer
|
|
37
37
|
from urllib3.exceptions import RequestError
|
|
38
38
|
|
|
39
|
+
from ..caching_utils import cache_arguments
|
|
39
40
|
from ..constants import (
|
|
40
41
|
DUMMY_FILL_VALUE,
|
|
41
42
|
GENERATIVE_PIPELINE_TAGS,
|
|
@@ -43,7 +44,7 @@ from ..constants import (
|
|
|
43
44
|
MAX_CONTEXT_LENGTH,
|
|
44
45
|
MERGE_TAGS,
|
|
45
46
|
)
|
|
46
|
-
from ..data_models import HFModelInfo, ModelConfig
|
|
47
|
+
from ..data_models import HashableDict, HFModelInfo, ModelConfig
|
|
47
48
|
from ..enums import (
|
|
48
49
|
BatchingPreference,
|
|
49
50
|
GenerativeType,
|
|
@@ -60,6 +61,7 @@ from ..exceptions import (
|
|
|
60
61
|
)
|
|
61
62
|
from ..generation_utils import raise_if_wrong_params
|
|
62
63
|
from ..languages import get_all_languages
|
|
64
|
+
from ..logging_utils import block_terminal_output, log, log_once
|
|
63
65
|
from ..task_group_utils import (
|
|
64
66
|
multiple_choice_classification,
|
|
65
67
|
question_answering,
|
|
@@ -67,12 +69,10 @@ from ..task_group_utils import (
|
|
|
67
69
|
)
|
|
68
70
|
from ..tokenisation_utils import get_bos_token, get_eos_token
|
|
69
71
|
from ..utils import (
|
|
70
|
-
block_terminal_output,
|
|
71
72
|
create_model_cache_dir,
|
|
72
73
|
get_class_by_name,
|
|
73
74
|
get_hf_token,
|
|
74
75
|
internet_connection_available,
|
|
75
|
-
log_once,
|
|
76
76
|
split_model_id,
|
|
77
77
|
)
|
|
78
78
|
from .base import BenchmarkModule
|
|
@@ -85,8 +85,6 @@ if t.TYPE_CHECKING:
|
|
|
85
85
|
from ..data_models import BenchmarkConfig, DatasetConfig, Task
|
|
86
86
|
from ..types import ExtractLabelsFunction
|
|
87
87
|
|
|
88
|
-
logger = logging.getLogger("euroeval")
|
|
89
|
-
|
|
90
88
|
|
|
91
89
|
class HuggingFaceEncoderModel(BenchmarkModule):
|
|
92
90
|
"""An encoder model from the Hugging Face Hub."""
|
|
@@ -183,12 +181,13 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
183
181
|
elif hasattr(self._model, "parameters"):
|
|
184
182
|
num_params = sum(p.numel() for p in self._model.parameters())
|
|
185
183
|
else:
|
|
186
|
-
|
|
184
|
+
log(
|
|
187
185
|
"The number of parameters could not be determined for the model, since "
|
|
188
186
|
"the model is not stored in the safetensors format. If this is your "
|
|
189
187
|
"own model, then you can use this Hugging Face Space to convert your "
|
|
190
188
|
"model to the safetensors format: "
|
|
191
|
-
"https://huggingface.co/spaces/safetensors/convert."
|
|
189
|
+
"https://huggingface.co/spaces/safetensors/convert.",
|
|
190
|
+
level=logging.WARNING,
|
|
192
191
|
)
|
|
193
192
|
num_params = -1
|
|
194
193
|
return num_params
|
|
@@ -268,7 +267,7 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
268
267
|
return model_max_length
|
|
269
268
|
|
|
270
269
|
@property
|
|
271
|
-
def data_collator(self) -> c.Callable[[
|
|
270
|
+
def data_collator(self) -> c.Callable[[c.Sequence[t.Any]], dict[str, t.Any]]:
|
|
272
271
|
"""The data collator used to prepare samples during finetuning.
|
|
273
272
|
|
|
274
273
|
Returns:
|
|
@@ -491,7 +490,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
491
490
|
model_info = get_model_repo_info(
|
|
492
491
|
model_id=model_id_components.model_id,
|
|
493
492
|
revision=model_id_components.revision,
|
|
494
|
-
|
|
493
|
+
api_key=benchmark_config.api_key,
|
|
494
|
+
cache_dir=benchmark_config.cache_dir,
|
|
495
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
496
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
497
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
495
498
|
)
|
|
496
499
|
return (
|
|
497
500
|
model_info is not None
|
|
@@ -517,7 +520,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
517
520
|
model_info = get_model_repo_info(
|
|
518
521
|
model_id=model_id_components.model_id,
|
|
519
522
|
revision=model_id_components.revision,
|
|
520
|
-
|
|
523
|
+
api_key=benchmark_config.api_key,
|
|
524
|
+
cache_dir=benchmark_config.cache_dir,
|
|
525
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
526
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
527
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
521
528
|
)
|
|
522
529
|
if model_info is None:
|
|
523
530
|
raise InvalidModel(f"The model {model_id!r} could not be found.")
|
|
@@ -583,8 +590,8 @@ def load_model_and_tokeniser(
|
|
|
583
590
|
config = load_hf_model_config(
|
|
584
591
|
model_id=model_id,
|
|
585
592
|
num_labels=len(id2label),
|
|
586
|
-
id2label=id2label,
|
|
587
|
-
label2id={label: idx for idx, label in id2label.items()},
|
|
593
|
+
id2label=HashableDict(id2label),
|
|
594
|
+
label2id=HashableDict({label: idx for idx, label in id2label.items()}),
|
|
588
595
|
revision=model_config.revision,
|
|
589
596
|
model_cache_dir=model_config.model_cache_dir,
|
|
590
597
|
api_key=benchmark_config.api_key,
|
|
@@ -608,11 +615,8 @@ def load_model_and_tokeniser(
|
|
|
608
615
|
),
|
|
609
616
|
)
|
|
610
617
|
|
|
611
|
-
# These are used when a timeout occurs
|
|
612
|
-
attempts_left = 5
|
|
613
|
-
|
|
614
618
|
model: "PreTrainedModel | None" = None
|
|
615
|
-
|
|
619
|
+
for _ in range(num_attempts := 5):
|
|
616
620
|
# Get the model class associated with the task group
|
|
617
621
|
model_cls_or_none: t.Type["PreTrainedModel"] | None = get_class_by_name(
|
|
618
622
|
class_name=task_group_to_class_name(task_group=task_group),
|
|
@@ -639,22 +643,21 @@ def load_model_and_tokeniser(
|
|
|
639
643
|
break
|
|
640
644
|
except (KeyError, RuntimeError) as e:
|
|
641
645
|
if not model_kwargs["ignore_mismatched_sizes"]:
|
|
642
|
-
|
|
646
|
+
log(
|
|
643
647
|
f"{type(e).__name__} occurred during the loading "
|
|
644
648
|
f"of the {model_id!r} model. Retrying with "
|
|
645
|
-
"`ignore_mismatched_sizes` set to True."
|
|
649
|
+
"`ignore_mismatched_sizes` set to True.",
|
|
650
|
+
level=logging.DEBUG,
|
|
646
651
|
)
|
|
647
652
|
model_kwargs["ignore_mismatched_sizes"] = True
|
|
648
653
|
continue
|
|
649
654
|
else:
|
|
650
655
|
raise InvalidModel(str(e)) from e
|
|
651
|
-
except (TimeoutError, RequestError)
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
) from e
|
|
657
|
-
logger.info(f"Couldn't load the model {model_id!r}. Retrying.")
|
|
656
|
+
except (TimeoutError, RequestError):
|
|
657
|
+
log(
|
|
658
|
+
f"Couldn't load the model {model_id!r}. Retrying.",
|
|
659
|
+
level=logging.WARNING,
|
|
660
|
+
)
|
|
658
661
|
sleep(5)
|
|
659
662
|
continue
|
|
660
663
|
except (OSError, ValueError) as e:
|
|
@@ -671,6 +674,10 @@ def load_model_and_tokeniser(
|
|
|
671
674
|
raise InvalidModel(
|
|
672
675
|
f"The model {model_id!r} could not be loaded. The error was {e!r}."
|
|
673
676
|
) from e
|
|
677
|
+
else:
|
|
678
|
+
raise InvalidModel(
|
|
679
|
+
f"Could not load the model {model_id!r} after {num_attempts} attempts."
|
|
680
|
+
)
|
|
674
681
|
|
|
675
682
|
if isinstance(model_or_tuple, tuple):
|
|
676
683
|
model = model_or_tuple[0]
|
|
@@ -698,8 +705,15 @@ def load_model_and_tokeniser(
|
|
|
698
705
|
return model, tokeniser
|
|
699
706
|
|
|
700
707
|
|
|
708
|
+
@cache_arguments("model_id", "revision")
|
|
701
709
|
def get_model_repo_info(
|
|
702
|
-
model_id: str,
|
|
710
|
+
model_id: str,
|
|
711
|
+
revision: str,
|
|
712
|
+
api_key: str | None,
|
|
713
|
+
cache_dir: str,
|
|
714
|
+
trust_remote_code: bool,
|
|
715
|
+
requires_safetensors: bool,
|
|
716
|
+
run_with_cli: bool,
|
|
703
717
|
) -> "HFModelInfo | None":
|
|
704
718
|
"""Get the information about the model from the HF Hub or a local directory.
|
|
705
719
|
|
|
@@ -708,13 +722,11 @@ def get_model_repo_info(
|
|
|
708
722
|
The model ID.
|
|
709
723
|
revision:
|
|
710
724
|
The revision of the model.
|
|
711
|
-
benchmark_config:
|
|
712
|
-
The benchmark configuration.
|
|
713
725
|
|
|
714
726
|
Returns:
|
|
715
727
|
The information about the model, or None if the model could not be found.
|
|
716
728
|
"""
|
|
717
|
-
token = get_hf_token(api_key=
|
|
729
|
+
token = get_hf_token(api_key=api_key)
|
|
718
730
|
hf_api = HfApi(token=token)
|
|
719
731
|
|
|
720
732
|
# Get information on the model.
|
|
@@ -722,7 +734,7 @@ def get_model_repo_info(
|
|
|
722
734
|
# model info object.
|
|
723
735
|
model_info: HfApiModelInfo | None = None
|
|
724
736
|
if Path(model_id).is_dir():
|
|
725
|
-
|
|
737
|
+
log(f"Checking for local model in {model_id}.", level=logging.DEBUG)
|
|
726
738
|
if all(
|
|
727
739
|
(Path(model_id) / required_file).exists()
|
|
728
740
|
for required_file in LOCAL_MODELS_REQUIRED_FILES
|
|
@@ -748,42 +760,39 @@ def get_model_repo_info(
|
|
|
748
760
|
except (GatedRepoError, LocalTokenNotFoundError) as e:
|
|
749
761
|
try:
|
|
750
762
|
hf_whoami(token=token)
|
|
751
|
-
|
|
763
|
+
log(
|
|
752
764
|
f"Could not access the model {model_id} with the revision "
|
|
753
|
-
f"{revision}. The error was {str(e)!r}."
|
|
765
|
+
f"{revision}. The error was {str(e)!r}.",
|
|
766
|
+
level=logging.DEBUG,
|
|
754
767
|
)
|
|
755
768
|
return None
|
|
756
769
|
except LocalTokenNotFoundError:
|
|
757
|
-
|
|
770
|
+
log(
|
|
758
771
|
f"Could not access the model {model_id} with the revision "
|
|
759
772
|
f"{revision}. The error was {str(e)!r}. Please set the "
|
|
760
773
|
"`HUGGINGFACE_API_KEY` environment variable or use the "
|
|
761
|
-
"`--api-key` argument."
|
|
774
|
+
"`--api-key` argument.",
|
|
775
|
+
level=logging.DEBUG,
|
|
762
776
|
)
|
|
763
777
|
return None
|
|
764
|
-
except (RepositoryNotFoundError, HFValidationError):
|
|
778
|
+
except (RepositoryNotFoundError, HFValidationError, HfHubHTTPError):
|
|
765
779
|
return None
|
|
766
|
-
except HfHubHTTPError as e:
|
|
767
|
-
if "unauthorized" in str(e).lower():
|
|
768
|
-
raise InvalidModel(
|
|
769
|
-
"It seems like your specified Hugging Face API key is invalid. "
|
|
770
|
-
"Please double-check your API key."
|
|
771
|
-
) from e
|
|
772
|
-
raise InvalidModel(str(e)) from e
|
|
773
780
|
except (OSError, RequestException) as e:
|
|
774
781
|
if internet_connection_available():
|
|
775
782
|
errors.append(e)
|
|
776
783
|
continue
|
|
777
|
-
|
|
784
|
+
log(
|
|
778
785
|
"Could not access the Hugging Face Hub. Please check your internet "
|
|
779
|
-
"connection."
|
|
786
|
+
"connection.",
|
|
787
|
+
level=logging.DEBUG,
|
|
780
788
|
)
|
|
781
789
|
return None
|
|
782
790
|
else:
|
|
783
|
-
|
|
791
|
+
log(
|
|
784
792
|
f"Could not access model info for the model {model_id!r} from the "
|
|
785
793
|
f"Hugging Face Hub, after {num_attempts} attempts. The errors "
|
|
786
|
-
f"encountered were {errors!r}."
|
|
794
|
+
f"encountered were {errors!r}.",
|
|
795
|
+
level=logging.DEBUG,
|
|
787
796
|
)
|
|
788
797
|
return None
|
|
789
798
|
|
|
@@ -814,15 +823,15 @@ def get_model_repo_info(
|
|
|
814
823
|
hf_config = load_hf_model_config(
|
|
815
824
|
model_id=base_model_id or model_id,
|
|
816
825
|
num_labels=0,
|
|
817
|
-
id2label=
|
|
818
|
-
label2id=
|
|
826
|
+
id2label=HashableDict(),
|
|
827
|
+
label2id=HashableDict(),
|
|
819
828
|
revision=revision,
|
|
820
829
|
model_cache_dir=create_model_cache_dir(
|
|
821
|
-
cache_dir=
|
|
830
|
+
cache_dir=cache_dir, model_id=model_id
|
|
822
831
|
),
|
|
823
|
-
api_key=
|
|
824
|
-
trust_remote_code=
|
|
825
|
-
run_with_cli=
|
|
832
|
+
api_key=api_key,
|
|
833
|
+
trust_remote_code=trust_remote_code,
|
|
834
|
+
run_with_cli=run_with_cli,
|
|
826
835
|
)
|
|
827
836
|
class_names = hf_config.architectures
|
|
828
837
|
generative_class_names = [
|
|
@@ -837,19 +846,19 @@ def get_model_repo_info(
|
|
|
837
846
|
else:
|
|
838
847
|
pipeline_tag = "fill-mask"
|
|
839
848
|
|
|
840
|
-
if
|
|
849
|
+
if requires_safetensors:
|
|
841
850
|
repo_files = hf_api.list_repo_files(repo_id=model_id, revision=revision)
|
|
842
851
|
has_safetensors = any(f.endswith(".safetensors") for f in repo_files)
|
|
843
852
|
if not has_safetensors:
|
|
844
853
|
msg = f"Model {model_id} does not have safetensors weights available. "
|
|
845
|
-
if
|
|
854
|
+
if run_with_cli:
|
|
846
855
|
msg += "Skipping since the `--only-allow-safetensors` flag is set."
|
|
847
856
|
else:
|
|
848
857
|
msg += (
|
|
849
858
|
"Skipping since the `requires_safetensors` argument is set "
|
|
850
859
|
"to `True`."
|
|
851
860
|
)
|
|
852
|
-
|
|
861
|
+
log(msg, level=logging.WARNING)
|
|
853
862
|
return None
|
|
854
863
|
|
|
855
864
|
# Also check base model if we are evaluating an adapter
|
|
@@ -863,7 +872,7 @@ def get_model_repo_info(
|
|
|
863
872
|
f"Base model {base_model_id} does not have safetensors weights "
|
|
864
873
|
"available."
|
|
865
874
|
)
|
|
866
|
-
if
|
|
875
|
+
if run_with_cli:
|
|
867
876
|
msg += " Skipping since the `--only-allow-safetensors` flag is set."
|
|
868
877
|
else:
|
|
869
878
|
msg += (
|
|
@@ -929,7 +938,10 @@ def load_tokeniser(
|
|
|
929
938
|
f"Could not load tokeniser for model {model_id!r}."
|
|
930
939
|
) from e
|
|
931
940
|
except (TimeoutError, RequestError):
|
|
932
|
-
|
|
941
|
+
log(
|
|
942
|
+
f"Couldn't load tokeniser for {model_id!r}. Retrying.",
|
|
943
|
+
level=logging.WARNING,
|
|
944
|
+
)
|
|
933
945
|
sleep(5)
|
|
934
946
|
continue
|
|
935
947
|
else:
|
|
@@ -945,6 +957,7 @@ def load_tokeniser(
|
|
|
945
957
|
return tokeniser
|
|
946
958
|
|
|
947
959
|
|
|
960
|
+
@cache_arguments()
|
|
948
961
|
def get_dtype(
|
|
949
962
|
device: torch.device, dtype_is_set: bool, bf16_available: bool
|
|
950
963
|
) -> str | torch.dtype:
|
|
@@ -953,6 +966,7 @@ def get_dtype(
|
|
|
953
966
|
Args:
|
|
954
967
|
device:
|
|
955
968
|
The device to use.
|
|
969
|
+
dtype_is_set:
|
|
956
970
|
Whether the data type is set in the model configuration.
|
|
957
971
|
bf16_available:
|
|
958
972
|
Whether bfloat16 is available.
|
|
@@ -970,6 +984,7 @@ def get_dtype(
|
|
|
970
984
|
return torch.float32
|
|
971
985
|
|
|
972
986
|
|
|
987
|
+
@cache_arguments("model_id", "revision", "num_labels", "id2label", "label2id")
|
|
973
988
|
def load_hf_model_config(
|
|
974
989
|
model_id: str,
|
|
975
990
|
num_labels: int,
|
|
@@ -1006,7 +1021,7 @@ def load_hf_model_config(
|
|
|
1006
1021
|
Returns:
|
|
1007
1022
|
The Hugging Face model configuration.
|
|
1008
1023
|
"""
|
|
1009
|
-
|
|
1024
|
+
for _ in range(num_attempts := 5):
|
|
1010
1025
|
try:
|
|
1011
1026
|
config = AutoConfig.from_pretrained(
|
|
1012
1027
|
model_id,
|
|
@@ -1019,12 +1034,7 @@ def load_hf_model_config(
|
|
|
1019
1034
|
cache_dir=model_cache_dir,
|
|
1020
1035
|
local_files_only=not internet_connection_available(),
|
|
1021
1036
|
)
|
|
1022
|
-
|
|
1023
|
-
if isinstance(config.eos_token_id, list):
|
|
1024
|
-
config.pad_token_id = config.eos_token_id[0]
|
|
1025
|
-
else:
|
|
1026
|
-
config.pad_token_id = config.eos_token_id
|
|
1027
|
-
return config
|
|
1037
|
+
break
|
|
1028
1038
|
except KeyError as e:
|
|
1029
1039
|
key = e.args[0]
|
|
1030
1040
|
raise InvalidModel(
|
|
@@ -1032,18 +1042,23 @@ def load_hf_model_config(
|
|
|
1032
1042
|
f"loaded, as the key {key!r} was not found in the config."
|
|
1033
1043
|
) from e
|
|
1034
1044
|
except (OSError, GatedRepoError) as e:
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1045
|
+
if isinstance(e, GatedRepoError) or "gated repo" in str(e).lower():
|
|
1046
|
+
raise InvalidModel(
|
|
1047
|
+
f"The model {model_id!r} is a gated repository. Please ensure "
|
|
1048
|
+
"that you are logged in with `hf auth login` or have provided a "
|
|
1049
|
+
"valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
|
|
1050
|
+
"environment variable or the `--api-key` argument. Also check that "
|
|
1051
|
+
"your account has access to this model."
|
|
1052
|
+
) from e
|
|
1041
1053
|
raise InvalidModel(
|
|
1042
1054
|
f"Couldn't load model config for {model_id!r}. The error was "
|
|
1043
1055
|
f"{e!r}. Skipping"
|
|
1044
1056
|
) from e
|
|
1045
1057
|
except (TimeoutError, RequestError):
|
|
1046
|
-
|
|
1058
|
+
log(
|
|
1059
|
+
f"Couldn't load model config for {model_id!r}. Retrying.",
|
|
1060
|
+
level=logging.WARNING,
|
|
1061
|
+
)
|
|
1047
1062
|
sleep(5)
|
|
1048
1063
|
continue
|
|
1049
1064
|
except ValueError as e:
|
|
@@ -1062,6 +1077,20 @@ def load_hf_model_config(
|
|
|
1062
1077
|
f"The config for the model {model_id!r} could not be loaded. The "
|
|
1063
1078
|
f"error was {e!r}."
|
|
1064
1079
|
) from e
|
|
1080
|
+
else:
|
|
1081
|
+
raise InvalidModel(
|
|
1082
|
+
f"Couldn't load model config for {model_id!r} after {num_attempts} "
|
|
1083
|
+
"attempts."
|
|
1084
|
+
)
|
|
1085
|
+
|
|
1086
|
+
# Ensure that the PAD token ID is set
|
|
1087
|
+
if config.eos_token_id is not None and config.pad_token_id is None:
|
|
1088
|
+
if isinstance(config.eos_token_id, list):
|
|
1089
|
+
config.pad_token_id = config.eos_token_id[0]
|
|
1090
|
+
else:
|
|
1091
|
+
config.pad_token_id = config.eos_token_id
|
|
1092
|
+
|
|
1093
|
+
return config
|
|
1065
1094
|
|
|
1066
1095
|
|
|
1067
1096
|
def setup_model_for_question_answering(model: "PreTrainedModel") -> "PreTrainedModel":
|
|
@@ -1230,6 +1259,7 @@ def align_model_and_tokeniser(
|
|
|
1230
1259
|
return model, tokeniser
|
|
1231
1260
|
|
|
1232
1261
|
|
|
1262
|
+
@cache_arguments()
|
|
1233
1263
|
def task_group_to_class_name(task_group: TaskGroup) -> str:
|
|
1234
1264
|
"""Convert a task group to a class name.
|
|
1235
1265
|
|