EuroEval 16.2.2__py3-none-any.whl → 16.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +7 -4
- euroeval/benchmark_config_factory.py +0 -4
- euroeval/benchmark_modules/base.py +3 -16
- euroeval/benchmark_modules/fresh.py +5 -2
- euroeval/benchmark_modules/hf.py +107 -66
- euroeval/benchmark_modules/litellm.py +103 -55
- euroeval/benchmark_modules/vllm.py +155 -82
- euroeval/benchmarker.py +184 -129
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +1 -1
- euroeval/constants.py +9 -0
- euroeval/data_loading.py +14 -11
- euroeval/data_models.py +12 -4
- euroeval/dataset_configs/__init__.py +3 -0
- euroeval/dataset_configs/czech.py +79 -0
- euroeval/dataset_configs/danish.py +10 -13
- euroeval/dataset_configs/dutch.py +0 -3
- euroeval/dataset_configs/english.py +0 -3
- euroeval/dataset_configs/estonian.py +11 -1
- euroeval/dataset_configs/finnish.py +0 -3
- euroeval/dataset_configs/french.py +0 -3
- euroeval/dataset_configs/german.py +0 -3
- euroeval/dataset_configs/italian.py +0 -3
- euroeval/dataset_configs/latvian.py +2 -4
- euroeval/dataset_configs/lithuanian.py +68 -0
- euroeval/dataset_configs/norwegian.py +0 -3
- euroeval/dataset_configs/polish.py +0 -3
- euroeval/dataset_configs/portuguese.py +0 -3
- euroeval/dataset_configs/slovak.py +60 -0
- euroeval/dataset_configs/spanish.py +0 -3
- euroeval/dataset_configs/swedish.py +10 -15
- euroeval/finetuning.py +21 -15
- euroeval/generation.py +10 -10
- euroeval/generation_utils.py +2 -3
- euroeval/logging_utils.py +250 -0
- euroeval/metrics/base.py +0 -3
- euroeval/metrics/huggingface.py +10 -6
- euroeval/metrics/llm_as_a_judge.py +5 -3
- euroeval/metrics/pipeline.py +22 -9
- euroeval/metrics/speed.py +0 -3
- euroeval/model_cache.py +11 -14
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/linguistic_acceptability.py +30 -3
- euroeval/prompt_templates/multiple_choice.py +34 -1
- euroeval/prompt_templates/named_entity_recognition.py +71 -11
- euroeval/prompt_templates/reading_comprehension.py +41 -3
- euroeval/prompt_templates/sentiment_classification.py +34 -1
- euroeval/prompt_templates/summarization.py +26 -6
- euroeval/scores.py +7 -7
- euroeval/speed_benchmark.py +3 -5
- euroeval/task_group_utils/multiple_choice_classification.py +0 -3
- euroeval/task_group_utils/question_answering.py +0 -3
- euroeval/task_group_utils/sequence_classification.py +43 -31
- euroeval/task_group_utils/text_to_text.py +17 -8
- euroeval/task_group_utils/token_classification.py +10 -9
- euroeval/tokenisation_utils.py +22 -20
- euroeval/utils.py +30 -147
- {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/METADATA +182 -61
- euroeval-16.4.0.dist-info/RECORD +75 -0
- euroeval-16.2.2.dist-info/RECORD +0 -70
- {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
- {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.2.2.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,10 +14,9 @@ from time import sleep
|
|
|
14
14
|
import torch
|
|
15
15
|
from huggingface_hub import snapshot_download
|
|
16
16
|
from pydantic import conlist, create_model
|
|
17
|
-
from tqdm.auto import tqdm
|
|
18
|
-
from transformers import MistralCommonTokenizer
|
|
19
17
|
from transformers.models.auto.configuration_auto import AutoConfig
|
|
20
18
|
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
19
|
+
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
|
21
20
|
from urllib3.exceptions import RequestError
|
|
22
21
|
|
|
23
22
|
from ..constants import (
|
|
@@ -30,7 +29,7 @@ from ..constants import (
|
|
|
30
29
|
REASONING_TOKENS,
|
|
31
30
|
VLLM_BF16_MIN_CUDA_COMPUTE_CAPABILITY,
|
|
32
31
|
)
|
|
33
|
-
from ..data_models import GenerativeModelOutput, ModelConfig
|
|
32
|
+
from ..data_models import GenerativeModelOutput, HashableDict, ModelConfig
|
|
34
33
|
from ..enums import (
|
|
35
34
|
BatchingPreference,
|
|
36
35
|
GenerativeType,
|
|
@@ -50,6 +49,7 @@ from ..generation_utils import (
|
|
|
50
49
|
raise_if_wrong_params,
|
|
51
50
|
)
|
|
52
51
|
from ..languages import get_all_languages
|
|
52
|
+
from ..logging_utils import get_pbar, log, log_once, no_terminal_output
|
|
53
53
|
from ..task_group_utils import (
|
|
54
54
|
question_answering,
|
|
55
55
|
sequence_classification,
|
|
@@ -73,7 +73,6 @@ from ..utils import (
|
|
|
73
73
|
get_hf_token,
|
|
74
74
|
get_min_cuda_compute_capability,
|
|
75
75
|
internet_connection_available,
|
|
76
|
-
log_once,
|
|
77
76
|
resolve_model_path,
|
|
78
77
|
split_model_id,
|
|
79
78
|
)
|
|
@@ -86,7 +85,7 @@ if t.TYPE_CHECKING or importlib.util.find_spec("vllm") is not None:
|
|
|
86
85
|
destroy_model_parallel,
|
|
87
86
|
)
|
|
88
87
|
from vllm.lora.request import LoRARequest
|
|
89
|
-
from vllm.sampling_params import
|
|
88
|
+
from vllm.sampling_params import StructuredOutputsParams
|
|
90
89
|
|
|
91
90
|
if t.TYPE_CHECKING:
|
|
92
91
|
from datasets import DatasetDict
|
|
@@ -95,8 +94,6 @@ if t.TYPE_CHECKING:
|
|
|
95
94
|
|
|
96
95
|
from ..data_models import BenchmarkConfig, DatasetConfig, Task
|
|
97
96
|
|
|
98
|
-
logger = logging.getLogger("euroeval")
|
|
99
|
-
|
|
100
97
|
|
|
101
98
|
class VLLMModel(HuggingFaceEncoderModel):
|
|
102
99
|
"""A generative model using the vLLM inference framework."""
|
|
@@ -104,7 +101,7 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
104
101
|
fresh_model = False
|
|
105
102
|
batching_preference = BatchingPreference.ALL_AT_ONCE
|
|
106
103
|
high_priority = True
|
|
107
|
-
allowed_params = {re.compile(r".*"): ["thinking", "no-thinking"]}
|
|
104
|
+
allowed_params = {re.compile(r".*"): ["thinking", "no-thinking", "slow-tokenizer"]}
|
|
108
105
|
|
|
109
106
|
def __init__(
|
|
110
107
|
self,
|
|
@@ -132,9 +129,10 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
132
129
|
model_config=model_config, allowed_params=self.allowed_params
|
|
133
130
|
)
|
|
134
131
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
132
|
+
with no_terminal_output(disable=benchmark_config.verbose):
|
|
133
|
+
model, tokeniser = load_model_and_tokeniser(
|
|
134
|
+
model_config=model_config, benchmark_config=benchmark_config
|
|
135
|
+
)
|
|
138
136
|
self._model: "LLM" = model
|
|
139
137
|
self._tokeniser: "PreTrainedTokenizer" = tokeniser
|
|
140
138
|
|
|
@@ -245,6 +243,7 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
245
243
|
return partial(
|
|
246
244
|
sequence_classification.extract_labels_from_generation,
|
|
247
245
|
dataset_config=self.dataset_config,
|
|
246
|
+
model_config=self.model_config,
|
|
248
247
|
first_label_token_mapping=self.buffer["first_label_token_mapping"],
|
|
249
248
|
)
|
|
250
249
|
case TaskGroup.TEXT_TO_TEXT:
|
|
@@ -394,10 +393,11 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
394
393
|
self.dataset_config.task.uses_structured_output
|
|
395
394
|
or (self.dataset_config.task.uses_logprobs and self.dataset_config.labels)
|
|
396
395
|
) and self.generative_type == GenerativeType.REASONING:
|
|
397
|
-
|
|
398
|
-
|
|
396
|
+
structured_outputs = None
|
|
397
|
+
log(
|
|
399
398
|
"The dataset uses structured output, but we are not using it as the "
|
|
400
|
-
"model is a reasoning model."
|
|
399
|
+
"model is a reasoning model.",
|
|
400
|
+
level=logging.DEBUG,
|
|
401
401
|
)
|
|
402
402
|
elif self.dataset_config.task.uses_structured_output:
|
|
403
403
|
ner_tag_names = list(self.dataset_config.prompt_label_mapping.values())
|
|
@@ -412,9 +412,11 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
412
412
|
f"{json.dumps(structured_generation_schema)}",
|
|
413
413
|
level=logging.DEBUG,
|
|
414
414
|
)
|
|
415
|
-
|
|
415
|
+
structured_outputs = StructuredOutputsParams(
|
|
416
|
+
json=structured_generation_schema
|
|
417
|
+
)
|
|
416
418
|
elif self.dataset_config.task.uses_logprobs and self.dataset_config.labels:
|
|
417
|
-
|
|
419
|
+
structured_outputs = StructuredOutputsParams(
|
|
418
420
|
choice=[
|
|
419
421
|
self.dataset_config.prompt_label_mapping[label]
|
|
420
422
|
for label in self.dataset_config.labels
|
|
@@ -422,11 +424,11 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
422
424
|
)
|
|
423
425
|
log_once(
|
|
424
426
|
"Using structured generation with the choices: "
|
|
425
|
-
f"{
|
|
427
|
+
f"{structured_outputs.choice!r}.",
|
|
426
428
|
level=logging.DEBUG,
|
|
427
429
|
)
|
|
428
430
|
else:
|
|
429
|
-
|
|
431
|
+
structured_outputs = None
|
|
430
432
|
log_once(
|
|
431
433
|
"Not using structured generation as the dataset does not require it.",
|
|
432
434
|
level=logging.DEBUG,
|
|
@@ -445,14 +447,14 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
445
447
|
else None,
|
|
446
448
|
temperature=0.0,
|
|
447
449
|
stop=[stop_token for stop_token in stop_tokens if stop_token],
|
|
448
|
-
|
|
450
|
+
structured_outputs=structured_outputs,
|
|
449
451
|
)
|
|
450
452
|
|
|
451
453
|
# If any of the prompts are empty then we need to replace them with a BOS token
|
|
452
454
|
# so that the vLLM model can generate from them
|
|
453
455
|
prompts: list[str] = inputs["text"]
|
|
454
456
|
if any(len(prompt) == 0 for prompt in prompts):
|
|
455
|
-
|
|
457
|
+
log("Found empty prompts, replacing with BOS token.", level=logging.DEBUG)
|
|
456
458
|
prompts = [
|
|
457
459
|
prompt if len(prompt) > 0 else str(self._tokeniser.bos_token)
|
|
458
460
|
for prompt in prompts
|
|
@@ -480,13 +482,14 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
480
482
|
raw_outputs = self._model.generate(
|
|
481
483
|
prompts=prompts,
|
|
482
484
|
sampling_params=sampling_params,
|
|
483
|
-
use_tqdm=False if input_is_a_test else
|
|
485
|
+
use_tqdm=False if input_is_a_test else get_pbar,
|
|
484
486
|
lora_request=self.buffer.get("lora_request"),
|
|
485
487
|
)
|
|
486
488
|
break
|
|
487
489
|
except TypeError as e:
|
|
488
|
-
|
|
489
|
-
f"Encountered error during vLLM generation: {str(e)}. Retrying..."
|
|
490
|
+
log(
|
|
491
|
+
f"Encountered error during vLLM generation: {str(e)}. Retrying...",
|
|
492
|
+
level=logging.DEBUG,
|
|
490
493
|
)
|
|
491
494
|
sleep(1)
|
|
492
495
|
except ValueError as e:
|
|
@@ -498,10 +501,11 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
498
501
|
re.search(pattern, str(e), flags=re.IGNORECASE) is not None
|
|
499
502
|
for pattern in truncate_error_messages
|
|
500
503
|
):
|
|
501
|
-
|
|
502
|
-
"Prompts are too long, so truncating them and trying again..."
|
|
504
|
+
log(
|
|
505
|
+
"Prompts are too long, so truncating them and trying again...",
|
|
506
|
+
level=logging.WARNING,
|
|
503
507
|
)
|
|
504
|
-
|
|
508
|
+
log(f"The error message was: {str(e)}", level=logging.DEBUG)
|
|
505
509
|
|
|
506
510
|
# If we have already tried truncating the prompts a few times, then
|
|
507
511
|
# we truncate a bit more aggressively
|
|
@@ -544,26 +548,49 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
544
548
|
f"{num_extra_outputs!r} extra outputs."
|
|
545
549
|
)
|
|
546
550
|
else:
|
|
547
|
-
|
|
551
|
+
log(
|
|
548
552
|
f"Filtered out {num_extra_outputs:,} extra outputs from the model, "
|
|
549
553
|
"which occured as we interupted the generation when we truncated "
|
|
550
|
-
"the prompts."
|
|
554
|
+
"the prompts.",
|
|
555
|
+
level=logging.DEBUG,
|
|
551
556
|
)
|
|
552
557
|
|
|
553
558
|
# Parse the raw model outputs
|
|
554
559
|
completion_ids: list[list[int]] = [
|
|
555
|
-
output.outputs[0].token_ids for output in raw_outputs
|
|
560
|
+
list(output.outputs[0].token_ids) for output in raw_outputs
|
|
556
561
|
]
|
|
557
562
|
completions = self._tokeniser.batch_decode(
|
|
558
563
|
sequences=[
|
|
559
564
|
torch.LongTensor(completion_id) for completion_id in completion_ids
|
|
560
565
|
]
|
|
561
566
|
)
|
|
562
|
-
if
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
+
if (
|
|
568
|
+
self.end_of_reasoning_token is not None
|
|
569
|
+
and self.generative_type == GenerativeType.REASONING
|
|
570
|
+
):
|
|
571
|
+
num_samples_without_eor_token = 0
|
|
572
|
+
for idx in range(len(completions)):
|
|
573
|
+
if self.end_of_reasoning_token in completions[idx]:
|
|
574
|
+
completions[idx] = completions[idx].split(
|
|
575
|
+
self.end_of_reasoning_token
|
|
576
|
+
)[-1]
|
|
577
|
+
else:
|
|
578
|
+
num_samples_without_eor_token += 1
|
|
579
|
+
completions[idx] = ""
|
|
580
|
+
if num_samples_without_eor_token > 0:
|
|
581
|
+
log_once(
|
|
582
|
+
f"The model {self.model_config.model_id!r} is a reasoning "
|
|
583
|
+
"model, but the generated output did not contain the end of "
|
|
584
|
+
f"reasoning token ({self.end_of_reasoning_token!r}) in "
|
|
585
|
+
f"{num_samples_without_eor_token:,}/{len(completions):,} of "
|
|
586
|
+
"the samples. Using an empty string for all these samples "
|
|
587
|
+
"instead.",
|
|
588
|
+
level=(
|
|
589
|
+
logging.WARNING
|
|
590
|
+
if num_samples_without_eor_token / len(completions) > 0.5
|
|
591
|
+
else logging.DEBUG
|
|
592
|
+
),
|
|
593
|
+
)
|
|
567
594
|
stop_token_pattern = re.compile(
|
|
568
595
|
"|".join(re.escape(stop_token) for stop_token in stop_tokens)
|
|
569
596
|
)
|
|
@@ -584,10 +611,10 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
584
611
|
scores: list[list[list[tuple[str, float]]]] = [
|
|
585
612
|
[
|
|
586
613
|
[
|
|
587
|
-
(obj.decoded_token, obj.logprob)
|
|
614
|
+
(obj.decoded_token or "", obj.logprob)
|
|
588
615
|
for obj in token_logprobs_dict.values()
|
|
589
616
|
]
|
|
590
|
-
for token_logprobs_dict in raw_output.outputs[0].logprobs
|
|
617
|
+
for token_logprobs_dict in raw_output.outputs[0].logprobs or list()
|
|
591
618
|
]
|
|
592
619
|
for raw_output in raw_outputs
|
|
593
620
|
]
|
|
@@ -625,7 +652,13 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
625
652
|
revision = model_id_components.revision
|
|
626
653
|
|
|
627
654
|
model_info = get_model_repo_info(
|
|
628
|
-
model_id=model_id,
|
|
655
|
+
model_id=model_id,
|
|
656
|
+
revision=revision,
|
|
657
|
+
api_key=benchmark_config.api_key,
|
|
658
|
+
cache_dir=benchmark_config.cache_dir,
|
|
659
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
660
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
661
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
629
662
|
)
|
|
630
663
|
return (
|
|
631
664
|
model_info is not None
|
|
@@ -651,7 +684,11 @@ class VLLMModel(HuggingFaceEncoderModel):
|
|
|
651
684
|
model_info = get_model_repo_info(
|
|
652
685
|
model_id=model_id_components.model_id,
|
|
653
686
|
revision=model_id_components.revision,
|
|
654
|
-
|
|
687
|
+
api_key=benchmark_config.api_key,
|
|
688
|
+
cache_dir=benchmark_config.cache_dir,
|
|
689
|
+
trust_remote_code=benchmark_config.trust_remote_code,
|
|
690
|
+
requires_safetensors=benchmark_config.requires_safetensors,
|
|
691
|
+
run_with_cli=benchmark_config.run_with_cli,
|
|
655
692
|
)
|
|
656
693
|
if model_info is None:
|
|
657
694
|
raise InvalidModel(f"The model {model_id!r} could not be found.")
|
|
@@ -728,8 +765,8 @@ def load_model_and_tokeniser(
|
|
|
728
765
|
hf_model_config = load_hf_model_config(
|
|
729
766
|
model_id=model_id,
|
|
730
767
|
num_labels=0,
|
|
731
|
-
id2label=
|
|
732
|
-
label2id=
|
|
768
|
+
id2label=HashableDict(),
|
|
769
|
+
label2id=HashableDict(),
|
|
733
770
|
revision=revision,
|
|
734
771
|
model_cache_dir=model_config.model_cache_dir,
|
|
735
772
|
api_key=benchmark_config.api_key,
|
|
@@ -756,32 +793,36 @@ def load_model_and_tokeniser(
|
|
|
756
793
|
# Choose bf16 over fp16 if the model is a fp32 model and the GPU supports it
|
|
757
794
|
if hf_model_config.dtype == torch.float32:
|
|
758
795
|
if torch.cuda.is_bf16_supported():
|
|
759
|
-
|
|
796
|
+
log(
|
|
760
797
|
"You are loading a model with dtype FP32, which we will convert to "
|
|
761
798
|
"BF16 as FP32 is not supported by vLLM and BF16 is supported by your "
|
|
762
|
-
"GPU."
|
|
799
|
+
"GPU.",
|
|
800
|
+
level=logging.WARNING,
|
|
763
801
|
)
|
|
764
802
|
dtype = torch.bfloat16
|
|
765
803
|
else:
|
|
766
|
-
|
|
804
|
+
log(
|
|
767
805
|
"You are loading a model with dtype FP32, which we will convert to "
|
|
768
806
|
"FP16 as FP32 is not supported by vLLM and BF16 is not supported by "
|
|
769
|
-
"your GPU."
|
|
807
|
+
"your GPU.",
|
|
808
|
+
level=logging.WARNING,
|
|
770
809
|
)
|
|
771
810
|
dtype = torch.float16
|
|
772
811
|
|
|
773
812
|
# If the model is a quantized model, we might need to change the dtype
|
|
774
813
|
if quantization == "mxfp4" and hf_model_config.dtype is None:
|
|
775
814
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
776
|
-
|
|
815
|
+
log(
|
|
777
816
|
"You are loading a quantized model where `dtype` has not been set. "
|
|
778
|
-
f"Setting dtype to {dtype!r}."
|
|
817
|
+
f"Setting dtype to {dtype!r}.",
|
|
818
|
+
level=logging.DEBUG,
|
|
779
819
|
)
|
|
780
820
|
elif quantization is not None and hf_model_config.dtype != torch.float16:
|
|
781
|
-
|
|
821
|
+
log(
|
|
782
822
|
"You are loading a quantized model with dtype "
|
|
783
823
|
f"{hf_model_config.dtype}, which vLLM does not support. Setting "
|
|
784
|
-
"dtype to float16 instead."
|
|
824
|
+
"dtype to float16 instead.",
|
|
825
|
+
level=logging.WARNING,
|
|
785
826
|
)
|
|
786
827
|
dtype = torch.float16
|
|
787
828
|
|
|
@@ -792,12 +833,13 @@ def load_model_and_tokeniser(
|
|
|
792
833
|
|
|
793
834
|
if min_cuda_compute_capability is not None:
|
|
794
835
|
if min_cuda_compute_capability < required_capability:
|
|
795
|
-
|
|
836
|
+
log(
|
|
796
837
|
f"You are loading a model with dtype {hf_model_config.dtype}, "
|
|
797
838
|
"which vLLM only supports for CUDA devices with CUDA compute "
|
|
798
839
|
f"capability >={required_capability}. You are using one or more "
|
|
799
840
|
f"devices with compute capability {min_cuda_compute_capability}. "
|
|
800
|
-
"Setting dtype to float16 instead."
|
|
841
|
+
"Setting dtype to float16 instead.",
|
|
842
|
+
level=logging.WARNING,
|
|
801
843
|
)
|
|
802
844
|
dtype = torch.float16
|
|
803
845
|
|
|
@@ -830,9 +872,12 @@ def load_model_and_tokeniser(
|
|
|
830
872
|
adapter_base_model_id=model_config.adapter_base_model_id,
|
|
831
873
|
trust_remote_code=benchmark_config.trust_remote_code,
|
|
832
874
|
model_max_length=true_max_model_len,
|
|
833
|
-
|
|
875
|
+
model_config=model_config,
|
|
834
876
|
token=get_hf_token(api_key=benchmark_config.api_key),
|
|
835
877
|
)
|
|
878
|
+
vllm_tokenisation_params = get_vllm_tokenisation_params(
|
|
879
|
+
tokeniser=tokeniser, model_config=model_config
|
|
880
|
+
)
|
|
836
881
|
|
|
837
882
|
clear_vllm()
|
|
838
883
|
|
|
@@ -865,16 +910,7 @@ def load_model_and_tokeniser(
|
|
|
865
910
|
enable_prefix_caching=False,
|
|
866
911
|
enable_lora=model_config.adapter_base_model_id is not None,
|
|
867
912
|
max_lora_rank=256,
|
|
868
|
-
|
|
869
|
-
tokenizer_mode="mistral"
|
|
870
|
-
if isinstance(tokeniser, MistralCommonTokenizer)
|
|
871
|
-
else "auto",
|
|
872
|
-
config_format="mistral"
|
|
873
|
-
if isinstance(tokeniser, MistralCommonTokenizer)
|
|
874
|
-
else "auto",
|
|
875
|
-
load_format="mistral"
|
|
876
|
-
if isinstance(tokeniser, MistralCommonTokenizer)
|
|
877
|
-
else "auto",
|
|
913
|
+
**vllm_tokenisation_params,
|
|
878
914
|
)
|
|
879
915
|
except (RuntimeError, ValueError, OSError) as e:
|
|
880
916
|
if "awaiting a review from the repo authors" in str(e):
|
|
@@ -903,7 +939,7 @@ def load_tokeniser(
|
|
|
903
939
|
adapter_base_model_id: str | None,
|
|
904
940
|
trust_remote_code: bool,
|
|
905
941
|
model_max_length: int,
|
|
906
|
-
|
|
942
|
+
model_config: "ModelConfig",
|
|
907
943
|
token: str | bool,
|
|
908
944
|
) -> "PreTrainedTokenizer":
|
|
909
945
|
"""Load the tokeniser.
|
|
@@ -920,8 +956,8 @@ def load_tokeniser(
|
|
|
920
956
|
Whether to trust remote code.
|
|
921
957
|
model_max_length:
|
|
922
958
|
The maximum length of the model.
|
|
923
|
-
|
|
924
|
-
The
|
|
959
|
+
model_config:
|
|
960
|
+
The model configuration.
|
|
925
961
|
token:
|
|
926
962
|
The Hugging Face API token.
|
|
927
963
|
|
|
@@ -932,7 +968,7 @@ def load_tokeniser(
|
|
|
932
968
|
config = AutoConfig.from_pretrained(
|
|
933
969
|
adapter_base_model_id or model_id,
|
|
934
970
|
revision=revision,
|
|
935
|
-
cache_dir=model_cache_dir,
|
|
971
|
+
cache_dir=model_config.model_cache_dir,
|
|
936
972
|
token=token,
|
|
937
973
|
trust_remote_code=trust_remote_code,
|
|
938
974
|
local_files_only=not internet_connection_available(),
|
|
@@ -940,15 +976,25 @@ def load_tokeniser(
|
|
|
940
976
|
num_retries = 5
|
|
941
977
|
for _ in range(num_retries):
|
|
942
978
|
try:
|
|
979
|
+
# Mistral instruction-tuned models need a custom tokeniser
|
|
980
|
+
if model_id.startswith("mistralai/") and "base" not in model_id.lower():
|
|
981
|
+
tokeniser = MistralCommonTokenizer.from_pretrained(
|
|
982
|
+
model_id,
|
|
983
|
+
padding_side="left",
|
|
984
|
+
truncation_side="left",
|
|
985
|
+
model_max_length=model_max_length,
|
|
986
|
+
token=token,
|
|
987
|
+
)
|
|
988
|
+
break
|
|
943
989
|
tokeniser = AutoTokenizer.from_pretrained(
|
|
944
990
|
model_id,
|
|
945
|
-
use_fast=True,
|
|
991
|
+
use_fast=False if model_config.param == "slow-tokenizer" else True,
|
|
946
992
|
verbose=False,
|
|
947
993
|
trust_remote_code=trust_remote_code,
|
|
948
994
|
padding_side="left",
|
|
949
995
|
truncation_side="left",
|
|
950
996
|
model_max_length=model_max_length,
|
|
951
|
-
cache_dir=model_cache_dir,
|
|
997
|
+
cache_dir=model_config.model_cache_dir,
|
|
952
998
|
config=config,
|
|
953
999
|
token=token,
|
|
954
1000
|
local_files_only=not internet_connection_available(),
|
|
@@ -960,13 +1006,17 @@ def load_tokeniser(
|
|
|
960
1006
|
f"Could not load tokeniser for model {model_id!r}. The error was "
|
|
961
1007
|
f"{str(e)}."
|
|
962
1008
|
) from e
|
|
963
|
-
|
|
1009
|
+
log(
|
|
964
1010
|
f"Could not load tokeniser for {model_id!r}. Falling back to "
|
|
965
|
-
f"{adapter_base_model_id!r}."
|
|
1011
|
+
f"{adapter_base_model_id!r}.",
|
|
1012
|
+
level=logging.DEBUG,
|
|
966
1013
|
)
|
|
967
1014
|
model_id = adapter_base_model_id
|
|
968
1015
|
except (TimeoutError, RequestError):
|
|
969
|
-
|
|
1016
|
+
log(
|
|
1017
|
+
f"Couldn't load tokeniser for {model_id!r}. Retrying.",
|
|
1018
|
+
level=logging.WARNING,
|
|
1019
|
+
)
|
|
970
1020
|
sleep(5)
|
|
971
1021
|
continue
|
|
972
1022
|
except (KeyError, ValueError) as e:
|
|
@@ -1165,27 +1215,50 @@ def get_custom_stop_tokens(
|
|
|
1165
1215
|
if stop_token in prompt or stop_token in completion
|
|
1166
1216
|
]
|
|
1167
1217
|
if stop_tokens:
|
|
1168
|
-
|
|
1218
|
+
log(
|
|
1169
1219
|
f"Found the following custom stop tokens for model {model_id!r}: "
|
|
1170
|
-
f"{stop_tokens}."
|
|
1220
|
+
f"{stop_tokens}.",
|
|
1221
|
+
level=logging.DEBUG,
|
|
1171
1222
|
)
|
|
1172
1223
|
else:
|
|
1173
|
-
|
|
1224
|
+
log(f"Found no custom stop tokens for model {model_id!r}.", level=logging.DEBUG)
|
|
1174
1225
|
|
|
1175
1226
|
return stop_tokens
|
|
1176
1227
|
|
|
1177
1228
|
|
|
1178
|
-
def
|
|
1179
|
-
""
|
|
1229
|
+
def get_vllm_tokenisation_params(
|
|
1230
|
+
tokeniser: "PreTrainedTokenizer", model_config: "ModelConfig"
|
|
1231
|
+
) -> dict[str, t.Any]:
|
|
1232
|
+
"""Get the tokenisation parameters for vLLM.
|
|
1180
1233
|
|
|
1181
1234
|
Args:
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1235
|
+
tokeniser:
|
|
1236
|
+
The tokeniser.
|
|
1237
|
+
model_config:
|
|
1238
|
+
The model configuration.
|
|
1186
1239
|
|
|
1187
1240
|
Returns:
|
|
1188
|
-
A
|
|
1241
|
+
A dictionary of tokenisation parameters to pass to vLLM.
|
|
1189
1242
|
"""
|
|
1190
|
-
|
|
1191
|
-
|
|
1243
|
+
if isinstance(tokeniser, MistralCommonTokenizer):
|
|
1244
|
+
tokeniser_mode = "mistral"
|
|
1245
|
+
elif model_config.param == "slow-tokenizer":
|
|
1246
|
+
tokeniser_mode = "slow"
|
|
1247
|
+
else:
|
|
1248
|
+
tokeniser_mode = "auto"
|
|
1249
|
+
|
|
1250
|
+
if isinstance(tokeniser, MistralCommonTokenizer):
|
|
1251
|
+
config_format = "mistral"
|
|
1252
|
+
else:
|
|
1253
|
+
config_format = "auto"
|
|
1254
|
+
|
|
1255
|
+
if isinstance(tokeniser, MistralCommonTokenizer):
|
|
1256
|
+
load_format = "mistral"
|
|
1257
|
+
else:
|
|
1258
|
+
load_format = "auto"
|
|
1259
|
+
|
|
1260
|
+
return dict(
|
|
1261
|
+
tokenizer_mode=tokeniser_mode,
|
|
1262
|
+
config_format=config_format,
|
|
1263
|
+
load_format=load_format,
|
|
1264
|
+
)
|