evalscope 0.17.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/__init__.py +4 -1
- evalscope/api/__init__.py +0 -0
- evalscope/api/benchmark/__init__.py +3 -0
- evalscope/api/benchmark/adapters/__init__.py +3 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +683 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +155 -0
- evalscope/api/benchmark/benchmark.py +321 -0
- evalscope/api/benchmark/meta.py +115 -0
- evalscope/api/dataset/__init__.py +2 -0
- evalscope/api/dataset/dataset.py +349 -0
- evalscope/api/dataset/loader.py +261 -0
- evalscope/api/dataset/utils.py +143 -0
- evalscope/api/evaluator/__init__.py +3 -0
- evalscope/api/evaluator/cache.py +355 -0
- evalscope/api/evaluator/evaluator.py +56 -0
- evalscope/api/evaluator/state.py +264 -0
- evalscope/api/filter/__init__.py +1 -0
- evalscope/api/filter/filter.py +72 -0
- evalscope/api/messages/__init__.py +11 -0
- evalscope/api/messages/chat_message.py +198 -0
- evalscope/api/messages/content.py +102 -0
- evalscope/api/messages/utils.py +35 -0
- evalscope/api/metric/__init__.py +2 -0
- evalscope/api/metric/metric.py +55 -0
- evalscope/api/metric/scorer.py +105 -0
- evalscope/api/mixin/__init__.py +2 -0
- evalscope/api/mixin/dataset_mixin.py +105 -0
- evalscope/api/mixin/llm_judge_mixin.py +168 -0
- evalscope/api/model/__init__.py +12 -0
- evalscope/api/model/generate_config.py +157 -0
- evalscope/api/model/model.py +383 -0
- evalscope/api/model/model_output.py +285 -0
- evalscope/api/registry.py +182 -0
- evalscope/api/tool/__init__.py +3 -0
- evalscope/api/tool/tool_call.py +101 -0
- evalscope/api/tool/tool_info.py +173 -0
- evalscope/api/tool/utils.py +64 -0
- evalscope/app/ui/app_ui.py +2 -1
- evalscope/app/ui/multi_model.py +50 -25
- evalscope/app/ui/single_model.py +23 -11
- evalscope/app/utils/data_utils.py +42 -26
- evalscope/app/utils/text_utils.py +0 -2
- evalscope/app/utils/visualization.py +9 -4
- evalscope/arguments.py +6 -7
- evalscope/backend/opencompass/api_meta_template.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +6 -3
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
- evalscope/backend/rag_eval/ragas/task_template.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
- evalscope/backend/rag_eval/utils/embedding.py +2 -1
- evalscope/backend/rag_eval/utils/llm.py +13 -12
- evalscope/benchmarks/__init__.py +0 -2
- evalscope/benchmarks/aigc/i2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/i2i/general_i2i_adapter.py +44 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +53 -55
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +41 -46
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +29 -45
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +34 -44
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +16 -27
- evalscope/benchmarks/aime/aime24_adapter.py +38 -40
- evalscope/benchmarks/aime/aime25_adapter.py +34 -40
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
- evalscope/benchmarks/arc/arc_adapter.py +34 -147
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
- evalscope/benchmarks/arena_hard/utils.py +37 -1
- evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
- evalscope/benchmarks/bfcl/bfcl_adapter.py +181 -160
- evalscope/benchmarks/bfcl/generation.py +222 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +94 -162
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
- evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
- evalscope/benchmarks/data_collection/data_collection_adapter.py +183 -45
- evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
- evalscope/benchmarks/docmath/utils.py +4 -5
- evalscope/benchmarks/drop/drop_adapter.py +88 -40
- evalscope/benchmarks/frames/frames_adapter.py +135 -52
- evalscope/benchmarks/general_arena/general_arena_adapter.py +136 -98
- evalscope/benchmarks/general_arena/utils.py +23 -27
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
- evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
- evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
- evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
- evalscope/benchmarks/hle/hle_adapter.py +127 -93
- evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
- evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
- evalscope/benchmarks/ifeval/instructions.py +109 -64
- evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
- evalscope/benchmarks/ifeval/utils.py +6 -7
- evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
- evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
- evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
- evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
- evalscope/benchmarks/musr/musr_adapter.py +33 -64
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +192 -152
- evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
- evalscope/benchmarks/race/race_adapter.py +33 -119
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
- evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
- evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
- evalscope/benchmarks/super_gpqa/utils.py +2 -1
- evalscope/benchmarks/tau_bench/generation.py +147 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +112 -54
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -265
- evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
- evalscope/cli/cli.py +2 -0
- evalscope/cli/start_server.py +6 -3
- evalscope/collections/__init__.py +2 -10
- evalscope/collections/sampler.py +10 -10
- evalscope/collections/schema.py +13 -11
- evalscope/config.py +95 -54
- evalscope/constants.py +29 -61
- evalscope/evaluator/__init__.py +1 -1
- evalscope/evaluator/evaluator.py +277 -423
- evalscope/filters/__init__.py +2 -0
- evalscope/filters/extraction.py +126 -0
- evalscope/filters/selection.py +57 -0
- evalscope/metrics/__init__.py +13 -13
- evalscope/metrics/llm_judge.py +32 -30
- evalscope/metrics/math_parser.py +27 -22
- evalscope/metrics/metric.py +307 -0
- evalscope/metrics/metrics.py +22 -18
- evalscope/metrics/t2v_metrics/__init__.py +0 -52
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
- evalscope/models/__init__.py +6 -29
- evalscope/models/mockllm.py +65 -0
- evalscope/models/model_apis.py +47 -0
- evalscope/models/modelscope.py +455 -0
- evalscope/models/openai_compatible.py +123 -0
- evalscope/models/text2image_model.py +124 -0
- evalscope/models/utils/openai.py +698 -0
- evalscope/perf/benchmark.py +2 -1
- evalscope/perf/http_client.py +4 -2
- evalscope/perf/plugin/api/custom_api.py +5 -4
- evalscope/perf/plugin/api/openai_api.py +11 -9
- evalscope/perf/plugin/datasets/custom.py +2 -1
- evalscope/perf/plugin/datasets/flickr8k.py +1 -1
- evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
- evalscope/perf/plugin/datasets/line_by_line.py +2 -1
- evalscope/perf/plugin/datasets/longalpaca.py +2 -1
- evalscope/perf/plugin/datasets/openqa.py +4 -2
- evalscope/perf/utils/benchmark_util.py +7 -5
- evalscope/perf/utils/db_util.py +9 -6
- evalscope/perf/utils/local_server.py +8 -3
- evalscope/perf/utils/rich_display.py +16 -10
- evalscope/report/__init__.py +2 -2
- evalscope/report/combinator.py +18 -12
- evalscope/report/generator.py +101 -6
- evalscope/report/{utils.py → report.py} +8 -6
- evalscope/run.py +26 -44
- evalscope/summarizer.py +1 -1
- evalscope/utils/__init__.py +21 -2
- evalscope/utils/chat_service.py +2 -1
- evalscope/utils/deprecation_utils.py +12 -1
- evalscope/utils/function_utils.py +29 -0
- evalscope/utils/io_utils.py +100 -5
- evalscope/utils/json_schema.py +208 -0
- evalscope/utils/logger.py +51 -12
- evalscope/utils/model_utils.py +10 -7
- evalscope/utils/multi_choices.py +271 -0
- evalscope/utils/url_utils.py +65 -0
- evalscope/version.py +2 -2
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/METADATA +98 -49
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/RECORD +234 -216
- tests/aigc/test_t2i.py +22 -4
- tests/benchmark/__init__.py +1 -0
- tests/benchmark/test_eval.py +386 -0
- tests/cli/test_all.py +3 -5
- tests/cli/test_collection.py +13 -4
- tests/cli/test_custom.py +22 -15
- tests/rag/test_clip_benchmark.py +1 -0
- evalscope/benchmarks/aigc/t2i/base.py +0 -56
- evalscope/benchmarks/arc/ai2_arc.py +0 -151
- evalscope/benchmarks/benchmark.py +0 -81
- evalscope/benchmarks/ceval/ceval_exam.py +0 -146
- evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
- evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
- evalscope/benchmarks/competition_math/competition_math.py +0 -79
- evalscope/benchmarks/data_adapter.py +0 -528
- evalscope/benchmarks/filters.py +0 -59
- evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
- evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
- evalscope/benchmarks/humaneval/humaneval.py +0 -79
- evalscope/benchmarks/mmlu/mmlu.py +0 -160
- evalscope/benchmarks/mmlu/samples.jsonl +0 -5
- evalscope/benchmarks/process_bench/critique_template.txt +0 -13
- evalscope/benchmarks/race/race.py +0 -104
- evalscope/benchmarks/race/samples.jsonl +0 -5
- evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
- evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
- evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
- evalscope/benchmarks/utils.py +0 -60
- evalscope/collections/evaluator.py +0 -375
- evalscope/metrics/completion_parsers.py +0 -227
- evalscope/metrics/named_metrics.py +0 -55
- evalscope/models/adapters/__init__.py +0 -14
- evalscope/models/adapters/base_adapter.py +0 -84
- evalscope/models/adapters/bfcl_adapter.py +0 -246
- evalscope/models/adapters/chat_adapter.py +0 -207
- evalscope/models/adapters/choice_adapter.py +0 -222
- evalscope/models/adapters/custom_adapter.py +0 -71
- evalscope/models/adapters/server_adapter.py +0 -236
- evalscope/models/adapters/t2i_adapter.py +0 -79
- evalscope/models/adapters/tau_bench_adapter.py +0 -189
- evalscope/models/custom/__init__.py +0 -4
- evalscope/models/custom/custom_model.py +0 -50
- evalscope/models/custom/dummy_model.py +0 -99
- evalscope/models/local_model.py +0 -128
- evalscope/models/register.py +0 -41
- tests/cli/test_run.py +0 -489
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/LICENSE +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/WHEEL +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -61,17 +61,18 @@ def t5_tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_IN
|
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def load_pretrained_model(
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
64
|
+
model_cls,
|
|
65
|
+
model_args,
|
|
66
|
+
model_path=None,
|
|
67
|
+
tokenizer_path=None,
|
|
68
|
+
model_max_length=None,
|
|
69
|
+
padding_side=None,
|
|
70
|
+
image_aspect_ratio='pad', # or 'square'
|
|
71
|
+
mmprojector_repo=None,
|
|
72
|
+
mmprojector_name=None,
|
|
73
|
+
device='cuda',
|
|
74
|
+
cache_dir=CACHE_DIR
|
|
75
|
+
):
|
|
75
76
|
tokenizer_dict = {}
|
|
76
77
|
if model_max_length:
|
|
77
78
|
tokenizer_dict['model_max_length'] = model_max_length
|
|
@@ -80,7 +81,7 @@ def load_pretrained_model(
|
|
|
80
81
|
|
|
81
82
|
from ..utils import download_file
|
|
82
83
|
|
|
83
|
-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
|
|
84
|
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_dict)
|
|
84
85
|
# tokenizer.pad_token = tokenizer.unk_token # could be redundant
|
|
85
86
|
|
|
86
87
|
model_path = download_file(model_path, cache_dir=cache_dir)
|
|
@@ -106,7 +107,8 @@ def load_pretrained_model(
|
|
|
106
107
|
model_args.pretrain_mm_mlp_adapter = pretrain_mm_mlp_adapter # important to set to correct path
|
|
107
108
|
|
|
108
109
|
model.get_model().initialize_vision_modules(
|
|
109
|
-
model_args
|
|
110
|
+
model_args
|
|
111
|
+
) # This will load the CLIP vision encoder and MLP projector
|
|
110
112
|
else:
|
|
111
113
|
model.resize_token_embeddings(len(tokenizer)) # perhaps not needed
|
|
112
114
|
|
|
@@ -8,8 +8,9 @@ from ..model import ScoreModel
|
|
|
8
8
|
class VQAScoreModel(ScoreModel):
|
|
9
9
|
|
|
10
10
|
@abstractmethod
|
|
11
|
-
def forward(
|
|
12
|
-
|
|
11
|
+
def forward(
|
|
12
|
+
self, images: List[str], texts: List[str], question_template: str, answer_template: str
|
|
13
|
+
) -> torch.Tensor:
|
|
13
14
|
"""Forward pass of the model to return n scores for n (image, text) pairs (in PyTorch Tensor)
|
|
14
15
|
question_template: a string with optional {} to be replaced with the 'text'
|
|
15
16
|
answer_template: a string with optional {} to be replaced with the 'text'
|
evalscope/models/__init__.py
CHANGED
|
@@ -4,38 +4,15 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from evalscope.utils.import_utils import _LazyModule
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
|
-
from .
|
|
8
|
-
CustomModelAdapter, MultiChoiceModelAdapter, ServerModelAdapter, T2IModelAdapter,
|
|
9
|
-
TauBenchAdapter, initialize_model_adapter)
|
|
10
|
-
from .custom import CustomModel, DummyCustomModel
|
|
11
|
-
from .local_model import LocalModel, get_local_model
|
|
12
|
-
from .register import get_model_adapter
|
|
7
|
+
from .model_apis import llm_ckpt, mockllm, openai_api
|
|
13
8
|
|
|
14
9
|
else:
|
|
15
10
|
_import_structure = {
|
|
16
|
-
'
|
|
17
|
-
'
|
|
18
|
-
'
|
|
19
|
-
'
|
|
20
|
-
|
|
21
|
-
'MultiChoiceModelAdapter',
|
|
22
|
-
'CustomModelAdapter',
|
|
23
|
-
'ServerModelAdapter',
|
|
24
|
-
'T2IModelAdapter',
|
|
25
|
-
'TauBenchAdapter',
|
|
26
|
-
'BFCLAdapter',
|
|
27
|
-
],
|
|
28
|
-
'custom': [
|
|
29
|
-
'CustomModel',
|
|
30
|
-
'DummyCustomModel',
|
|
31
|
-
],
|
|
32
|
-
'local_model': [
|
|
33
|
-
'LocalModel',
|
|
34
|
-
'get_local_model',
|
|
35
|
-
],
|
|
36
|
-
'register': [
|
|
37
|
-
'get_model_adapter',
|
|
38
|
-
],
|
|
11
|
+
'model_apis': [
|
|
12
|
+
'openai_api',
|
|
13
|
+
'mockllm',
|
|
14
|
+
'llm_ckpt',
|
|
15
|
+
]
|
|
39
16
|
}
|
|
40
17
|
|
|
41
18
|
import sys
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Any, Dict, Generator, Iterable, Iterator, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from evalscope.api.dataset import Dataset
|
|
4
|
+
from evalscope.api.messages import ChatMessage
|
|
5
|
+
from evalscope.api.model import GenerateConfig, ModelAPI, ModelOutput
|
|
6
|
+
from evalscope.api.tool import ToolChoice, ToolInfo
|
|
7
|
+
from evalscope.utils.function_utils import thread_safe
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MockLLM(ModelAPI):
|
|
11
|
+
"""A mock implementation of the ModelAPI class for testing purposes.
|
|
12
|
+
|
|
13
|
+
Always returns default_output, unless you pass in a model_args
|
|
14
|
+
key "custom_outputs" with a value of an Iterable[ModelOutput]
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
default_output = 'Default output from mockllm/model'
|
|
18
|
+
|
|
19
|
+
outputs: Iterator[ModelOutput]
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model_name: str,
|
|
24
|
+
base_url: Optional[str] = None,
|
|
25
|
+
api_key: Optional[str] = None,
|
|
26
|
+
config: GenerateConfig = GenerateConfig(),
|
|
27
|
+
custom_outputs: Iterable[ModelOutput] = None,
|
|
28
|
+
**model_args: Dict[str, Any],
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__(model_name, base_url, api_key, config)
|
|
31
|
+
self.model_args = model_args
|
|
32
|
+
if custom_outputs is not None:
|
|
33
|
+
# We cannot rely on the user of this model giving custom_outputs
|
|
34
|
+
# the correct type since they do not call this constructor
|
|
35
|
+
# Hence this type check and the one in generate.
|
|
36
|
+
if not isinstance(custom_outputs, (Iterable, Generator)):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"model_args['custom_outputs'] must be an Iterable or a Generator, got {custom_outputs}"
|
|
39
|
+
)
|
|
40
|
+
self.outputs = iter(custom_outputs)
|
|
41
|
+
else:
|
|
42
|
+
self.outputs = iter((
|
|
43
|
+
ModelOutput.from_content(model='mockllm', content=self.default_output)
|
|
44
|
+
for _ in iter(int, 1) # produce an infinite iterator
|
|
45
|
+
))
|
|
46
|
+
|
|
47
|
+
@thread_safe
|
|
48
|
+
def generate(
|
|
49
|
+
self,
|
|
50
|
+
input: List[ChatMessage],
|
|
51
|
+
tools: List[ToolInfo],
|
|
52
|
+
tool_choice: ToolChoice,
|
|
53
|
+
config: GenerateConfig,
|
|
54
|
+
) -> ModelOutput:
|
|
55
|
+
try:
|
|
56
|
+
output = next(self.outputs)
|
|
57
|
+
except StopIteration:
|
|
58
|
+
raise ValueError('custom_outputs ran out of values')
|
|
59
|
+
|
|
60
|
+
if not isinstance(output, ModelOutput):
|
|
61
|
+
raise ValueError(f'output must be an instance of ModelOutput; got {type(output)}; content: {repr(output)}')
|
|
62
|
+
return output
|
|
63
|
+
|
|
64
|
+
def batch_generate(inputs: Dataset, config: GenerateConfig) -> List[ModelOutput]:
|
|
65
|
+
return super().batch_generate(inputs, config)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from evalscope.api.model import ModelAPI
|
|
2
|
+
from evalscope.api.registry import register_model_api
|
|
3
|
+
from evalscope.utils.deprecation_utils import deprecated
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@register_model_api(name='mock_llm')
|
|
7
|
+
def mockllm() -> type[ModelAPI]:
|
|
8
|
+
from .mockllm import MockLLM
|
|
9
|
+
|
|
10
|
+
return MockLLM
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@register_model_api(name='openai_api')
|
|
14
|
+
def openai_api() -> type[ModelAPI]:
|
|
15
|
+
from .openai_compatible import OpenAICompatibleAPI
|
|
16
|
+
|
|
17
|
+
return OpenAICompatibleAPI
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@register_model_api(name='server')
|
|
21
|
+
@deprecated(since='1.0.0', remove_in='1.1.0', alternative='openai_api')
|
|
22
|
+
def server() -> type[ModelAPI]:
|
|
23
|
+
from .openai_compatible import OpenAICompatibleAPI
|
|
24
|
+
|
|
25
|
+
return OpenAICompatibleAPI
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@register_model_api(name='llm_ckpt')
|
|
29
|
+
def llm_ckpt() -> type[ModelAPI]:
|
|
30
|
+
from .modelscope import ModelScopeAPI
|
|
31
|
+
|
|
32
|
+
return ModelScopeAPI
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@register_model_api(name='checkpoint')
|
|
36
|
+
@deprecated(since='1.0.0', remove_in='1.1.0', alternative='llm_ckpt')
|
|
37
|
+
def checkpoint() -> type[ModelAPI]:
|
|
38
|
+
from .modelscope import ModelScopeAPI
|
|
39
|
+
|
|
40
|
+
return ModelScopeAPI
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_model_api(name='text2image')
|
|
44
|
+
def text2image() -> type[ModelAPI]:
|
|
45
|
+
from .text2image_model import Text2ImageAPI
|
|
46
|
+
|
|
47
|
+
return Text2ImageAPI
|
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import functools
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
import torch # type: ignore
|
|
8
|
+
from concurrent.futures import Future
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from logging import getLogger
|
|
11
|
+
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
|
12
|
+
from queue import Empty, Queue
|
|
13
|
+
from threading import Thread
|
|
14
|
+
from torch import Tensor # type: ignore
|
|
15
|
+
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union, cast
|
|
16
|
+
from typing_extensions import override
|
|
17
|
+
|
|
18
|
+
from evalscope.api.messages import (
|
|
19
|
+
ChatMessage,
|
|
20
|
+
ChatMessageAssistant,
|
|
21
|
+
ContentAudio,
|
|
22
|
+
ContentImage,
|
|
23
|
+
ContentText,
|
|
24
|
+
ContentVideo,
|
|
25
|
+
)
|
|
26
|
+
from evalscope.api.model import (
|
|
27
|
+
ChatCompletionChoice,
|
|
28
|
+
GenerateConfig,
|
|
29
|
+
Logprob,
|
|
30
|
+
Logprobs,
|
|
31
|
+
ModelAPI,
|
|
32
|
+
ModelOutput,
|
|
33
|
+
ModelUsage,
|
|
34
|
+
TopLogprob,
|
|
35
|
+
)
|
|
36
|
+
from evalscope.api.tool import ToolChoice, ToolInfo
|
|
37
|
+
from evalscope.utils.model_utils import get_device
|
|
38
|
+
|
|
39
|
+
logger = getLogger()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ModelScopeAPI(ModelAPI):
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
model_name: str,
|
|
47
|
+
base_url: Optional[str] = None,
|
|
48
|
+
api_key: Optional[str] = None,
|
|
49
|
+
config: GenerateConfig = GenerateConfig(),
|
|
50
|
+
**model_args: Any,
|
|
51
|
+
):
|
|
52
|
+
super().__init__(
|
|
53
|
+
model_name=model_name,
|
|
54
|
+
base_url=base_url,
|
|
55
|
+
api_key=api_key,
|
|
56
|
+
config=config,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# collect known model_args (then delete them so we can pass the rest on)
|
|
60
|
+
def collect_model_arg(name: str) -> Optional[Any]:
|
|
61
|
+
nonlocal model_args
|
|
62
|
+
value = model_args.get(name, None)
|
|
63
|
+
if value is not None:
|
|
64
|
+
model_args.pop(name)
|
|
65
|
+
return value
|
|
66
|
+
|
|
67
|
+
model_path = collect_model_arg('model_path')
|
|
68
|
+
device_map = collect_model_arg('device_map')
|
|
69
|
+
torch_dtype = collect_model_arg('precision')
|
|
70
|
+
tokenizer_path = collect_model_arg('tokenizer_path')
|
|
71
|
+
self.chat_template = collect_model_arg('chat_template')
|
|
72
|
+
self.tokenizer_call_args = collect_model_arg('tokenizer_call_args')
|
|
73
|
+
self.enable_thinking = collect_model_arg('enable_thinking')
|
|
74
|
+
if self.tokenizer_call_args is None:
|
|
75
|
+
self.tokenizer_call_args = {}
|
|
76
|
+
|
|
77
|
+
# device
|
|
78
|
+
self.device = device_map or get_device()
|
|
79
|
+
|
|
80
|
+
# torch dtype
|
|
81
|
+
DTYPE_MAP = {'float16': torch.float16, 'float32': torch.float32, 'bfloat16': torch.bfloat16, 'auto': 'auto'}
|
|
82
|
+
|
|
83
|
+
if isinstance(torch_dtype, str) and torch_dtype != 'auto':
|
|
84
|
+
torch_dtype = DTYPE_MAP.get(torch_dtype, torch.float32)
|
|
85
|
+
self.torch_dtype = torch_dtype
|
|
86
|
+
|
|
87
|
+
# model
|
|
88
|
+
model_name_or_path = model_path or model_name
|
|
89
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
90
|
+
model_name_or_path,
|
|
91
|
+
device_map=self.device,
|
|
92
|
+
token=self.api_key,
|
|
93
|
+
torch_dtype=self.torch_dtype,
|
|
94
|
+
trust_remote_code=True,
|
|
95
|
+
**model_args
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# tokenizer
|
|
99
|
+
tokenizer_name_or_path = tokenizer_path or model_name_or_path
|
|
100
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
|
|
101
|
+
# LLMs generally don't have a pad token and we need one for batching
|
|
102
|
+
if self.tokenizer.pad_token is None:
|
|
103
|
+
if self.tokenizer.eos_token is not None:
|
|
104
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
105
|
+
else:
|
|
106
|
+
# add a pad token
|
|
107
|
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
108
|
+
# set padding side to left for LLMs
|
|
109
|
+
self.tokenizer.padding_side = 'left'
|
|
110
|
+
# set chat template if provided
|
|
111
|
+
if self.chat_template:
|
|
112
|
+
self.tokenizer.chat_template = self.chat_template
|
|
113
|
+
logger.info(f'Using custom chat template: {self.chat_template}')
|
|
114
|
+
|
|
115
|
+
def generate(
|
|
116
|
+
self,
|
|
117
|
+
input: List[ChatMessage],
|
|
118
|
+
tools: List[ToolInfo],
|
|
119
|
+
tool_choice: ToolChoice,
|
|
120
|
+
config: GenerateConfig,
|
|
121
|
+
) -> ModelOutput:
|
|
122
|
+
|
|
123
|
+
# create chat
|
|
124
|
+
chat = self.ms_chat(input, tools)
|
|
125
|
+
|
|
126
|
+
assert isinstance(self.tokenizer_call_args, dict)
|
|
127
|
+
# prepare tokenizer
|
|
128
|
+
tokenizer = functools.partial(
|
|
129
|
+
self.tokenizer,
|
|
130
|
+
return_tensors='pt',
|
|
131
|
+
padding=True,
|
|
132
|
+
**self.tokenizer_call_args,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# prepare generator
|
|
136
|
+
kwargs: Dict[str, Any] = {}
|
|
137
|
+
if config.do_sample is not None:
|
|
138
|
+
kwargs['do_sample'] = config.do_sample
|
|
139
|
+
if config.n is not None:
|
|
140
|
+
if config.n > 1:
|
|
141
|
+
assert config.do_sample, 'n > 1 requires do_sample=True in GenerateConfig'
|
|
142
|
+
kwargs['num_return_sequences'] = config.n
|
|
143
|
+
if config.max_tokens is not None:
|
|
144
|
+
kwargs['max_new_tokens'] = config.max_tokens
|
|
145
|
+
if config.temperature is not None:
|
|
146
|
+
kwargs['temperature'] = config.temperature
|
|
147
|
+
if config.top_p is not None:
|
|
148
|
+
kwargs['top_p'] = config.top_p
|
|
149
|
+
if config.top_k is not None:
|
|
150
|
+
kwargs['top_k'] = config.top_k
|
|
151
|
+
if config.logprobs is not None:
|
|
152
|
+
kwargs['output_logits'] = config.logprobs
|
|
153
|
+
if 'return_dict_in_generate' in kwargs:
|
|
154
|
+
assert kwargs['return_dict_in_generate']
|
|
155
|
+
if config.stop_seqs is not None:
|
|
156
|
+
from transformers.generation import StopStringCriteria # type: ignore
|
|
157
|
+
|
|
158
|
+
stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
|
|
159
|
+
kwargs['stopping_criteria'] = stopping_criteria
|
|
160
|
+
|
|
161
|
+
kwargs['return_dict_in_generate'] = True
|
|
162
|
+
generator = functools.partial(self.model.generate, **kwargs)
|
|
163
|
+
|
|
164
|
+
# prepare decoder
|
|
165
|
+
decoder = functools.partial(
|
|
166
|
+
self.tokenizer.batch_decode,
|
|
167
|
+
skip_special_tokens=True,
|
|
168
|
+
clean_up_tokenization_spaces=False,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# generate
|
|
172
|
+
responses = batched_generate(
|
|
173
|
+
GenerateInput(
|
|
174
|
+
input=chat,
|
|
175
|
+
device=self.model.device,
|
|
176
|
+
tokenizer=tokenizer,
|
|
177
|
+
generator=generator,
|
|
178
|
+
decoder=decoder,
|
|
179
|
+
batch_size=config.batch_size or self.max_connections(),
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
choices: List[ChatCompletionChoice] = []
|
|
184
|
+
for response in responses:
|
|
185
|
+
# gather logprobs
|
|
186
|
+
final_logprobs = None
|
|
187
|
+
if config.logprobs is not None:
|
|
188
|
+
final_logprobs = extract_logprobs(
|
|
189
|
+
response=response,
|
|
190
|
+
top=config.top_logprobs,
|
|
191
|
+
tokenizer=self.tokenizer,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# construct choice
|
|
195
|
+
# TODO: Handle tool calls
|
|
196
|
+
choice = ChatCompletionChoice(
|
|
197
|
+
message=ChatMessageAssistant(content=response.output, model=self.model_name, source='generate'),
|
|
198
|
+
logprobs=(Logprobs(content=final_logprobs) if final_logprobs is not None else None),
|
|
199
|
+
)
|
|
200
|
+
choices.append(choice)
|
|
201
|
+
|
|
202
|
+
# return output
|
|
203
|
+
return ModelOutput(
|
|
204
|
+
model=self.model_name,
|
|
205
|
+
choices=choices,
|
|
206
|
+
usage=ModelUsage(
|
|
207
|
+
input_tokens=response.input_tokens,
|
|
208
|
+
output_tokens=response.output_tokens,
|
|
209
|
+
total_tokens=response.total_tokens,
|
|
210
|
+
),
|
|
211
|
+
time=response.time,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
@override
|
|
215
|
+
def max_tokens(self) -> Optional[int]:
|
|
216
|
+
"""Default is 2048, bump it up to a value suitable for evals."""
|
|
217
|
+
return 2048
|
|
218
|
+
|
|
219
|
+
@override
|
|
220
|
+
def max_connections(self) -> int:
|
|
221
|
+
"""Effectively the batch size."""
|
|
222
|
+
return 8
|
|
223
|
+
|
|
224
|
+
def ms_chat(self, messages: List[ChatMessage], tools: List[ToolInfo]) -> str:
|
|
225
|
+
# convert to ms format
|
|
226
|
+
tools_list = []
|
|
227
|
+
ms_messages = copy.deepcopy(messages)
|
|
228
|
+
if len(tools) > 0:
|
|
229
|
+
tools_list = [json.loads(tool.model_dump_json(exclude_none=True, indent=2)) for tool in tools]
|
|
230
|
+
|
|
231
|
+
ms_messages = message_content_to_string(ms_messages)
|
|
232
|
+
# apply chat template
|
|
233
|
+
if self.tokenizer.chat_template is not None:
|
|
234
|
+
chat = self.tokenizer.apply_chat_template(
|
|
235
|
+
ms_messages,
|
|
236
|
+
add_generation_prompt=True,
|
|
237
|
+
tokenize=False,
|
|
238
|
+
tools=tools_list if len(tools_list) > 0 else None,
|
|
239
|
+
enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
chat = ''
|
|
243
|
+
for message in ms_messages:
|
|
244
|
+
chat += f'{message.role}: {message.content}\n'
|
|
245
|
+
# return
|
|
246
|
+
return cast(str, chat)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def message_content_to_string(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|
250
|
+
"""Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
|
|
251
|
+
for message in messages:
|
|
252
|
+
if isinstance(message.content, list):
|
|
253
|
+
is_multimodal = any(
|
|
254
|
+
isinstance(item, (ContentAudio, ContentImage, ContentVideo)) for item in message.content
|
|
255
|
+
)
|
|
256
|
+
if is_multimodal:
|
|
257
|
+
raise NotImplementedError(
|
|
258
|
+
'Transformer model does not support multimodal content, please provide text inputs only.'
|
|
259
|
+
)
|
|
260
|
+
message.content = message.text
|
|
261
|
+
return messages
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# return value from generate as a result of specifying return_dict_in_generate
|
|
265
|
+
class ModelGenerateOutput:
|
|
266
|
+
sequences: Tensor
|
|
267
|
+
logits: tuple[Tensor]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class Tokenizer(Protocol):
|
|
271
|
+
|
|
272
|
+
def __call__(self, input: List[str]) -> Dict[Literal['input_ids', 'attention_mask'], Tensor]:
|
|
273
|
+
...
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class Generator(Protocol):
|
|
277
|
+
|
|
278
|
+
def __call__(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
|
279
|
+
...
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class Decoder(Protocol):
|
|
283
|
+
|
|
284
|
+
def __call__(self, sequences: Tensor) -> list[str]:
|
|
285
|
+
...
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@dataclass
|
|
289
|
+
class GenerateInput:
|
|
290
|
+
input: str
|
|
291
|
+
device: str
|
|
292
|
+
tokenizer: Tokenizer
|
|
293
|
+
generator: Generator
|
|
294
|
+
decoder: Decoder
|
|
295
|
+
batch_size: int
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@dataclass
|
|
299
|
+
class GenerateOutput:
|
|
300
|
+
output: str
|
|
301
|
+
input_tokens: int
|
|
302
|
+
output_tokens: int
|
|
303
|
+
total_tokens: int
|
|
304
|
+
logprobs: Optional[torch.Tensor]
|
|
305
|
+
time: float
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@dataclass
|
|
309
|
+
class _QueueItem:
|
|
310
|
+
input: GenerateInput
|
|
311
|
+
future: Future[GenerateOutput]
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
batch_thread: Optional[Thread] = None
|
|
315
|
+
|
|
316
|
+
batch_queue: 'Queue[_QueueItem]' = Queue()
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def batched_generate(input: GenerateInput) -> List[GenerateOutput]:
|
|
320
|
+
# start the background thread if necessary
|
|
321
|
+
global batch_thread
|
|
322
|
+
if batch_thread is None:
|
|
323
|
+
batch_thread = Thread(target=process_batches, daemon=True)
|
|
324
|
+
batch_thread.start()
|
|
325
|
+
|
|
326
|
+
# enqueue the job
|
|
327
|
+
future = Future[GenerateOutput]()
|
|
328
|
+
batch_queue.put(_QueueItem(input=input, future=future))
|
|
329
|
+
|
|
330
|
+
return future.result()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def process_batches() -> None:
|
|
334
|
+
while True:
|
|
335
|
+
# drain the queue (wait until no new messages have shown up for 2 seconds)
|
|
336
|
+
inputs: List[Tuple[GenerateInput, Future[GenerateOutput]]] = []
|
|
337
|
+
while True:
|
|
338
|
+
try:
|
|
339
|
+
input = batch_queue.get(timeout=2)
|
|
340
|
+
inputs.append((input.input, input.future))
|
|
341
|
+
if len(inputs) == input.input.batch_size:
|
|
342
|
+
# max batch size reached
|
|
343
|
+
break
|
|
344
|
+
except Empty:
|
|
345
|
+
# we have exhausted the queue
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
# see if we have any work to do
|
|
349
|
+
if len(inputs) == 0:
|
|
350
|
+
continue
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
# capture the generator and decoder functions
|
|
354
|
+
start_time = time.monotonic()
|
|
355
|
+
first_input = inputs[0][0]
|
|
356
|
+
device = first_input.device
|
|
357
|
+
tokenizer = first_input.tokenizer
|
|
358
|
+
generator = first_input.generator
|
|
359
|
+
decoder = first_input.decoder
|
|
360
|
+
num_return_sequences = generator.keywords.get('num_return_sequences', 1)
|
|
361
|
+
|
|
362
|
+
# tokenize and move to device
|
|
363
|
+
tokenized_inputs = tokenizer([item[0].input for item in inputs])
|
|
364
|
+
input_ids = tokenized_inputs['input_ids']
|
|
365
|
+
attention_mask = tokenized_inputs['attention_mask']
|
|
366
|
+
input_ids = input_ids.to(device)
|
|
367
|
+
attention_mask = attention_mask.to(device)
|
|
368
|
+
|
|
369
|
+
# generate
|
|
370
|
+
with torch.inference_mode():
|
|
371
|
+
generation_outputs = cast(
|
|
372
|
+
ModelGenerateOutput,
|
|
373
|
+
generator(input_ids=input_ids, attention_mask=attention_mask),
|
|
374
|
+
)
|
|
375
|
+
generate_ids = generation_outputs.sequences
|
|
376
|
+
logits = generation_outputs.logits
|
|
377
|
+
|
|
378
|
+
# get logprobs from logits
|
|
379
|
+
logprobs = None
|
|
380
|
+
if logits is not None:
|
|
381
|
+
stacked_logits = torch.stack(logits).transpose(0, 1)
|
|
382
|
+
logprobs = torch.nn.functional.log_softmax(stacked_logits, dim=-1)
|
|
383
|
+
|
|
384
|
+
# decode
|
|
385
|
+
generated_tokens = generate_ids[:, input_ids.size(dim=1):]
|
|
386
|
+
if logprobs is not None:
|
|
387
|
+
assert logprobs.shape[1] == generated_tokens.shape[1]
|
|
388
|
+
outputs = decoder(sequences=generated_tokens)
|
|
389
|
+
|
|
390
|
+
# call back futures
|
|
391
|
+
total_time = time.monotonic() - start_time
|
|
392
|
+
for input_index in range(len(inputs)):
|
|
393
|
+
choices: List[GenerateOutput] = []
|
|
394
|
+
# handle input
|
|
395
|
+
future = inputs[input_index][1]
|
|
396
|
+
input_tokens = input_ids[input_index].shape[-1]
|
|
397
|
+
# handle choices
|
|
398
|
+
for choice_index in range(num_return_sequences):
|
|
399
|
+
output_index = input_index * num_return_sequences + choice_index
|
|
400
|
+
# handle out of
|
|
401
|
+
output = outputs[output_index]
|
|
402
|
+
output_tokens = generate_ids[output_index].shape[-1] - input_tokens
|
|
403
|
+
logprobs_tensor = logprobs[output_index] if logprobs is not None else None
|
|
404
|
+
# create the output
|
|
405
|
+
choices.append(
|
|
406
|
+
GenerateOutput(
|
|
407
|
+
output=output,
|
|
408
|
+
input_tokens=input_tokens,
|
|
409
|
+
output_tokens=output_tokens,
|
|
410
|
+
total_tokens=input_tokens + output_tokens,
|
|
411
|
+
logprobs=logprobs_tensor,
|
|
412
|
+
time=total_time,
|
|
413
|
+
)
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# asyncio futures are not thread safe, so we need to pass the event loop
|
|
417
|
+
# down to this point, so we can mark the future as done in a thread safe manner.
|
|
418
|
+
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
|
419
|
+
future.set_result(choices)
|
|
420
|
+
|
|
421
|
+
except Exception as ex:
|
|
422
|
+
for inp in inputs:
|
|
423
|
+
future = inp[1]
|
|
424
|
+
future.set_exception(ex)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def extract_logprobs(
|
|
428
|
+
response: GenerateOutput,
|
|
429
|
+
top: Optional[int],
|
|
430
|
+
tokenizer,
|
|
431
|
+
) -> List[Logprob]:
|
|
432
|
+
assert response.logprobs is not None
|
|
433
|
+
k = top or 1
|
|
434
|
+
topk_values, topk_inds = response.logprobs.topk(k=k, dim=-1)
|
|
435
|
+
final_logprobs = []
|
|
436
|
+
for toks, vals in zip(topk_inds, topk_values):
|
|
437
|
+
top_logprobs: List[TopLogprob] = []
|
|
438
|
+
for tok, val in zip(toks, vals):
|
|
439
|
+
# TODO: you get byte artifacts converting single ids to tokens like this...
|
|
440
|
+
# but `tokenizer.decode` strips spaces. There must be a better way to do this.
|
|
441
|
+
token_str = tokenizer.convert_ids_to_tokens(tok.item())
|
|
442
|
+
top_logprobs.append(TopLogprob(
|
|
443
|
+
token=token_str,
|
|
444
|
+
logprob=val,
|
|
445
|
+
bytes=list(map(ord, token_str)),
|
|
446
|
+
))
|
|
447
|
+
final_logprobs.append(
|
|
448
|
+
Logprob(
|
|
449
|
+
token=top_logprobs[0].token,
|
|
450
|
+
logprob=top_logprobs[0].logprob,
|
|
451
|
+
bytes=top_logprobs[0].bytes,
|
|
452
|
+
top_logprobs=top_logprobs,
|
|
453
|
+
)
|
|
454
|
+
)
|
|
455
|
+
return final_logprobs
|