xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -11,45 +11,25 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import copy
15
14
  import json
16
- import threading
17
- import time
15
+ import typing
18
16
  import uuid
17
+ from threading import Thread
19
18
  from typing import Any, Dict, Iterator, List, Optional, Union
20
19
 
21
20
  import torch
22
- from transformers.generation.logits_process import LogitsProcessor
23
- from transformers.generation.utils import LogitsProcessorList
24
21
 
25
22
  from ....core.scheduler import InferenceRequest
26
- from ....types import (
27
- SPECIAL_TOOL_PROMPT,
28
- ChatCompletion,
29
- ChatCompletionChoice,
30
- ChatCompletionChunk,
31
- ChatCompletionMessage,
32
- CompletionChoice,
33
- CompletionChunk,
34
- CompletionUsage,
35
- LoRA,
36
- PytorchGenerateConfig,
37
- )
23
+ from ....types import ChatCompletion, ChatCompletionChunk, LoRA, PytorchGenerateConfig
38
24
  from ..llm_family import LLMFamilyV1, LLMSpecV1
39
- from ..utils import GLM4_TOOL_CALL_FAMILY
25
+ from ..utils import (
26
+ GLM4_TOOL_CALL_FAMILY,
27
+ generate_chat_completion,
28
+ generate_completion_chunk,
29
+ )
40
30
  from .core import PytorchChatModel, PytorchModelConfig
41
31
 
42
32
 
43
- class InvalidScoreLogitsProcessor(LogitsProcessor):
44
- def __call__(
45
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
46
- ) -> torch.FloatTensor:
47
- if torch.isnan(scores).any() or torch.isinf(scores).any():
48
- scores.zero_()
49
- scores[..., 198] = 5e4
50
- return scores
51
-
52
-
53
33
  class ChatglmPytorchChatModel(PytorchChatModel):
54
34
  def __init__(
55
35
  self,
@@ -107,40 +87,28 @@ class ChatglmPytorchChatModel(PytorchChatModel):
107
87
  if llm_spec.model_format != "pytorch":
108
88
  return False
109
89
  model_family = llm_family.model_family or llm_family.model_name
110
- if "chatglm" not in model_family and "glm4" not in model_family:
90
+ if "glm4" not in model_family:
111
91
  return False
112
92
  if "chat" not in llm_family.model_ability:
113
93
  return False
114
94
  return True
115
95
 
116
- def _handle_tools(self, chat_history, generate_config) -> bool:
96
+ def _handle_tools(self, messages, generate_config):
117
97
  """Convert openai tools to ChatGLM tools."""
98
+ if self.model_family.model_name not in GLM4_TOOL_CALL_FAMILY:
99
+ return None
118
100
  if generate_config is None:
119
- return False
101
+ return None
120
102
  tools = generate_config.pop("tools", None)
121
103
  if tools is None:
122
- return False
123
- # Convert a iterable to a list
104
+ return None
105
+ # Convert an iterable to a list
124
106
  tools = list(tools)
125
107
  tool_choice = generate_config.pop("tool_choice", "none")
126
- if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
127
- chat_history[:] = self._process_messages(
128
- chat_history, tools=tools, tool_choice=tool_choice
129
- )
130
- return True
131
- else:
132
- chatglm_tools = []
133
- for elem in tools:
134
- if elem.get("type") != "function" or "function" not in elem:
135
- raise ValueError("ChatGLM tools only support function type.")
136
- chatglm_tools.append(elem["function"])
137
- tool_prompt_message = {
138
- "role": "system",
139
- "content": f"Answer the following questions as best as you can. You have access to the following tools:",
140
- "tools": chatglm_tools,
141
- }
142
- chat_history.insert(0, tool_prompt_message)
143
- return True
108
+ messages[:] = self._process_messages(
109
+ messages, tools=tools, tool_choice=tool_choice
110
+ )
111
+ return tools
144
112
 
145
113
  @staticmethod
146
114
  def _process_messages(messages, tools=None, tool_choice="none"):
@@ -230,12 +198,70 @@ class ChatglmPytorchChatModel(PytorchChatModel):
230
198
  return processed_messages
231
199
 
232
200
  @staticmethod
233
- def _process_response(output, history, tools, end=False):
201
+ @typing.no_type_check
202
+ def _process_response_non_streaming(
203
+ output: str, tools: Union[Dict, List[Dict]] = None, use_tool: bool = False
204
+ ) -> Union[str, dict]:
205
+ """
206
+ Copied from https://github.com/THUDM/GLM-4/blob/main/basic_demo/openai_api_server.py#L150
207
+ """
208
+ import re
209
+
210
+ lines = output.strip().split("\n")
211
+ arguments_json = None
212
+ special_tools = ["cogview", "simple_browser"]
213
+ tools = {tool["function"]["name"] for tool in tools} if tools else {}
214
+
215
+ # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
216
+ ##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
217
+
218
+ if len(lines) >= 2 and lines[1].startswith("{"):
219
+ function_name = lines[0].strip()
220
+ arguments = "\n".join(lines[1:]).strip()
221
+ if function_name in tools or function_name in special_tools:
222
+ try:
223
+ arguments_json = json.loads(arguments)
224
+ is_tool_call = True
225
+ except json.JSONDecodeError:
226
+ is_tool_call = function_name in special_tools
227
+
228
+ if is_tool_call and use_tool:
229
+ content = {
230
+ "name": function_name,
231
+ "arguments": json.dumps(
232
+ arguments_json
233
+ if isinstance(arguments_json, dict)
234
+ else arguments,
235
+ ensure_ascii=False,
236
+ ),
237
+ }
238
+ if function_name == "simple_browser":
239
+ search_pattern = re.compile(
240
+ r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)'
241
+ )
242
+ match = search_pattern.match(arguments)
243
+ if match:
244
+ content["arguments"] = json.dumps(
245
+ {
246
+ "query": match.group(1),
247
+ "recency_days": int(match.group(2)),
248
+ },
249
+ ensure_ascii=False,
250
+ )
251
+ elif function_name == "cogview":
252
+ content["arguments"] = json.dumps(
253
+ {"prompt": arguments}, ensure_ascii=False
254
+ )
255
+
256
+ return content
257
+ return output.strip()
258
+
259
+ @staticmethod
260
+ def _process_response_streaming(output, tools, end=False):
234
261
  # Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
