xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ import json
17
17
  import logging
18
18
  import os
19
19
  import time
20
+ import typing
20
21
  import uuid
21
22
  from io import BytesIO
22
23
  from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast
@@ -25,19 +26,18 @@ import requests
25
26
  from PIL import Image
26
27
 
27
28
  from ...types import (
28
- SPECIAL_TOOL_PROMPT,
29
29
  ChatCompletion,
30
+ ChatCompletionChoice,
30
31
  ChatCompletionChunk,
31
- ChatCompletionMessage,
32
32
  Completion,
33
+ CompletionChoice,
33
34
  CompletionChunk,
35
+ CompletionUsage,
34
36
  )
35
- from ..utils import ensure_cache_cleared
36
37
  from .llm_family import (
37
38
  LlamaCppLLMSpecV1,
38
39
  LLMFamilyV1,
39
40
  LLMSpecV1,
40
- PromptStyleV1,
41
41
  _get_cache_dir,
42
42
  get_cache_status,
43
43
  )
@@ -46,7 +46,6 @@ logger = logging.getLogger(__name__)
46
46
 
47
47
 
48
48
  QWEN_TOOL_CALL_FAMILY = [
49
- "qwen-chat",
50
49
  "qwen1.5-chat",
51
50
  "qwen1.5-moe-chat",
52
51
  "qwen2-instruct",
@@ -58,416 +57,90 @@ GLM4_TOOL_CALL_FAMILY = [
58
57
  "glm4-chat-1m",
59
58
  ]
60
59
 
60
+ QWEN_TOOL_CALL_SYMBOLS = ["<tool_call>", "</tool_call>"]
61
+
61
62
 
62
63
  class ChatModelMixin:
63
64
  @staticmethod
64
- def get_prompt(
65
- prompt: str,
66
- chat_history: List[ChatCompletionMessage],
67
- prompt_style: PromptStyleV1,
68
- tools: Optional[List[Dict]] = None,
69
- ):
65
+ @functools.lru_cache
66
+ def _compile_jinja_template(chat_template):
70
67
  """
71
- Inspired by FastChat. Format chat history into a prompt according to the prompty style of
72
- different models.
68
+ Copied from transformers source code.
73
69
  """
74
- assert prompt_style.roles is not None
75
- if prompt != SPECIAL_TOOL_PROMPT:
76
- chat_history.append(
77
- ChatCompletionMessage(role=prompt_style.roles[0], content=prompt)
78
- )
79
- chat_history.append(
80
- ChatCompletionMessage(role=prompt_style.roles[1], content="")
70
+ try:
71
+ from jinja2.exceptions import TemplateError
72
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
73
+ except ImportError:
74
+ raise ImportError("xinference requires jinja2 to be installed.")
75
+
76
+ def raise_exception(message):
77
+ raise TemplateError(message)
78
+
79
+ jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
80
+ jinja_env.globals["raise_exception"] = raise_exception
81
+ return jinja_env.from_string(chat_template)
82
+
83
+ def _build_from_raw_template(
84
+ self, messages: List, chat_template: str, **kwargs
85
+ ) -> str:
86
+ compiled_template = self._compile_jinja_template(chat_template)
87
+ rendered = compiled_template.render(
88
+ messages=messages, add_generation_prompt=True, **kwargs
81
89
  )
82
-
83
- def get_role(role_name: str):
84
- if role_name == "user":
85
- return prompt_style.roles[0]
86
- elif role_name == "assistant":
87
- return prompt_style.roles[1]
88
- else:
89
- return role_name
90
-
91
- if prompt_style.style_name == "ADD_COLON_SINGLE":
92
- ret = prompt_style.system_prompt + prompt_style.intra_message_sep
93
- for message in chat_history:
94
- role = get_role(message["role"])
95
- content = message["content"]
96
- if content:
97
- ret += role + ": " + content + prompt_style.intra_message_sep
98
- else:
99
- ret += role + ":"
100
- return ret
101
- elif prompt_style.style_name == "NO_COLON_TWO":
102
- seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
103
- ret = prompt_style.system_prompt
104
- for i, message in enumerate(chat_history):
105
- role = get_role(message["role"])
106
- content = message["content"]
107
- if content:
108
- ret += role + content + seps[i % 2]
109
- else:
110
- ret += role
111
- return ret
112
- elif prompt_style.style_name == "LLAMA2":
113
- seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
114
- ret = ""
115
- for i, message in enumerate(chat_history):
116
- role = get_role(message["role"])
117
- content = message["content"]
118
- if content:
119
- if i == 0:
120
- ret += prompt_style.system_prompt + content
121
- else:
122
- ret += role + " " + content + seps[i % 2]
123
- else:
124
- ret += role
125
- return ret
126
- elif prompt_style.style_name == "LLAMA3":
127
- ret = (
128
- f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
129
- f"{prompt_style.intra_message_sep}{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
130
- )
131
- for i, message in enumerate(chat_history):
132
- role = get_role(message["role"])
133
- content = message["content"]
134
- if content:
135
- ret += (
136
- f"<|start_header_id|>{role}<|end_header_id|>"
137
- f"{prompt_style.intra_message_sep}{content}{prompt_style.inter_message_sep}"
138
- )
139
- else:
140
- ret += f"<|start_header_id|>{role}<|end_header_id|>{prompt_style.intra_message_sep}"
141
- return ret
142
- elif prompt_style.style_name == "MIXTRAL_V01":
143
- ret = ""
144
- for i, message in enumerate(chat_history):
145
- content = message["content"]
146
- if i % 2 == 0: # user
147
- ret += f"<s> [INST] {content} [/INST]"
148
- else: # assistant
149
- ret += f"{content} </s>"
150
- return ret
151
- elif prompt_style.style_name == "CHATGLM3":
152
- prompts = (
153
- [f"<|system|>\n {prompt_style.system_prompt}"]
154
- if prompt_style.system_prompt
155
- else []
156
- )
157
-
158
- for i, message in enumerate(chat_history):
159
- role = get_role(message["role"])
160
- content = message.get("content")
161
- tool_calls = message.get("tool_calls")
162
- if tool_calls:
163
- content = tool_calls[0]["function"]
164
- if content:
165
- if role == "tool":
166
- role = "observation"
167
- prompts.append(f"<|{role}|>\n {content}")
168
- else:
169
- prompts.append(f"<|{role}|>")
170
- return "\n".join(prompts)
171
- elif prompt_style.style_name == "XVERSE":
172
- ret = (
173
- f"<|system|> \n {prompt_style.system_prompt}"
174
- if prompt_style.system_prompt
175
- else ""
176
- )
177
- for i, message in enumerate(chat_history):
178
- role = get_role(message["role"])
179
- content = message["content"]
180
- if content:
181
- ret += f"<|{role}|> \n {content}"
182
- else:
183
- ret += f"<|{role}|>"
184
- return ret
185
- elif prompt_style.style_name == "QWEN":
186
- if tools:
187
- tool_desc = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object."""
188
-
189
- react_instruction = """Answer the following questions as best you can. You have access to the following APIs:
190
-
191
- {tools_text}
192
-
193
- Use the following format:
194
-
195
- Question: the input question you must answer
196
- Thought: you should always think about what to do
197
- Action: the action to take, should be one of [{tools_name_text}]
198
- Action Input: the input to the action
199
- Observation: the result of the action
200
- ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
201
- Thought: I now know the final answer
202
- Final Answer: the final answer to the original input question
203
-
204
- Begin!"""
205
- tools_text = []
206
- tools_name_text = []
207
- for func_info in tools:
208
- parameters = []
209
- fp = func_info["function"].get("parameters", {})
210
- if fp:
211
- required_parameters = fp.get("required", [])
212
- for name, p in fp["properties"].items():
213
- param = dict({"name": name}, **p)
214
- if name in required_parameters:
215
- param["required"] = True
216
- parameters.append(param)
217
-
218
- name = func_info["function"]["name"]
219
- desc = func_info["function"]["description"]
220
- tool_string = tool_desc.format(
221
- name_for_model=name,
222
- name_for_human=name,
223
- # Hint: You can add the following format requirements in description:
224
- # "Format the arguments as a JSON object."
225
- # "Enclose the code within triple backticks (`) at the beginning and end of the code."
226
- description_for_model=desc,
227
- parameters=json.dumps(parameters, ensure_ascii=False),
228
- )
229
- tools_text.append(tool_string)
230
- tools_name_text.append(name)
231
- tools_text_string = "\n\n".join(tools_text)
232
- tools_name_text_string = ", ".join(tools_name_text)
233
- tool_system = react_instruction.format(
234
- tools_text=tools_text_string,
235
- tools_name_text=tools_name_text_string,
90
+ return rendered
91
+
92
+ def get_full_context(
93
+ self, messages: List, chat_template: str, tokenizer=None, **kwargs
94
+ ) -> str:
95
+ if tokenizer is not None:
96
+ try:
97
+ full_context = tokenizer.apply_chat_template(
98
+ messages,
99
+ tokenize=False,
100
+ chat_template=chat_template,
101
+ add_generation_prompt=True,
102
+ **kwargs,
236
103
  )
237
- else:
238
- tool_system = ""
239
-
240
- ret = f"<|im_start|>system\n{prompt_style.system_prompt}<|im_end|>"
241
- for message in chat_history:
242
- role = get_role(message["role"])
243
- content = message.get("content")
244
-
245
- ret += prompt_style.intra_message_sep
246
- if tools:
247
- if role == "user":
248
- if tool_system:
249
- content = tool_system + f"\n\nQuestion: {content}"
250
- tool_system = ""
251
- else:
252
- content = f"Question: {content}"
253
- elif role == "assistant":
254
- tool_calls = message.get("tool_calls")
255
- if tool_calls:
256
- func_call = tool_calls[0]["function"]
257
- f_name, f_args = (
258
- func_call["name"],
259
- func_call["arguments"],
260
- )
261
- content = f"Thought: I can use {f_name}.\nAction: {f_name}\nAction Input: {f_args}"
262
- elif content:
263
- content = f"Thought: I now know the final answer.\nFinal answer: {content}"
264
- elif role == "tool":
265
- role = "function"
266
- content = f"Observation: {content}"
267
- else:
268
- raise Exception(f"Unsupported message role: {role}")
269
- if content:
270
- content = content.lstrip("\n").rstrip()
271
- ret += f"<|im_start|>{role}\n{content}<|im_end|>"
272
- else:
273
- ret += f"<|im_start|>{role}\n"
274
- return ret
275
- elif prompt_style.style_name == "CHATML":
276
- ret = (
277
- ""
278
- if prompt_style.system_prompt == ""
279
- else prompt_style.system_prompt + prompt_style.intra_message_sep + "\n"
280
- )
281
- for message in chat_history:
282
- role = get_role(message["role"])
283
- content = message["content"]
104
+ return full_context
105
+ except Exception as e:
106
+ logger.warning(
107
+ f"tokenizer.apply_chat_template error. Maybe this is an old model: {e}"
108
+ )
109
+ return self._build_from_raw_template(messages, chat_template, **kwargs)
110
+ else:
111
+ # build from jinja
112
+ # Compilation function uses a cache to avoid recompiling the same template
113
+ return self._build_from_raw_template(messages, chat_template, **kwargs)
284
114
 
285
- if content:
286
- ret += role + "\n" + content + prompt_style.intra_message_sep + "\n"
287
- else:
288
- ret += role + "\n"
289
- return ret
290
- elif prompt_style.style_name == "INTERNLM2":
291
- ret = (
292
- "<s>"
293
- if prompt_style.system_prompt == ""
294
- else "<s><|im_start|>system\n"
295
- + prompt_style.system_prompt
296
- + prompt_style.intra_message_sep
297
- + "\n"
298
- )
299
- for message in chat_history:
300
- role = get_role(message["role"])
301
- content = message["content"]
115
+ @staticmethod
116
+ def get_specific_prompt(model_family: str, messages: List[Dict]):
117
+ """
118
+ Inspired by FastChat. Format chat history into a prompt according to the prompty style of
119
+ different models.
120
+ """
121
+ _messages = [x for x in messages] # copy for not modifying the origin messages
122
+ _messages.append({"role": "assistant", "content": ""})
302
123
 
303
- if content:
304
- ret += role + "\n" + content + prompt_style.intra_message_sep + "\n"
305
- else:
306
- ret += role + "\n"
307
- return ret
308
- elif prompt_style.style_name == "ADD_COLON_SINGLE_COT":
309
- ret = prompt_style.system_prompt + prompt_style.intra_message_sep
310
- for message in chat_history:
311
- role = get_role(message["role"])
312
- content = message["content"]
313
- if content:
314
- ret += role + ": " + content + prompt_style.intra_message_sep
315
- else:
316
- ret += role + ": Let's think step by step."
317
- return ret
318
- elif prompt_style.style_name == "DEEPSEEK_CHAT":
319
- seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
320
- ret = prompt_style.system_prompt
321
- for i, message in enumerate(chat_history):
322
- role = get_role(message["role"])
323
- content = message["content"]
324
- if content:
325
- ret += role + ": " + content + seps[i % 2]
326
- else:
327
- ret += role + ":"
328
- return ret
329
- elif prompt_style.style_name == "DEEPSEEK_CODER":
330
- sep = prompt_style.inter_message_sep
331
- ret = prompt_style.system_prompt + sep
332
- for i, message in enumerate(chat_history):
333
- role = get_role(message["role"])
334
- content = message["content"]
335
- if content:
336
- ret += role + "\n" + content + sep
337
- else:
338
- ret += role + "\n"
339
- return ret
340
- elif prompt_style.style_name == "GORILLA_OPENFUNCTIONS":
341
- if tools:
342
- gorilla_functions = []
343
- for tool in tools:
344
- gorilla_functions.append(
345
- {
346
- "name": tool["function"]["name"],
347
- "api_name": tool["function"]["name"],
348
- "description": tool["function"]["description"],
349
- "parameters": [
350
- dict({"name": name}, **p)
351
- for name, p in tool["function"]["parameters"][
352
- "properties"
353
- ].items()
354
- ],
355
- }
356
- )
357
- tools_string = json.dumps(gorilla_functions)
358
- return f"USER: <<question>> {prompt} <<function>> {tools_string}\nASSISTANT: "
359
- else:
360
- return f"USER: <<question>> {prompt}\nASSISTANT: "
361
- elif prompt_style.style_name == "orion":
362
- ret = "<s>"
363
- for i, message in enumerate(chat_history):
364
- content = message["content"]
365
- role = get_role(message["role"])
366
- if i % 2 == 0: # Human
367
- assert content is not None
368
- ret += role + ": " + content + "\n\n"
369
- else: # Assistant
370
- if content:
371
- ret += role + ": </s>" + content + "</s>"
372
- else:
373
- ret += role + ": </s>"
374
- return ret
375
- elif prompt_style.style_name == "gemma":
376
- ret = ""
377
- for message in chat_history:
378
- content = message["content"]
379
- role = get_role(message["role"])
380
- ret += "<start_of_turn>" + role + "\n"
381
- if content:
382
- ret += content + "<end_of_turn>\n"
383
- return ret
384
- elif prompt_style.style_name == "CodeShell":
385
- ret = ""
386
- for message in chat_history:
387
- content = message["content"]
388
- role = get_role(message["role"])
389
- if content:
390
- ret += f"{role}{content}|<end>|"
391
- else:
392
- ret += f"{role}".rstrip()
393
- return ret
394
- elif prompt_style.style_name == "MINICPM-2B":
395
- ret = ""
396
- for message in chat_history:
397
- content = message["content"] or ""
398
- role = get_role(message["role"])
399
- if role == "user":
400
- ret += "<用户>" + content.strip()
401
- else:
402
- ret += "<AI>" + content.strip()
403
- return ret
404
- elif prompt_style.style_name == "PHI3":
405
- ret = f"<|system|>{prompt_style.intra_message_sep}{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
406
- for message in chat_history:
407
- content = message["content"] or ""
408
- role = get_role(message["role"])
409
- if content:
410
- ret += f"<|{role}|>{prompt_style.intra_message_sep}{content}{prompt_style.inter_message_sep}"
411
- else:
412
- ret += f"<|{role}|>{prompt_style.intra_message_sep}"
413
- ret += "<|assistant|>\n"
414
- return ret
415
- elif prompt_style.style_name == "c4ai-command-r":
416
- ret = (
417
- f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>"
418
- f"{prompt_style.system_prompt}{prompt_style.inter_message_sep}"
124
+ if model_family == "internvl2":
125
+ system_prompt = (
126
+ messages[0]["content"] if messages[0]["role"] == "system" else ""
419
127
  )
420
- for i, message in enumerate(chat_history):
421
- role = get_role(message["role"])
422
- content = message["content"]
423
- if content:
424
- ret += f"{role}{content}{prompt_style.inter_message_sep}"
425
- else:
426
- ret += role
427
- return ret
428
- elif prompt_style.style_name == "mistral-nemo":
429
- seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
430
- ret = "<s>"
431
- for i, message in enumerate(chat_history):
432
- role = get_role(message["role"])
433
- content = message["content"]
434
- if content:
435
- if i == len(chat_history) - 2 and prompt_style.system_prompt:
436
- ret += (
437
- role
438
- + " "
439
- + prompt_style.system_prompt
440
- + "\n\n"
441
- + content
442
- + seps[i % 2]
443
- )
444
- else:
445
- ret += role + " " + content + seps[i % 2]
446
- else:
447
- ret += role
448
- return ret
449
- elif prompt_style.style_name == "INTERNVL":
128
+ intra_message_sep = "<|im_end|>"
450
129
  ret = (
451
130
  "<s>"
452
- if prompt_style.system_prompt == ""
131
+ if system_prompt == ""
453
132
  else "<s><|im_start|>system\n"
454
- + prompt_style.system_prompt
455
- + prompt_style.intra_message_sep
133
+ + system_prompt
134
+ + intra_message_sep
456
135
  + "\n"
457
136
  )
458
137
  images = [] # type: ignore
459
- for message in chat_history:
460
- role = get_role(message["role"])
138
+ for message in _messages:
139
+ role = "<|im_start|>" + message["role"]
461
140
  content = message["content"]
462
141
  if isinstance(content, str):
463
142
  if content:
464
- ret += (
465
- role
466
- + "\n"
467
- + content
468
- + prompt_style.intra_message_sep
469
- + "\n"
470
- )
143
+ ret += role + "\n" + content + intra_message_sep + "\n"
471
144
  else:
472
145
  ret += role + "\n"
473
146
  elif isinstance(content, list):
@@ -488,21 +161,15 @@ Begin!"""
488
161
  image_futures.append(fut)
489
162
  images = [fut.result() for fut in image_futures]
490
163
  if len(image_futures) == 0:
491
- ret += (
492
- role + "\n" + text + prompt_style.intra_message_sep + "\n"
493
- )
164
+ ret += role + "\n" + text + intra_message_sep + "\n"
494
165
  else:
495
166
  ret += (
496
- role
497
- + "\n"
498
- + f"<image>\n{text}"
499
- + prompt_style.intra_message_sep
500
- + "\n"
167
+ role + "\n" + f"<image>\n{text}" + intra_message_sep + "\n"
501
168
  )
502
169
 
503
- return (ret, images)
170
+ return ret, images
504
171
  else:
505
- raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
172
+ raise ValueError(f"Invalid model family: {model_family}")
506
173
 
507
174
  @classmethod
508
175
  def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
@@ -523,7 +190,11 @@ Begin!"""
523
190
  {
524
191
  "index": i,
525
192
  "delta": {
526
- "content": choice.get("text"),
193
+ **(
194
+ {"content": choice["text"]}
195
+ if ("text" in choice and choice["finish_reason"] is None)
196
+ else {}
197
+ ),
527
198
  **(
528
199
  {"tool_calls": choice["tool_calls"]}
529
200
  if "tool_calls" in choice
@@ -577,7 +248,6 @@ Begin!"""
577
248
  return cast(ChatCompletionChunk, chat_chunk)
578
249
 
579
250
  @classmethod
580
- @ensure_cache_cleared
581
251
  def _to_chat_completion_chunks(
582
252
  cls,
583
253
  chunks: Iterator[CompletionChunk],
@@ -610,7 +280,6 @@ Begin!"""
610
280
  i += 1
611
281
 
612
282
  @staticmethod
613
- @ensure_cache_cleared
614
283
  def _to_chat_completion(completion: Completion) -> ChatCompletion:
615
284
  return {
616
285
  "id": "chat" + completion["id"],
@@ -632,143 +301,89 @@ Begin!"""
632
301
  }
633
302
 
634
303
  @staticmethod
635
- def _eval_gorilla_openfunctions_arguments(c, tools):
636
- tool_names = [tool["function"]["name"] for tool in tools]
637
- arguments = c["choices"][0]["text"]
638
-
639
- def tool_call(n, **kwargs):
640
- return None, n, kwargs
641
-
642
- try:
643
- a, b, c = eval(
644
- arguments, {n: functools.partial(tool_call, n) for n in tool_names}
645
- )
646
- return a, b, c
647
- except Exception as e:
648
- logger.error("Eval tool calls completion failed: %s", e)
649
- return arguments, None, None
650
-
651
- @staticmethod
652
- def _eval_glm_chat_arguments(c, tools):
304
+ def _eval_glm_chat_arguments(c) -> List[Tuple]:
305
+ """
306
+ Currently, glm4 tool call only supports one function
307
+ """
653
308
  try:
654
- if isinstance(c[0], str):
655
- return c[0], None, None
656
- return None, c[0]["name"], c[0]["parameters"]
309
+ if isinstance(c, dict):
310
+ return [(None, c["name"], c["arguments"])]
657
311
  except KeyError:
658
312
  logger.error("Can't parse glm output: %s", c)
659
- return str(c), None, None
313
+ return [(str(c), None, None)]
314
+ else:
315
+ return [(str(c), None, None)]
660
316
 
661
- @staticmethod
662
- def _eval_qwen_chat_arguments(c, tools):
317
+ @classmethod
318
+ def _handle_qwen_tool_result(cls, text: str) -> List[Tuple]:
319
+ text: str = text.strip() # type: ignore
320
+ contents: List[str] = text.split(QWEN_TOOL_CALL_SYMBOLS[1])
321
+ results: List[Tuple] = []
322
+ for content in contents:
323
+ content = content.strip()
324
+ if content:
325
+ if content.startswith(QWEN_TOOL_CALL_SYMBOLS[0]):
326
+ content = content[len(QWEN_TOOL_CALL_SYMBOLS[0]) :]
327
+ content = content.strip()
328
+ try:
329
+ res = json.loads(content)
330
+ results.append((None, res["name"], res["arguments"]))
331
+ except Exception as e:
332
+ logger.error(
333
+ "Can't parse single qwen tool call output: %s. Error: %s",
334
+ content,
335
+ e,
336
+ )
337
+ results.append((content, None, None))
338
+ return results
339
+
340
+ @classmethod
341
+ def _eval_qwen_chat_arguments(cls, c) -> List[Tuple]:
663
342
  text = c["choices"][0]["text"]
664
- try:
665
- # Refer to:
666
- # https://github.com/QwenLM/Qwen/blob/main/examples/react_prompt.md
667
- # https://github.com/QwenLM/Qwen/blob/main/openai_api.py#L297
668
- func_name, func_args, content = "", "", ""
669
- i = text.rfind("\nAction:")
670
- j = text.rfind("\nAction Input:")
671
- k = text.rfind("\nObservation:")
672
- t = max(
673
- text.rfind("\nThought:", 0, i), text.rfind("Thought:", 0, i)
674
- ) # find the last thought just before Action, considering the Thought at the very beginning
675
- if 0 <= i < j: # If the text has `Action` and `Action input`,
676
- if k < j: # but does not contain `Observation`,
677
- # then it is likely that `Observation` is omitted by the LLM,
678
- # because the output text may have discarded the stop word.
679
- text = text.rstrip() + "\nObservation:" # Add it back.
680
- k = text.rfind("\nObservation:")
681
- if 0 <= t < i < j < k:
682
- func_name = text[i + len("\nAction:") : j].strip()
683
- func_args = text[j + len("\nAction Input:") : k].strip()
684
- content = text[
685
- t + len("\nThought:") : i
686
- ].strip() # len("\nThought:") and len("Thought:") both are OK since there is a space after :
687
- if func_name:
688
- return content, func_name, json.loads(func_args)
689
- except Exception as e:
690
- logger.error("Eval tool calls completion failed: %s", e)
691
- t = max(text.rfind("\nThought:"), text.rfind("Thought:"))
692
- z = max(text.rfind("\nFinal Answer:"), text.rfind("Final Answer:"))
693
- if z >= 0:
694
- text = text[
695
- z + len("\nFinal Answer:") :
696
- ] # len("\nFinal Answer::") and len("Final Answer::") both are OK since there is a space after :
697
- else:
698
- text = text[
699
- t + len("\nThought:") :
700
- ] # There is only Thought: no Final Answer:
701
- return text, None, None
343
+ return cls._handle_qwen_tool_result(text)
702
344
 
703
345
  @classmethod
704
- def _eval_tool_arguments(cls, model_family, c, tools):
346
+ def _eval_tool_arguments(cls, model_family, c):
705
347
  family = model_family.model_family or model_family.model_name
706
- if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
707
- content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
708
- elif family in GLM4_TOOL_CALL_FAMILY:
709
- content, func, args = cls._eval_glm_chat_arguments(c, tools)
348
+ if family in GLM4_TOOL_CALL_FAMILY:
349
+ result = cls._eval_glm_chat_arguments(c)
710
350
  elif family in QWEN_TOOL_CALL_FAMILY:
711
- content, func, args = cls._eval_qwen_chat_arguments(c, tools)
351
+ result = cls._eval_qwen_chat_arguments(c)
712
352
  else:
713
353
  raise Exception(
714
354
  f"Model {model_family.model_name} is not support tool calls."
715
355
  )
716
- logger.debug("Tool call content: %s, func: %s, args: %s", content, func, args)
717
- return content, func, args
718
-
719
- @classmethod
720
- def _tools_token_filter(cls, model_family):
721
- """
722
- Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
723
-
724
- Returns:
725
- A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
726
- returns the part after "\nFinal Answer:" if found, else returns delta.
727
- """
728
- family = model_family.model_family or model_family.model_name
729
- if family in QWEN_TOOL_CALL_FAMILY:
730
- # Encapsulating function to reset 'found' after each call
731
- found = False
732
-
733
- def process_tokens(tokens: str, delta: str):
734
- nonlocal found
735
- # Once "Final Answer:" is found, future tokens are allowed.
736
- if found:
737
- return delta
738
- # Check if the token ends with "\nFinal Answer:" and update `found`.
739
- final_answer_idx = tokens.lower().rfind("\nfinal answer:")
740
- if final_answer_idx != -1:
741
- found = True
742
- return tokens[final_answer_idx + len("\nfinal answer:") :]
743
- return ""
744
-
745
- return process_tokens
746
- else:
747
- return lambda tokens, delta: delta
356
+ logger.debug(f"Tool call content: {result}")
357
+ return result
748
358
 
749
359
  @classmethod
750
- def _tool_calls_completion_chunk(cls, model_family, model_uid, c, tools):
360
+ def _tool_calls_completion_chunk(cls, model_family, model_uid, c):
751
361
  _id = str(uuid.uuid4())
752
- content, func, args = cls._eval_tool_arguments(model_family, c, tools)
753
- if func:
754
- d = {
755
- "role": "assistant",
756
- "content": content,
757
- "tool_calls": [
758
- {
759
- "id": f"call_{_id}",
760
- "type": "function",
761
- "function": {
762
- "name": func,
763
- "arguments": json.dumps(args),
764
- },
765
- }
766
- ],
767
- }
768
- finish_reason = "tool_calls"
769
- else:
770
- d = {"role": "assistant", "content": content, "tool_calls": []}
771
- finish_reason = "stop"
362
+ tool_result = cls._eval_tool_arguments(model_family, c)
363
+ tool_calls = []
364
+ failed_contents = []
365
+ for content, func, args in tool_result:
366
+ if func:
367
+ tool_calls.append(
368
+ [
369
+ {
370
+ "id": f"call_{_id}",
371
+ "type": "function",
372
+ "function": {
373
+ "name": func,
374
+ "arguments": json.dumps(args, ensure_ascii=False),
375
+ },
376
+ }
377
+ ]
378
+ )
379
+ else:
380
+ failed_contents.append(content)
381
+ finish_reason = "tool_calls" if tool_calls else "stop"
382
+ d = {
383
+ "role": "assistant",
384
+ "content": ". ".join(failed_contents) if failed_contents else None,
385
+ "tool_calls": tool_calls,
386
+ }
772
387
  try:
773
388
  usage = c.get("usage")
774
389
  assert "prompt_tokens" in usage
@@ -795,28 +410,32 @@ Begin!"""
795
410
  }
796
411
 
797
412
  @classmethod
798
- def _tool_calls_completion(cls, model_family, model_uid, c, tools):
413
+ def _tool_calls_completion(cls, model_family, model_uid, c):
799
414
  _id = str(uuid.uuid4())
800
- content, func, args = cls._eval_tool_arguments(model_family, c, tools)
801
- if func:
802
- m = {
803
- "role": "assistant",
804
- "content": content,
805
- "tool_calls": [
415
+ tool_result = cls._eval_tool_arguments(model_family, c)
416
+
417
+ tool_calls = []
418
+ failed_contents = []
419
+ for content, func, args in tool_result:
420
+ if func:
421
+ tool_calls.append(
806
422
  {
807
423
  "id": f"call_{_id}",
808
424
  "type": "function",
809
425
  "function": {
810
426
  "name": func,
811
- "arguments": json.dumps(args),
427
+ "arguments": json.dumps(args, ensure_ascii=False),
812
428
  },
813
429
  }
814
- ],
815
- }
816
- finish_reason = "tool_calls"
817
- else:
818
- m = {"role": "assistant", "content": content, "tool_calls": []}
819
- finish_reason = "stop"
430
+ )
431
+ else:
432
+ failed_contents.append(content)
433
+ finish_reason = "tool_calls" if tool_calls else "stop"
434
+ m = {
435
+ "role": "assistant",
436
+ "content": ". ".join(failed_contents) if failed_contents else None,
437
+ "tool_calls": tool_calls,
438
+ }
820
439
  try:
821
440
  usage = c.get("usage")
822
441
  assert "prompt_tokens" in usage
@@ -841,16 +460,6 @@ Begin!"""
841
460
  "usage": usage,
842
461
  }
843
462
 
844
- @classmethod
845
- def get_full_prompt(cls, model_family, prompt, system_prompt, chat_history, tools):
846
- assert model_family.prompt_style is not None
847
- prompt_style = model_family.prompt_style.copy()
848
- if system_prompt:
849
- prompt_style.system_prompt = system_prompt
850
- chat_history = chat_history or []
851
- full_prompt = cls.get_prompt(prompt, chat_history, prompt_style, tools=tools)
852
- return full_prompt
853
-
854
463
 
855
464
  def get_file_location(
856
465
  llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
@@ -903,3 +512,120 @@ def _decode_image(_url):
903
512
  return Image.open(_url).convert("RGB")
904
513
  else:
905
514
  return Image.open(BytesIO(response.content)).convert("RGB")
515
+
516
+
517
+ @typing.no_type_check
518
+ def generate_completion_chunk(
519
+ chunk_text: Optional[str],
520
+ finish_reason: Optional[str],
521
+ chunk_id: str,
522
+ model_uid: str,
523
+ prompt_tokens: int,
524
+ completion_tokens: int,
525
+ total_tokens: int,
526
+ has_choice: bool = True,
527
+ has_content: bool = True,
528
+ ):
529
+ choices = []
530
+ if has_choice:
531
+ choices.append(
532
+ CompletionChoice(
533
+ text=chunk_text, index=0, logprobs=None, finish_reason=finish_reason
534
+ )
535
+ if has_content
536
+ else CompletionChoice(index=0, logprobs=None, finish_reason=finish_reason)
537
+ )
538
+ return CompletionChunk(
539
+ id=chunk_id,
540
+ object="text_completion",
541
+ created=int(time.time()),
542
+ model=model_uid,
543
+ choices=choices,
544
+ usage=CompletionUsage(
545
+ prompt_tokens=prompt_tokens,
546
+ completion_tokens=completion_tokens,
547
+ total_tokens=total_tokens,
548
+ ),
549
+ )
550
+
551
+
552
+ def generate_completion(
553
+ model_uid: str,
554
+ response: str,
555
+ prompt_tokens=-1,
556
+ completion_tokens=-1,
557
+ total_tokens=-1,
558
+ finish_reason="stop",
559
+ ) -> Completion:
560
+ return Completion(
561
+ id=str(uuid.uuid1()),
562
+ object="text_completion",
563
+ created=int(time.time()),
564
+ model=model_uid,
565
+ choices=[
566
+ CompletionChoice(
567
+ text=response, index=0, logprobs=None, finish_reason=finish_reason
568
+ )
569
+ ],
570
+ usage=CompletionUsage(
571
+ prompt_tokens=prompt_tokens,
572
+ completion_tokens=completion_tokens,
573
+ total_tokens=total_tokens,
574
+ ),
575
+ )
576
+
577
+
578
+ def generate_chat_completion(
579
+ model_uid: str,
580
+ response: str,
581
+ prompt_tokens=-1,
582
+ completion_tokens=-1,
583
+ total_tokens=-1,
584
+ finish_reason="stop",
585
+ ) -> ChatCompletion:
586
+ return ChatCompletion(
587
+ id="chat" + str(uuid.uuid1()),
588
+ object="chat.completion",
589
+ created=int(time.time()),
590
+ model=model_uid,
591
+ choices=[
592
+ ChatCompletionChoice(
593
+ index=0,
594
+ message={"role": "assistant", "content": response},
595
+ finish_reason=finish_reason,
596
+ )
597
+ ],
598
+ usage=CompletionUsage(
599
+ prompt_tokens=prompt_tokens,
600
+ completion_tokens=completion_tokens,
601
+ total_tokens=total_tokens,
602
+ ),
603
+ )
604
+
605
+
606
+ @functools.lru_cache
607
+ def get_stop_token_ids_from_config_file(model_path: str) -> Optional[List[int]]:
608
+ from transformers import GenerationConfig as TransformersGenerationConfig
609
+
610
+ transformers_config = TransformersGenerationConfig.from_pretrained(model_path)
611
+ if transformers_config.eos_token_id is not None:
612
+ stop_token_ids = (
613
+ transformers_config.eos_token_id
614
+ if isinstance(transformers_config.eos_token_id, list)
615
+ else [transformers_config.eos_token_id]
616
+ )
617
+ return stop_token_ids
618
+ return None
619
+
620
+
621
+ def parse_messages(messages: List[Dict]) -> Tuple:
622
+ """
623
+ Some older models still follow the old way of parameter passing.
624
+ This function helps to parse out the needed information from OpenAI-compatible `messages`.
625
+ """
626
+ system_messages = [mess["content"] for mess in messages if mess["role"] == "system"]
627
+ content_messages = [mess for mess in messages if mess["role"] != "system"]
628
+ prompt = content_messages[-1]["content"]
629
+ system_prompt = ". ".join(system_messages) if system_messages else None
630
+ chat_history = content_messages[:-1]
631
+ return prompt, system_prompt, chat_history