xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (149) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +5 -39
  4. xinference/client/restful/restful_client.py +3 -24
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/model.py +82 -31
  11. xinference/core/scheduler.py +37 -37
  12. xinference/core/status_guard.py +1 -1
  13. xinference/core/supervisor.py +11 -10
  14. xinference/core/utils.py +80 -22
  15. xinference/core/worker.py +17 -16
  16. xinference/deploy/cmdline.py +8 -16
  17. xinference/deploy/local.py +1 -1
  18. xinference/deploy/supervisor.py +1 -1
  19. xinference/deploy/utils.py +1 -1
  20. xinference/deploy/worker.py +1 -1
  21. xinference/model/audio/cosyvoice.py +86 -41
  22. xinference/model/embedding/core.py +52 -31
  23. xinference/model/image/stable_diffusion/core.py +18 -1
  24. xinference/model/llm/__init__.py +21 -11
  25. xinference/model/llm/llama_cpp/core.py +16 -33
  26. xinference/model/llm/llm_family.json +619 -1297
  27. xinference/model/llm/llm_family.py +31 -52
  28. xinference/model/llm/llm_family_csghub.json +18 -35
  29. xinference/model/llm/llm_family_modelscope.json +573 -1119
  30. xinference/model/llm/lmdeploy/core.py +56 -88
  31. xinference/model/llm/mlx/core.py +46 -69
  32. xinference/model/llm/sglang/core.py +33 -18
  33. xinference/model/llm/transformers/chatglm.py +167 -305
  34. xinference/model/llm/transformers/cogvlm2.py +36 -63
  35. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  36. xinference/model/llm/transformers/core.py +49 -50
  37. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  38. xinference/model/llm/transformers/glm4v.py +55 -111
  39. xinference/model/llm/transformers/intern_vl.py +39 -70
  40. xinference/model/llm/transformers/internlm2.py +32 -54
  41. xinference/model/llm/transformers/minicpmv25.py +22 -55
  42. xinference/model/llm/transformers/minicpmv26.py +158 -68
  43. xinference/model/llm/transformers/omnilmm.py +5 -28
  44. xinference/model/llm/transformers/qwen2_vl.py +208 -0
  45. xinference/model/llm/transformers/qwen_vl.py +34 -86
  46. xinference/model/llm/transformers/utils.py +32 -38
  47. xinference/model/llm/transformers/yi_vl.py +32 -72
  48. xinference/model/llm/utils.py +195 -489
  49. xinference/model/llm/vllm/core.py +153 -100
  50. xinference/model/rerank/core.py +41 -8
  51. xinference/model/rerank/model_spec.json +7 -0
  52. xinference/model/rerank/model_spec_modelscope.json +7 -1
  53. xinference/model/utils.py +1 -31
  54. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  55. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  56. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  57. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  58. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  59. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  60. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  61. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  62. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  63. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  64. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  65. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  66. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  67. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  69. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  70. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  71. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  72. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  73. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  74. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  75. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  76. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  77. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  78. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  79. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  80. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  81. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
  82. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  83. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  84. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  85. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  88. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  89. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  90. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  91. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  92. xinference/thirdparty/matcha/VERSION +1 -0
  93. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  94. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  95. xinference/thirdparty/omnilmm/LICENSE +201 -0
  96. xinference/thirdparty/whisper/__init__.py +156 -0
  97. xinference/thirdparty/whisper/__main__.py +3 -0
  98. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  99. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  100. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  101. xinference/thirdparty/whisper/audio.py +157 -0
  102. xinference/thirdparty/whisper/decoding.py +826 -0
  103. xinference/thirdparty/whisper/model.py +314 -0
  104. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  105. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  106. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  107. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  108. xinference/thirdparty/whisper/timing.py +386 -0
  109. xinference/thirdparty/whisper/tokenizer.py +395 -0
  110. xinference/thirdparty/whisper/transcribe.py +605 -0
  111. xinference/thirdparty/whisper/triton_ops.py +109 -0
  112. xinference/thirdparty/whisper/utils.py +316 -0
  113. xinference/thirdparty/whisper/version.py +1 -0
  114. xinference/types.py +7 -49
  115. xinference/web/ui/build/asset-manifest.json +6 -6
  116. xinference/web/ui/build/index.html +1 -1
  117. xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
  118. xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
  119. xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
  120. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
  121. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
  122. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
  123. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  124. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
  125. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  126. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  127. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  129. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  130. xinference/web/ui/node_modules/.package-lock.json +37 -0
  131. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  132. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  133. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  134. xinference/web/ui/package-lock.json +38 -0
  135. xinference/web/ui/package.json +1 -0
  136. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
  137. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
  138. xinference/model/llm/transformers/llama_2.py +0 -108
  139. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  140. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  141. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  142. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  144. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  145. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  146. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
  147. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
  148. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
  149. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import uuid