235
262
  content = ""
236
- history = copy.deepcopy(history)
237
263
  if not tools and end:
238
- return None, None
264
+ return None
239
265
  for response in output.split("<|assistant|>"):
240
266
  if "\n" in response:
241
267
  metadata, content = response.split("\n", maxsplit=1)
@@ -244,205 +270,54 @@ class ChatglmPytorchChatModel(PytorchChatModel):
244
270
  if not metadata.strip():
245
271
  if tools and any(t.startswith(response) for t in tools) and not end:
246
272
  # Waiting for tool call complete.
247
- return None, None
273
+ return None
248
274
  content = content.strip()
249
- history.append(
250
- {"role": "assistant", "metadata": metadata, "content": content}
251
- )
252
275
  content = content.replace("[[训练时间]]", "2023年")
253
276
  else:
254
277
  if tools and metadata in tools and not end:
255
- return None, None
256
- history.append(
257
- {"role": "assistant", "metadata": metadata, "content": content}
258
- )
278
+ return None
259
279
  metadata = metadata.strip()
260
280
  if tools and metadata in tools and end:
261
281
  try:
262
282
  parameters = json.loads(content)
263
- content = {"name": metadata.strip(), "parameters": parameters}
283
+ content = {"name": metadata.strip(), "arguments": parameters}
264
284
  except json.JSONDecodeError:
265
285
  content = {"name": metadata.strip(), "content": content}
266
286
  else:
267
287
  content = {"name": metadata.strip(), "content": content}
