EuroEval 16.3.0__py3-none-any.whl → 16.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +3 -2
- euroeval/benchmark_config_factory.py +0 -4
- euroeval/benchmark_modules/base.py +3 -16
- euroeval/benchmark_modules/fresh.py +2 -1
- euroeval/benchmark_modules/hf.py +99 -62
- euroeval/benchmark_modules/litellm.py +101 -41
- euroeval/benchmark_modules/vllm.py +91 -83
- euroeval/benchmarker.py +84 -78
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/constants.py +6 -0
- euroeval/data_loading.py +14 -11
- euroeval/data_models.py +12 -4
- euroeval/dataset_configs/__init__.py +2 -0
- euroeval/dataset_configs/czech.py +79 -0
- euroeval/dataset_configs/danish.py +10 -11
- euroeval/dataset_configs/dutch.py +0 -1
- euroeval/dataset_configs/english.py +0 -1
- euroeval/dataset_configs/estonian.py +11 -1
- euroeval/dataset_configs/finnish.py +0 -1
- euroeval/dataset_configs/french.py +0 -1
- euroeval/dataset_configs/german.py +0 -1
- euroeval/dataset_configs/italian.py +0 -1
- euroeval/dataset_configs/latvian.py +0 -1
- euroeval/dataset_configs/lithuanian.py +9 -3
- euroeval/dataset_configs/norwegian.py +0 -1
- euroeval/dataset_configs/polish.py +0 -1
- euroeval/dataset_configs/portuguese.py +0 -1
- euroeval/dataset_configs/slovak.py +60 -0
- euroeval/dataset_configs/spanish.py +0 -1
- euroeval/dataset_configs/swedish.py +10 -12
- euroeval/finetuning.py +21 -15
- euroeval/generation.py +10 -10
- euroeval/generation_utils.py +2 -3
- euroeval/logging_utils.py +250 -0
- euroeval/metrics/base.py +0 -3
- euroeval/metrics/huggingface.py +9 -5
- euroeval/metrics/llm_as_a_judge.py +5 -3
- euroeval/metrics/pipeline.py +17 -9
- euroeval/metrics/speed.py +0 -3
- euroeval/model_cache.py +11 -14
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/linguistic_acceptability.py +21 -3
- euroeval/prompt_templates/multiple_choice.py +25 -1
- euroeval/prompt_templates/named_entity_recognition.py +51 -11
- euroeval/prompt_templates/reading_comprehension.py +31 -3
- euroeval/prompt_templates/sentiment_classification.py +23 -1
- euroeval/prompt_templates/summarization.py +26 -6
- euroeval/scores.py +7 -7
- euroeval/speed_benchmark.py +3 -5
- euroeval/task_group_utils/multiple_choice_classification.py +0 -3
- euroeval/task_group_utils/question_answering.py +0 -3
- euroeval/task_group_utils/sequence_classification.py +43 -31
- euroeval/task_group_utils/text_to_text.py +17 -8
- euroeval/task_group_utils/token_classification.py +10 -9
- euroeval/tokenisation_utils.py +14 -12
- euroeval/utils.py +29 -146
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/METADATA +4 -4
- euroeval-16.4.0.dist-info/RECORD +75 -0
- euroeval-16.3.0.dist-info/RECORD +0 -71
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
euroeval/__init__.py
CHANGED
|
@@ -21,7 +21,8 @@ if os.getenv("FULL_LOG") != "1":
|
|
|
21
21
|
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
22
22
|
|
|
23
23
|
# Set up logging
|
|
24
|
-
fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
|
|
24
|
+
# fmt = colored("%(asctime)s", "light_blue") + " ⋅ " + colored("%(message)s", "green")
|
|
25
|
+
fmt = colored("%(message)s", "light_yellow")
|
|
25
26
|
logging.basicConfig(
|
|
26
27
|
level=logging.CRITICAL if hasattr(sys, "_called_from_test") else logging.INFO,
|
|
27
28
|
format=fmt,
|
|
@@ -50,7 +51,7 @@ import importlib.metadata # noqa: E402
|
|
|
50
51
|
from dotenv import load_dotenv # noqa: E402
|
|
51
52
|
|
|
52
53
|
from .benchmarker import Benchmarker # noqa: E402
|
|
53
|
-
from .
|
|
54
|
+
from .logging_utils import block_terminal_output # noqa: E402
|
|
54
55
|
|
|
55
56
|
# Block unwanted terminal outputs. This blocks way more than the above, but since it
|
|
56
57
|
# relies on importing from the `utils` module, external modules are already imported
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Factory class for creating dataset configurations."""
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
import sys
|
|
5
4
|
import typing as t
|
|
6
5
|
|
|
@@ -17,9 +16,6 @@ if t.TYPE_CHECKING:
|
|
|
17
16
|
from .data_models import Language, Task
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
logger = logging.getLogger("euroeval")
|
|
21
|
-
|
|
22
|
-
|
|
23
19
|
def build_benchmark_config(
|
|
24
20
|
benchmark_config_params: BenchmarkConfigParams,
|
|
25
21
|
) -> BenchmarkConfig:
|
|
@@ -3,24 +3,22 @@
|
|
|
3
3
|
import collections.abc as c
|
|
4
4
|
import logging
|
|
5
5
|
import re
|
|
6
|
-
import sys
|
|
7
6
|
import typing as t
|
|
8
7
|
from abc import ABC, abstractmethod
|
|
9
8
|
from functools import cached_property, partial
|
|
10
9
|
|
|
11
10
|
from datasets import Dataset, DatasetDict
|
|
12
11
|
from torch import nn
|
|
13
|
-
from tqdm.auto import tqdm
|
|
14
12
|
|
|
15
13
|
from ..enums import TaskGroup
|
|
16
14
|
from ..exceptions import InvalidBenchmark, NeedsEnvironmentVariable, NeedsExtraInstalled
|
|
15
|
+
from ..logging_utils import get_pbar, log_once
|
|
17
16
|
from ..task_group_utils import (
|
|
18
17
|
question_answering,
|
|
19
18
|
sequence_classification,
|
|
20
19
|
text_to_text,
|
|
21
20
|
token_classification,
|
|
22
21
|
)
|
|
23
|
-
from ..utils import log_once
|
|
24
22
|
|
|
25
23
|
if t.TYPE_CHECKING:
|
|
26
24
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
@@ -36,8 +34,6 @@ if t.TYPE_CHECKING:
|
|
|
36
34
|
from ..enums import BatchingPreference, GenerativeType
|
|
37
35
|
from ..types import ComputeMetricsFunction, ExtractLabelsFunction
|
|
38
36
|
|
|
39
|
-
logger = logging.getLogger("euroeval")
|
|
40
|
-
|
|
41
37
|
|
|
42
38
|
class BenchmarkModule(ABC):
|
|
43
39
|
"""Abstract class for a benchmark module.
|
|
@@ -87,16 +83,7 @@ class BenchmarkModule(ABC):
|
|
|
87
83
|
|
|
88
84
|
def _log_metadata(self) -> None:
|
|
89
85
|
"""Log the metadata of the model."""
|
|
90
|
-
|
|
91
|
-
if hasattr(sys, "_called_from_test"):
|
|
92
|
-
logging_level = logging.CRITICAL
|
|
93
|
-
elif self.benchmark_config.verbose:
|
|
94
|
-
logging_level = logging.DEBUG
|
|
95
|
-
else:
|
|
96
|
-
logging_level = logging.INFO
|
|
97
|
-
logger.setLevel(logging_level)
|
|
98
|
-
|
|
99
|
-
logging_msg: str = ""
|
|
86
|
+
logging_msg: str = " ↳ "
|
|
100
87
|
if self.num_params < 0:
|
|
101
88
|
logging_msg += "The model has an unknown number of parameters, "
|
|
102
89
|
else:
|
|
@@ -273,7 +260,7 @@ class BenchmarkModule(ABC):
|
|
|
273
260
|
tasks.
|
|
274
261
|
"""
|
|
275
262
|
for idx, dataset in enumerate(
|
|
276
|
-
|
|
263
|
+
get_pbar(iterable=datasets, desc="Preparing datasets")
|
|
277
264
|
):
|
|
278
265
|
prepared_dataset = self.prepare_dataset(
|
|
279
266
|
dataset=dataset, task=task, itr_idx=idx
|
|
@@ -27,7 +27,8 @@ from ..exceptions import (
|
|
|
27
27
|
NeedsExtraInstalled,
|
|
28
28
|
)
|
|
29
29
|
from ..generation_utils import raise_if_wrong_params
|
|
30
|
-
from ..
|
|
30
|
+
from ..logging_utils import block_terminal_output
|
|
31
|
+
from ..utils import create_model_cache_dir, get_hf_token
|
|
31
32
|
from .hf import (
|
|
32
33
|
HuggingFaceEncoderModel,
|
|
33
34
|
align_model_and_tokeniser,
|
euroeval/benchmark_modules/hf.py
CHANGED
|
@@ -36,6 +36,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
|
36
36
|
from transformers.trainer import Trainer
|
|
37
37
|
from urllib3.exceptions import RequestError
|
|
38
38
|
|
|
39
|
+
from ..caching_utils import cache_arguments
|
|
39
40
|
from ..constants import (
|
|
40
41
|
DUMMY_FILL_VALUE,
|
|
41
42
|
GENERATIVE_PIPELINE_TAGS,
|
|
@@ -43,7 +44,7 @@ from ..constants import (
|
|
|
43
44
|
MAX_CONTEXT_LENGTH,
|
|
44
45
|
MERGE_TAGS,
|
|
45
46
|
)
|
|
46
|
-
from ..data_models import HFModelInfo, ModelConfig
|
|
47
|
+
from ..data_models import HashableDict, HFModelInfo, ModelConfig
|
|
47
48
|
from ..enums import (
|
|
48
49
|
BatchingPreference,
|
|
49
50
|
GenerativeType,
|
|
@@ -60,6 +61,7 @@ from ..exceptions import (
|
|
|
60
61
|
)
|
|
61
62
|
from ..generation_utils import raise_if_wrong_params
|
|
62
63
|
from ..languages import get_all_languages
|
|
64
|
+
from ..logging_utils import block_terminal_output, log, log_once
|
|
63
65
|
from ..task_group_utils import (
|
|
64
66
|
multiple_choice_classification,
|
|
65
67
|
question_answering,
|
|
@@ -67,12 +69,10 @@ from ..task_group_utils import (
|
|
|
67
69
|
)
|
|
68
70
|
from ..tokenisation_utils import get_bos_token, get_eos_token
|
|
69
71
|
from ..utils import (
|
|
70
|
-
block_terminal_output,
|
|
71
72
|
create_model_cache_dir,
|
|
72
73
|
get_class_by_name,
|
|
73
74
|
get_hf_token,
|
|
74
75
|
internet_connection_available,
|
|
75
|
-
log_once,
|
|
76
76
|
split_model_id,
|
|
77
77
|
)
|
|
78
78
|
from .base import BenchmarkModule
|
|
@@ -85,8 +85,6 @@ if t.TYPE_CHECKING:
|
|
|
85
85
|
from ..data_models import BenchmarkConfig, DatasetConfig, Task
|
|
86
86
|
from ..types import ExtractLabelsFunction
|
|
87
87
|
|
|
88
|
-
logger = logging.getLogger("euroeval")
|
|
89
|
-
|
|
90
88
|
|
|
91
89
|
class HuggingFaceEncoderModel(BenchmarkModule):
|
|
92
90
|
"""An encoder model from the Hugging Face Hub."""
|
|
@@ -183,12 +181,13 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
183
181
|
elif hasattr(self._model, "parameters"):
|
|
184
182
|
num_params = sum(p.numel() for p in self._model.parameters())
|
|
185
183
|
else:
|
|
186
|
-
|
|
184
|
+
log(
|
|
187
185
|
"The number of parameters could not be determined for the model, since "
|
|
188
186
|
"the model is not stored in the safetensors format. If this is your "
|
|
189
187
|
"own model, then you can use this Hugging Face Space to convert your "
|
|
190
188
|
"model to the safetensors format: "
|
|
191
|
-
"https://huggingface.co/spaces/safetensors/convert."
|
|
189
|
+
"https://huggingface.co/spaces/safetensors/convert.",
|
|
190
|
+
level=logging.WARNING,
|
|
192
191
|
)
|
|
193
192
|
num_params = -1
|
|
194
193
|
return num_params
|
|
@@ -491,7 +490,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
491
490
|
model_info = get_model_repo_info(
|
|
492
491
|
model_id=model_id_components.model_id,
|
|
493
492
|
revision=model_id_components.revision,
|
|
494
|
-
|
|
493
|
+
api_key=benchmark_config.api_key,
|
|
494
|
+
cache_dir=benchmark_config.cache_dir,
|
|
495
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
496
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
497
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
495
498
|
)
|
|
496
499
|
return (
|
|
497
500
|
model_info is not None
|
|
@@ -517,7 +520,11 @@ class HuggingFaceEncoderModel(BenchmarkModule):
|
|
|
517
520
|
model_info = get_model_repo_info(
|
|
518
521
|
model_id=model_id_components.model_id,
|
|
519
522
|
revision=model_id_components.revision,
|
|
520
|
-
|
|
523
|
+
api_key=benchmark_config.api_key,
|
|
524
|
+
cache_dir=benchmark_config.cache_dir,
|
|
525
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
526
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
527
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
521
528
|
)
|
|
522
529
|
if model_info is None:
|
|
523
530
|
raise InvalidModel(f"The model {model_id!r} could not be found.")
|
|
@@ -583,8 +590,8 @@ def load_model_and_tokeniser(
|
|
|
583
590
|
config = load_hf_model_config(
|
|
584
591
|
model_id=model_id,
|
|
585
592
|
num_labels=len(id2label),
|
|
586
|
-
id2label=id2label,
|
|
587
|
-
label2id={label: idx for idx, label in id2label.items()},
|
|
593
|
+
id2label=HashableDict(id2label),
|
|
594
|
+
label2id=HashableDict({label: idx for idx, label in id2label.items()}),
|
|
588
595
|
revision=model_config.revision,
|
|
589
596
|
model_cache_dir=model_config.model_cache_dir,
|
|
590
597
|
api_key=benchmark_config.api_key,
|
|
@@ -608,11 +615,8 @@ def load_model_and_tokeniser(
|
|
|
608
615
|
),
|
|
609
616
|
)
|
|
610
617
|
|
|
611
|
-
# These are used when a timeout occurs
|
|
612
|
-
attempts_left = 5
|
|
613
|
-
|
|
614
618
|
model: "PreTrainedModel | None" = None
|
|
615
|
-
|
|
619
|
+
for _ in range(num_attempts := 5):
|
|
616
620
|
# Get the model class associated with the task group
|
|
617
621
|
model_cls_or_none: t.Type["PreTrainedModel"] | None = get_class_by_name(
|
|
618
622
|
class_name=task_group_to_class_name(task_group=task_group),
|
|
@@ -639,22 +643,21 @@ def load_model_and_tokeniser(
|
|
|
639
643
|
break
|
|
640
644
|
except (KeyError, RuntimeError) as e:
|
|
641
645
|
if not model_kwargs["ignore_mismatched_sizes"]:
|
|
642
|
-
|
|
646
|
+
log(
|
|
643
647
|
f"{type(e).__name__} occurred during the loading "
|
|
644
648
|
f"of the {model_id!r} model. Retrying with "
|
|
645
|
-
"`ignore_mismatched_sizes` set to True."
|
|
649
|
+
"`ignore_mismatched_sizes` set to True.",
|
|
650
|
+
level=logging.DEBUG,
|
|
646
651
|
)
|
|
647
652
|
model_kwargs["ignore_mismatched_sizes"] = True
|
|
648
653
|
continue
|
|
649
654
|
else:
|
|
650
655
|
raise InvalidModel(str(e)) from e
|
|
651
|
-
except (TimeoutError, RequestError)
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
) from e
|
|
657
|
-
logger.info(f"Couldn't load the model {model_id!r}. Retrying.")
|
|
656
|
+
except (TimeoutError, RequestError):
|
|
657
|
+
log(
|
|
658
|
+
f"Couldn't load the model {model_id!r}. Retrying.",
|
|
659
|
+
level=logging.WARNING,
|
|
660
|
+
)
|
|
658
661
|
sleep(5)
|
|
659
662
|
continue
|
|
660
663
|
except (OSError, ValueError) as e:
|
|
@@ -671,6 +674,10 @@ def load_model_and_tokeniser(
|
|
|
671
674
|
raise InvalidModel(
|
|
672
675
|
f"The model {model_id!r} could not be loaded. The error was {e!r}."
|
|
673
676
|
) from e
|
|
677
|
+
else:
|
|
678
|
+
raise InvalidModel(
|
|
679
|
+
f"Could not load the model {model_id!r} after {num_attempts} attempts."
|
|
680
|
+
)
|
|
674
681
|
|
|
675
682
|
if isinstance(model_or_tuple, tuple):
|
|
676
683
|
model = model_or_tuple[0]
|
|
@@ -698,8 +705,15 @@ def load_model_and_tokeniser(
|
|
|
698
705
|
return model, tokeniser
|
|
699
706
|
|
|
700
707
|
|
|
708
|
+
@cache_arguments("model_id", "revision")
|
|
701
709
|
def get_model_repo_info(
|
|
702
|
-
model_id: str,
|
|
710
|
+
model_id: str,
|
|
711
|
+
revision: str,
|
|
712
|
+
api_key: str | None,
|
|
713
|
+
cache_dir: str,
|
|
714
|
+
trust_remote_code: bool,
|
|
715
|
+
requires_safetensors: bool,
|
|
716
|
+
run_with_cli: bool,
|
|
703
717
|
) -> "HFModelInfo | None":
|
|
704
718
|
"""Get the information about the model from the HF Hub or a local directory.
|
|
705
719
|
|
|
@@ -708,13 +722,11 @@ def get_model_repo_info(
|
|
|
708
722
|
The model ID.
|
|
709
723
|
revision:
|
|
710
724
|
The revision of the model.
|
|
711
|
-
benchmark_config:
|
|
712
|
-
The benchmark configuration.
|
|
713
725
|
|
|
714
726
|
Returns:
|
|
715
727
|
The information about the model, or None if the model could not be found.
|
|
716
728
|
"""
|
|
717
|
-
token = get_hf_token(api_key=
|
|
729
|
+
token = get_hf_token(api_key=api_key)
|
|
718
730
|
hf_api = HfApi(token=token)
|
|
719
731
|
|
|
720
732
|
# Get information on the model.
|
|
@@ -722,7 +734,7 @@ def get_model_repo_info(
|
|
|
722
734
|
# model info object.
|
|
723
735
|
model_info: HfApiModelInfo | None = None
|
|
724
736
|
if Path(model_id).is_dir():
|
|
725
|
-
|
|
737
|
+
log(f"Checking for local model in {model_id}.", level=logging.DEBUG)
|
|
726
738
|
if all(
|
|
727
739
|
(Path(model_id) / required_file).exists()
|
|
728
740
|
for required_file in LOCAL_MODELS_REQUIRED_FILES
|
|
@@ -748,17 +760,19 @@ def get_model_repo_info(
|
|
|
748
760
|
except (GatedRepoError, LocalTokenNotFoundError) as e:
|
|
749
761
|
try:
|
|
750
762
|
hf_whoami(token=token)
|
|
751
|
-
|
|
763
|
+
log(
|
|
752
764
|
f"Could not access the model {model_id} with the revision "
|
|
753
|
-
f"{revision}. The error was {str(e)!r}."
|
|
765
|
+
f"{revision}. The error was {str(e)!r}.",
|
|
766
|
+
level=logging.DEBUG,
|
|
754
767
|
)
|
|
755
768
|
return None
|
|
756
769
|
except LocalTokenNotFoundError:
|
|
757
|
-
|
|
770
|
+
log(
|
|
758
771
|
f"Could not access the model {model_id} with the revision "
|
|
759
772
|
f"{revision}. The error was {str(e)!r}. Please set the "
|
|
760
773
|
"`HUGGINGFACE_API_KEY` environment variable or use the "
|
|
761
|
-
"`--api-key` argument."
|
|
774
|
+
"`--api-key` argument.",
|
|
775
|
+
level=logging.DEBUG,
|
|
762
776
|
)
|
|
763
777
|
return None
|
|
764
778
|
except (RepositoryNotFoundError, HFValidationError):
|
|
@@ -774,16 +788,18 @@ def get_model_repo_info(
|
|
|
774
788
|
if internet_connection_available():
|
|
775
789
|
errors.append(e)
|
|
776
790
|
continue
|
|
777
|
-
|
|
791
|
+
log(
|
|
778
792
|
"Could not access the Hugging Face Hub. Please check your internet "
|
|
779
|
-
"connection."
|
|
793
|
+
"connection.",
|
|
794
|
+
level=logging.DEBUG,
|
|
780
795
|
)
|
|
781
796
|
return None
|
|
782
797
|
else:
|
|
783
|
-
|
|
798
|
+
log(
|
|
784
799
|
f"Could not access model info for the model {model_id!r} from the "
|
|
785
800
|
f"Hugging Face Hub, after {num_attempts} attempts. The errors "
|
|
786
|
-
f"encountered were {errors!r}."
|
|
801
|
+
f"encountered were {errors!r}.",
|
|
802
|
+
level=logging.DEBUG,
|
|
787
803
|
)
|
|
788
804
|
return None
|
|
789
805
|
|
|
@@ -814,15 +830,15 @@ def get_model_repo_info(
|
|
|
814
830
|
hf_config = load_hf_model_config(
|
|
815
831
|
model_id=base_model_id or model_id,
|
|
816
832
|
num_labels=0,
|
|
817
|
-
id2label=
|
|
818
|
-
label2id=
|
|
833
|
+
id2label=HashableDict(),
|
|
834
|
+
label2id=HashableDict(),
|
|
819
835
|
revision=revision,
|
|
820
836
|
model_cache_dir=create_model_cache_dir(
|
|
821
|
-
cache_dir=
|
|
837
|
+
cache_dir=cache_dir, model_id=model_id
|
|
822
838
|
),
|
|
823
|
-
api_key=
|
|
824
|
-
trust_remote_code=
|
|
825
|
-
run_with_cli=
|
|
839
|
+
api_key=api_key,
|
|
840
|
+
trust_remote_code=trust_remote_code,
|
|
841
|
+
run_with_cli=run_with_cli,
|
|
826
842
|
)
|
|
827
843
|
class_names = hf_config.architectures
|
|
828
844
|
generative_class_names = [
|
|
@@ -837,19 +853,19 @@ def get_model_repo_info(
|
|
|
837
853
|
else:
|
|
838
854
|
pipeline_tag = "fill-mask"
|
|
839
855
|
|
|
840
|
-
if
|
|
856
|
+
if requires_safetensors:
|
|
841
857
|
repo_files = hf_api.list_repo_files(repo_id=model_id, revision=revision)
|
|
842
858
|
has_safetensors = any(f.endswith(".safetensors") for f in repo_files)
|
|
843
859
|
if not has_safetensors:
|
|
844
860
|
msg = f"Model {model_id} does not have safetensors weights available. "
|
|
845
|
-
if
|
|
861
|
+
if run_with_cli:
|
|
846
862
|
msg += "Skipping since the `--only-allow-safetensors` flag is set."
|
|
847
863
|
else:
|
|
848
864
|
msg += (
|
|
849
865
|
"Skipping since the `requires_safetensors` argument is set "
|
|
850
866
|
"to `True`."
|
|
851
867
|
)
|
|
852
|
-
|
|
868
|
+
log(msg, level=logging.WARNING)
|
|
853
869
|
return None
|
|
854
870
|
|
|
855
871
|
# Also check base model if we are evaluating an adapter
|
|
@@ -863,7 +879,7 @@ def get_model_repo_info(
|
|
|
863
879
|
f"Base model {base_model_id} does not have safetensors weights "
|
|
864
880
|
"available."
|
|
865
881
|
)
|
|
866
|
-
if
|
|
882
|
+
if run_with_cli:
|
|
867
883
|
msg += " Skipping since the `--only-allow-safetensors` flag is set."
|
|
868
884
|
else:
|
|
869
885
|
msg += (
|
|
@@ -929,7 +945,10 @@ def load_tokeniser(
|
|
|
929
945
|
f"Could not load tokeniser for model {model_id!r}."
|
|
930
946
|
) from e
|
|
931
947
|
except (TimeoutError, RequestError):
|
|
932
|
-
|
|
948
|
+
log(
|
|
949
|
+
f"Couldn't load tokeniser for {model_id!r}. Retrying.",
|
|
950
|
+
level=logging.WARNING,
|
|
951
|
+
)
|
|
933
952
|
sleep(5)
|
|
934
953
|
continue
|
|
935
954
|
else:
|
|
@@ -945,6 +964,7 @@ def load_tokeniser(
|
|
|
945
964
|
return tokeniser
|
|
946
965
|
|
|
947
966
|
|
|
967
|
+
@cache_arguments()
|
|
948
968
|
def get_dtype(
|
|
949
969
|
device: torch.device, dtype_is_set: bool, bf16_available: bool
|
|
950
970
|
) -> str | torch.dtype:
|
|
@@ -953,6 +973,7 @@ def get_dtype(
|
|
|
953
973
|
Args:
|
|
954
974
|
device:
|
|
955
975
|
The device to use.
|
|
976
|
+
dtype_is_set:
|
|
956
977
|
Whether the data type is set in the model configuration.
|
|
957
978
|
bf16_available:
|
|
958
979
|
Whether bfloat16 is available.
|
|
@@ -970,6 +991,7 @@ def get_dtype(
|
|
|
970
991
|
return torch.float32
|
|
971
992
|
|
|
972
993
|
|
|
994
|
+
@cache_arguments("model_id", "revision", "num_labels", "id2label", "label2id")
|
|
973
995
|
def load_hf_model_config(
|
|
974
996
|
model_id: str,
|
|
975
997
|
num_labels: int,
|
|
@@ -1006,7 +1028,7 @@ def load_hf_model_config(
|
|
|
1006
1028
|
Returns:
|
|
1007
1029
|
The Hugging Face model configuration.
|
|
1008
1030
|
"""
|
|
1009
|
-
|
|
1031
|
+
for _ in range(num_attempts := 5):
|
|
1010
1032
|
try:
|
|
1011
1033
|
config = AutoConfig.from_pretrained(
|
|
1012
1034
|
model_id,
|
|
@@ -1019,12 +1041,7 @@ def load_hf_model_config(
|
|
|
1019
1041
|
cache_dir=model_cache_dir,
|
|
1020
1042
|
local_files_only=not internet_connection_available(),
|
|
1021
1043
|
)
|
|
1022
|
-
|
|
1023
|
-
if isinstance(config.eos_token_id, list):
|
|
1024
|
-
config.pad_token_id = config.eos_token_id[0]
|
|
1025
|
-
else:
|
|
1026
|
-
config.pad_token_id = config.eos_token_id
|
|
1027
|
-
return config
|
|
1044
|
+
break
|
|
1028
1045
|
except KeyError as e:
|
|
1029
1046
|
key = e.args[0]
|
|
1030
1047
|
raise InvalidModel(
|
|
@@ -1032,18 +1049,23 @@ def load_hf_model_config(
|
|
|
1032
1049
|
f"loaded, as the key {key!r} was not found in the config."
|
|
1033
1050
|
) from e
|
|
1034
1051
|
except (OSError, GatedRepoError) as e:
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1052
|
+
if isinstance(e, GatedRepoError) or "gated repo" in str(e).lower():
|
|
1053
|
+
raise InvalidModel(
|
|
1054
|
+
f"The model {model_id!r} is a gated repository. Please ensure "
|
|
1055
|
+
"that you are logged in with `hf auth login` or have provided a "
|
|
1056
|
+
"valid Hugging Face access token with the `HUGGINGFACE_API_KEY` "
|
|
1057
|
+
"environment variable or the `--api-key` argument. Also check that "
|
|
1058
|
+
"your account has access to this model."
|
|
1059
|
+
) from e
|
|
1041
1060
|
raise InvalidModel(
|
|
1042
1061
|
f"Couldn't load model config for {model_id!r}. The error was "
|
|
1043
1062
|
f"{e!r}. Skipping"
|
|
1044
1063
|
) from e
|
|
1045
1064
|
except (TimeoutError, RequestError):
|
|
1046
|
-
|
|
1065
|
+
log(
|
|
1066
|
+
f"Couldn't load model config for {model_id!r}. Retrying.",
|
|
1067
|
+
level=logging.WARNING,
|
|
1068
|
+
)
|
|
1047
1069
|
sleep(5)
|
|
1048
1070
|
continue
|
|
1049
1071
|
except ValueError as e:
|
|
@@ -1062,6 +1084,20 @@ def load_hf_model_config(
|
|
|
1062
1084
|
f"The config for the model {model_id!r} could not be loaded. The "
|
|
1063
1085
|
f"error was {e!r}."
|
|
1064
1086
|
) from e
|
|
1087
|
+
else:
|
|
1088
|
+
raise InvalidModel(
|
|
1089
|
+
f"Couldn't load model config for {model_id!r} after {num_attempts} "
|
|
1090
|
+
"attempts."
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
# Ensure that the PAD token ID is set
|
|
1094
|
+
if config.eos_token_id is not None and config.pad_token_id is None:
|
|
1095
|
+
if isinstance(config.eos_token_id, list):
|
|
1096
|
+
config.pad_token_id = config.eos_token_id[0]
|
|
1097
|
+
else:
|
|
1098
|
+
config.pad_token_id = config.eos_token_id
|
|
1099
|
+
|
|
1100
|
+
return config
|
|
1065
1101
|
|
|
1066
1102
|
|
|
1067
1103
|
def setup_model_for_question_answering(model: "PreTrainedModel") -> "PreTrainedModel":
|
|
@@ -1230,6 +1266,7 @@ def align_model_and_tokeniser(
|
|
|
1230
1266
|
return model, tokeniser
|
|
1231
1267
|
|
|
1232
1268
|
|
|
1269
|
+
@cache_arguments()
|
|
1233
1270
|
def task_group_to_class_name(task_group: TaskGroup) -> str:
|
|
1234
1271
|
"""Convert a task group to a class name.
|
|
1235
1272
|
|