xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
xinference/model/llm/utils.py
CHANGED
|
@@ -17,6 +17,7 @@ import json
|
|
|
17
17
|
import logging
|
|
18
18
|
import os
|
|
19
19
|
import time
|
|
20
|
+
import typing
|
|
20
21
|
import uuid
|
|
21
22
|
from io import BytesIO
|
|
22
23
|
from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast
|
|
@@ -25,19 +26,18 @@ import requests
|
|
|
25
26
|
from PIL import Image
|
|
26
27
|
|
|
27
28
|
from ...types import (
|
|
28
|
-
SPECIAL_TOOL_PROMPT,
|
|
29
29
|
ChatCompletion,
|
|
30
|
+
ChatCompletionChoice,
|
|
30
31
|
ChatCompletionChunk,
|
|
31
|
-
ChatCompletionMessage,
|
|
32
32
|
Completion,
|
|
33
|
+
CompletionChoice,
|
|
33
34
|
CompletionChunk,
|
|
35
|
+
CompletionUsage,
|
|
34
36
|
)
|
|
35
|
-
from ..utils import ensure_cache_cleared
|
|
36
37
|
from .llm_family import (
|
|
37
38
|
LlamaCppLLMSpecV1,
|
|
38
39
|
LLMFamilyV1,
|
|
39
40
|
LLMSpecV1,
|
|
40
|
-
PromptStyleV1,
|
|
41
41
|
_get_cache_dir,
|
|
42
42
|
get_cache_status,
|
|
43
43
|
)
|
|
@@ -46,7 +46,6 @@ logger = logging.getLogger(__name__)
|
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
QWEN_TOOL_CALL_FAMILY = [
|
|
49
|
-
"qwen-chat",
|
|
50
49
|
"qwen1.5-chat",
|
|
51
50
|
"qwen1.5-moe-chat",
|
|
52
51
|
"qwen2-instruct",
|
|
@@ -58,416 +57,90 @@ GLM4_TOOL_CALL_FAMILY = [
|
|
|
58
57
|
"glm4-chat-1m",
|
|
59
58
|
]
|
|
60
59
|
|
|
60
|
+
QWEN_TOOL_CALL_SYMBOLS = ["<tool_call>", "</tool_call>"]
|
|
61
|
+
|
|
61
62
|
|
|
62
63
|
class ChatModelMixin:
|
|
63
64
|
@staticmethod
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
chat_history: List[ChatCompletionMessage],
|
|
67
|
-
prompt_style: PromptStyleV1,
|
|
68
|
-
tools: Optional[List[Dict]] = None,
|
|
69
|
-
):
|
|
65
|
+
@functools.lru_cache
|
|
66
|
+
def _compile_jinja_template(chat_template):
|
|
70
67
|
"""
|
|
71
|
-
|
|
72
|
-
different models.
|
|
68
|
+
Copied from transformers source code.
|
|
73
69
|
"""
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
|
|
70
|
+
try:
|
|
71
|
+
from jinja2.exceptions import TemplateError
|
|
72
|
+
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
73
|
+
except ImportError:
|
|
74
|
+
raise ImportError("xinference requires jinja2 to be installed.")
|
|
75
|
+
|
|
76
|
+
def raise_exception(message):
|
|
77
|
+
raise TemplateError(message)
|
|
78
|
+
|
|
79
|
+
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
|
80
|
+
jinja_env.globals["raise_exception"] = raise_exception
|
|
81
|
+
return jinja_env.from_string(chat_template)
|
|
82
|
+
|
|
83
|
+
def _build_from_raw_template(
|
|
84
|
+
self, messages: List, chat_template: str, **kwargs
|
|
85
|
+
) -> str:
|
|
86
|
+
compiled_template = self._compile_jinja_template(chat_template)
|
|
87
|
+
rendered = compiled_template.render(
|
|
88
|
+
messages=messages, add_generation_prompt=True, **kwargs
|
|
81
89
|
)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
content = message["content"]
|
|
96
|
-
if content:
|
|
97
|
-
ret += role + ": " + content + prompt_style.intra_message_sep
|
|
98
|
-
else:
|
|
99
|
-
ret += role + ":"
|
|
100
|
-
return ret
|
|
101
|
-
elif prompt_style.style_name == "NO_COLON_TWO":
|
|
102
|
-
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
103
|
-
ret = prompt_style.system_prompt
|
|
104
|
-
for i, message in enumerate(chat_history):
|
|
105
|
-
role = get_role(message["role"])
|
|
106
|
-
content = message["content"]
|
|
107
|
-
if content:
|
|
108
|
-
ret += role + content + seps[i % 2]
|
|
109
|
-
else:
|
|
110
|
-
ret += role
|
|
111
|
-
return ret
|
|
112
|
-
elif prompt_style.style_name == "LLAMA2":
|
|
113
|
-
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
114
|
-
ret = ""
|
|
115
|
-
for i, message in enumerate(chat_history):
|
|
116
|
-
role = get_role(message["role"])
|
|
117
|
-
content = message["content"]
|
|
118
|
-
if content:
|
|
119
|
-
if i == 0:
|
|
120
|
-
ret += prompt_style.system_prompt + content
|
|
121
|
-
else:
|
|
122
|
-
ret += role + " " + content + seps[i % 2]
|
|
123
|
-
else:
|
|
124
|
-
ret += role
|
|
125
|
-
return ret
|
|
126
|
-
elif prompt_style.style_name == "LLAMA3":
|
|
127
|
-
ret = (
|
|
128
|
-
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
|
|
129
|
-
f"{prompt_style.intra_message_sep}{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
|
|
130
|
-
)
|
|
131
|
-
for i, message in enumerate(chat_history):
|
|
132
|
-
role = get_role(message["role"])
|
|
133
|
-
content = message["content"]
|
|
134
|
-
if content:
|
|
135
|
-
ret += (
|
|
136
|
-
f"<|start_header_id|>{role}<|end_header_id|>"
|
|
137
|
-
f"{prompt_style.intra_message_sep}{content}{prompt_style.inter_message_sep}"
|
|
138
|
-
)
|
|
139
|
-
else:
|
|
140
|
-
ret += f"<|start_header_id|>{role}<|end_header_id|>{prompt_style.intra_message_sep}"
|
|
141
|
-
return ret
|
|
142
|
-
elif prompt_style.style_name == "MIXTRAL_V01":
|
|
143
|
-
ret = ""
|
|
144
|
-
for i, message in enumerate(chat_history):
|
|
145
|
-
content = message["content"]
|
|
146
|
-
if i % 2 == 0: # user
|
|
147
|
-
ret += f"<s> [INST] {content} [/INST]"
|
|
148
|
-
else: # assistant
|
|
149
|
-
ret += f"{content} </s>"
|
|
150
|
-
return ret
|
|
151
|
-
elif prompt_style.style_name == "CHATGLM3":
|
|
152
|
-
prompts = (
|
|
153
|
-
[f"<|system|>\n {prompt_style.system_prompt}"]
|
|
154
|
-
if prompt_style.system_prompt
|
|
155
|
-
else []
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
for i, message in enumerate(chat_history):
|
|
159
|
-
role = get_role(message["role"])
|
|
160
|
-
content = message.get("content")
|
|
161
|
-
tool_calls = message.get("tool_calls")
|
|
162
|
-
if tool_calls:
|
|
163
|
-
content = tool_calls[0]["function"]
|
|
164
|
-
if content:
|
|
165
|
-
if role == "tool":
|
|
166
|
-
role = "observation"
|
|
167
|
-
prompts.append(f"<|{role}|>\n {content}")
|
|
168
|
-
else:
|
|
169
|
-
prompts.append(f"<|{role}|>")
|
|
170
|
-
return "\n".join(prompts)
|
|
171
|
-
elif prompt_style.style_name == "XVERSE":
|
|
172
|
-
ret = (
|
|
173
|
-
f"<|system|> \n {prompt_style.system_prompt}"
|
|
174
|
-
if prompt_style.system_prompt
|
|
175
|
-
else ""
|
|
176
|
-
)
|
|
177
|
-
for i, message in enumerate(chat_history):
|
|
178
|
-
role = get_role(message["role"])
|
|
179
|
-
content = message["content"]
|
|
180
|
-
if content:
|
|
181
|
-
ret += f"<|{role}|> \n {content}"
|
|
182
|
-
else:
|
|
183
|
-
ret += f"<|{role}|>"
|
|
184
|
-
return ret
|
|
185
|
-
elif prompt_style.style_name == "QWEN":
|
|
186
|
-
if tools:
|
|
187
|
-
tool_desc = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object."""
|
|
188
|
-
|
|
189
|
-
react_instruction = """Answer the following questions as best you can. You have access to the following APIs:
|
|
190
|
-
|
|
191
|
-
{tools_text}
|
|
192
|
-
|
|
193
|
-
Use the following format:
|
|
194
|
-
|
|
195
|
-
Question: the input question you must answer
|
|
196
|
-
Thought: you should always think about what to do
|
|
197
|
-
Action: the action to take, should be one of [{tools_name_text}]
|
|
198
|
-
Action Input: the input to the action
|
|
199
|
-
Observation: the result of the action
|
|
200
|
-
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
|
201
|
-
Thought: I now know the final answer
|
|
202
|
-
Final Answer: the final answer to the original input question
|
|
203
|
-
|
|
204
|
-
Begin!"""
|
|
205
|
-
tools_text = []
|
|
206
|
-
tools_name_text = []
|
|
207
|
-
for func_info in tools:
|
|
208
|
-
parameters = []
|
|
209
|
-
fp = func_info["function"].get("parameters", {})
|
|
210
|
-
if fp:
|
|
211
|
-
required_parameters = fp.get("required", [])
|
|
212
|
-
for name, p in fp["properties"].items():
|
|
213
|
-
param = dict({"name": name}, **p)
|
|
214
|
-
if name in required_parameters:
|
|
215
|
-
param["required"] = True
|
|
216
|
-
parameters.append(param)
|
|
217
|
-
|
|
218
|
-
name = func_info["function"]["name"]
|
|
219
|
-
desc = func_info["function"]["description"]
|
|
220
|
-
tool_string = tool_desc.format(
|
|
221
|
-
name_for_model=name,
|
|
222
|
-
name_for_human=name,
|
|
223
|
-
# Hint: You can add the following format requirements in description:
|
|
224
|
-
# "Format the arguments as a JSON object."
|
|
225
|
-
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
|
|
226
|
-
description_for_model=desc,
|
|
227
|
-
parameters=json.dumps(parameters, ensure_ascii=False),
|
|
228
|
-
)
|
|
229
|
-
tools_text.append(tool_string)
|
|
230
|
-
tools_name_text.append(name)
|
|
231
|
-
tools_text_string = "\n\n".join(tools_text)
|
|
232
|
-
tools_name_text_string = ", ".join(tools_name_text)
|
|
233
|
-
tool_system = react_instruction.format(
|
|
234
|
-
tools_text=tools_text_string,
|
|
235
|
-
tools_name_text=tools_name_text_string,
|
|
90
|
+
return rendered
|
|
91
|
+
|
|
92
|
+
def get_full_context(
|
|
93
|
+
self, messages: List, chat_template: str, tokenizer=None, **kwargs
|
|
94
|
+
) -> str:
|
|
95
|
+
if tokenizer is not None:
|
|
96
|
+
try:
|
|
97
|
+
full_context = tokenizer.apply_chat_template(
|
|
98
|
+
messages,
|
|
99
|
+
tokenize=False,
|
|
100
|
+
chat_template=chat_template,
|
|
101
|
+
add_generation_prompt=True,
|
|
102
|
+
**kwargs,
|
|
236
103
|
)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
if role == "user":
|
|
248
|
-
if tool_system:
|
|
249
|
-
content = tool_system + f"\n\nQuestion: {content}"
|
|
250
|
-
tool_system = ""
|
|
251
|
-
else:
|
|
252
|
-
content = f"Question: {content}"
|
|
253
|
-
elif role == "assistant":
|
|
254
|
-
tool_calls = message.get("tool_calls")
|
|
255
|
-
if tool_calls:
|
|
256
|
-
func_call = tool_calls[0]["function"]
|
|
257
|
-
f_name, f_args = (
|
|
258
|
-
func_call["name"],
|
|
259
|
-
func_call["arguments"],
|
|
260
|
-
)
|
|
261
|
-
content = f"Thought: I can use {f_name}.\nAction: {f_name}\nAction Input: {f_args}"
|
|
262
|
-
elif content:
|
|
263
|
-
content = f"Thought: I now know the final answer.\nFinal answer: {content}"
|
|
264
|
-
elif role == "tool":
|
|
265
|
-
role = "function"
|
|
266
|
-
content = f"Observation: {content}"
|
|
267
|
-
else:
|
|
268
|
-
raise Exception(f"Unsupported message role: {role}")
|
|
269
|
-
if content:
|
|
270
|
-
content = content.lstrip("\n").rstrip()
|
|
271
|
-
ret += f"<|im_start|>{role}\n{content}<|im_end|>"
|
|
272
|
-
else:
|
|
273
|
-
ret += f"<|im_start|>{role}\n"
|
|
274
|
-
return ret
|
|
275
|
-
elif prompt_style.style_name == "CHATML":
|
|
276
|
-
ret = (
|
|
277
|
-
""
|
|
278
|
-
if prompt_style.system_prompt == ""
|
|
279
|
-
else prompt_style.system_prompt + prompt_style.intra_message_sep + "\n"
|
|
280
|
-
)
|
|
281
|
-
for message in chat_history:
|
|
282
|
-
role = get_role(message["role"])
|
|
283
|
-
content = message["content"]
|
|
104
|
+
return full_context
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.warning(
|
|
107
|
+
f"tokenizer.apply_chat_template error. Maybe this is an old model: {e}"
|
|
108
|
+
)
|
|
109
|
+
return self._build_from_raw_template(messages, chat_template, **kwargs)
|
|
110
|
+
else:
|
|
111
|
+
# build from jinja
|
|
112
|
+
# Compilation function uses a cache to avoid recompiling the same template
|
|
113
|
+
return self._build_from_raw_template(messages, chat_template, **kwargs)
|
|
284
114
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
if prompt_style.system_prompt == ""
|
|
294
|
-
else "<s><|im_start|>system\n"
|
|
295
|
-
+ prompt_style.system_prompt
|
|
296
|
-
+ prompt_style.intra_message_sep
|
|
297
|
-
+ "\n"
|
|
298
|
-
)
|
|
299
|
-
for message in chat_history:
|
|
300
|
-
role = get_role(message["role"])
|
|
301
|
-
content = message["content"]
|
|
115
|
+
@staticmethod
|
|
116
|
+
def get_specific_prompt(model_family: str, messages: List[Dict]):
|
|
117
|
+
"""
|
|
118
|
+
Inspired by FastChat. Format chat history into a prompt according to the prompty style of
|
|
119
|
+
different models.
|
|
120
|
+
"""
|
|
121
|
+
_messages = [x for x in messages] # copy for not modifying the origin messages
|
|
122
|
+
_messages.append({"role": "assistant", "content": ""})
|
|
302
123
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
else
|
|
306
|
-
ret += role + "\n"
|
|
307
|
-
return ret
|
|
308
|
-
elif prompt_style.style_name == "ADD_COLON_SINGLE_COT":
|
|
309
|
-
ret = prompt_style.system_prompt + prompt_style.intra_message_sep
|
|
310
|
-
for message in chat_history:
|
|
311
|
-
role = get_role(message["role"])
|
|
312
|
-
content = message["content"]
|
|
313
|
-
if content:
|
|
314
|
-
ret += role + ": " + content + prompt_style.intra_message_sep
|
|
315
|
-
else:
|
|
316
|
-
ret += role + ": Let's think step by step."
|
|
317
|
-
return ret
|
|
318
|
-
elif prompt_style.style_name == "DEEPSEEK_CHAT":
|
|
319
|
-
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
320
|
-
ret = prompt_style.system_prompt
|
|
321
|
-
for i, message in enumerate(chat_history):
|
|
322
|
-
role = get_role(message["role"])
|
|
323
|
-
content = message["content"]
|
|
324
|
-
if content:
|
|
325
|
-
ret += role + ": " + content + seps[i % 2]
|
|
326
|
-
else:
|
|
327
|
-
ret += role + ":"
|
|
328
|
-
return ret
|
|
329
|
-
elif prompt_style.style_name == "DEEPSEEK_CODER":
|
|
330
|
-
sep = prompt_style.inter_message_sep
|
|
331
|
-
ret = prompt_style.system_prompt + sep
|
|
332
|
-
for i, message in enumerate(chat_history):
|
|
333
|
-
role = get_role(message["role"])
|
|
334
|
-
content = message["content"]
|
|
335
|
-
if content:
|
|
336
|
-
ret += role + "\n" + content + sep
|
|
337
|
-
else:
|
|
338
|
-
ret += role + "\n"
|
|
339
|
-
return ret
|
|
340
|
-
elif prompt_style.style_name == "GORILLA_OPENFUNCTIONS":
|
|
341
|
-
if tools:
|
|
342
|
-
gorilla_functions = []
|
|
343
|
-
for tool in tools:
|
|
344
|
-
gorilla_functions.append(
|
|
345
|
-
{
|
|
346
|
-
"name": tool["function"]["name"],
|
|
347
|
-
"api_name": tool["function"]["name"],
|
|
348
|
-
"description": tool["function"]["description"],
|
|
349
|
-
"parameters": [
|
|
350
|
-
dict({"name": name}, **p)
|
|
351
|
-
for name, p in tool["function"]["parameters"][
|
|
352
|
-
"properties"
|
|
353
|
-
].items()
|
|
354
|
-
],
|
|
355
|
-
}
|
|
356
|
-
)
|
|
357
|
-
tools_string = json.dumps(gorilla_functions)
|
|
358
|
-
return f"USER: <<question>> {prompt} <<function>> {tools_string}\nASSISTANT: "
|
|
359
|
-
else:
|
|
360
|
-
return f"USER: <<question>> {prompt}\nASSISTANT: "
|
|
361
|
-
elif prompt_style.style_name == "orion":
|
|
362
|
-
ret = "<s>"
|
|
363
|
-
for i, message in enumerate(chat_history):
|
|
364
|
-
content = message["content"]
|
|
365
|
-
role = get_role(message["role"])
|
|
366
|
-
if i % 2 == 0: # Human
|
|
367
|
-
assert content is not None
|
|
368
|
-
ret += role + ": " + content + "\n\n"
|
|
369
|
-
else: # Assistant
|
|
370
|
-
if content:
|
|
371
|
-
ret += role + ": </s>" + content + "</s>"
|
|
372
|
-
else:
|
|
373
|
-
ret += role + ": </s>"
|
|
374
|
-
return ret
|
|
375
|
-
elif prompt_style.style_name == "gemma":
|
|
376
|
-
ret = ""
|
|
377
|
-
for message in chat_history:
|
|
378
|
-
content = message["content"]
|
|
379
|
-
role = get_role(message["role"])
|
|
380
|
-
ret += "<start_of_turn>" + role + "\n"
|
|
381
|
-
if content:
|
|
382
|
-
ret += content + "<end_of_turn>\n"
|
|
383
|
-
return ret
|
|
384
|
-
elif prompt_style.style_name == "CodeShell":
|
|
385
|
-
ret = ""
|
|
386
|
-
for message in chat_history:
|
|
387
|
-
content = message["content"]
|
|
388
|
-
role = get_role(message["role"])
|
|
389
|
-
if content:
|
|
390
|
-
ret += f"{role}{content}|<end>|"
|
|
391
|
-
else:
|
|
392
|
-
ret += f"{role}".rstrip()
|
|
393
|
-
return ret
|
|
394
|
-
elif prompt_style.style_name == "MINICPM-2B":
|
|
395
|
-
ret = ""
|
|
396
|
-
for message in chat_history:
|
|
397
|
-
content = message["content"] or ""
|
|
398
|
-
role = get_role(message["role"])
|
|
399
|
-
if role == "user":
|
|
400
|
-
ret += "<用户>" + content.strip()
|
|
401
|
-
else:
|
|
402
|
-
ret += "<AI>" + content.strip()
|
|
403
|
-
return ret
|
|
404
|
-
elif prompt_style.style_name == "PHI3":
|
|
405
|
-
ret = f"<|system|>{prompt_style.intra_message_sep}{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
|
|
406
|
-
for message in chat_history:
|
|
407
|
-
content = message["content"] or ""
|
|
408
|
-
role = get_role(message["role"])
|
|
409
|
-
if content:
|
|
410
|
-
ret += f"<|{role}|>{prompt_style.intra_message_sep}{content}{prompt_style.inter_message_sep}"
|
|
411
|
-
else:
|
|
412
|
-
ret += f"<|{role}|>{prompt_style.intra_message_sep}"
|
|
413
|
-
ret += "<|assistant|>\n"
|
|
414
|
-
return ret
|
|
415
|
-
elif prompt_style.style_name == "c4ai-command-r":
|
|
416
|
-
ret = (
|
|
417
|
-
f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>"
|
|
418
|
-
f"{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
|
|
124
|
+
if model_family == "internvl2":
|
|
125
|
+
system_prompt = (
|
|
126
|
+
messages[0]["content"] if messages[0]["role"] == "system" else ""
|
|
419
127
|
)
|
|
420
|
-
|
|
421
|
-
role = get_role(message["role"])
|
|
422
|
-
content = message["content"]
|
|
423
|
-
if content:
|
|
424
|
-
ret += f"{role}{content}{prompt_style.inter_message_sep}"
|
|
425
|
-
else:
|
|
426
|
-
ret += role
|
|
427
|
-
return ret
|
|
428
|
-
elif prompt_style.style_name == "mistral-nemo":
|
|
429
|
-
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
430
|
-
ret = "<s>"
|
|
431
|
-
for i, message in enumerate(chat_history):
|
|
432
|
-
role = get_role(message["role"])
|
|
433
|
-
content = message["content"]
|
|
434
|
-
if content:
|
|
435
|
-
if i == len(chat_history) - 2 and prompt_style.system_prompt:
|
|
436
|
-
ret += (
|
|
437
|
-
role
|
|
438
|
-
+ " "
|
|
439
|
-
+ prompt_style.system_prompt
|
|
440
|
-
+ "\n\n"
|
|
441
|
-
+ content
|
|
442
|
-
+ seps[i % 2]
|
|
443
|
-
)
|
|
444
|
-
else:
|
|
445
|
-
ret += role + " " + content + seps[i % 2]
|
|
446
|
-
else:
|
|
447
|
-
ret += role
|
|
448
|
-
return ret
|
|
449
|
-
elif prompt_style.style_name == "INTERNVL":
|
|
128
|
+
intra_message_sep = "<|im_end|>"
|
|
450
129
|
ret = (
|
|
451
130
|
"<s>"
|
|
452
|
-
if
|
|
131
|
+
if system_prompt == ""
|
|
453
132
|
else "<s><|im_start|>system\n"
|
|
454
|
-
+
|
|
455
|
-
+
|
|
133
|
+
+ system_prompt
|
|
134
|
+
+ intra_message_sep
|
|
456
135
|
+ "\n"
|
|
457
136
|
)
|
|
458
137
|
images = [] # type: ignore
|
|
459
|
-
for message in
|
|
460
|
-
role =
|
|
138
|
+
for message in _messages:
|
|
139
|
+
role = "<|im_start|>" + message["role"]
|
|
461
140
|
content = message["content"]
|
|
462
141
|
if isinstance(content, str):
|
|
463
142
|
if content:
|
|
464
|
-
ret +=
|
|
465
|
-
role
|
|
466
|
-
+ "\n"
|
|
467
|
-
+ content
|
|
468
|
-
+ prompt_style.intra_message_sep
|
|
469
|
-
+ "\n"
|
|
470
|
-
)
|
|
143
|
+
ret += role + "\n" + content + intra_message_sep + "\n"
|
|
471
144
|
else:
|
|
472
145
|
ret += role + "\n"
|
|
473
146
|
elif isinstance(content, list):
|
|
@@ -488,21 +161,15 @@ Begin!"""
|
|
|
488
161
|
image_futures.append(fut)
|
|
489
162
|
images = [fut.result() for fut in image_futures]
|
|
490
163
|
if len(image_futures) == 0:
|
|
491
|
-
ret +=
|
|
492
|
-
role + "\n" + text + prompt_style.intra_message_sep + "\n"
|
|
493
|
-
)
|
|
164
|
+
ret += role + "\n" + text + intra_message_sep + "\n"
|
|
494
165
|
else:
|
|
495
166
|
ret += (
|
|
496
|
-
role
|
|
497
|
-
+ "\n"
|
|
498
|
-
+ f"<image>\n{text}"
|
|
499
|
-
+ prompt_style.intra_message_sep
|
|
500
|
-
+ "\n"
|
|
167
|
+
role + "\n" + f"<image>\n{text}" + intra_message_sep + "\n"
|
|
501
168
|
)
|
|
502
169
|
|
|
503
|
-
return
|
|
170
|
+
return ret, images
|
|
504
171
|
else:
|
|
505
|
-
raise ValueError(f"Invalid
|
|
172
|
+
raise ValueError(f"Invalid model family: {model_family}")
|
|
506
173
|
|
|
507
174
|
@classmethod
|
|
508
175
|
def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
|
|
@@ -523,7 +190,11 @@ Begin!"""
|
|
|
523
190
|
{
|
|
524
191
|
"index": i,
|
|
525
192
|
"delta": {
|
|
526
|
-
|
|
193
|
+
**(
|
|
194
|
+
{"content": choice["text"]}
|
|
195
|
+
if ("text" in choice and choice["finish_reason"] is None)
|
|
196
|
+
else {}
|
|
197
|
+
),
|
|
527
198
|
**(
|
|
528
199
|
{"tool_calls": choice["tool_calls"]}
|
|
529
200
|
if "tool_calls" in choice
|
|
@@ -577,7 +248,6 @@ Begin!"""
|
|
|
577
248
|
return cast(ChatCompletionChunk, chat_chunk)
|
|
578
249
|
|
|
579
250
|
@classmethod
|
|
580
|
-
@ensure_cache_cleared
|
|
581
251
|
def _to_chat_completion_chunks(
|
|
582
252
|
cls,
|
|
583
253
|
chunks: Iterator[CompletionChunk],
|
|
@@ -610,7 +280,6 @@ Begin!"""
|
|
|
610
280
|
i += 1
|
|
611
281
|
|
|
612
282
|
@staticmethod
|
|
613
|
-
@ensure_cache_cleared
|
|
614
283
|
def _to_chat_completion(completion: Completion) -> ChatCompletion:
|
|
615
284
|
return {
|
|
616
285
|
"id": "chat" + completion["id"],
|
|
@@ -632,143 +301,89 @@ Begin!"""
|
|
|
632
301
|
}
|
|
633
302
|
|
|
634
303
|
@staticmethod
|
|
635
|
-
def
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
def tool_call(n, **kwargs):
|
|
640
|
-
return None, n, kwargs
|
|
641
|
-
|
|
642
|
-
try:
|
|
643
|
-
a, b, c = eval(
|
|
644
|
-
arguments, {n: functools.partial(tool_call, n) for n in tool_names}
|
|
645
|
-
)
|
|
646
|
-
return a, b, c
|
|
647
|
-
except Exception as e:
|
|
648
|
-
logger.error("Eval tool calls completion failed: %s", e)
|
|
649
|
-
return arguments, None, None
|
|
650
|
-
|
|
651
|
-
@staticmethod
|
|
652
|
-
def _eval_glm_chat_arguments(c, tools):
|
|
304
|
+
def _eval_glm_chat_arguments(c) -> List[Tuple]:
|
|
305
|
+
"""
|
|
306
|
+
Currently, glm4 tool call only supports one function
|
|
307
|
+
"""
|
|
653
308
|
try:
|
|
654
|
-
if isinstance(c
|
|
655
|
-
return c[
|
|
656
|
-
return None, c[0]["name"], c[0]["parameters"]
|
|
309
|
+
if isinstance(c, dict):
|
|
310
|
+
return [(None, c["name"], c["arguments"])]
|
|
657
311
|
except KeyError:
|
|
658
312
|
logger.error("Can't parse glm output: %s", c)
|
|
659
|
-
return str(c), None, None
|
|
313
|
+
return [(str(c), None, None)]
|
|
314
|
+
else:
|
|
315
|
+
return [(str(c), None, None)]
|
|
660
316
|
|
|
661
|
-
@
|
|
662
|
-
def
|
|
317
|
+
@classmethod
|
|
318
|
+
def _handle_qwen_tool_result(cls, text: str) -> List[Tuple]:
|
|
319
|
+
text: str = text.strip() # type: ignore
|
|
320
|
+
contents: List[str] = text.split(QWEN_TOOL_CALL_SYMBOLS[1])
|
|
321
|
+
results: List[Tuple] = []
|
|
322
|
+
for content in contents:
|
|
323
|
+
content = content.strip()
|
|
324
|
+
if content:
|
|
325
|
+
if content.startswith(QWEN_TOOL_CALL_SYMBOLS[0]):
|
|
326
|
+
content = content[len(QWEN_TOOL_CALL_SYMBOLS[0]) :]
|
|
327
|
+
content = content.strip()
|
|
328
|
+
try:
|
|
329
|
+
res = json.loads(content)
|
|
330
|
+
results.append((None, res["name"], res["arguments"]))
|
|
331
|
+
except Exception as e:
|
|
332
|
+
logger.error(
|
|
333
|
+
"Can't parse single qwen tool call output: %s. Error: %s",
|
|
334
|
+
content,
|
|
335
|
+
e,
|
|
336
|
+
)
|
|
337
|
+
results.append((content, None, None))
|
|
338
|
+
return results
|
|
339
|
+
|
|
340
|
+
@classmethod
|
|
341
|
+
def _eval_qwen_chat_arguments(cls, c) -> List[Tuple]:
|
|
663
342
|
text = c["choices"][0]["text"]
|
|
664
|
-
|
|
665
|
-
# Refer to:
|
|
666
|
-
# https://github.com/QwenLM/Qwen/blob/main/examples/react_prompt.md
|
|
667
|
-
# https://github.com/QwenLM/Qwen/blob/main/openai_api.py#L297
|
|
668
|
-
func_name, func_args, content = "", "", ""
|
|
669
|
-
i = text.rfind("\nAction:")
|
|
670
|
-
j = text.rfind("\nAction Input:")
|
|
671
|
-
k = text.rfind("\nObservation:")
|
|
672
|
-
t = max(
|
|
673
|
-
text.rfind("\nThought:", 0, i), text.rfind("Thought:", 0, i)
|
|
674
|
-
) # find the last thought just before Action, considering the Thought at the very beginning
|
|
675
|
-
if 0 <= i < j: # If the text has `Action` and `Action input`,
|
|
676
|
-
if k < j: # but does not contain `Observation`,
|
|
677
|
-
# then it is likely that `Observation` is omitted by the LLM,
|
|
678
|
-
# because the output text may have discarded the stop word.
|
|
679
|
-
text = text.rstrip() + "\nObservation:" # Add it back.
|
|
680
|
-
k = text.rfind("\nObservation:")
|
|
681
|
-
if 0 <= t < i < j < k:
|
|
682
|
-
func_name = text[i + len("\nAction:") : j].strip()
|
|
683
|
-
func_args = text[j + len("\nAction Input:") : k].strip()
|
|
684
|
-
content = text[
|
|
685
|
-
t + len("\nThought:") : i
|
|
686
|
-
].strip() # len("\nThought:") and len("Thought:") both are OK since there is a space after :
|
|
687
|
-
if func_name:
|
|
688
|
-
return content, func_name, json.loads(func_args)
|
|
689
|
-
except Exception as e:
|
|
690
|
-
logger.error("Eval tool calls completion failed: %s", e)
|
|
691
|
-
t = max(text.rfind("\nThought:"), text.rfind("Thought:"))
|
|
692
|
-
z = max(text.rfind("\nFinal Answer:"), text.rfind("Final Answer:"))
|
|
693
|
-
if z >= 0:
|
|
694
|
-
text = text[
|
|
695
|
-
z + len("\nFinal Answer:") :
|
|
696
|
-
] # len("\nFinal Answer::") and len("Final Answer::") both are OK since there is a space after :
|
|
697
|
-
else:
|
|
698
|
-
text = text[
|
|
699
|
-
t + len("\nThought:") :
|
|
700
|
-
] # There is only Thought: no Final Answer:
|
|
701
|
-
return text, None, None
|
|
343
|
+
return cls._handle_qwen_tool_result(text)
|
|
702
344
|
|
|
703
345
|
@classmethod
|
|
704
|
-
def _eval_tool_arguments(cls, model_family, c
|
|
346
|
+
def _eval_tool_arguments(cls, model_family, c):
|
|
705
347
|
family = model_family.model_family or model_family.model_name
|
|
706
|
-
if family in
|
|
707
|
-
|
|
708
|
-
elif family in GLM4_TOOL_CALL_FAMILY:
|
|
709
|
-
content, func, args = cls._eval_glm_chat_arguments(c, tools)
|
|
348
|
+
if family in GLM4_TOOL_CALL_FAMILY:
|
|
349
|
+
result = cls._eval_glm_chat_arguments(c)
|
|
710
350
|
elif family in QWEN_TOOL_CALL_FAMILY:
|
|
711
|
-
|
|
351
|
+
result = cls._eval_qwen_chat_arguments(c)
|
|
712
352
|
else:
|
|
713
353
|
raise Exception(
|
|
714
354
|
f"Model {model_family.model_name} is not support tool calls."
|
|
715
355
|
)
|
|
716
|
-
logger.debug("Tool call content:
|
|
717
|
-
return
|
|
718
|
-
|
|
719
|
-
@classmethod
|
|
720
|
-
def _tools_token_filter(cls, model_family):
|
|
721
|
-
"""
|
|
722
|
-
Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
|
|
723
|
-
|
|
724
|
-
Returns:
|
|
725
|
-
A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
|
|
726
|
-
returns the part after "\nFinal Answer:" if found, else returns delta.
|
|
727
|
-
"""
|
|
728
|
-
family = model_family.model_family or model_family.model_name
|
|
729
|
-
if family in QWEN_TOOL_CALL_FAMILY:
|
|
730
|
-
# Encapsulating function to reset 'found' after each call
|
|
731
|
-
found = False
|
|
732
|
-
|
|
733
|
-
def process_tokens(tokens: str, delta: str):
|
|
734
|
-
nonlocal found
|
|
735
|
-
# Once "Final Answer:" is found, future tokens are allowed.
|
|
736
|
-
if found:
|
|
737
|
-
return delta
|
|
738
|
-
# Check if the token ends with "\nFinal Answer:" and update `found`.
|
|
739
|
-
final_answer_idx = tokens.lower().rfind("\nfinal answer:")
|
|
740
|
-
if final_answer_idx != -1:
|
|
741
|
-
found = True
|
|
742
|
-
return tokens[final_answer_idx + len("\nfinal answer:") :]
|
|
743
|
-
return ""
|
|
744
|
-
|
|
745
|
-
return process_tokens
|
|
746
|
-
else:
|
|
747
|
-
return lambda tokens, delta: delta
|
|
356
|
+
logger.debug(f"Tool call content: {result}")
|
|
357
|
+
return result
|
|
748
358
|
|
|
749
359
|
@classmethod
|
|
750
|
-
def _tool_calls_completion_chunk(cls, model_family, model_uid, c
|
|
360
|
+
def _tool_calls_completion_chunk(cls, model_family, model_uid, c):
|
|
751
361
|
_id = str(uuid.uuid4())
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
"
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
362
|
+
tool_result = cls._eval_tool_arguments(model_family, c)
|
|
363
|
+
tool_calls = []
|
|
364
|
+
failed_contents = []
|
|
365
|
+
for content, func, args in tool_result:
|
|
366
|
+
if func:
|
|
367
|
+
tool_calls.append(
|
|
368
|
+
[
|
|
369
|
+
{
|
|
370
|
+
"id": f"call_{_id}",
|
|
371
|
+
"type": "function",
|
|
372
|
+
"function": {
|
|
373
|
+
"name": func,
|
|
374
|
+
"arguments": json.dumps(args, ensure_ascii=False),
|
|
375
|
+
},
|
|
376
|
+
}
|
|
377
|
+
]
|
|
378
|
+
)
|
|
379
|
+
else:
|
|
380
|
+
failed_contents.append(content)
|
|
381
|
+
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
382
|
+
d = {
|
|
383
|
+
"role": "assistant",
|
|
384
|
+
"content": ". ".join(failed_contents) if failed_contents else None,
|
|
385
|
+
"tool_calls": tool_calls,
|
|
386
|
+
}
|
|
772
387
|
try:
|
|
773
388
|
usage = c.get("usage")
|
|
774
389
|
assert "prompt_tokens" in usage
|
|
@@ -795,28 +410,32 @@ Begin!"""
|
|
|
795
410
|
}
|
|
796
411
|
|
|
797
412
|
@classmethod
|
|
798
|
-
def _tool_calls_completion(cls, model_family, model_uid, c
|
|
413
|
+
def _tool_calls_completion(cls, model_family, model_uid, c):
|
|
799
414
|
_id = str(uuid.uuid4())
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
415
|
+
tool_result = cls._eval_tool_arguments(model_family, c)
|
|
416
|
+
|
|
417
|
+
tool_calls = []
|
|
418
|
+
failed_contents = []
|
|
419
|
+
for content, func, args in tool_result:
|
|
420
|
+
if func:
|
|
421
|
+
tool_calls.append(
|
|
806
422
|
{
|
|
807
423
|
"id": f"call_{_id}",
|
|
808
424
|
"type": "function",
|
|
809
425
|
"function": {
|
|
810
426
|
"name": func,
|
|
811
|
-
"arguments": json.dumps(args),
|
|
427
|
+
"arguments": json.dumps(args, ensure_ascii=False),
|
|
812
428
|
},
|
|
813
429
|
}
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
else
|
|
818
|
-
|
|
819
|
-
|
|
430
|
+
)
|
|
431
|
+
else:
|
|
432
|
+
failed_contents.append(content)
|
|
433
|
+
finish_reason = "tool_calls" if tool_calls else "stop"
|
|
434
|
+
m = {
|
|
435
|
+
"role": "assistant",
|
|
436
|
+
"content": ". ".join(failed_contents) if failed_contents else None,
|
|
437
|
+
"tool_calls": tool_calls,
|
|
438
|
+
}
|
|
820
439
|
try:
|
|
821
440
|
usage = c.get("usage")
|
|
822
441
|
assert "prompt_tokens" in usage
|
|
@@ -841,16 +460,6 @@ Begin!"""
|
|
|
841
460
|
"usage": usage,
|
|
842
461
|
}
|
|
843
462
|
|
|
844
|
-
@classmethod
|
|
845
|
-
def get_full_prompt(cls, model_family, prompt, system_prompt, chat_history, tools):
|
|
846
|
-
assert model_family.prompt_style is not None
|
|
847
|
-
prompt_style = model_family.prompt_style.copy()
|
|
848
|
-
if system_prompt:
|
|
849
|
-
prompt_style.system_prompt = system_prompt
|
|
850
|
-
chat_history = chat_history or []
|
|
851
|
-
full_prompt = cls.get_prompt(prompt, chat_history, prompt_style, tools=tools)
|
|
852
|
-
return full_prompt
|
|
853
|
-
|
|
854
463
|
|
|
855
464
|
def get_file_location(
|
|
856
465
|
llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
|
|
@@ -903,3 +512,120 @@ def _decode_image(_url):
|
|
|
903
512
|
return Image.open(_url).convert("RGB")
|
|
904
513
|
else:
|
|
905
514
|
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
@typing.no_type_check
|
|
518
|
+
def generate_completion_chunk(
|
|
519
|
+
chunk_text: Optional[str],
|
|
520
|
+
finish_reason: Optional[str],
|
|
521
|
+
chunk_id: str,
|
|
522
|
+
model_uid: str,
|
|
523
|
+
prompt_tokens: int,
|
|
524
|
+
completion_tokens: int,
|
|
525
|
+
total_tokens: int,
|
|
526
|
+
has_choice: bool = True,
|
|
527
|
+
has_content: bool = True,
|
|
528
|
+
):
|
|
529
|
+
choices = []
|
|
530
|
+
if has_choice:
|
|
531
|
+
choices.append(
|
|
532
|
+
CompletionChoice(
|
|
533
|
+
text=chunk_text, index=0, logprobs=None, finish_reason=finish_reason
|
|
534
|
+
)
|
|
535
|
+
if has_content
|
|
536
|
+
else CompletionChoice(index=0, logprobs=None, finish_reason=finish_reason)
|
|
537
|
+
)
|
|
538
|
+
return CompletionChunk(
|
|
539
|
+
id=chunk_id,
|
|
540
|
+
object="text_completion",
|
|
541
|
+
created=int(time.time()),
|
|
542
|
+
model=model_uid,
|
|
543
|
+
choices=choices,
|
|
544
|
+
usage=CompletionUsage(
|
|
545
|
+
prompt_tokens=prompt_tokens,
|
|
546
|
+
completion_tokens=completion_tokens,
|
|
547
|
+
total_tokens=total_tokens,
|
|
548
|
+
),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def generate_completion(
|
|
553
|
+
model_uid: str,
|
|
554
|
+
response: str,
|
|
555
|
+
prompt_tokens=-1,
|
|
556
|
+
completion_tokens=-1,
|
|
557
|
+
total_tokens=-1,
|
|
558
|
+
finish_reason="stop",
|
|
559
|
+
) -> Completion:
|
|
560
|
+
return Completion(
|
|
561
|
+
id=str(uuid.uuid1()),
|
|
562
|
+
object="text_completion",
|
|
563
|
+
created=int(time.time()),
|
|
564
|
+
model=model_uid,
|
|
565
|
+
choices=[
|
|
566
|
+
CompletionChoice(
|
|
567
|
+
text=response, index=0, logprobs=None, finish_reason=finish_reason
|
|
568
|
+
)
|
|
569
|
+
],
|
|
570
|
+
usage=CompletionUsage(
|
|
571
|
+
prompt_tokens=prompt_tokens,
|
|
572
|
+
completion_tokens=completion_tokens,
|
|
573
|
+
total_tokens=total_tokens,
|
|
574
|
+
),
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def generate_chat_completion(
|
|
579
|
+
model_uid: str,
|
|
580
|
+
response: str,
|
|
581
|
+
prompt_tokens=-1,
|
|
582
|
+
completion_tokens=-1,
|
|
583
|
+
total_tokens=-1,
|
|
584
|
+
finish_reason="stop",
|
|
585
|
+
) -> ChatCompletion:
|
|
586
|
+
return ChatCompletion(
|
|
587
|
+
id="chat" + str(uuid.uuid1()),
|
|
588
|
+
object="chat.completion",
|
|
589
|
+
created=int(time.time()),
|
|
590
|
+
model=model_uid,
|
|
591
|
+
choices=[
|
|
592
|
+
ChatCompletionChoice(
|
|
593
|
+
index=0,
|
|
594
|
+
message={"role": "assistant", "content": response},
|
|
595
|
+
finish_reason=finish_reason,
|
|
596
|
+
)
|
|
597
|
+
],
|
|
598
|
+
usage=CompletionUsage(
|
|
599
|
+
prompt_tokens=prompt_tokens,
|
|
600
|
+
completion_tokens=completion_tokens,
|
|
601
|
+
total_tokens=total_tokens,
|
|
602
|
+
),
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@functools.lru_cache
|
|
607
|
+
def get_stop_token_ids_from_config_file(model_path: str) -> Optional[List[int]]:
|
|
608
|
+
from transformers import GenerationConfig as TransformersGenerationConfig
|
|
609
|
+
|
|
610
|
+
transformers_config = TransformersGenerationConfig.from_pretrained(model_path)
|
|
611
|
+
if transformers_config.eos_token_id is not None:
|
|
612
|
+
stop_token_ids = (
|
|
613
|
+
transformers_config.eos_token_id
|
|
614
|
+
if isinstance(transformers_config.eos_token_id, list)
|
|
615
|
+
else [transformers_config.eos_token_id]
|
|
616
|
+
)
|
|
617
|
+
return stop_token_ids
|
|
618
|
+
return None
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def parse_messages(messages: List[Dict]) -> Tuple:
|
|
622
|
+
"""
|
|
623
|
+
Some older models still follow the old way of parameter passing.
|
|
624
|
+
This function helps to parse out the needed information from OpenAI-compatible `messages`.
|
|
625
|
+
"""
|
|
626
|
+
system_messages = [mess["content"] for mess in messages if mess["role"] == "system"]
|
|
627
|
+
content_messages = [mess for mess in messages if mess["role"] != "system"]
|
|
628
|
+
prompt = content_messages[-1]["content"]
|
|
629
|
+
system_prompt = ". ".join(system_messages) if system_messages else None
|
|
630
|
+
chat_history = content_messages[:-1]
|
|
631
|
+
return prompt, system_prompt, chat_history
|