268
- return content, history
269
-
270
- def _get_generate_args(
271
- self,
272
- tokenizer,
273
- query: str,
274
- history: Optional[List[Dict]] = None,
275
- role: str = "user",
276
- past_key_values=None,
277
- max_length: int = 8192,
278
- do_sample=True,
279
- top_p=0.8,
280
- temperature=0.8,
281
- logits_processor=None,
282
- **kwargs,
283
- ):
284
- # Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
285
- if history is None:
286
- history = []
287
- if logits_processor is None:
288
- logits_processor = LogitsProcessorList()
289
- logits_processor.append(InvalidScoreLogitsProcessor())
290
- eos_token_id = [
291
- tokenizer.eos_token_id,
292
- tokenizer.convert_tokens_to_ids("<|user|>"),
293
- tokenizer.convert_tokens_to_ids("<|observation|>"),
294
- ]
295
- gen_kwargs = {
296
- "max_length": max_length,
297
- "do_sample": do_sample,
298
- "top_p": top_p,
299
- "temperature": temperature,
300
- "logits_processor": logits_processor,
301
- **kwargs,
302
- }
303
- if past_key_values is None:
304
- inputs = tokenizer.apply_chat_template(
305
- history + [{"role": role, "content": query}],
306
- add_generation_prompt=True,
307
- tokenize=True,
308
- return_tensors="pt",
309
- return_dict=True,
310
- )
311
- else:
312
- inputs = tokenizer.apply_chat_template(
313
- [{"role": role, "content": query}],
314
- add_special_tokens=False,
315
- add_generation_prompt=True,
316
- tokenize=True,
317
- return_tensors="pt",
318
- return_dict=True,
319
- )
320
- inputs = inputs.to(self._model.device)
321
- if past_key_values is not None:
322
- past_length = past_key_values[0][0].shape[2]
323
- inputs.position_ids += past_length
324
- attention_mask = inputs.attention_mask
325
- attention_mask = torch.cat(
326
- (attention_mask.new_ones(1, past_length), attention_mask), dim=1
327
- )
328
- inputs["attention_mask"] = attention_mask
329
- history.append({"role": role, "content": query})
330
- tools = history[0]["role"] == "system" and history[0].get("tools")
331
- tools = (
332
- [
333
- t.get("function", {}).get("name", "")
334
- for t in tools
335
- if isinstance(t, dict)
336
- ]
337
- if tools
338
- else []
339
- )
340
- kwargs = dict(inputs)
341
- kwargs["past_key_values"] = past_key_values
342
- kwargs["eos_token_id"] = eos_token_id
343
- kwargs.update(gen_kwargs)
344
- return kwargs, tools
288
+ return content
345
289
 
346
290
  @torch.inference_mode()
347
- def _stream_chat(
348
- self,
349
- tokenizer,
350
- query: str,
351
- history: Optional[List[Dict]] = None,
352
- role: str = "user",
353
- past_key_values=None,
354
- max_length: int = 8192,
355
- do_sample=True,
356
- top_p=0.8,
357
- temperature=0.8,
358
- logits_processor=None,
359
- **kwargs,
360
- ):
291
+ def _stream_chat(self, inputs, tools, **kwargs):
361
292
  from transformers import TextIteratorStreamer
362
293
 
363
- kwargs, tools = self._get_generate_args(
364
- tokenizer=tokenizer,
365
- query=query,
366
- history=history,
367
- role=role,
368
- past_key_values=past_key_values,
369
- max_length=max_length,
370
- do_sample=do_sample,
371
- top_p=top_p,
372
- temperature=temperature,
373
- logits_processor=logits_processor,
374
- **kwargs,
375
- )
376
-
377
294
  streamer = TextIteratorStreamer(
378
- tokenizer, skip_prompt=True, skip_special_tokens=True
295
+ self._tokenizer, skip_prompt=True, skip_special_tokens=True
379
296
  )
380
- kwargs["streamer"] = streamer
381
- thread = threading.Thread(target=self._model.generate, kwargs=kwargs)
297
+ tools = {tool["function"]["name"] for tool in tools} if tools else {}
298
+ generation_kwargs = dict(inputs, streamer=streamer)
299
+ generation_kwargs.update(kwargs)
300
+ thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
382
301
  thread.start()
383
302
 
384
303
  response = ""
385
304
  for token in streamer:
386
305
  response += token
387
306
  if response and response[-1] != "�":
