evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +4 -1
- evalscope/api/benchmark/__init__.py +3 -0
- evalscope/api/benchmark/adapters/__init__.py +5 -0
- evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
- evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
- evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
- evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
- evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
- evalscope/api/benchmark/benchmark.py +356 -0
- evalscope/api/benchmark/meta.py +121 -0
- evalscope/api/dataset/__init__.py +2 -0
- evalscope/api/dataset/dataset.py +349 -0
- evalscope/api/dataset/loader.py +262 -0
- evalscope/api/dataset/utils.py +143 -0
- evalscope/api/evaluator/__init__.py +3 -0
- evalscope/api/evaluator/cache.py +378 -0
- evalscope/api/evaluator/evaluator.py +56 -0
- evalscope/api/evaluator/state.py +275 -0
- evalscope/api/filter/__init__.py +1 -0
- evalscope/api/filter/filter.py +72 -0
- evalscope/api/messages/__init__.py +12 -0
- evalscope/api/messages/chat_message.py +243 -0
- evalscope/api/messages/content.py +102 -0
- evalscope/api/messages/utils.py +35 -0
- evalscope/api/metric/__init__.py +2 -0
- evalscope/api/metric/metric.py +55 -0
- evalscope/api/metric/scorer.py +113 -0
- evalscope/api/mixin/__init__.py +1 -0
- evalscope/api/mixin/llm_judge_mixin.py +168 -0
- evalscope/api/model/__init__.py +12 -0
- evalscope/api/model/generate_config.py +155 -0
- evalscope/api/model/model.py +386 -0
- evalscope/api/model/model_output.py +285 -0
- evalscope/api/registry.py +182 -0
- evalscope/api/tool/__init__.py +3 -0
- evalscope/api/tool/tool_call.py +101 -0
- evalscope/api/tool/tool_info.py +173 -0
- evalscope/api/tool/utils.py +64 -0
- evalscope/app/app.py +3 -0
- evalscope/app/ui/app_ui.py +2 -1
- evalscope/app/ui/multi_model.py +50 -25
- evalscope/app/ui/single_model.py +26 -14
- evalscope/app/utils/data_utils.py +43 -27
- evalscope/app/utils/env_utils.py +12 -0
- evalscope/app/utils/text_utils.py +14 -14
- evalscope/app/utils/visualization.py +9 -4
- evalscope/arguments.py +7 -10
- evalscope/backend/opencompass/api_meta_template.py +2 -1
- evalscope/backend/opencompass/backend_manager.py +6 -5
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
- evalscope/backend/rag_eval/ragas/task_template.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
- evalscope/backend/rag_eval/utils/embedding.py +10 -1
- evalscope/backend/rag_eval/utils/llm.py +13 -12
- evalscope/benchmarks/__init__.py +0 -2
- evalscope/benchmarks/aime/aime24_adapter.py +38 -40
- evalscope/benchmarks/aime/aime25_adapter.py +34 -40
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
- evalscope/benchmarks/arc/arc_adapter.py +34 -147
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
- evalscope/benchmarks/arena_hard/utils.py +37 -1
- evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
- evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
- evalscope/benchmarks/bfcl/generation.py +222 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
- evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
- evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
- evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
- evalscope/benchmarks/docmath/utils.py +4 -5
- evalscope/benchmarks/drop/drop_adapter.py +88 -40
- evalscope/benchmarks/frames/frames_adapter.py +136 -52
- evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
- evalscope/benchmarks/general_arena/utils.py +23 -27
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
- evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
- evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
- evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
- evalscope/benchmarks/hle/hle_adapter.py +127 -93
- evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
- evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
- evalscope/benchmarks/ifeval/instructions.py +109 -64
- evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
- evalscope/benchmarks/ifeval/instructions_util.py +2 -3
- evalscope/benchmarks/ifeval/utils.py +6 -7
- evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
- evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
- evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
- evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
- evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
- evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
- evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
- evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
- evalscope/benchmarks/math_vista/__init__.py +0 -0
- evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
- evalscope/benchmarks/mmmu/__init__.py +0 -0
- evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
- evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
- evalscope/benchmarks/musr/musr_adapter.py +33 -64
- evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
- evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
- evalscope/benchmarks/race/race_adapter.py +33 -119
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
- evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
- evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
- evalscope/benchmarks/super_gpqa/utils.py +2 -1
- evalscope/benchmarks/tau_bench/generation.py +147 -0
- evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
- evalscope/benchmarks/text2image/__init__.py +0 -0
- evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
- evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
- evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
- evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
- evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
- evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
- evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
- evalscope/cli/cli.py +2 -0
- evalscope/cli/start_app.py +7 -1
- evalscope/cli/start_perf.py +7 -1
- evalscope/cli/start_server.py +6 -3
- evalscope/collections/__init__.py +2 -10
- evalscope/collections/sampler.py +10 -10
- evalscope/collections/schema.py +13 -11
- evalscope/config.py +157 -57
- evalscope/constants.py +37 -61
- evalscope/evaluator/__init__.py +1 -1
- evalscope/evaluator/evaluator.py +275 -419
- evalscope/filters/__init__.py +2 -0
- evalscope/filters/extraction.py +126 -0
- evalscope/filters/selection.py +57 -0
- evalscope/metrics/__init__.py +13 -13
- evalscope/metrics/llm_judge.py +47 -33
- evalscope/metrics/math_parser.py +27 -22
- evalscope/metrics/metric.py +307 -0
- evalscope/metrics/metrics.py +22 -18
- evalscope/metrics/t2v_metrics/__init__.py +0 -52
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
- evalscope/models/__init__.py +6 -29
- evalscope/models/image_edit_model.py +125 -0
- evalscope/models/mockllm.py +65 -0
- evalscope/models/model_apis.py +67 -0
- evalscope/models/modelscope.py +455 -0
- evalscope/models/openai_compatible.py +126 -0
- evalscope/models/text2image_model.py +124 -0
- evalscope/models/utils/openai.py +701 -0
- evalscope/perf/benchmark.py +4 -1
- evalscope/perf/http_client.py +4 -2
- evalscope/perf/plugin/api/custom_api.py +5 -4
- evalscope/perf/plugin/api/openai_api.py +11 -9
- evalscope/perf/plugin/datasets/custom.py +2 -1
- evalscope/perf/plugin/datasets/flickr8k.py +1 -1
- evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
- evalscope/perf/plugin/datasets/line_by_line.py +2 -1
- evalscope/perf/plugin/datasets/longalpaca.py +2 -1
- evalscope/perf/plugin/datasets/openqa.py +4 -2
- evalscope/perf/utils/benchmark_util.py +15 -10
- evalscope/perf/utils/db_util.py +9 -6
- evalscope/perf/utils/local_server.py +11 -3
- evalscope/perf/utils/rich_display.py +16 -10
- evalscope/report/__init__.py +2 -3
- evalscope/report/combinator.py +18 -12
- evalscope/report/generator.py +51 -35
- evalscope/report/{utils.py → report.py} +8 -6
- evalscope/run.py +33 -47
- evalscope/summarizer.py +1 -1
- evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
- evalscope/utils/__init__.py +21 -2
- evalscope/utils/chat_service.py +3 -2
- evalscope/utils/deprecation_utils.py +12 -1
- evalscope/utils/function_utils.py +29 -0
- evalscope/utils/import_utils.py +23 -1
- evalscope/utils/io_utils.py +142 -6
- evalscope/utils/json_schema.py +208 -0
- evalscope/utils/logger.py +51 -12
- evalscope/utils/model_utils.py +11 -7
- evalscope/utils/multi_choices.py +288 -0
- evalscope/utils/url_utils.py +65 -0
- evalscope/version.py +2 -2
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
- tests/benchmark/test_eval.py +385 -0
- tests/benchmark/test_image_edit.py +65 -0
- tests/{aigc → benchmark}/test_t2i.py +22 -4
- tests/benchmark/test_vlm.py +80 -0
- tests/cli/test_all.py +85 -47
- tests/cli/test_collection.py +20 -8
- tests/cli/test_custom.py +22 -15
- tests/cli/test_reasoning.py +81 -0
- tests/common.py +73 -0
- tests/perf/test_perf.py +4 -2
- tests/rag/test_clip_benchmark.py +0 -2
- evalscope/benchmarks/aigc/t2i/base.py +0 -56
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
- evalscope/benchmarks/arc/ai2_arc.py +0 -151
- evalscope/benchmarks/benchmark.py +0 -81
- evalscope/benchmarks/ceval/ceval_exam.py +0 -146
- evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
- evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
- evalscope/benchmarks/competition_math/competition_math.py +0 -79
- evalscope/benchmarks/data_adapter.py +0 -528
- evalscope/benchmarks/filters.py +0 -59
- evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
- evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
- evalscope/benchmarks/humaneval/humaneval.py +0 -79
- evalscope/benchmarks/mmlu/mmlu.py +0 -160
- evalscope/benchmarks/mmlu/samples.jsonl +0 -5
- evalscope/benchmarks/process_bench/critique_template.txt +0 -13
- evalscope/benchmarks/race/race.py +0 -104
- evalscope/benchmarks/race/samples.jsonl +0 -5
- evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
- evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
- evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
- evalscope/benchmarks/utils.py +0 -60
- evalscope/collections/evaluator.py +0 -375
- evalscope/metrics/completion_parsers.py +0 -227
- evalscope/metrics/named_metrics.py +0 -55
- evalscope/models/adapters/__init__.py +0 -14
- evalscope/models/adapters/base_adapter.py +0 -84
- evalscope/models/adapters/bfcl_adapter.py +0 -246
- evalscope/models/adapters/chat_adapter.py +0 -207
- evalscope/models/adapters/choice_adapter.py +0 -222
- evalscope/models/adapters/custom_adapter.py +0 -71
- evalscope/models/adapters/server_adapter.py +0 -236
- evalscope/models/adapters/t2i_adapter.py +0 -79
- evalscope/models/adapters/tau_bench_adapter.py +0 -189
- evalscope/models/custom/__init__.py +0 -4
- evalscope/models/custom/custom_model.py +0 -50
- evalscope/models/custom/dummy_model.py +0 -99
- evalscope/models/local_model.py +0 -128
- evalscope/models/register.py +0 -41
- tests/cli/test_run.py +0 -489
- /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
- /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
- /tests/{aigc → benchmark}/__init__.py +0 -0
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import functools
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
import torch # type: ignore
|
|
8
|
+
from concurrent.futures import Future
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from logging import getLogger
|
|
11
|
+
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
|
12
|
+
from queue import Empty, Queue
|
|
13
|
+
from threading import Thread
|
|
14
|
+
from torch import Tensor # type: ignore
|
|
15
|
+
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union, cast
|
|
16
|
+
from typing_extensions import override
|
|
17
|
+
|
|
18
|
+
from evalscope.api.messages import (
|
|
19
|
+
ChatMessage,
|
|
20
|
+
ChatMessageAssistant,
|
|
21
|
+
ContentAudio,
|
|
22
|
+
ContentImage,
|
|
23
|
+
ContentText,
|
|
24
|
+
ContentVideo,
|
|
25
|
+
)
|
|
26
|
+
from evalscope.api.model import (
|
|
27
|
+
ChatCompletionChoice,
|
|
28
|
+
GenerateConfig,
|
|
29
|
+
Logprob,
|
|
30
|
+
Logprobs,
|
|
31
|
+
ModelAPI,
|
|
32
|
+
ModelOutput,
|
|
33
|
+
ModelUsage,
|
|
34
|
+
TopLogprob,
|
|
35
|
+
)
|
|
36
|
+
from evalscope.api.tool import ToolChoice, ToolInfo
|
|
37
|
+
from evalscope.utils.model_utils import get_device
|
|
38
|
+
|
|
39
|
+
logger = getLogger()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ModelScopeAPI(ModelAPI):
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
model_name: str,
|
|
47
|
+
base_url: Optional[str] = None,
|
|
48
|
+
api_key: Optional[str] = None,
|
|
49
|
+
config: GenerateConfig = GenerateConfig(),
|
|
50
|
+
**model_args: Any,
|
|
51
|
+
):
|
|
52
|
+
super().__init__(
|
|
53
|
+
model_name=model_name,
|
|
54
|
+
base_url=base_url,
|
|
55
|
+
api_key=api_key,
|
|
56
|
+
config=config,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# collect known model_args (then delete them so we can pass the rest on)
|
|
60
|
+
def collect_model_arg(name: str) -> Optional[Any]:
|
|
61
|
+
nonlocal model_args
|
|
62
|
+
value = model_args.get(name, None)
|
|
63
|
+
if value is not None:
|
|
64
|
+
model_args.pop(name)
|
|
65
|
+
return value
|
|
66
|
+
|
|
67
|
+
model_path = collect_model_arg('model_path')
|
|
68
|
+
device_map = collect_model_arg('device_map')
|
|
69
|
+
torch_dtype = collect_model_arg('precision')
|
|
70
|
+
tokenizer_path = collect_model_arg('tokenizer_path')
|
|
71
|
+
self.chat_template = collect_model_arg('chat_template')
|
|
72
|
+
self.tokenizer_call_args = collect_model_arg('tokenizer_call_args')
|
|
73
|
+
self.enable_thinking = collect_model_arg('enable_thinking')
|
|
74
|
+
if self.tokenizer_call_args is None:
|
|
75
|
+
self.tokenizer_call_args = {}
|
|
76
|
+
|
|
77
|
+
# device
|
|
78
|
+
self.device = device_map or get_device()
|
|
79
|
+
|
|
80
|
+
# torch dtype
|
|
81
|
+
DTYPE_MAP = {'float16': torch.float16, 'float32': torch.float32, 'bfloat16': torch.bfloat16, 'auto': 'auto'}
|
|
82
|
+
|
|
83
|
+
if isinstance(torch_dtype, str) and torch_dtype != 'auto':
|
|
84
|
+
torch_dtype = DTYPE_MAP.get(torch_dtype, torch.float32)
|
|
85
|
+
self.torch_dtype = torch_dtype
|
|
86
|
+
|
|
87
|
+
# model
|
|
88
|
+
model_name_or_path = model_path or model_name
|
|
89
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
90
|
+
model_name_or_path,
|
|
91
|
+
device_map=self.device,
|
|
92
|
+
token=self.api_key,
|
|
93
|
+
torch_dtype=self.torch_dtype,
|
|
94
|
+
trust_remote_code=True,
|
|
95
|
+
**model_args
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# tokenizer
|
|
99
|
+
tokenizer_name_or_path = tokenizer_path or model_name_or_path
|
|
100
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
|
|
101
|
+
# LLMs generally don't have a pad token and we need one for batching
|
|
102
|
+
if self.tokenizer.pad_token is None:
|
|
103
|
+
if self.tokenizer.eos_token is not None:
|
|
104
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
105
|
+
else:
|
|
106
|
+
# add a pad token
|
|
107
|
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
|
108
|
+
# set padding side to left for LLMs
|
|
109
|
+
self.tokenizer.padding_side = 'left'
|
|
110
|
+
# set chat template if provided
|
|
111
|
+
if self.chat_template:
|
|
112
|
+
self.tokenizer.chat_template = self.chat_template
|
|
113
|
+
logger.info(f'Using custom chat template: {self.chat_template}')
|
|
114
|
+
|
|
115
|
+
def generate(
|
|
116
|
+
self,
|
|
117
|
+
input: List[ChatMessage],
|
|
118
|
+
tools: List[ToolInfo],
|
|
119
|
+
tool_choice: ToolChoice,
|
|
120
|
+
config: GenerateConfig,
|
|
121
|
+
) -> ModelOutput:
|
|
122
|
+
|
|
123
|
+
# create chat
|
|
124
|
+
chat = self.ms_chat(input, tools)
|
|
125
|
+
|
|
126
|
+
assert isinstance(self.tokenizer_call_args, dict)
|
|
127
|
+
# prepare tokenizer
|
|
128
|
+
tokenizer = functools.partial(
|
|
129
|
+
self.tokenizer,
|
|
130
|
+
return_tensors='pt',
|
|
131
|
+
padding=True,
|
|
132
|
+
**self.tokenizer_call_args,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# prepare generator
|
|
136
|
+
kwargs: Dict[str, Any] = {}
|
|
137
|
+
if config.do_sample is not None:
|
|
138
|
+
kwargs['do_sample'] = config.do_sample
|
|
139
|
+
if config.n is not None:
|
|
140
|
+
if config.n > 1:
|
|
141
|
+
assert config.do_sample, 'n > 1 requires do_sample=True in GenerateConfig'
|
|
142
|
+
kwargs['num_return_sequences'] = config.n
|
|
143
|
+
if config.max_tokens is not None:
|
|
144
|
+
kwargs['max_new_tokens'] = config.max_tokens
|
|
145
|
+
if config.temperature is not None:
|
|
146
|
+
kwargs['temperature'] = config.temperature
|
|
147
|
+
if config.top_p is not None:
|
|
148
|
+
kwargs['top_p'] = config.top_p
|
|
149
|
+
if config.top_k is not None:
|
|
150
|
+
kwargs['top_k'] = config.top_k
|
|
151
|
+
if config.logprobs is not None:
|
|
152
|
+
kwargs['output_logits'] = config.logprobs
|
|
153
|
+
if 'return_dict_in_generate' in kwargs:
|
|
154
|
+
assert kwargs['return_dict_in_generate']
|
|
155
|
+
if config.stop_seqs is not None:
|
|
156
|
+
from transformers.generation import StopStringCriteria # type: ignore
|
|
157
|
+
|
|
158
|
+
stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
|
|
159
|
+
kwargs['stopping_criteria'] = stopping_criteria
|
|
160
|
+
|
|
161
|
+
kwargs['return_dict_in_generate'] = True
|
|
162
|
+
generator = functools.partial(self.model.generate, **kwargs)
|
|
163
|
+
|
|
164
|
+
# prepare decoder
|
|
165
|
+
decoder = functools.partial(
|
|
166
|
+
self.tokenizer.batch_decode,
|
|
167
|
+
skip_special_tokens=True,
|
|
168
|
+
clean_up_tokenization_spaces=False,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# generate
|
|
172
|
+
responses = batched_generate(
|
|
173
|
+
GenerateInput(
|
|
174
|
+
input=chat,
|
|
175
|
+
device=self.model.device,
|
|
176
|
+
tokenizer=tokenizer,
|
|
177
|
+
generator=generator,
|
|
178
|
+
decoder=decoder,
|
|
179
|
+
batch_size=config.batch_size or self.max_connections(),
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
choices: List[ChatCompletionChoice] = []
|
|
184
|
+
for response in responses:
|
|
185
|
+
# gather logprobs
|
|
186
|
+
final_logprobs = None
|
|
187
|
+
if config.logprobs is not None:
|
|
188
|
+
final_logprobs = extract_logprobs(
|
|
189
|
+
response=response,
|
|
190
|
+
top=config.top_logprobs,
|
|
191
|
+
tokenizer=self.tokenizer,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# construct choice
|
|
195
|
+
# TODO: Handle tool calls
|
|
196
|
+
choice = ChatCompletionChoice(
|
|
197
|
+
message=ChatMessageAssistant(content=response.output, model=self.model_name, source='generate'),
|
|
198
|
+
logprobs=(Logprobs(content=final_logprobs) if final_logprobs is not None else None),
|
|
199
|
+
)
|
|
200
|
+
choices.append(choice)
|
|
201
|
+
|
|
202
|
+
# return output
|
|
203
|
+
return ModelOutput(
|
|
204
|
+
model=self.model_name,
|
|
205
|
+
choices=choices,
|
|
206
|
+
usage=ModelUsage(
|
|
207
|
+
input_tokens=response.input_tokens,
|
|
208
|
+
output_tokens=response.output_tokens,
|
|
209
|
+
total_tokens=response.total_tokens,
|
|
210
|
+
),
|
|
211
|
+
time=response.time,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
@override
|
|
215
|
+
def max_tokens(self) -> Optional[int]:
|
|
216
|
+
"""Default is 2048, bump it up to a value suitable for evals."""
|
|
217
|
+
return 2048
|
|
218
|
+
|
|
219
|
+
@override
|
|
220
|
+
def max_connections(self) -> int:
|
|
221
|
+
"""Effectively the batch size."""
|
|
222
|
+
return 8
|
|
223
|
+
|
|
224
|
+
def ms_chat(self, messages: List[ChatMessage], tools: List[ToolInfo]) -> str:
|
|
225
|
+
# convert to ms format
|
|
226
|
+
tools_list = []
|
|
227
|
+
ms_messages = copy.deepcopy(messages)
|
|
228
|
+
if len(tools) > 0:
|
|
229
|
+
tools_list = [json.loads(tool.model_dump_json(exclude_none=True, indent=2)) for tool in tools]
|
|
230
|
+
|
|
231
|
+
ms_messages = message_content_to_string(ms_messages)
|
|
232
|
+
# apply chat template
|
|
233
|
+
if self.tokenizer.chat_template is not None:
|
|
234
|
+
chat = self.tokenizer.apply_chat_template(
|
|
235
|
+
ms_messages,
|
|
236
|
+
add_generation_prompt=True,
|
|
237
|
+
tokenize=False,
|
|
238
|
+
tools=tools_list if len(tools_list) > 0 else None,
|
|
239
|
+
enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
chat = ''
|
|
243
|
+
for message in ms_messages:
|
|
244
|
+
chat += f'{message.role}: {message.content}\n'
|
|
245
|
+
# return
|
|
246
|
+
return cast(str, chat)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def message_content_to_string(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|
250
|
+
"""Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
|
|
251
|
+
for message in messages:
|
|
252
|
+
if isinstance(message.content, list):
|
|
253
|
+
is_multimodal = any(
|
|
254
|
+
isinstance(item, (ContentAudio, ContentImage, ContentVideo)) for item in message.content
|
|
255
|
+
)
|
|
256
|
+
if is_multimodal:
|
|
257
|
+
raise NotImplementedError(
|
|
258
|
+
'Transformer model does not support multimodal content, please provide text inputs only.'
|
|
259
|
+
)
|
|
260
|
+
message.content = message.text
|
|
261
|
+
return messages
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# return value from generate as a result of specifying return_dict_in_generate
|
|
265
|
+
class ModelGenerateOutput:
|
|
266
|
+
sequences: Tensor
|
|
267
|
+
logits: tuple[Tensor]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class Tokenizer(Protocol):
|
|
271
|
+
|
|
272
|
+
def __call__(self, input: List[str]) -> Dict[Literal['input_ids', 'attention_mask'], Tensor]:
|
|
273
|
+
...
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class Generator(Protocol):
|
|
277
|
+
|
|
278
|
+
def __call__(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
|
279
|
+
...
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class Decoder(Protocol):
|
|
283
|
+
|
|
284
|
+
def __call__(self, sequences: Tensor) -> list[str]:
|
|
285
|
+
...
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@dataclass
|
|
289
|
+
class GenerateInput:
|
|
290
|
+
input: str
|
|
291
|
+
device: str
|
|
292
|
+
tokenizer: Tokenizer
|
|
293
|
+
generator: Generator
|
|
294
|
+
decoder: Decoder
|
|
295
|
+
batch_size: int
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@dataclass
|
|
299
|
+
class GenerateOutput:
|
|
300
|
+
output: str
|
|
301
|
+
input_tokens: int
|
|
302
|
+
output_tokens: int
|
|
303
|
+
total_tokens: int
|
|
304
|
+
logprobs: Optional[torch.Tensor]
|
|
305
|
+
time: float
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@dataclass
|
|
309
|
+
class _QueueItem:
|
|
310
|
+
input: GenerateInput
|
|
311
|
+
future: Future[GenerateOutput]
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
batch_thread: Optional[Thread] = None
|
|
315
|
+
|
|
316
|
+
batch_queue: 'Queue[_QueueItem]' = Queue()
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def batched_generate(input: GenerateInput) -> List[GenerateOutput]:
|
|
320
|
+
# start the background thread if necessary
|
|
321
|
+
global batch_thread
|
|
322
|
+
if batch_thread is None:
|
|
323
|
+
batch_thread = Thread(target=process_batches, daemon=True)
|
|
324
|
+
batch_thread.start()
|
|
325
|
+
|
|
326
|
+
# enqueue the job
|
|
327
|
+
future = Future[GenerateOutput]()
|
|
328
|
+
batch_queue.put(_QueueItem(input=input, future=future))
|
|
329
|
+
|
|
330
|
+
return future.result()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def process_batches() -> None:
|
|
334
|
+
while True:
|
|
335
|
+
# drain the queue (wait until no new messages have shown up for 2 seconds)
|
|
336
|
+
inputs: List[Tuple[GenerateInput, Future[GenerateOutput]]] = []
|
|
337
|
+
while True:
|
|
338
|
+
try:
|
|
339
|
+
input = batch_queue.get(timeout=2)
|
|
340
|
+
inputs.append((input.input, input.future))
|
|
341
|
+
if len(inputs) == input.input.batch_size:
|
|
342
|
+
# max batch size reached
|
|
343
|
+
break
|
|
344
|
+
except Empty:
|
|
345
|
+
# we have exhausted the queue
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
# see if we have any work to do
|
|
349
|
+
if len(inputs) == 0:
|
|
350
|
+
continue
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
# capture the generator and decoder functions
|
|
354
|
+
start_time = time.monotonic()
|
|
355
|
+
first_input = inputs[0][0]
|
|
356
|
+
device = first_input.device
|
|
357
|
+
tokenizer = first_input.tokenizer
|
|
358
|
+
generator = first_input.generator
|
|
359
|
+
decoder = first_input.decoder
|
|
360
|
+
num_return_sequences = generator.keywords.get('num_return_sequences', 1)
|
|
361
|
+
|
|
362
|
+
# tokenize and move to device
|
|
363
|
+
tokenized_inputs = tokenizer([item[0].input for item in inputs])
|
|
364
|
+
input_ids = tokenized_inputs['input_ids']
|
|
365
|
+
attention_mask = tokenized_inputs['attention_mask']
|
|
366
|
+
input_ids = input_ids.to(device)
|
|
367
|
+
attention_mask = attention_mask.to(device)
|
|
368
|
+
|
|
369
|
+
# generate
|
|
370
|
+
with torch.inference_mode():
|
|
371
|
+
generation_outputs = cast(
|
|
372
|
+
ModelGenerateOutput,
|
|
373
|
+
generator(input_ids=input_ids, attention_mask=attention_mask),
|
|
374
|
+
)
|
|
375
|
+
generate_ids = generation_outputs.sequences
|
|
376
|
+
logits = generation_outputs.logits
|
|
377
|
+
|
|
378
|
+
# get logprobs from logits
|
|
379
|
+
logprobs = None
|
|
380
|
+
if logits is not None:
|
|
381
|
+
stacked_logits = torch.stack(logits).transpose(0, 1)
|
|
382
|
+
logprobs = torch.nn.functional.log_softmax(stacked_logits, dim=-1)
|
|
383
|
+
|
|
384
|
+
# decode
|
|
385
|
+
generated_tokens = generate_ids[:, input_ids.size(dim=1):]
|
|
386
|
+
if logprobs is not None:
|
|
387
|
+
assert logprobs.shape[1] == generated_tokens.shape[1]
|
|
388
|
+
outputs = decoder(sequences=generated_tokens)
|
|
389
|
+
|
|
390
|
+
# call back futures
|
|
391
|
+
total_time = time.monotonic() - start_time
|
|
392
|
+
for input_index in range(len(inputs)):
|
|
393
|
+
choices: List[GenerateOutput] = []
|
|
394
|
+
# handle input
|
|
395
|
+
future = inputs[input_index][1]
|
|
396
|
+
input_tokens = input_ids[input_index].shape[-1]
|
|
397
|
+
# handle choices
|
|
398
|
+
for choice_index in range(num_return_sequences):
|
|
399
|
+
output_index = input_index * num_return_sequences + choice_index
|
|
400
|
+
# handle out of
|
|
401
|
+
output = outputs[output_index]
|
|
402
|
+
output_tokens = generate_ids[output_index].shape[-1] - input_tokens
|
|
403
|
+
logprobs_tensor = logprobs[output_index] if logprobs is not None else None
|
|
404
|
+
# create the output
|
|
405
|
+
choices.append(
|
|
406
|
+
GenerateOutput(
|
|
407
|
+
output=output,
|
|
408
|
+
input_tokens=input_tokens,
|
|
409
|
+
output_tokens=output_tokens,
|
|
410
|
+
total_tokens=input_tokens + output_tokens,
|
|
411
|
+
logprobs=logprobs_tensor,
|
|
412
|
+
time=total_time,
|
|
413
|
+
)
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# asyncio futures are not thread safe, so we need to pass the event loop
|
|
417
|
+
# down to this point, so we can mark the future as done in a thread safe manner.
|
|
418
|
+
# see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
|
|
419
|
+
future.set_result(choices)
|
|
420
|
+
|
|
421
|
+
except Exception as ex:
|
|
422
|
+
for inp in inputs:
|
|
423
|
+
future = inp[1]
|
|
424
|
+
future.set_exception(ex)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def extract_logprobs(
|
|
428
|
+
response: GenerateOutput,
|
|
429
|
+
top: Optional[int],
|
|
430
|
+
tokenizer,
|
|
431
|
+
) -> List[Logprob]:
|
|
432
|
+
assert response.logprobs is not None
|
|
433
|
+
k = top or 1
|
|
434
|
+
topk_values, topk_inds = response.logprobs.topk(k=k, dim=-1)
|
|
435
|
+
final_logprobs = []
|
|
436
|
+
for toks, vals in zip(topk_inds, topk_values):
|
|
437
|
+
top_logprobs: List[TopLogprob] = []
|
|
438
|
+
for tok, val in zip(toks, vals):
|
|
439
|
+
# TODO: you get byte artifacts converting single ids to tokens like this...
|
|
440
|
+
# but `tokenizer.decode` strips spaces. There must be a better way to do this.
|
|
441
|
+
token_str = tokenizer.convert_ids_to_tokens(tok.item())
|
|
442
|
+
top_logprobs.append(TopLogprob(
|
|
443
|
+
token=token_str,
|
|
444
|
+
logprob=val,
|
|
445
|
+
bytes=list(map(ord, token_str)),
|
|
446
|
+
))
|
|
447
|
+
final_logprobs.append(
|
|
448
|
+
Logprob(
|
|
449
|
+
token=top_logprobs[0].token,
|
|
450
|
+
logprob=top_logprobs[0].logprob,
|
|
451
|
+
bytes=top_logprobs[0].bytes,
|
|
452
|
+
top_logprobs=top_logprobs,
|
|
453
|
+
)
|
|
454
|
+
)
|
|
455
|
+
return final_logprobs
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from openai import APIStatusError, BadRequestError, OpenAI, PermissionDeniedError, UnprocessableEntityError
|
|
3
|
+
from openai._types import NOT_GIVEN
|
|
4
|
+
from openai.types.chat import ChatCompletion
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
from evalscope.api.messages import ChatMessage
|
|
8
|
+
from evalscope.api.model import ChatCompletionChoice, GenerateConfig, ModelAPI, ModelOutput
|
|
9
|
+
from evalscope.api.tool import ToolChoice, ToolInfo
|
|
10
|
+
from evalscope.utils import get_logger
|
|
11
|
+
from .utils.openai import (
|
|
12
|
+
chat_choices_from_openai,
|
|
13
|
+
collect_stream_response,
|
|
14
|
+
model_output_from_openai,
|
|
15
|
+
openai_chat_messages,
|
|
16
|
+
openai_chat_tool_choice,
|
|
17
|
+
openai_chat_tools,
|
|
18
|
+
openai_completion_params,
|
|
19
|
+
openai_handle_bad_request,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = get_logger()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OpenAICompatibleAPI(ModelAPI):
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model_name: str,
|
|
30
|
+
base_url: Optional[str] = None,
|
|
31
|
+
api_key: Optional[str] = None,
|
|
32
|
+
config: GenerateConfig = GenerateConfig(),
|
|
33
|
+
**model_args: Any,
|
|
34
|
+
) -> None:
|
|
35
|
+
|
|
36
|
+
super().__init__(
|
|
37
|
+
model_name=model_name,
|
|
38
|
+
base_url=base_url,
|
|
39
|
+
api_key=api_key,
|
|
40
|
+
config=config,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# use service prefix to lookup api_key
|
|
44
|
+
self.api_key = api_key or os.environ.get('EVALSCOPE_API_KEY', None)
|
|
45
|
+
assert self.api_key, f'API key for {model_name} not found'
|
|
46
|
+
|
|
47
|
+
# use service prefix to lookup base_url
|
|
48
|
+
self.base_url = base_url or os.environ.get('EVALSCOPE_BASE_URL', None)
|
|
49
|
+
assert self.base_url, f'Base URL for {model_name} not found'
|
|
50
|
+
|
|
51
|
+
# remove trailing slash from base_url
|
|
52
|
+
self.base_url = self.base_url.rstrip('/').removesuffix('/chat/completions')
|
|
53
|
+
|
|
54
|
+
# create http client
|
|
55
|
+
self.client = OpenAI(
|
|
56
|
+
api_key=self.api_key,
|
|
57
|
+
base_url=self.base_url,
|
|
58
|
+
**model_args,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def generate(
|
|
62
|
+
self,
|
|
63
|
+
input: List[ChatMessage],
|
|
64
|
+
tools: List[ToolInfo],
|
|
65
|
+
tool_choice: ToolChoice,
|
|
66
|
+
config: GenerateConfig,
|
|
67
|
+
) -> ModelOutput:
|
|
68
|
+
# setup request and response for ModelCall
|
|
69
|
+
request: Dict[str, Any] = {}
|
|
70
|
+
response: Dict[str, Any] = {}
|
|
71
|
+
|
|
72
|
+
tools, tool_choice, config = self.resolve_tools(tools, tool_choice, config)
|
|
73
|
+
|
|
74
|
+
# get completion params (slice off service from model name)
|
|
75
|
+
completion_params = self.completion_params(
|
|
76
|
+
config=config,
|
|
77
|
+
tools=len(tools) > 0,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
request = dict(
|
|
81
|
+
messages=openai_chat_messages(input),
|
|
82
|
+
tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
|
83
|
+
tool_choice=openai_chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
|
|
84
|
+
**completion_params,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
# generate completion and save response for model call
|
|
89
|
+
completion = self.client.chat.completions.create(**request)
|
|
90
|
+
# handle streaming response
|
|
91
|
+
if not isinstance(completion, ChatCompletion):
|
|
92
|
+
completion = collect_stream_response(completion)
|
|
93
|
+
response = completion.model_dump()
|
|
94
|
+
self.on_response(response)
|
|
95
|
+
|
|
96
|
+
# return output and call
|
|
97
|
+
choices = self.chat_choices_from_completion(completion, tools)
|
|
98
|
+
return model_output_from_openai(completion, choices)
|
|
99
|
+
|
|
100
|
+
except (BadRequestError, UnprocessableEntityError, PermissionDeniedError) as ex:
|
|
101
|
+
return self.handle_bad_request(ex)
|
|
102
|
+
|
|
103
|
+
def resolve_tools(self, tools: List[ToolInfo], tool_choice: ToolChoice,
|
|
104
|
+
config: GenerateConfig) -> Tuple[List[ToolInfo], ToolChoice, GenerateConfig]:
|
|
105
|
+
"""Provides an opportunity for concrete classes to customize tool resolution."""
|
|
106
|
+
return tools, tool_choice, config
|
|
107
|
+
|
|
108
|
+
def completion_params(self, config: GenerateConfig, tools: bool) -> Dict[str, Any]:
|
|
109
|
+
return openai_completion_params(
|
|
110
|
+
model=self.model_name,
|
|
111
|
+
config=config,
|
|
112
|
+
tools=tools,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def on_response(self, response: Dict[str, Any]) -> None:
|
|
116
|
+
"""Hook for subclasses to do custom response handling."""
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
def chat_choices_from_completion(self, completion: ChatCompletion,
|
|
120
|
+
tools: List[ToolInfo]) -> List[ChatCompletionChoice]:
|
|
121
|
+
"""Hook for subclasses to do custom chat choice processing."""
|
|
122
|
+
return chat_choices_from_openai(completion, tools)
|
|
123
|
+
|
|
124
|
+
def handle_bad_request(self, ex: APIStatusError) -> Union[ModelOutput, Exception]:
|
|
125
|
+
"""Hook for subclasses to do bad request handling"""
|
|
126
|
+
return openai_handle_bad_request(self.model_name, ex)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import time
|
|
5
|
+
import torch
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union, cast
|
|
8
|
+
|
|
9
|
+
from evalscope.api.messages import (
|
|
10
|
+
ChatMessage,
|
|
11
|
+
ChatMessageAssistant,
|
|
12
|
+
ContentAudio,
|
|
13
|
+
ContentImage,
|
|
14
|
+
ContentText,
|
|
15
|
+
ContentVideo,
|
|
16
|
+
)
|
|
17
|
+
from evalscope.api.model import (
|
|
18
|
+
ChatCompletionChoice,
|
|
19
|
+
GenerateConfig,
|
|
20
|
+
Logprob,
|
|
21
|
+
Logprobs,
|
|
22
|
+
ModelAPI,
|
|
23
|
+
ModelOutput,
|
|
24
|
+
ModelUsage,
|
|
25
|
+
TopLogprob,
|
|
26
|
+
)
|
|
27
|
+
from evalscope.api.tool import ToolChoice, ToolInfo
|
|
28
|
+
from evalscope.utils.io_utils import PIL_to_base64
|
|
29
|
+
from evalscope.utils.model_utils import get_device
|
|
30
|
+
|
|
31
|
+
logger = getLogger()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Text2ImageAPI(ModelAPI):
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
model_name: str,
|
|
39
|
+
base_url: Optional[str] = None,
|
|
40
|
+
api_key: Optional[str] = None,
|
|
41
|
+
config: GenerateConfig = GenerateConfig(),
|
|
42
|
+
**model_args: Any,
|
|
43
|
+
):
|
|
44
|
+
super().__init__(
|
|
45
|
+
model_name=model_name,
|
|
46
|
+
base_url=base_url,
|
|
47
|
+
api_key=api_key,
|
|
48
|
+
config=config,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# collect known model_args (then delete them so we can pass the rest on)
|
|
52
|
+
def collect_model_arg(name: str) -> Optional[Any]:
|
|
53
|
+
nonlocal model_args
|
|
54
|
+
value = model_args.get(name, None)
|
|
55
|
+
if value is not None:
|
|
56
|
+
model_args.pop(name)
|
|
57
|
+
return value
|
|
58
|
+
|
|
59
|
+
model_path = collect_model_arg('model_path')
|
|
60
|
+
torch_dtype = collect_model_arg('precision') or collect_model_arg('torch_dtype')
|
|
61
|
+
device_map = collect_model_arg('device_map')
|
|
62
|
+
# torch dtype
|
|
63
|
+
DTYPE_MAP = {'float16': torch.float16, 'float32': torch.float32, 'bfloat16': torch.bfloat16, 'auto': 'auto'}
|
|
64
|
+
|
|
65
|
+
if isinstance(torch_dtype, str) and torch_dtype != 'auto':
|
|
66
|
+
torch_dtype = DTYPE_MAP.get(torch_dtype, torch.float32)
|
|
67
|
+
self.torch_dtype = torch_dtype
|
|
68
|
+
self.device = device_map or get_device()
|
|
69
|
+
|
|
70
|
+
self.pipeline_cls = collect_model_arg('pipeline_cls')
|
|
71
|
+
# default to DiffusionPipeline if not specified
|
|
72
|
+
if self.pipeline_cls is None:
|
|
73
|
+
if 'flux' in model_name.lower():
|
|
74
|
+
self.pipeline_cls = 'FluxPipeline'
|
|
75
|
+
else:
|
|
76
|
+
self.pipeline_cls = 'DiffusionPipeline'
|
|
77
|
+
|
|
78
|
+
model_name_or_path = model_path or model_name
|
|
79
|
+
|
|
80
|
+
# from modelscope import pipeline_cls
|
|
81
|
+
module = getattr(importlib.import_module('modelscope'), self.pipeline_cls)
|
|
82
|
+
logger.info(f'Loading model {model_name_or_path} with {self.pipeline_cls} ...')
|
|
83
|
+
|
|
84
|
+
self.model = module.from_pretrained(
|
|
85
|
+
model_name_or_path,
|
|
86
|
+
torch_dtype=self.torch_dtype,
|
|
87
|
+
**model_args,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self.model.to(self.device)
|
|
91
|
+
|
|
92
|
+
def generate(
|
|
93
|
+
self,
|
|
94
|
+
input: List[ChatMessage],
|
|
95
|
+
tools: List[ToolInfo],
|
|
96
|
+
tool_choice: ToolChoice,
|
|
97
|
+
config: GenerateConfig,
|
|
98
|
+
) -> ModelOutput:
|
|
99
|
+
|
|
100
|
+
# prepare generator
|
|
101
|
+
kwargs: Dict[str, Any] = {}
|
|
102
|
+
if config.height is not None:
|
|
103
|
+
kwargs['height'] = config.height
|
|
104
|
+
if config.width is not None:
|
|
105
|
+
kwargs['width'] = config.width
|
|
106
|
+
if config.num_inference_steps is not None:
|
|
107
|
+
kwargs['num_inference_steps'] = config.num_inference_steps
|
|
108
|
+
if config.guidance_scale is not None:
|
|
109
|
+
kwargs['guidance_scale'] = config.guidance_scale
|
|
110
|
+
# update with extra model parameters
|
|
111
|
+
kwargs.update(config.model_extra)
|
|
112
|
+
|
|
113
|
+
# assume the first text as prompt
|
|
114
|
+
prompt = input[0].text
|
|
115
|
+
# get the first image as output
|
|
116
|
+
image = self.model(prompt=prompt, **kwargs).images[0]
|
|
117
|
+
|
|
118
|
+
image_base64 = PIL_to_base64(image)
|
|
119
|
+
|
|
120
|
+
return ModelOutput(
|
|
121
|
+
model=self.model_name,
|
|
122
|
+
choices=[ChatCompletionChoice.from_content(content=[ContentImage(image=image_base64)])],
|
|
123
|
+
time=time.time(),
|
|
124
|
+
)
|