xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -11,45 +11,25 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import copy
|
|
15
14
|
import json
|
|
16
|
-
import
|
|
17
|
-
import time
|
|
15
|
+
import typing
|
|
18
16
|
import uuid
|
|
17
|
+
from threading import Thread
|
|
19
18
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
20
19
|
|
|
21
20
|
import torch
|
|
22
|
-
from transformers.generation.logits_process import LogitsProcessor
|
|
23
|
-
from transformers.generation.utils import LogitsProcessorList
|
|
24
21
|
|
|
25
22
|
from ....core.scheduler import InferenceRequest
|
|
26
|
-
from ....types import
|
|
27
|
-
SPECIAL_TOOL_PROMPT,
|
|
28
|
-
ChatCompletion,
|
|
29
|
-
ChatCompletionChoice,
|
|
30
|
-
ChatCompletionChunk,
|
|
31
|
-
ChatCompletionMessage,
|
|
32
|
-
CompletionChoice,
|
|
33
|
-
CompletionChunk,
|
|
34
|
-
CompletionUsage,
|
|
35
|
-
LoRA,
|
|
36
|
-
PytorchGenerateConfig,
|
|
37
|
-
)
|
|
23
|
+
from ....types import ChatCompletion, ChatCompletionChunk, LoRA, PytorchGenerateConfig
|
|
38
24
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
39
|
-
from ..utils import
|
|
25
|
+
from ..utils import (
|
|
26
|
+
GLM4_TOOL_CALL_FAMILY,
|
|
27
|
+
generate_chat_completion,
|
|
28
|
+
generate_completion_chunk,
|
|
29
|
+
)
|
|
40
30
|
from .core import PytorchChatModel, PytorchModelConfig
|
|
41
31
|
|
|
42
32
|
|
|
43
|
-
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
44
|
-
def __call__(
|
|
45
|
-
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
46
|
-
) -> torch.FloatTensor:
|
|
47
|
-
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
|
48
|
-
scores.zero_()
|
|
49
|
-
scores[..., 198] = 5e4
|
|
50
|
-
return scores
|
|
51
|
-
|
|
52
|
-
|
|
53
33
|
class ChatglmPytorchChatModel(PytorchChatModel):
|
|
54
34
|
def __init__(
|
|
55
35
|
self,
|
|
@@ -107,40 +87,28 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
107
87
|
if llm_spec.model_format != "pytorch":
|
|
108
88
|
return False
|
|
109
89
|
model_family = llm_family.model_family or llm_family.model_name
|
|
110
|
-
if "
|
|
90
|
+
if "glm4" not in model_family:
|
|
111
91
|
return False
|
|
112
92
|
if "chat" not in llm_family.model_ability:
|
|
113
93
|
return False
|
|
114
94
|
return True
|
|
115
95
|
|
|
116
|
-
def _handle_tools(self,
|
|
96
|
+
def _handle_tools(self, messages, generate_config):
|
|
117
97
|
"""Convert openai tools to ChatGLM tools."""
|
|
98
|
+
if self.model_family.model_name not in GLM4_TOOL_CALL_FAMILY:
|
|
99
|
+
return None
|
|
118
100
|
if generate_config is None:
|
|
119
|
-
return
|
|
101
|
+
return None
|
|
120
102
|
tools = generate_config.pop("tools", None)
|
|
121
103
|
if tools is None:
|
|
122
|
-
return
|
|
123
|
-
# Convert
|
|
104
|
+
return None
|
|
105
|
+
# Convert an iterable to a list
|
|
124
106
|
tools = list(tools)
|
|
125
107
|
tool_choice = generate_config.pop("tool_choice", "none")
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
return True
|
|
131
|
-
else:
|
|
132
|
-
chatglm_tools = []
|
|
133
|
-
for elem in tools:
|
|
134
|
-
if elem.get("type") != "function" or "function" not in elem:
|
|
135
|
-
raise ValueError("ChatGLM tools only support function type.")
|
|
136
|
-
chatglm_tools.append(elem["function"])
|
|
137
|
-
tool_prompt_message = {
|
|
138
|
-
"role": "system",
|
|
139
|
-
"content": f"Answer the following questions as best as you can. You have access to the following tools:",
|
|
140
|
-
"tools": chatglm_tools,
|
|
141
|
-
}
|
|
142
|
-
chat_history.insert(0, tool_prompt_message)
|
|
143
|
-
return True
|
|
108
|
+
messages[:] = self._process_messages(
|
|
109
|
+
messages, tools=tools, tool_choice=tool_choice
|
|
110
|
+
)
|
|
111
|
+
return tools
|
|
144
112
|
|
|
145
113
|
@staticmethod
|
|
146
114
|
def _process_messages(messages, tools=None, tool_choice="none"):
|
|
@@ -230,12 +198,70 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
230
198
|
return processed_messages
|
|
231
199
|
|
|
232
200
|
@staticmethod
|
|
233
|
-
|
|
201
|
+
@typing.no_type_check
|
|
202
|
+
def _process_response_non_streaming(
|
|
203
|
+
output: str, tools: Union[Dict, List[Dict]] = None, use_tool: bool = False
|
|
204
|
+
) -> Union[str, dict]:
|
|
205
|
+
"""
|
|
206
|
+
Copied from https://github.com/THUDM/GLM-4/blob/main/basic_demo/openai_api_server.py#L150
|
|
207
|
+
"""
|
|
208
|
+
import re
|
|
209
|
+
|
|
210
|
+
lines = output.strip().split("\n")
|
|
211
|
+
arguments_json = None
|
|
212
|
+
special_tools = ["cogview", "simple_browser"]
|
|
213
|
+
tools = {tool["function"]["name"] for tool in tools} if tools else {}
|
|
214
|
+
|
|
215
|
+
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
|
216
|
+
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
|
|
217
|
+
|
|
218
|
+
if len(lines) >= 2 and lines[1].startswith("{"):
|
|
219
|
+
function_name = lines[0].strip()
|
|
220
|
+
arguments = "\n".join(lines[1:]).strip()
|
|
221
|
+
if function_name in tools or function_name in special_tools:
|
|
222
|
+
try:
|
|
223
|
+
arguments_json = json.loads(arguments)
|
|
224
|
+
is_tool_call = True
|
|
225
|
+
except json.JSONDecodeError:
|
|
226
|
+
is_tool_call = function_name in special_tools
|
|
227
|
+
|
|
228
|
+
if is_tool_call and use_tool:
|
|
229
|
+
content = {
|
|
230
|
+
"name": function_name,
|
|
231
|
+
"arguments": json.dumps(
|
|
232
|
+
arguments_json
|
|
233
|
+
if isinstance(arguments_json, dict)
|
|
234
|
+
else arguments,
|
|
235
|
+
ensure_ascii=False,
|
|
236
|
+
),
|
|
237
|
+
}
|
|
238
|
+
if function_name == "simple_browser":
|
|
239
|
+
search_pattern = re.compile(
|
|
240
|
+
r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)'
|
|
241
|
+
)
|
|
242
|
+
match = search_pattern.match(arguments)
|
|
243
|
+
if match:
|
|
244
|
+
content["arguments"] = json.dumps(
|
|
245
|
+
{
|
|
246
|
+
"query": match.group(1),
|
|
247
|
+
"recency_days": int(match.group(2)),
|
|
248
|
+
},
|
|
249
|
+
ensure_ascii=False,
|
|
250
|
+
)
|
|
251
|
+
elif function_name == "cogview":
|
|
252
|
+
content["arguments"] = json.dumps(
|
|
253
|
+
{"prompt": arguments}, ensure_ascii=False
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return content
|
|
257
|
+
return output.strip()
|
|
258
|
+
|
|
259
|
+
@staticmethod
|
|
260
|
+
def _process_response_streaming(output, tools, end=False):
|
|
234
261
|
# Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
|
|
235
262
|
content = ""
|
|
236
|
-
history = copy.deepcopy(history)
|
|
237
263
|
if not tools and end:
|
|
238
|
-
return None
|
|
264
|
+
return None
|
|
239
265
|
for response in output.split("<|assistant|>"):
|
|
240
266
|
if "\n" in response:
|
|
241
267
|
metadata, content = response.split("\n", maxsplit=1)
|
|
@@ -244,205 +270,54 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
244
270
|
if not metadata.strip():
|
|
245
271
|
if tools and any(t.startswith(response) for t in tools) and not end:
|
|
246
272
|
# Waiting for tool call complete.
|
|
247
|
-
return None
|
|
273
|
+
return None
|
|
248
274
|
content = content.strip()
|
|
249
|
-
history.append(
|
|
250
|
-
{"role": "assistant", "metadata": metadata, "content": content}
|
|
251
|
-
)
|
|
252
275
|
content = content.replace("[[训练时间]]", "2023年")
|
|
253
276
|
else:
|
|
254
277
|
if tools and metadata in tools and not end:
|
|
255
|
-
return None
|
|
256
|
-
history.append(
|
|
257
|
-
{"role": "assistant", "metadata": metadata, "content": content}
|
|
258
|
-
)
|
|
278
|
+
return None
|
|
259
279
|
metadata = metadata.strip()
|
|
260
280
|
if tools and metadata in tools and end:
|
|
261
281
|
try:
|
|
262
282
|
parameters = json.loads(content)
|
|
263
|
-
content = {"name": metadata.strip(), "
|
|
283
|
+
content = {"name": metadata.strip(), "arguments": parameters}
|
|
264
284
|
except json.JSONDecodeError:
|
|
265
285
|
content = {"name": metadata.strip(), "content": content}
|
|
266
286
|
else:
|
|
267
287
|
content = {"name": metadata.strip(), "content": content}
|
|
268
|
-
return content
|
|
269
|
-
|
|
270
|
-
def _get_generate_args(
|
|
271
|
-
self,
|
|
272
|
-
tokenizer,
|
|
273
|
-
query: str,
|
|
274
|
-
history: Optional[List[Dict]] = None,
|
|
275
|
-
role: str = "user",
|
|
276
|
-
past_key_values=None,
|
|
277
|
-
max_length: int = 8192,
|
|
278
|
-
do_sample=True,
|
|
279
|
-
top_p=0.8,
|
|
280
|
-
temperature=0.8,
|
|
281
|
-
logits_processor=None,
|
|
282
|
-
**kwargs,
|
|
283
|
-
):
|
|
284
|
-
# Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
|
|
285
|
-
if history is None:
|
|
286
|
-
history = []
|
|
287
|
-
if logits_processor is None:
|
|
288
|
-
logits_processor = LogitsProcessorList()
|
|
289
|
-
logits_processor.append(InvalidScoreLogitsProcessor())
|
|
290
|
-
eos_token_id = [
|
|
291
|
-
tokenizer.eos_token_id,
|
|
292
|
-
tokenizer.convert_tokens_to_ids("<|user|>"),
|
|
293
|
-
tokenizer.convert_tokens_to_ids("<|observation|>"),
|
|
294
|
-
]
|
|
295
|
-
gen_kwargs = {
|
|
296
|
-
"max_length": max_length,
|
|
297
|
-
"do_sample": do_sample,
|
|
298
|
-
"top_p": top_p,
|
|
299
|
-
"temperature": temperature,
|
|
300
|
-
"logits_processor": logits_processor,
|
|
301
|
-
**kwargs,
|
|
302
|
-
}
|
|
303
|
-
if past_key_values is None:
|
|
304
|
-
inputs = tokenizer.apply_chat_template(
|
|
305
|
-
history + [{"role": role, "content": query}],
|
|
306
|
-
add_generation_prompt=True,
|
|
307
|
-
tokenize=True,
|
|
308
|
-
return_tensors="pt",
|
|
309
|
-
return_dict=True,
|
|
310
|
-
)
|
|
311
|
-
else:
|
|
312
|
-
inputs = tokenizer.apply_chat_template(
|
|
313
|
-
[{"role": role, "content": query}],
|
|
314
|
-
add_special_tokens=False,
|
|
315
|
-
add_generation_prompt=True,
|
|
316
|
-
tokenize=True,
|
|
317
|
-
return_tensors="pt",
|
|
318
|
-
return_dict=True,
|
|
319
|
-
)
|
|
320
|
-
inputs = inputs.to(self._model.device)
|
|
321
|
-
if past_key_values is not None:
|
|
322
|
-
past_length = past_key_values[0][0].shape[2]
|
|
323
|
-
inputs.position_ids += past_length
|
|
324
|
-
attention_mask = inputs.attention_mask
|
|
325
|
-
attention_mask = torch.cat(
|
|
326
|
-
(attention_mask.new_ones(1, past_length), attention_mask), dim=1
|
|
327
|
-
)
|
|
328
|
-
inputs["attention_mask"] = attention_mask
|
|
329
|
-
history.append({"role": role, "content": query})
|
|
330
|
-
tools = history[0]["role"] == "system" and history[0].get("tools")
|
|
331
|
-
tools = (
|
|
332
|
-
[
|
|
333
|
-
t.get("function", {}).get("name", "")
|
|
334
|
-
for t in tools
|
|
335
|
-
if isinstance(t, dict)
|
|
336
|
-
]
|
|
337
|
-
if tools
|
|
338
|
-
else []
|
|
339
|
-
)
|
|
340
|
-
kwargs = dict(inputs)
|
|
341
|
-
kwargs["past_key_values"] = past_key_values
|
|
342
|
-
kwargs["eos_token_id"] = eos_token_id
|
|
343
|
-
kwargs.update(gen_kwargs)
|
|
344
|
-
return kwargs, tools
|
|
288
|
+
return content
|
|
345
289
|
|
|
346
290
|
@torch.inference_mode()
|
|
347
|
-
def _stream_chat(
|
|
348
|
-
self,
|
|
349
|
-
tokenizer,
|
|
350
|
-
query: str,
|
|
351
|
-
history: Optional[List[Dict]] = None,
|
|
352
|
-
role: str = "user",
|
|
353
|
-
past_key_values=None,
|
|
354
|
-
max_length: int = 8192,
|
|
355
|
-
do_sample=True,
|
|
356
|
-
top_p=0.8,
|
|
357
|
-
temperature=0.8,
|
|
358
|
-
logits_processor=None,
|
|
359
|
-
**kwargs,
|
|
360
|
-
):
|
|
291
|
+
def _stream_chat(self, inputs, tools, **kwargs):
|
|
361
292
|
from transformers import TextIteratorStreamer
|
|
362
293
|
|
|
363
|
-
kwargs, tools = self._get_generate_args(
|
|
364
|
-
tokenizer=tokenizer,
|
|
365
|
-
query=query,
|
|
366
|
-
history=history,
|
|
367
|
-
role=role,
|
|
368
|
-
past_key_values=past_key_values,
|
|
369
|
-
max_length=max_length,
|
|
370
|
-
do_sample=do_sample,
|
|
371
|
-
top_p=top_p,
|
|
372
|
-
temperature=temperature,
|
|
373
|
-
logits_processor=logits_processor,
|
|
374
|
-
**kwargs,
|
|
375
|
-
)
|
|
376
|
-
|
|
377
294
|
streamer = TextIteratorStreamer(
|
|
378
|
-
|
|
295
|
+
self._tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
379
296
|
)
|
|
380
|
-
|
|
381
|
-
|
|
297
|
+
tools = {tool["function"]["name"] for tool in tools} if tools else {}
|
|
298
|
+
generation_kwargs = dict(inputs, streamer=streamer)
|
|
299
|
+
generation_kwargs.update(kwargs)
|
|
300
|
+
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
|
|
382
301
|
thread.start()
|
|
383
302
|
|
|
384
303
|
response = ""
|
|
385
304
|
for token in streamer:
|
|
386
305
|
response += token
|
|
387
306
|
if response and response[-1] != "�":
|
|
388
|
-
new_response
|
|
389
|
-
response,
|
|
307
|
+
new_response = self._process_response_streaming(
|
|
308
|
+
response, tools, end=False
|
|
390
309
|
)
|
|
391
310
|
if new_response is None:
|
|
392
311
|
continue
|
|
393
|
-
yield new_response
|
|
312
|
+
yield new_response
|
|
394
313
|
if tools:
|
|
395
|
-
new_response
|
|
396
|
-
response, history, tools, end=True
|
|
397
|
-
)
|
|
314
|
+
new_response = self._process_response_streaming(response, tools, end=True)
|
|
398
315
|
if new_response:
|
|
399
|
-
yield new_response
|
|
400
|
-
|
|
401
|
-
@torch.inference_mode()
|
|
402
|
-
def _non_stream_chat(
|
|
403
|
-
self,
|
|
404
|
-
tokenizer,
|
|
405
|
-
query: str,
|
|
406
|
-
history: Optional[List[Dict]] = None,
|
|
407
|
-
role: str = "user",
|
|
408
|
-
past_key_values=None,
|
|
409
|
-
max_length: int = 8192,
|
|
410
|
-
do_sample=True,
|
|
411
|
-
top_p=0.8,
|
|
412
|
-
temperature=0.8,
|
|
413
|
-
logits_processor=None,
|
|
414
|
-
**kwargs,
|
|
415
|
-
):
|
|
416
|
-
kwargs, tools = self._get_generate_args(
|
|
417
|
-
tokenizer=tokenizer,
|
|
418
|
-
query=query,
|
|
419
|
-
history=history,
|
|
420
|
-
role=role,
|
|
421
|
-
past_key_values=past_key_values,
|
|
422
|
-
max_length=max_length,
|
|
423
|
-
do_sample=do_sample,
|
|
424
|
-
top_p=top_p,
|
|
425
|
-
temperature=temperature,
|
|
426
|
-
logits_processor=logits_processor,
|
|
427
|
-
**kwargs,
|
|
428
|
-
)
|
|
429
|
-
|
|
430
|
-
outputs = self._model.generate(**kwargs)
|
|
431
|
-
outputs = outputs[:, kwargs["input_ids"].shape[1] :]
|
|
432
|
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
433
|
-
if tools:
|
|
434
|
-
return self._process_response(response, history, tools, end=True)
|
|
435
|
-
else:
|
|
436
|
-
return self._process_response(response, history, tools)
|
|
316
|
+
yield new_response
|
|
437
317
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
system_prompt: Optional[str] = None,
|
|
442
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
443
|
-
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
444
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
445
|
-
kwargs: Dict[str, Any] = {}
|
|
318
|
+
@staticmethod
|
|
319
|
+
def _get_generate_kwargs(generate_config):
|
|
320
|
+
kwargs: Dict[str, Any] = {} # type: ignore
|
|
446
321
|
generate_config = generate_config or {}
|
|
447
322
|
temperature = generate_config.get("temperature")
|
|
448
323
|
if temperature is not None:
|
|
@@ -453,18 +328,26 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
453
328
|
max_new_tokens = generate_config.get("max_tokens")
|
|
454
329
|
if max_new_tokens is not None:
|
|
455
330
|
kwargs["max_new_tokens"] = int(max_new_tokens)
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
kwargs["
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
331
|
+
do_sample = generate_config.get("do_sample")
|
|
332
|
+
if do_sample is not None:
|
|
333
|
+
kwargs["do_sample"] = bool(do_sample)
|
|
334
|
+
top_k = generate_config.get("top_k")
|
|
335
|
+
if top_k is not None:
|
|
336
|
+
kwargs["top_k"] = top_k
|
|
337
|
+
repetition_penalty = generate_config.get("repetition_penalty")
|
|
338
|
+
if repetition_penalty is not None:
|
|
339
|
+
kwargs["repetition_penalty"] = repetition_penalty
|
|
340
|
+
return kwargs
|
|
341
|
+
|
|
342
|
+
def chat(
|
|
343
|
+
self,
|
|
344
|
+
messages: List[Dict],
|
|
345
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
346
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
347
|
+
generate_config = generate_config or {}
|
|
348
|
+
kwargs: Dict[str, Any] = self._get_generate_kwargs(generate_config)
|
|
349
|
+
tools = self._handle_tools(messages, generate_config)
|
|
350
|
+
has_tools = tools is not None
|
|
468
351
|
stream = generate_config.get("stream", False)
|
|
469
352
|
stream_options = generate_config.pop("stream_options", None)
|
|
470
353
|
include_usage = (
|
|
@@ -472,103 +355,82 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
472
355
|
if isinstance(stream_options, dict)
|
|
473
356
|
else False
|
|
474
357
|
)
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
358
|
+
inputs = self._tokenizer.apply_chat_template(
|
|
359
|
+
messages,
|
|
360
|
+
return_tensors="pt",
|
|
361
|
+
chat_template=self.model_family.chat_template,
|
|
362
|
+
add_generation_prompt=True,
|
|
363
|
+
return_dict=True,
|
|
364
|
+
)
|
|
365
|
+
inputs = inputs.to(self._model.device)
|
|
366
|
+
|
|
367
|
+
if not stream:
|
|
368
|
+
with torch.no_grad():
|
|
369
|
+
outputs = self._model.generate(**inputs, **kwargs)
|
|
370
|
+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
|
|
371
|
+
response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
372
|
+
# In some cases, the response starts with `\n`
|
|
373
|
+
if response.startswith("\n"):
|
|
374
|
+
response = response[1:]
|
|
375
|
+
if has_tools:
|
|
376
|
+
function_call = self._process_response_non_streaming(
|
|
377
|
+
response, tools, use_tool=True
|
|
378
|
+
)
|
|
379
|
+
return self._tool_calls_completion(
|
|
380
|
+
self.model_family, self.model_uid, function_call
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
return generate_chat_completion(self.model_uid, response)
|
|
384
|
+
else:
|
|
478
385
|
|
|
479
386
|
def _stream_generator():
|
|
480
387
|
last_chunk_text_length = 0
|
|
481
388
|
chunk_id = "chat-" + str(uuid.uuid1())
|
|
482
389
|
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
483
|
-
inputs = self._tokenizer([prompt], return_tensors="pt")
|
|
484
|
-
inputs = inputs.to(self._model.device)
|
|
485
390
|
prompt_tokens = len(inputs["input_ids"][0])
|
|
486
|
-
for chunk_text
|
|
487
|
-
self._tokenizer, prompt, chat_history, **kwargs
|
|
488
|
-
):
|
|
391
|
+
for chunk_text in self._stream_chat(inputs, tools, **kwargs):
|
|
489
392
|
if tools and isinstance(chunk_text, dict):
|
|
490
393
|
yield self._tool_calls_completion_chunk(
|
|
491
|
-
self.model_family, self.model_uid,
|
|
394
|
+
self.model_family, self.model_uid, chunk_text
|
|
492
395
|
)
|
|
493
396
|
return
|
|
494
397
|
completion_tokens = completion_tokens + 1
|
|
495
398
|
total_tokens = prompt_tokens + completion_tokens
|
|
496
399
|
chunk_text = chunk_text[last_chunk_text_length:]
|
|
497
400
|
last_chunk_text_length += len(chunk_text)
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
choices=[completion_choice],
|
|
507
|
-
usage=CompletionUsage(
|
|
508
|
-
prompt_tokens=prompt_tokens,
|
|
509
|
-
completion_tokens=completion_tokens,
|
|
510
|
-
total_tokens=total_tokens,
|
|
511
|
-
),
|
|
401
|
+
yield generate_completion_chunk(
|
|
402
|
+
chunk_text,
|
|
403
|
+
finish_reason=None,
|
|
404
|
+
chunk_id=chunk_id,
|
|
405
|
+
model_uid=self.model_uid,
|
|
406
|
+
prompt_tokens=prompt_tokens,
|
|
407
|
+
completion_tokens=completion_tokens,
|
|
408
|
+
total_tokens=total_tokens,
|
|
512
409
|
)
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
object="text_completion",
|
|
519
|
-
created=int(time.time()),
|
|
520
|
-
model=self.model_uid,
|
|
521
|
-
choices=[completion_choice],
|
|
522
|
-
)
|
|
523
|
-
completion_usage = CompletionUsage(
|
|
410
|
+
yield generate_completion_chunk(
|
|
411
|
+
None,
|
|
412
|
+
finish_reason="stop",
|
|
413
|
+
chunk_id=chunk_id,
|
|
414
|
+
model_uid=self.model_uid,
|
|
524
415
|
prompt_tokens=prompt_tokens,
|
|
525
416
|
completion_tokens=completion_tokens,
|
|
526
417
|
total_tokens=total_tokens,
|
|
418
|
+
has_choice=True,
|
|
419
|
+
has_content=False,
|
|
527
420
|
)
|
|
528
|
-
chunk["usage"] = completion_usage
|
|
529
|
-
yield chunk
|
|
530
421
|
if include_usage:
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
choices=[],
|
|
537
|
-
)
|
|
538
|
-
chunk["usage"] = CompletionUsage(
|
|
422
|
+
yield generate_completion_chunk(
|
|
423
|
+
None,
|
|
424
|
+
finish_reason=None,
|
|
425
|
+
chunk_id=chunk_id,
|
|
426
|
+
model_uid=self.model_uid,
|
|
539
427
|
prompt_tokens=prompt_tokens,
|
|
540
428
|
completion_tokens=completion_tokens,
|
|
541
429
|
total_tokens=total_tokens,
|
|
430
|
+
has_choice=False,
|
|
542
431
|
)
|
|
543
|
-
yield chunk
|
|
544
432
|
|
|
545
433
|
return self._to_chat_completion_chunks(_stream_generator())
|
|
546
|
-
else:
|
|
547
|
-
response = self._non_stream_chat(
|
|
548
|
-
self._tokenizer, prompt, chat_history, **kwargs
|
|
549
|
-
)
|
|
550
|
-
if tools:
|
|
551
|
-
return self._tool_calls_completion(
|
|
552
|
-
self.model_family, self.model_uid, response, tools
|
|
553
|
-
)
|
|
554
|
-
else:
|
|
555
|
-
content, _ = response
|
|
556
|
-
return ChatCompletion(
|
|
557
|
-
id="chat" + str(uuid.uuid1()),
|
|
558
|
-
object="chat.completion",
|
|
559
|
-
created=int(time.time()),
|
|
560
|
-
model=self.model_uid,
|
|
561
|
-
choices=[
|
|
562
|
-
ChatCompletionChoice(
|
|
563
|
-
index=0,
|
|
564
|
-
message={"role": "assistant", "content": content},
|
|
565
|
-
finish_reason="stop",
|
|
566
|
-
)
|
|
567
|
-
],
|
|
568
|
-
usage=CompletionUsage(
|
|
569
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
570
|
-
),
|
|
571
|
-
)
|
|
572
434
|
|
|
573
435
|
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
574
436
|
"""
|