16
+ from typing import Iterator, List, Optional, Union
17
+
18
+ from ....model.utils import select_device
19
+ from ....types import (
20
+ ChatCompletion,
21
+ ChatCompletionChunk,
22
+ ChatCompletionMessage,
23
+ CompletionChunk,
24
+ )
25
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
26
+ from ..utils import generate_chat_completion, generate_completion_chunk
27
+ from .core import PytorchChatModel, PytorchGenerateConfig
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class Qwen2VLChatModel(PytorchChatModel):
33
+ def __init__(self, *args, **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+ self._tokenizer = None
36
+ self._model = None
37
+ self._device = None
38
+ self._processor = None
39
+
40
+ @classmethod
41
+ def match(
42
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
43
+ ) -> bool:
44
+ llm_family = model_family.model_family or model_family.model_name
45
+ if "qwen2-vl-instruct".lower() in llm_family.lower():
46
+ return True
47
+ return False
48
+
49
+ def load(self):
50
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
51
+
52
+ device = self._pytorch_model_config.get("device", "auto")
53
+ device = select_device(device)
54
+ self._device = device
55
+ # for multiple GPU, set back to auto to make multiple devices work
56
+ device = "auto" if device == "cuda" else device
57
+
58
+ self._processor = AutoProcessor.from_pretrained(
59
+ self.model_path, trust_remote_code=True
60
+ )
61
+ self._tokenizer = self._processor.tokenizer
62
+ self._model = Qwen2VLForConditionalGeneration.from_pretrained(
63
+ self.model_path, device_map=device, trust_remote_code=True
64
+ ).eval()
65
+
66
+ def _transform_messages(
67
+ self,
68
+ messages: List[ChatCompletionMessage],
69
+ ):
70
+ transformed_messages = []
71
+ for msg in messages:
72
+ new_content = []
73
+ role = msg["role"]
74
+ content = msg["content"]
75
+ if isinstance(content, str):
76
+ new_content.append({"type": "text", "text": content})
77
+ elif isinstance(content, List):
78
+ for item in content: # type: ignore
79
+ if "text" in item:
80
+ new_content.append({"type": "text", "text": item["text"]})
81
+ elif "image_url" in item:
82
+ new_content.append(
83
+ {"type": "image", "image": item["image_url"]["url"]}
84
+ )
85
+ elif "video_url" in item:
86
+ new_content.append(
87
+ {"type": "video", "video": item["video_url"]["url"]}
88
+ )
89
+ new_message = {"role": role, "content": new_content}
90
+ transformed_messages.append(new_message)
91
+
92
+ return transformed_messages
93
+
94
+ def chat(
95
+ self,
96
+ messages: List[ChatCompletionMessage], # type: ignore
97
+ generate_config: Optional[PytorchGenerateConfig] = None,
98
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
99
+ messages = self._transform_messages(messages)
100
+
101
+ generate_config = generate_config if generate_config else {}
102
+
103
+ stream = generate_config.get("stream", False) if generate_config else False
104
+
105
+ if stream:
106
+ it = self._generate_stream(messages, generate_config)
107
+ return self._to_chat_completion_chunks(it)
108
+ else:
109
+ c = self._generate(messages, generate_config)
110
+ return c
111
+
112
+ def _generate(
113
+ self, messages: List, config: PytorchGenerateConfig = {}
114
+ ) -> ChatCompletion:
115
+ from qwen_vl_utils import process_vision_info
116
+
117
+ # Preparation for inference
118
+ text = self._processor.apply_chat_template(
119
+ messages, tokenize=False, add_generation_prompt=True
120
+ )
121
+ image_inputs, video_inputs = process_vision_info(messages)
122
+ inputs = self._processor(
123
+ text=[text],
124
+ images=image_inputs,
125
+ videos=video_inputs,
126
+ padding=True,
127
+ return_tensors="pt",
128
+ )
129
+ inputs = inputs.to("cuda")
130
+
131
+ # Inference: Generation of the output
132
+ generated_ids = self._model.generate(
133
+ **inputs,
134
+ max_new_tokens=config.get("max_tokens", 512),
135
+ temperature=config.get("temperature", 1),
136
+ )
137
+ generated_ids_trimmed = [
138
+ out_ids[len(in_ids) :]
139
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
140
+ ]
141
+ output_text = self._processor.batch_decode(
142
+ generated_ids_trimmed,
143
+ skip_special_tokens=True,
144
+ clean_up_tokenization_spaces=False,
145
+ )[0]
146
+ return generate_chat_completion(self.model_uid, output_text)
147
+
148
+ def _generate_stream(
149
+ self, messages: List, config: PytorchGenerateConfig = {}
150
+ ) -> Iterator[CompletionChunk]:
151
+ from threading import Thread
152
+
153
+ from qwen_vl_utils import process_vision_info
154
+ from transformers import TextIteratorStreamer
155
+
156
+ text = self._processor.apply_chat_template(
157
+ messages, tokenize=False, add_generation_prompt=True
158
+ )
159
+ image_inputs, video_inputs = process_vision_info(messages)
160
+ inputs = self._processor(
161
+ text=[text],
162
+ images=image_inputs,
163
+ videos=video_inputs,
164
+ padding=True,
165
+ return_tensors="pt",
166
+ )
167
+ inputs = inputs.to(self._model.device)
168
+
169
+ tokenizer = self._tokenizer
170
+ streamer = TextIteratorStreamer(
171
+ tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
172
+ )
173
+
174
+ gen_kwargs = {
175
+ "max_new_tokens": config.get("max_tokens", 512),
176
+ "temperature": config.get("temperature", 1),
177
+ "streamer": streamer,
178
+ **inputs,
179
+ }
180
+
181
+ thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
182
+ thread.start()
183
+
184
+ completion_id = str(uuid.uuid1())
185
+ for new_text in streamer:
186
+ yield generate_completion_chunk(
187
+ chunk_text=new_text,
188
+ finish_reason=None,
189
+ chunk_id=completion_id,
190
+ model_uid=self.model_uid,
191
+ prompt_tokens=-1,
192
+ completion_tokens=-1,
193
+ total_tokens=-1,
194
+ has_choice=True,
195
+ has_content=True,
196
+ )
197
+
198
+ yield generate_completion_chunk(
199
+ chunk_text=None,
200
+ finish_reason="stop",
201
+ chunk_id=completion_id,
202
+ model_uid=self.model_uid,
203
+ prompt_tokens=-1,
204
+ completion_tokens=-1,
205
+ total_tokens=-1,
206
+ has_choice=True,
207
+ has_content=False,
208
+ )
@@ -15,7 +15,6 @@ import base64
15
15
  import logging
16
16
  import operator
17
17
  import tempfile
18
- import time
19
18
  import typing
20
19
  import uuid
21
20
  from typing import Dict, Iterator, List, Optional, Tuple, Union
@@ -25,16 +24,9 @@ from transformers import PreTrainedTokenizer
25
24
 
26
25
  from ....core.scheduler import InferenceRequest
27
26
  from ....model.utils import select_device
28
- from ....types import (
29
- ChatCompletion,
30
- ChatCompletionChunk,
31
- ChatCompletionMessage,
32
- Completion,
33
- CompletionChoice,
34
- CompletionChunk,
35
- CompletionUsage,
36
- )
27
+ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
37
28
  from ..llm_family import LLMFamilyV1, LLMSpecV1
29
+ from ..utils import generate_chat_completion, generate_completion_chunk
38
30
  from .core import PytorchChatModel, PytorchGenerateConfig
39
31
  from .utils import pad_prefill_tokens
40
32
 
@@ -53,7 +45,7 @@ class QwenVLChatModel(PytorchChatModel):
53
45
  cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
54
46
  ) -> bool:
55
47
  llm_family = model_family.model_family or model_family.model_name
56
- if "qwen" in llm_family and "vision" in model_family.model_ability:
48
+ if "qwen-" in llm_family and "vision" in model_family.model_ability:
57
49
  return True
58
50
  return False
59
51
 
@@ -129,18 +121,12 @@ class QwenVLChatModel(PytorchChatModel):
129
121
  return self._tokenizer.from_list_format(content)
130
122
  return content
131
123
 
132
- def _get_prompt_and_chat_history(
133
- self,
134
- prompt: Union[str, List[Dict]],
135
- chat_history: Optional[List[ChatCompletionMessage]] = None,
136
- ):
137
- prompt = self._message_content_to_qwen(prompt)
138
- # Convert openai history to qwen vl history
124
+ def _get_prompt_and_chat_history(self, messages: List[Dict]):
139
125
  qwen_history = []
140
126
  query_to_response: List = []
141
- for h in chat_history or []:
142
- role = h["role"]
143
- content = self._message_content_to_qwen(h["content"])
127
+ for message in messages[:-1]:
128
+ role = message["role"]
129
+ content = self._message_content_to_qwen(message["content"])
144
130
  if len(query_to_response) == 0 and role == "user":
145
131
  query_to_response.append(content)
146
132
  if len(query_to_response) == 1 and role == "assistant":
@@ -148,18 +134,15 @@ class QwenVLChatModel(PytorchChatModel):
148
134
  if len(query_to_response) == 2:
149
135
  qwen_history.append(query_to_response)
150
136
  query_to_response = []
137
+ prompt = self._message_content_to_qwen(messages[-1]["content"])
151
138
  return prompt, qwen_history
152
139
 
153
140
  def chat(
154
141
  self,
155
- prompt: Union[str, List[Dict]],
156
- system_prompt: Optional[str] = None,
157
- chat_history: Optional[List[ChatCompletionMessage]] = None,
142
+ messages: List[Dict],
158
143
  generate_config: Optional[PytorchGenerateConfig] = None,
159
144
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
160
- prompt, qwen_history = self._get_prompt_and_chat_history(
161
- prompt, chat_history=chat_history
162
- )
145
+ prompt, qwen_history = self._get_prompt_and_chat_history(messages)
163
146
 
164
147
  stream = generate_config.get("stream", False) if generate_config else False
165
148
  stream_options = (
@@ -174,33 +157,17 @@ class QwenVLChatModel(PytorchChatModel):
174
157
  it = self._generate_stream(prompt, qwen_history, include_usage) # type: ignore
175
158
  return self._to_chat_completion_chunks(it)
176
159
  else:
177
- c = self._generate(prompt, qwen_history) # type: ignore
178
- return self._to_chat_completion(c)
160
+ return self._generate(prompt, qwen_history) # type: ignore
179
161
 
180
- def _generate(self, prompt: str, qwen_history: List) -> Completion:
162
+ def _generate(self, prompt: str, qwen_history: List) -> ChatCompletion:
181
163
  response, history = self._model.chat(
182
164
  self._tokenizer, query=prompt, history=qwen_history
183
165
  )
184
- c = Completion(
185
- id=str(uuid.uuid1()),
186
- object="text_completion",
187
- created=int(time.time()),
188
- model=self.model_uid,
189
- choices=[
190
- CompletionChoice(
191
- index=0, text=response, finish_reason="stop", logprobs=None
192
- )
193
- ],
194
- usage=CompletionUsage(
195
- prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
196
- ),
197
- )
198
- return c
166
+ return generate_chat_completion(self.model_uid, response)
199
167
 
200
168
  def _generate_stream(
201
169
  self, prompt: str, qwen_history: List, include_usage
202
170
  ) -> Iterator[CompletionChunk]:
203
- # response, history = model.chat(tokenizer, message, history=history)
204
171
  response_generator = self._model.chat_stream(
205
172
  self._tokenizer, query=prompt, history=qwen_history
206
173
  )
@@ -212,57 +179,40 @@ class QwenVLChatModel(PytorchChatModel):
212
179
  for response in response_generator:
213
180
  inc_content = response[len(full_response) :]
214
181
  full_response = response
215
- completion_choice = CompletionChoice(
216
- text=inc_content, index=0, logprobs=None, finish_reason=None
217
- )
218
- completion_chunk = CompletionChunk(
219
- id=completion_id,
220
- object="text_completion",
221
- created=int(time.time()),
222
- model=self.model_uid,
223
- choices=[completion_choice],
224
- )
225
182
  completion_tokens = completion_tokens + 1
226
183
  total_tokens = prompt_tokens + completion_tokens
227
- completion_usage = CompletionUsage(
184
+ yield generate_completion_chunk(
185
+ chunk_text=inc_content,
186
+ finish_reason=None,
187
+ chunk_id=completion_id,
188
+ model_uid=self.model_uid,
228
189
  prompt_tokens=prompt_tokens,
229
190
  completion_tokens=completion_tokens,
230
191
  total_tokens=total_tokens,
231
192
  )
232
- completion_chunk["usage"] = completion_usage
233
- yield completion_chunk
234
-
235
- completion_choice = CompletionChoice(
236
- text="", index=0, logprobs=None, finish_reason="stop"
237
- )
238
- completion_chunk = CompletionChunk(
239
- id=completion_id,
240
- object="text_completion",
241
- created=int(time.time()),
242
- model=self.model_uid,
243
- choices=[completion_choice],
244
- )
245
- completion_usage = CompletionUsage(
193
+ yield generate_completion_chunk(
194
+ chunk_text=None,
195
+ finish_reason="stop",
196
+ chunk_id=completion_id,
197
+ model_uid=self.model_uid,
246
198
  prompt_tokens=prompt_tokens,
247
199
  completion_tokens=completion_tokens,
248
200
  total_tokens=total_tokens,
201
+ has_choice=True,
202
+ has_content=False,
249
203
  )
250
- completion_chunk["usage"] = completion_usage
251
- yield completion_chunk
252
204
  if include_usage:
253
- chunk = CompletionChunk(
254
- id=completion_id,
255
- object="text_completion",
256
- created=int(time.time()),
257
- model=self.model_uid,
258
- choices=[],
259
- )
260
- chunk["usage"] = CompletionUsage(
205
+ yield generate_completion_chunk(
206
+ chunk_text=None,
207
+ finish_reason=None,
208
+ chunk_id=completion_id,
209
+ model_uid=self.model_uid,
261
210
  prompt_tokens=prompt_tokens,
262
211
  completion_tokens=completion_tokens,
263
212
  total_tokens=total_tokens,
213
+ has_choice=False,
214
+ has_content=False,
264
215
  )
265
- yield chunk
266
216
 
267
217
  @staticmethod
268
218
  def get_batch_size_and_seq_len_indexes_from_kv() -> Tuple[int, int]:
@@ -359,10 +309,8 @@ class QwenVLChatModel(PytorchChatModel):
359
309
 
360
310
  return raw_text, context_tokens
361
311
 
362
- def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
363
- prompt, qwen_history = self._get_prompt_and_chat_history(
364
- prompt, chat_history=chat_history
365
- )
312
+ def _get_full_prompt(self, messages: List[Dict], tools):
313
+ prompt, qwen_history = self._get_prompt_and_chat_history(messages)
366
314
  _, context_tokens = self.make_context(self._tokenizer, prompt, qwen_history)
367
315
  return context_tokens
368
316
 
@@ -321,7 +321,7 @@ def generate_stream(
321
321
 
322
322
  if stream:
323
323
  completion_choice = CompletionChoice(
324
- text="", index=0, logprobs=None, finish_reason=finish_reason
324
+ text=output, index=0, logprobs=None, finish_reason=finish_reason
325
325
  )
326
326
  else:
327
327
  completion_choice = CompletionChoice(
@@ -430,39 +430,6 @@ def pad_prefill_tokens(
430
430
  return prompt_tokens
431
431
 
432
432
 
433
- def _get_completion_chunk(
434
- output: str,
435
- chunk_id: str,
436
- finish_reason: Optional[str],
437
- model_uid: str,
438
- r: InferenceRequest,
439
- just_usage: bool,
440
- ):
441
- completion_choice = (
442
- [
443
- CompletionChoice(
444
- text=output, index=0, logprobs=None, finish_reason=finish_reason
445
- )
446
- ]
447
- if not just_usage
448
- else []
449
- )
450
- completion_chunk = CompletionChunk(
451
- id=chunk_id,
452
- object="text_completion",
453
- created=int(time.time()),
454
- model=model_uid,
455
- choices=completion_choice,
456
- )
457
- completion_usage = CompletionUsage(
458
- prompt_tokens=len(r.prompt_tokens),
459
- completion_tokens=len(r.new_tokens),
460
- total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
461
- )
462
- completion_chunk["usage"] = completion_usage
463
- return completion_chunk
464
-
465
-
466
433
  def _get_completion(
467
434
  output: str,
468
435
  chunk_id: str,
@@ -551,6 +518,8 @@ def _batch_inference_one_step_internal(
551
518
  bos_flag: str = "<bos_stream>",
552
519
  eos_flag: str = "<eos_stream>",
553
520
  ):
521
+ from ..utils import generate_completion_chunk
522
+
554
523
  # need to judge stopped here,
555
524
  # since some requests state may change to stopped due to invalid parameters, e.g. max_src_len
556
525
  valid_req_list = [r for r in req_list if not r.stopped]
@@ -710,11 +679,28 @@ def _batch_inference_one_step_internal(
710
679
  output = output[r.last_output_length :]
711
680
  r.last_output_length += len(output)
712
681
 
713
- completion_chunk = _get_completion_chunk(
714
- output, r.chunk_id, r.finish_reason, model_uid, r, False
682
+ completion_chunk = generate_completion_chunk(
683
+ chunk_text=output,
684
+ finish_reason=None,
685
+ chunk_id=r.chunk_id,
686
+ model_uid=model_uid,
687
+ prompt_tokens=len(r.prompt_tokens),
688
+ completion_tokens=len(r.new_tokens),
689
+ total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
715
690
  )
716
691
  r.completion.append(completion_chunk)
717
692
  if r.stopped:
693
+ # OpenAI compatible chunk
694
+ completion_chunk = generate_completion_chunk(
695
+ chunk_text="",
696
+ finish_reason=r.finish_reason,
697
+ chunk_id=r.chunk_id,
698
+ model_uid=model_uid,
699
+ prompt_tokens=len(r.prompt_tokens),
700
+ completion_tokens=len(r.new_tokens),
701
+ total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
702
+ )
703
+ r.completion.append(completion_chunk)
718
704
  r.completion.append(eos_flag)
719
705
 
720
706
  # last round, handle stream result
@@ -723,8 +709,16 @@ def _batch_inference_one_step_internal(
723
709
  # these tokens are real generated and should be counted.
724
710
  if r.stopped and _i == decode_round - 1 and include_usage:
725
711
  r.completion.append(
726
- _get_completion_chunk(
727
- "", r.chunk_id, r.finish_reason, model_uid, r, True
712
+ generate_completion_chunk(
713
+ chunk_text=None,
714
+ finish_reason=None,
715
+ chunk_id=r.chunk_id,
716
+ model_uid=model_uid,
717
+ prompt_tokens=len(r.prompt_tokens),
718
+ completion_tokens=len(r.new_tokens),
719
+ total_tokens=len(r.prompt_tokens) + len(r.new_tokens),
720
+ has_choice=False,
721
+ has_content=False,
728
722
  )
729
723
  )
730
724
  else:
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
- import time
16
15
  import uuid
17
16
  from concurrent.futures import ThreadPoolExecutor
18
17
  from threading import Thread
@@ -21,17 +20,14 @@ from typing import Dict, Iterator, List, Optional, Union
21
20
  import torch
22
21
 
23
22
  from ....model.utils import select_device
24
- from ....types import (
25
- ChatCompletion,
26
- ChatCompletionChunk,
27
- ChatCompletionMessage,
28
- Completion,
29
- CompletionChoice,
30
- CompletionChunk,
31
- CompletionUsage,
32
- )
23
+ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
33
24
  from ..llm_family import LLMFamilyV1, LLMSpecV1
34
- from ..utils import _decode_image
25
+ from ..utils import (
26
+ _decode_image,
27
+ generate_chat_completion,
28
+ generate_completion_chunk,
29
+ parse_messages,
30
+ )
35
31
  from .core import PytorchChatModel, PytorchGenerateConfig
36
32
 
37
33
  logger = logging.getLogger(__name__)
@@ -105,15 +101,11 @@ class YiVLChatModel(PytorchChatModel):
105
101
 
106
102
  def chat(
107
103
  self,
108
- prompt: Union[str, List[Dict]],
109
- system_prompt: Optional[str] = None,
110
- chat_history: Optional[List[ChatCompletionMessage]] = None,
104
+ messages: List[Dict],
111
105
  generate_config: Optional[PytorchGenerateConfig] = None,
112
106
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
113
107
  from transformers import TextIteratorStreamer
114
108
 
115
- # TODO(codingl2k1): implement stream mode.
116
-
117
109
  if not generate_config:
118
110
  generate_config = {}
119
111
 
@@ -134,7 +126,8 @@ class YiVLChatModel(PytorchChatModel):
134
126
 
135
127
  # Convert chat history to llava state
136
128
  state = conv_templates["mm_default"].copy()
137
- for message in chat_history or []:
129
+ prompt, _, chat_history = parse_messages(messages)
130
+ for message in chat_history:
138
131
  content = self._message_content_to_yi(message["content"])
139
132
  state.append_message(message["role"], content)
140
133
  state.append_message(state.roles[0], self._message_content_to_yi(prompt))
@@ -190,31 +183,15 @@ class YiVLChatModel(PytorchChatModel):
190
183
  it = self._generate_stream(streamer, stop_str, input_ids, include_usage)
191
184
  return self._to_chat_completion_chunks(it)
192
185
  else:
193
- c = self._generate(streamer, stop_str)
194
- return self._to_chat_completion(c)
186
+ return self._generate(streamer, stop_str)
195
187
 
196
- def _generate(self, streamer, stop_str) -> Completion:
188
+ def _generate(self, streamer, stop_str) -> ChatCompletion:
197
189
  generated_text = ""
198
190
  for new_text in streamer:
199
191
  generated_text += new_text
200
192
  if generated_text.endswith(stop_str):
201
193
  generated_text = generated_text[: -len(stop_str)]
202
-
203
- c = Completion(
204
- id=str(uuid.uuid1()),
205
- object="text_completion",
206
- created=int(time.time()),
207
- model=self.model_uid,
208
- choices=[
209
- CompletionChoice(
210
- index=0, text=generated_text, finish_reason="stop", logprobs=None
211
- )
212
- ],
213
- usage=CompletionUsage(
214
- prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
215
- ),
216
- )
217
- return c
194
+ return generate_chat_completion(self.model_uid, generated_text)
218
195
 
219
196
  def _generate_stream(
220
197
  self, streamer, stop_str, input_ids, include_usage
@@ -224,54 +201,37 @@ class YiVLChatModel(PytorchChatModel):
224
201
  prompt_tokens = len(input_ids[0])
225
202
  for i, new_text in enumerate(streamer):
226
203
  if not new_text.endswith(stop_str):
227
- completion_choice = CompletionChoice(
228
- text=new_text, index=0, logprobs=None, finish_reason=None
229
- )
230
- chunk = CompletionChunk(
231
- id=completion_id,
232
- object="text_completion",
233
- created=int(time.time()),
234
- model=self.model_uid,
235
- choices=[completion_choice],
236
- )
237
204
  completion_tokens = i
238
205
  total_tokens = prompt_tokens + completion_tokens
239
- completion_usage = CompletionUsage(
206
+ yield generate_completion_chunk(
207
+ chunk_text=new_text,
208
+ finish_reason=None,
209
+ chunk_id=completion_id,
210
+ model_uid=self.model_uid,
240
211
  prompt_tokens=prompt_tokens,
241
212
  completion_tokens=completion_tokens,
242
213
  total_tokens=total_tokens,
243
214
  )
244
- chunk["usage"] = completion_usage
245
- yield chunk
246
-
247
- completion_choice = CompletionChoice(
248
- text="", index=0, logprobs=None, finish_reason="stop"
249
- )
250
- chunk = CompletionChunk(
251
- id=completion_id,
252
- object="text_completion",
253
- created=int(time.time()),
254
- model=self.model_uid,
255
- choices=[completion_choice],
256
- )
257
- completion_usage = CompletionUsage(
215
+ yield generate_completion_chunk(
216
+ chunk_text=None,
217
+ finish_reason="stop",
218
+ chunk_id=completion_id,
219
+ model_uid=self.model_uid,
258
220
  prompt_tokens=prompt_tokens,
259
221
  completion_tokens=completion_tokens,
260
222
  total_tokens=total_tokens,
223
+ has_choice=True,
224
+ has_content=False,
261
225
  )
262
- chunk["usage"] = completion_usage
263
- yield chunk
264
226
  if include_usage:
265
- chunk = CompletionChunk(
266
- id=completion_id,
267
- object="text_completion",
268
- created=int(time.time()),
269
- model=self.model_uid,
270
- choices=[],
271
- )
272
- chunk["usage"] = CompletionUsage(
227
+ yield generate_completion_chunk(
228
+ chunk_text=None,
229
+ finish_reason=None,
230
+ chunk_id=completion_id,
231
+ model_uid=self.model_uid,
273
232
  prompt_tokens=prompt_tokens,
274
233
  completion_tokens=completion_tokens,
275
234
  total_tokens=total_tokens,
235
+ has_choice=False,
236
+ has_content=False,
276
237
  )
277
- yield chunk