388
- new_response, new_history = self._process_response(
389
- response, history, tools, end=False
307
+ new_response = self._process_response_streaming(
308
+ response, tools, end=False
390
309
  )
391
310
  if new_response is None:
392
311
  continue
393
- yield new_response, new_history
312
+ yield new_response
394
313
  if tools:
395
- new_response, new_history = self._process_response(
396
- response, history, tools, end=True
397
- )
314
+ new_response = self._process_response_streaming(response, tools, end=True)
398
315
  if new_response:
399
- yield new_response, new_history
400
-
401
- @torch.inference_mode()
402
- def _non_stream_chat(
403
- self,
404
- tokenizer,
405
- query: str,
406
- history: Optional[List[Dict]] = None,
407
- role: str = "user",
408
- past_key_values=None,
409
- max_length: int = 8192,
410
- do_sample=True,
411
- top_p=0.8,
412
- temperature=0.8,
413
- logits_processor=None,
414
- **kwargs,
415
- ):
416
- kwargs, tools = self._get_generate_args(
417
- tokenizer=tokenizer,
418
- query=query,
419
- history=history,
420
- role=role,
421
- past_key_values=past_key_values,
422
- max_length=max_length,
423
- do_sample=do_sample,
424
- top_p=top_p,
425
- temperature=temperature,
426
- logits_processor=logits_processor,
427
- **kwargs,
428
- )
429
-
430
- outputs = self._model.generate(**kwargs)
431
- outputs = outputs[:, kwargs["input_ids"].shape[1] :]
432
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
433
- if tools:
434
- return self._process_response(response, history, tools, end=True)
435
- else:
436
- return self._process_response(response, history, tools)
316
+ yield new_response
437
317
 
438
- def chat(
439
- self,
440
- prompt: str,
441
- system_prompt: Optional[str] = None,
442
- chat_history: Optional[List[ChatCompletionMessage]] = None,
443
- generate_config: Optional[PytorchGenerateConfig] = None,
444
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
445
- kwargs: Dict[str, Any] = {}
318
+ @staticmethod
319
+ def _get_generate_kwargs(generate_config):
320
+ kwargs: Dict[str, Any] = {} # type: ignore
446
321
  generate_config = generate_config or {}
447
322
  temperature = generate_config.get("temperature")
448
323
  if temperature is not None:
@@ -453,18 +328,26 @@ class ChatglmPytorchChatModel(PytorchChatModel):
453
328
  max_new_tokens = generate_config.get("max_tokens")
454
329
  if max_new_tokens is not None:
455
330
  kwargs["max_new_tokens"] = int(max_new_tokens)
456
- chat_history = chat_history or []
457
- tools = self._handle_tools(chat_history, generate_config)
458
- # Tool calls only works for non stream, so we call chat directly.
459
- if prompt == SPECIAL_TOOL_PROMPT and chat_history:
460
- tool_message = chat_history.pop()
461
- content = tool_message.get("content")
462
- assert content is not None
463
- prompt = content
464
- kwargs["role"] = "observation"
465
- chat_history = [h for h in chat_history if not h.get("tool_calls")]
466
- if system_prompt:
467
- chat_history.append({"role": "system", "content": system_prompt})
331
+ do_sample = generate_config.get("do_sample")
332
+ if do_sample is not None:
333
+ kwargs["do_sample"] = bool(do_sample)
334
+ top_k = generate_config.get("top_k")
335
+ if top_k is not None:
336
+ kwargs["top_k"] = top_k
337
+ repetition_penalty = generate_config.get("repetition_penalty")
338
+ if repetition_penalty is not None:
339
+ kwargs["repetition_penalty"] = repetition_penalty
340
+ return kwargs
341
+
342
+ def chat(
343
+ self,
344
+ messages: List[Dict],
345
+ generate_config: Optional[PytorchGenerateConfig] = None,
346
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
347
+ generate_config = generate_config or {}
348
+ kwargs: Dict[str, Any] = self._get_generate_kwargs(generate_config)
349
+ tools = self._handle_tools(messages, generate_config)
350
+ has_tools = tools is not None
468
351
  stream = generate_config.get("stream", False)
469
352
  stream_options = generate_config.pop("stream_options", None)
470
353
  include_usage = (
@@ -472,103 +355,82 @@ class ChatglmPytorchChatModel(PytorchChatModel):
472
355
  if isinstance(stream_options, dict)
473
356
  else False
474
357
  )
475
- if stream and (
476
- not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
477
- ):
358
+ inputs = self._tokenizer.apply_chat_template(
359
+ messages,
360
+ return_tensors="pt",
361
+ chat_template=self.model_family.chat_template,
362
+ add_generation_prompt=True,
363
+ return_dict=True,
364
+ )
365
+ inputs = inputs.to(self._model.device)
366
+
367
+ if not stream:
368
+ with torch.no_grad():
369
+ outputs = self._model.generate(**inputs, **kwargs)
370
+ outputs = outputs[:, inputs["input_ids"].shape[1] :]
371
+ response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
372
+ # In some cases, the response starts with `\n`
373
+ if response.startswith("\n"):
374
+ response = response[1:]
375
+ if has_tools:
376
+ function_call = self._process_response_non_streaming(
377
+ response, tools, use_tool=True
378
+ )
379
+ return self._tool_calls_completion(
380
+ self.model_family, self.model_uid, function_call
381
+ )
382
+ else:
383
+ return generate_chat_completion(self.model_uid, response)
384
+ else:
478
385
 
479
386
  def _stream_generator():
480
387
  last_chunk_text_length = 0
481
388
  chunk_id = "chat-" + str(uuid.uuid1())
482
389
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
483
- inputs = self._tokenizer([prompt], return_tensors="pt")
484
- inputs = inputs.to(self._model.device)
485
390
  prompt_tokens = len(inputs["input_ids"][0])
486
- for chunk_text, _ in self._stream_chat(
487
- self._tokenizer, prompt, chat_history, **kwargs
488
- ):
391
+ for chunk_text in self._stream_chat(inputs, tools, **kwargs):
489
392
  if tools and isinstance(chunk_text, dict):
490
393
  yield self._tool_calls_completion_chunk(
491
- self.model_family, self.model_uid, [chunk_text, _], tools
394
+ self.model_family, self.model_uid, chunk_text
492
395
  )
493
396
  return
494
397
  completion_tokens = completion_tokens + 1
495
398
  total_tokens = prompt_tokens + completion_tokens
496
399
  chunk_text = chunk_text[last_chunk_text_length:]
497
400
  last_chunk_text_length += len(chunk_text)
498
- completion_choice = CompletionChoice(
499
- text=chunk_text, index=0, logprobs=None, finish_reason=None
500
- )
501
- yield CompletionChunk(
502
- id=chunk_id,
503
- object="text_completion",
504
- created=int(time.time()),
505
- model=self.model_uid,
506
- choices=[completion_choice],
507
- usage=CompletionUsage(
508
- prompt_tokens=prompt_tokens,
509
- completion_tokens=completion_tokens,
510
- total_tokens=total_tokens,
511
- ),
401
+ yield generate_completion_chunk(
402
+ chunk_text,
403
+ finish_reason=None,
404
+ chunk_id=chunk_id,
405
+ model_uid=self.model_uid,
406
+ prompt_tokens=prompt_tokens,
407
+ completion_tokens=completion_tokens,
408
+ total_tokens=total_tokens,
512
409
  )
513
- completion_choice = CompletionChoice(
514
- text="", index=0, logprobs=None, finish_reason="stop"
515
- )
516
- chunk = CompletionChunk(
517
- id=chunk_id,
518
- object="text_completion",
519
- created=int(time.time()),
520
- model=self.model_uid,
521
- choices=[completion_choice],
522
- )
523
- completion_usage = CompletionUsage(
410
+ yield generate_completion_chunk(
411
+ None,
412
+ finish_reason="stop",
413
+ chunk_id=chunk_id,
414
+ model_uid=self.model_uid,
524
415
  prompt_tokens=prompt_tokens,
525
416
  completion_tokens=completion_tokens,
526
417
  total_tokens=total_tokens,
418
+ has_choice=True,
419
+ has_content=False,
527
420
  )
528
- chunk["usage"] = completion_usage
529
- yield chunk
530
421
  if include_usage:
531
- chunk = CompletionChunk(
532
- id=chunk_id,
533
- object="text_completion",
534
- created=int(time.time()),
535
- model=self.model_uid,
536
- choices=[],
537
- )
538
- chunk["usage"] = CompletionUsage(
422
+ yield generate_completion_chunk(
423
+ None,
424
+ finish_reason=None,
425
+ chunk_id=chunk_id,
426
+ model_uid=self.model_uid,
539
427
  prompt_tokens=prompt_tokens,
540
428
  completion_tokens=completion_tokens,
541
429
  total_tokens=total_tokens,
430
+ has_choice=False,
542
431
  )
543
- yield chunk
544
432
 
545
433
  return self._to_chat_completion_chunks(_stream_generator())
546
- else:
547
- response = self._non_stream_chat(
548
- self._tokenizer, prompt, chat_history, **kwargs
549
- )
550
- if tools:
551
- return self._tool_calls_completion(
552
- self.model_family, self.model_uid, response, tools
553
- )
554
- else:
555
- content, _ = response
556
- return ChatCompletion(
557
- id="chat" + str(uuid.uuid1()),
558
- object="chat.completion",
559
- created=int(time.time()),
560
- model=self.model_uid,
561
- choices=[
562
- ChatCompletionChoice(
563
- index=0,
564
- message={"role": "assistant", "content": content},
565
- finish_reason="stop",
566
- )
567
- ],
568
- usage=CompletionUsage(
569
- prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
570
- ),
571
- )
572
434
 
573
435
  def prepare_sanitize_generate_config(self, req: InferenceRequest):
574
436
  """