xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +51 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +209 -40
- xinference/client/restful/restful_client.py +7 -26
- xinference/conftest.py +1 -1
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +1 -1
- xinference/core/chat_interface.py +8 -14
- xinference/core/event.py +1 -1
- xinference/core/image_interface.py +28 -0
- xinference/core/model.py +110 -31
- xinference/core/scheduler.py +37 -37
- xinference/core/status_guard.py +1 -1
- xinference/core/supervisor.py +17 -10
- xinference/core/utils.py +80 -22
- xinference/core/worker.py +17 -16
- xinference/deploy/cmdline.py +8 -16
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/deploy/utils.py +1 -1
- xinference/deploy/worker.py +1 -1
- xinference/model/audio/cosyvoice.py +86 -41
- xinference/model/audio/fish_speech.py +9 -9
- xinference/model/audio/model_spec.json +9 -9
- xinference/model/audio/whisper.py +4 -1
- xinference/model/embedding/core.py +52 -31
- xinference/model/image/core.py +2 -1
- xinference/model/image/model_spec.json +16 -4
- xinference/model/image/model_spec_modelscope.json +16 -4
- xinference/model/image/sdapi.py +136 -0
- xinference/model/image/stable_diffusion/core.py +164 -19
- xinference/model/llm/__init__.py +29 -11
- xinference/model/llm/llama_cpp/core.py +16 -33
- xinference/model/llm/llm_family.json +1011 -1296
- xinference/model/llm/llm_family.py +34 -53
- xinference/model/llm/llm_family_csghub.json +18 -35
- xinference/model/llm/llm_family_modelscope.json +981 -1122
- xinference/model/llm/lmdeploy/core.py +56 -88
- xinference/model/llm/mlx/core.py +46 -69
- xinference/model/llm/sglang/core.py +36 -18
- xinference/model/llm/transformers/chatglm.py +168 -306
- xinference/model/llm/transformers/cogvlm2.py +36 -63
- xinference/model/llm/transformers/cogvlm2_video.py +33 -223
- xinference/model/llm/transformers/core.py +55 -50
- xinference/model/llm/transformers/deepseek_v2.py +340 -0
- xinference/model/llm/transformers/deepseek_vl.py +53 -96
- xinference/model/llm/transformers/glm4v.py +55 -111
- xinference/model/llm/transformers/intern_vl.py +39 -70
- xinference/model/llm/transformers/internlm2.py +32 -54
- xinference/model/llm/transformers/minicpmv25.py +22 -55
- xinference/model/llm/transformers/minicpmv26.py +158 -68
- xinference/model/llm/transformers/omnilmm.py +5 -28
- xinference/model/llm/transformers/qwen2_audio.py +168 -0
- xinference/model/llm/transformers/qwen2_vl.py +234 -0
- xinference/model/llm/transformers/qwen_vl.py +34 -86
- xinference/model/llm/transformers/utils.py +32 -38
- xinference/model/llm/transformers/yi_vl.py +32 -72
- xinference/model/llm/utils.py +280 -554
- xinference/model/llm/vllm/core.py +161 -100
- xinference/model/rerank/core.py +41 -8
- xinference/model/rerank/model_spec.json +7 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -1
- xinference/model/utils.py +1 -31
- xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
- xinference/thirdparty/cosyvoice/cli/model.py +139 -26
- xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
- xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
- xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
- xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
- xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
- xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
- xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
- xinference/thirdparty/cosyvoice/utils/common.py +36 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
- xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
- xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
- xinference/thirdparty/fish_speech/tools/api.py +79 -134
- xinference/thirdparty/fish_speech/tools/commons.py +35 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
- xinference/thirdparty/fish_speech/tools/file.py +17 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
- xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
- xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
- xinference/thirdparty/fish_speech/tools/webui.py +12 -146
- xinference/thirdparty/matcha/VERSION +1 -0
- xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
- xinference/thirdparty/matcha/hifigan/README.md +101 -0
- xinference/thirdparty/omnilmm/LICENSE +201 -0
- xinference/thirdparty/whisper/__init__.py +156 -0
- xinference/thirdparty/whisper/__main__.py +3 -0
- xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
- xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
- xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
- xinference/thirdparty/whisper/audio.py +157 -0
- xinference/thirdparty/whisper/decoding.py +826 -0
- xinference/thirdparty/whisper/model.py +314 -0
- xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
- xinference/thirdparty/whisper/normalizers/basic.py +76 -0
- xinference/thirdparty/whisper/normalizers/english.json +1741 -0
- xinference/thirdparty/whisper/normalizers/english.py +550 -0
- xinference/thirdparty/whisper/timing.py +386 -0
- xinference/thirdparty/whisper/tokenizer.py +395 -0
- xinference/thirdparty/whisper/transcribe.py +605 -0
- xinference/thirdparty/whisper/triton_ops.py +109 -0
- xinference/thirdparty/whisper/utils.py +316 -0
- xinference/thirdparty/whisper/version.py +1 -0
- xinference/types.py +14 -53
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
- xinference/web/ui/build/static/js/main.754740c0.js +3 -0
- xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
- xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +37 -0
- xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
- xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
- xinference/web/ui/node_modules/nunjucks/package.json +112 -0
- xinference/web/ui/package-lock.json +38 -0
- xinference/web/ui/package.json +1 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
- xinference/model/llm/transformers/llama_2.py +0 -108
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
- xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
- xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
- xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
- xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import logging
|
|
15
|
-
import time
|
|
16
15
|
import uuid
|
|
17
16
|
from concurrent.futures import ThreadPoolExecutor
|
|
18
17
|
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
@@ -21,17 +20,14 @@ import torch
|
|
|
21
20
|
|
|
22
21
|
from ....core.scheduler import InferenceRequest
|
|
23
22
|
from ....model.utils import select_device
|
|
24
|
-
from ....types import
|
|
25
|
-
ChatCompletion,
|
|
26
|
-
ChatCompletionChunk,
|
|
27
|
-
ChatCompletionMessage,
|
|
28
|
-
Completion,
|
|
29
|
-
CompletionChoice,
|
|
30
|
-
CompletionChunk,
|
|
31
|
-
CompletionUsage,
|
|
32
|
-
)
|
|
23
|
+
from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
|
|
33
24
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
-
from ..utils import
|
|
25
|
+
from ..utils import (
|
|
26
|
+
_decode_image,
|
|
27
|
+
generate_chat_completion,
|
|
28
|
+
generate_completion_chunk,
|
|
29
|
+
parse_messages,
|
|
30
|
+
)
|
|
35
31
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
36
32
|
from .utils import get_max_src_len
|
|
37
33
|
|
|
@@ -139,9 +135,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
139
135
|
)
|
|
140
136
|
return content, None
|
|
141
137
|
|
|
142
|
-
def _history_content_to_cogvlm2(
|
|
143
|
-
self, system_prompt: str, chat_history: List[ChatCompletionMessage]
|
|
144
|
-
):
|
|
138
|
+
def _history_content_to_cogvlm2(self, system_prompt: str, chat_history: List[Dict]):
|
|
145
139
|
query = system_prompt
|
|
146
140
|
history: List[Tuple] = []
|
|
147
141
|
pixel_values = None
|
|
@@ -163,7 +157,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
163
157
|
self,
|
|
164
158
|
prompt: Union[str, List[Dict]],
|
|
165
159
|
system_prompt: Optional[str] = None,
|
|
166
|
-
chat_history: Optional[List[
|
|
160
|
+
chat_history: Optional[List[Dict]] = None,
|
|
167
161
|
):
|
|
168
162
|
content, image = self._message_content_to_cogvlm2(prompt)
|
|
169
163
|
|
|
@@ -184,12 +178,12 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
184
178
|
|
|
185
179
|
def chat(
|
|
186
180
|
self,
|
|
187
|
-
|
|
188
|
-
system_prompt: Optional[str] = None,
|
|
189
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
181
|
+
messages: List[Dict],
|
|
190
182
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
191
183
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
192
|
-
system_prompt =
|
|
184
|
+
system_prompt = ""
|
|
185
|
+
if messages[0]["role"] == "system":
|
|
186
|
+
system_prompt = messages[0]["content"]
|
|
193
187
|
stream = generate_config.get("stream", False) if generate_config else False
|
|
194
188
|
|
|
195
189
|
sanitized_config = {
|
|
@@ -199,6 +193,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
199
193
|
else 512,
|
|
200
194
|
}
|
|
201
195
|
|
|
196
|
+
prompt, _, chat_history = parse_messages(messages)
|
|
202
197
|
query, image, history = self.get_query_and_history(
|
|
203
198
|
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
204
199
|
)
|
|
@@ -236,21 +231,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
236
231
|
response = self._tokenizer.decode(outputs[0])
|
|
237
232
|
response = response.split("<|end_of_text|>")[0]
|
|
238
233
|
|
|
239
|
-
|
|
240
|
-
id=str(uuid.uuid1()),
|
|
241
|
-
object="text_completion",
|
|
242
|
-
created=int(time.time()),
|
|
243
|
-
model=self.model_uid,
|
|
244
|
-
choices=[
|
|
245
|
-
CompletionChoice(
|
|
246
|
-
index=0, text=response, finish_reason="stop", logprobs=None
|
|
247
|
-
)
|
|
248
|
-
],
|
|
249
|
-
usage=CompletionUsage(
|
|
250
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
251
|
-
),
|
|
252
|
-
)
|
|
253
|
-
return self._to_chat_completion(chunk)
|
|
234
|
+
return generate_chat_completion(self.model_uid, response)
|
|
254
235
|
|
|
255
236
|
def _streaming_chat_response(
|
|
256
237
|
self, inputs: Dict, config: Dict
|
|
@@ -277,36 +258,26 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
277
258
|
|
|
278
259
|
completion_id = str(uuid.uuid1())
|
|
279
260
|
for new_text in streamer:
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
)
|
|
289
|
-
],
|
|
290
|
-
usage=CompletionUsage(
|
|
291
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
292
|
-
),
|
|
261
|
+
yield generate_completion_chunk(
|
|
262
|
+
chunk_text=new_text,
|
|
263
|
+
finish_reason=None,
|
|
264
|
+
chunk_id=completion_id,
|
|
265
|
+
model_uid=self.model_uid,
|
|
266
|
+
prompt_tokens=-1,
|
|
267
|
+
completion_tokens=-1,
|
|
268
|
+
total_tokens=-1,
|
|
293
269
|
)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
choices=[completion_choice],
|
|
305
|
-
usage=CompletionUsage(
|
|
306
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
307
|
-
),
|
|
270
|
+
yield generate_completion_chunk(
|
|
271
|
+
chunk_text=None,
|
|
272
|
+
finish_reason="stop",
|
|
273
|
+
chunk_id=completion_id,
|
|
274
|
+
model_uid=self.model_uid,
|
|
275
|
+
prompt_tokens=-1,
|
|
276
|
+
completion_tokens=-1,
|
|
277
|
+
total_tokens=-1,
|
|
278
|
+
has_choice=True,
|
|
279
|
+
has_content=False,
|
|
308
280
|
)
|
|
309
|
-
yield chunk
|
|
310
281
|
|
|
311
282
|
@staticmethod
|
|
312
283
|
def build_position_ids(x, attention_mask=None):
|
|
@@ -341,7 +312,9 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
341
312
|
def get_dtype(self):
|
|
342
313
|
return self._torch_type
|
|
343
314
|
|
|
344
|
-
def _get_full_prompt(self,
|
|
315
|
+
def _get_full_prompt(self, messages: List[Dict], tools):
|
|
316
|
+
prompt, system_prompt, chat_history = parse_messages(messages)
|
|
317
|
+
system_prompt = system_prompt or ""
|
|
345
318
|
query, image, history = self.get_query_and_history(
|
|
346
319
|
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
347
320
|
)
|
|
@@ -12,28 +12,22 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import logging
|
|
15
|
-
import time
|
|
16
15
|
import uuid
|
|
17
16
|
from concurrent.futures import ThreadPoolExecutor
|
|
18
17
|
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
19
18
|
|
|
20
19
|
import torch
|
|
21
20
|
|
|
22
|
-
from ....core.scheduler import InferenceRequest
|
|
23
21
|
from ....model.utils import select_device
|
|
24
|
-
from ....types import
|
|
25
|
-
ChatCompletion,
|
|
26
|
-
ChatCompletionChunk,
|
|
27
|
-
ChatCompletionMessage,
|
|
28
|
-
Completion,
|
|
29
|
-
CompletionChoice,
|
|
30
|
-
CompletionChunk,
|
|
31
|
-
CompletionUsage,
|
|
32
|
-
)
|
|
22
|
+
from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
|
|
33
23
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
34
|
-
from ..utils import
|
|
24
|
+
from ..utils import (
|
|
25
|
+
_decode_image,
|
|
26
|
+
generate_chat_completion,
|
|
27
|
+
generate_completion_chunk,
|
|
28
|
+
parse_messages,
|
|
29
|
+
)
|
|
35
30
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
36
|
-
from .utils import get_max_src_len
|
|
37
31
|
|
|
38
32
|
logger = logging.getLogger(__name__)
|
|
39
33
|
|
|
@@ -170,9 +164,7 @@ class CogVLM2VideoModel(PytorchChatModel):
|
|
|
170
164
|
return text, images, video
|
|
171
165
|
return content, [], None
|
|
172
166
|
|
|
173
|
-
def _history_content_to_cogvlm2(
|
|
174
|
-
self, system_prompt: str, chat_history: List[ChatCompletionMessage]
|
|
175
|
-
):
|
|
167
|
+
def _history_content_to_cogvlm2(self, system_prompt: str, chat_history: List[Dict]):
|
|
176
168
|
query = system_prompt
|
|
177
169
|
history: List[Tuple] = []
|
|
178
170
|
pixel_values = None
|
|
@@ -202,7 +194,7 @@ class CogVLM2VideoModel(PytorchChatModel):
|
|
|
202
194
|
self,
|
|
203
195
|
prompt: Union[str, List[Dict]],
|
|
204
196
|
system_prompt: Optional[str] = None,
|
|
205
|
-
chat_history: Optional[List[
|
|
197
|
+
chat_history: Optional[List[Dict]] = None,
|
|
206
198
|
):
|
|
207
199
|
content, image, video = self._message_content_to_cogvlm2(prompt)
|
|
208
200
|
|
|
@@ -237,12 +229,12 @@ class CogVLM2VideoModel(PytorchChatModel):
|
|
|
237
229
|
|
|
238
230
|
def chat(
|
|
239
231
|
self,
|
|
240
|
-
|
|
241
|
-
system_prompt: Optional[str] = None,
|
|
242
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
232
|
+
messages: List[Dict],
|
|
243
233
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
244
234
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
245
|
-
system_prompt =
|
|
235
|
+
system_prompt = ""
|
|
236
|
+
if messages[0]["role"] == "system":
|
|
237
|
+
system_prompt = messages[0]["content"]
|
|
246
238
|
stream = generate_config.get("stream", False) if generate_config else False
|
|
247
239
|
|
|
248
240
|
sanitized_config = {
|
|
@@ -252,6 +244,7 @@ class CogVLM2VideoModel(PytorchChatModel):
|
|
|
252
244
|
else 512,
|
|
253
245
|
}
|
|
254
246
|
|
|
247
|
+
prompt, _, chat_history = parse_messages(messages)
|
|
255
248
|
query, image, video, history = self.get_query_and_history(
|
|
256
249
|
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
257
250
|
)
|
|
@@ -292,21 +285,7 @@ class CogVLM2VideoModel(PytorchChatModel):
|
|
|
292
285
|
response = self._tokenizer.decode(outputs[0])
|
|
293
286
|
response = response.split("<|end_of_text|>")[0]
|
|
294
287
|
|
|
295
|
-
|
|
296
|
-
id=str(uuid.uuid1()),
|
|
297
|
-
object="text_completion",
|
|
298
|
-
created=int(time.time()),
|
|
299
|
-
model=self.model_uid,
|
|
300
|
-
choices=[
|
|
301
|
-
CompletionChoice(
|
|
302
|
-
index=0, text=response, finish_reason="stop", logprobs=None
|
|
303
|
-
)
|
|
304
|
-
],
|
|
305
|
-
usage=CompletionUsage(
|
|
306
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
307
|
-
),
|
|
308
|
-
)
|
|
309
|
-
return self._to_chat_completion(chunk)
|
|
288
|
+
return generate_chat_completion(self.model_uid, response)
|
|
310
289
|
|
|
311
290
|
def _streaming_chat_response(
|
|
312
291
|
self, inputs: Dict, config: Dict
|
|
@@ -333,192 +312,23 @@ class CogVLM2VideoModel(PytorchChatModel):
|
|
|
333
312
|
|
|
334
313
|
completion_id = str(uuid.uuid1())
|
|
335
314
|
for new_text in streamer:
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
)
|
|
345
|
-
],
|
|
346
|
-
usage=CompletionUsage(
|
|
347
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
348
|
-
),
|
|
315
|
+
yield generate_completion_chunk(
|
|
316
|
+
chunk_text=new_text,
|
|
317
|
+
finish_reason=None,
|
|
318
|
+
chunk_id=completion_id,
|
|
319
|
+
model_uid=self.model_uid,
|
|
320
|
+
prompt_tokens=-1,
|
|
321
|
+
completion_tokens=-1,
|
|
322
|
+
total_tokens=-1,
|
|
349
323
|
)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
choices=[completion_choice],
|
|
361
|
-
usage=CompletionUsage(
|
|
362
|
-
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
363
|
-
),
|
|
364
|
-
)
|
|
365
|
-
yield chunk
|
|
366
|
-
|
|
367
|
-
@staticmethod
|
|
368
|
-
def build_position_ids(x, attention_mask=None):
|
|
369
|
-
"""
|
|
370
|
-
Copied from https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B-int4/blob/main/modeling_cogvlm.py
|
|
371
|
-
"""
|
|
372
|
-
# Fix: 参考官方开源代码
|
|
373
|
-
if attention_mask is not None:
|
|
374
|
-
tmp = x.clone()
|
|
375
|
-
tmp[~(attention_mask.bool())] = -1
|
|
376
|
-
else:
|
|
377
|
-
tmp = x.clone()
|
|
378
|
-
# image boi eoi token as LANGUAGE_TOKEN_TYPE
|
|
379
|
-
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
|
|
380
|
-
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (
|
|
381
|
-
tmp[:, :-1] == LANGUAGE_TOKEN_TYPE
|
|
382
|
-
)
|
|
383
|
-
is_boi_eoi[:, 0] |= tmp[:, 0] == VISION_TOKEN_TYPE
|
|
384
|
-
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (
|
|
385
|
-
tmp[:, 1:] == LANGUAGE_TOKEN_TYPE
|
|
324
|
+
yield generate_completion_chunk(
|
|
325
|
+
chunk_text=None,
|
|
326
|
+
finish_reason="stop",
|
|
327
|
+
chunk_id=completion_id,
|
|
328
|
+
model_uid=self.model_uid,
|
|
329
|
+
prompt_tokens=-1,
|
|
330
|
+
completion_tokens=-1,
|
|
331
|
+
total_tokens=-1,
|
|
332
|
+
has_choice=True,
|
|
333
|
+
has_content=False,
|
|
386
334
|
)
|
|
387
|
-
is_boi_eoi[:, -1] |= tmp[:, -1] == VISION_TOKEN_TYPE
|
|
388
|
-
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
|
389
|
-
# final position ids
|
|
390
|
-
y = torch.zeros_like(x, dtype=torch.long)
|
|
391
|
-
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
|
|
392
|
-
(tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
|
|
393
|
-
)
|
|
394
|
-
y = y.cumsum(dim=-1)
|
|
395
|
-
return y
|
|
396
|
-
|
|
397
|
-
def get_dtype(self):
|
|
398
|
-
return self._torch_type
|
|
399
|
-
|
|
400
|
-
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
401
|
-
query, image, video, history = self.get_query_and_history(
|
|
402
|
-
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
if video:
|
|
406
|
-
image = [video]
|
|
407
|
-
|
|
408
|
-
input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
|
|
409
|
-
self._tokenizer,
|
|
410
|
-
query=query,
|
|
411
|
-
history=history,
|
|
412
|
-
images=image,
|
|
413
|
-
template_version="chat",
|
|
414
|
-
)
|
|
415
|
-
return {
|
|
416
|
-
"input_ids": input_by_model["input_ids"], # seq_len
|
|
417
|
-
"token_type_ids": input_by_model["token_type_ids"], # seq_len
|
|
418
|
-
"attention_mask": input_by_model["attention_mask"], # seq_len
|
|
419
|
-
"images": input_by_model["images"],
|
|
420
|
-
}
|
|
421
|
-
|
|
422
|
-
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
423
|
-
"""
|
|
424
|
-
See https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B/blob/main/generation_config.json
|
|
425
|
-
"""
|
|
426
|
-
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
427
|
-
temperature = raw_config.get("temperature", None)
|
|
428
|
-
if temperature is None:
|
|
429
|
-
raw_config["temperature"] = 0.6
|
|
430
|
-
top_p = raw_config.get("top_p", None)
|
|
431
|
-
if top_p is None:
|
|
432
|
-
raw_config["top_p"] = 0.9
|
|
433
|
-
return raw_config
|
|
434
|
-
|
|
435
|
-
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
436
|
-
context_len = self.get_context_len()
|
|
437
|
-
assert isinstance(prompts[0], dict)
|
|
438
|
-
images = []
|
|
439
|
-
max_length = float("-inf")
|
|
440
|
-
for i, feature in enumerate(prompts):
|
|
441
|
-
req = req_list[i]
|
|
442
|
-
if "images" in feature:
|
|
443
|
-
images.append(feature.pop("images", None))
|
|
444
|
-
max_src_len = get_max_src_len(context_len, req)
|
|
445
|
-
input_ids = feature["input_ids"][-max_src_len:]
|
|
446
|
-
req.prompt_tokens = input_ids.tolist()
|
|
447
|
-
feature["input_ids"] = input_ids
|
|
448
|
-
feature["token_type_ids"] = feature["token_type_ids"][-max_src_len:]
|
|
449
|
-
feature["attention_mask"] = feature["attention_mask"][-max_src_len:]
|
|
450
|
-
req.extra_kwargs["attention_mask_seq_len"] = feature[
|
|
451
|
-
"attention_mask"
|
|
452
|
-
].shape[0]
|
|
453
|
-
max_length = max(len(input_ids), max_length)
|
|
454
|
-
|
|
455
|
-
def pad_to_max_length_internal(feature, max_len, idx):
|
|
456
|
-
padding_length = max_len - len(feature["input_ids"])
|
|
457
|
-
req_list[idx].padding_len = padding_length
|
|
458
|
-
feature["input_ids"] = torch.cat(
|
|
459
|
-
[torch.full((padding_length,), 0), feature["input_ids"]]
|
|
460
|
-
)
|
|
461
|
-
feature["token_type_ids"] = torch.cat(
|
|
462
|
-
[
|
|
463
|
-
torch.zeros(padding_length, dtype=torch.long),
|
|
464
|
-
feature["token_type_ids"],
|
|
465
|
-
]
|
|
466
|
-
)
|
|
467
|
-
feature["attention_mask"] = torch.cat(
|
|
468
|
-
[
|
|
469
|
-
torch.zeros(padding_length, dtype=torch.long),
|
|
470
|
-
feature["attention_mask"],
|
|
471
|
-
]
|
|
472
|
-
)
|
|
473
|
-
return feature
|
|
474
|
-
|
|
475
|
-
features = [
|
|
476
|
-
pad_to_max_length_internal(feature, max_length, i)
|
|
477
|
-
for i, feature in enumerate(prompts)
|
|
478
|
-
]
|
|
479
|
-
batch = {
|
|
480
|
-
key: torch.stack([feature[key] for feature in features])
|
|
481
|
-
for key in features[0].keys()
|
|
482
|
-
}
|
|
483
|
-
|
|
484
|
-
position_ids = self.build_position_ids(batch["token_type_ids"])
|
|
485
|
-
batch["position_ids"] = position_ids
|
|
486
|
-
|
|
487
|
-
for i in range(len(prompts)):
|
|
488
|
-
req = req_list[i]
|
|
489
|
-
req.extra_kwargs["max_position_id"] = position_ids[i : i + 1, -1].item()
|
|
490
|
-
|
|
491
|
-
if images:
|
|
492
|
-
batch["images"] = images
|
|
493
|
-
|
|
494
|
-
batch = recur_move_to(
|
|
495
|
-
batch, self._device, lambda x: isinstance(x, torch.Tensor)
|
|
496
|
-
)
|
|
497
|
-
dtype = self.get_dtype()
|
|
498
|
-
if dtype:
|
|
499
|
-
batch = recur_move_to(
|
|
500
|
-
batch,
|
|
501
|
-
dtype,
|
|
502
|
-
lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x),
|
|
503
|
-
)
|
|
504
|
-
return batch
|
|
505
|
-
|
|
506
|
-
def build_decode_token_type_ids(
|
|
507
|
-
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
508
|
-
):
|
|
509
|
-
token_type_ids = torch.full(
|
|
510
|
-
(batch_size, 1), fill_value=1, dtype=torch.long, device=self._device
|
|
511
|
-
)
|
|
512
|
-
return token_type_ids
|
|
513
|
-
|
|
514
|
-
def build_decode_position_ids(
|
|
515
|
-
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
516
|
-
):
|
|
517
|
-
tmp = []
|
|
518
|
-
for r in reqs:
|
|
519
|
-
r.extra_kwargs["max_position_id"] += 1
|
|
520
|
-
tmp.append(r.extra_kwargs["max_position_id"])
|
|
521
|
-
position_ids = torch.as_tensor(
|
|
522
|
-
tmp, device=self._device, dtype=torch.long
|
|
523
|
-
).unsqueeze(1)
|
|
524
|
-
return position_ids
|
|
@@ -16,7 +16,7 @@ import json
|
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
18
|
from functools import lru_cache
|
|
19
|
-
from typing import Iterable, Iterator, List, Optional, Tuple, Union
|
|
19
|
+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
22
|
|
|
@@ -29,7 +29,6 @@ from ....device_utils import (
|
|
|
29
29
|
from ....types import (
|
|
30
30
|
ChatCompletion,
|
|
31
31
|
ChatCompletionChunk,
|
|
32
|
-
ChatCompletionMessage,
|
|
33
32
|
Completion,
|
|
34
33
|
CompletionChoice,
|
|
35
34
|
CompletionChunk,
|
|
@@ -52,8 +51,6 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
52
51
|
"chatglm3-128k",
|
|
53
52
|
"glm4-chat",
|
|
54
53
|
"glm4-chat-1m",
|
|
55
|
-
"llama-2",
|
|
56
|
-
"llama-2-chat",
|
|
57
54
|
"internlm2-chat",
|
|
58
55
|
"internlm2.5-chat",
|
|
59
56
|
"qwen-vl-chat",
|
|
@@ -67,6 +64,13 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
67
64
|
"MiniCPM-Llama3-V-2_5",
|
|
68
65
|
"MiniCPM-V-2.6",
|
|
69
66
|
"glm-4v",
|
|
67
|
+
"qwen2-vl-instruct",
|
|
68
|
+
"qwen2-audio",
|
|
69
|
+
"qwen2-audio-instruct",
|
|
70
|
+
"deepseek-v2",
|
|
71
|
+
"deepseek-v2-chat",
|
|
72
|
+
"deepseek-v2.5",
|
|
73
|
+
"deepseek-v2-chat-0628",
|
|
70
74
|
]
|
|
71
75
|
|
|
72
76
|
|
|
@@ -615,12 +619,17 @@ class PytorchModel(LLM):
|
|
|
615
619
|
r.error_msg = str(e)
|
|
616
620
|
|
|
617
621
|
def get_builtin_stop_token_ids(self) -> Tuple:
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
622
|
+
from ..utils import get_stop_token_ids_from_config_file
|
|
623
|
+
|
|
624
|
+
stop_token_ids = get_stop_token_ids_from_config_file(self.model_path)
|
|
625
|
+
if stop_token_ids is not None:
|
|
626
|
+
return tuple(stop_token_ids)
|
|
627
|
+
else:
|
|
628
|
+
return (
|
|
629
|
+
tuple(self.model_family.stop_token_ids)
|
|
630
|
+
if self.model_family.stop_token_ids
|
|
631
|
+
else tuple()
|
|
632
|
+
)
|
|
624
633
|
|
|
625
634
|
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
626
635
|
for req in req_list:
|
|
@@ -693,20 +702,13 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
693
702
|
generate_config: Optional[PytorchGenerateConfig],
|
|
694
703
|
) -> PytorchGenerateConfig:
|
|
695
704
|
generate_config = super()._sanitize_generate_config(generate_config)
|
|
696
|
-
if (
|
|
697
|
-
|
|
698
|
-
and self.model_family.prompt_style
|
|
699
|
-
and self.model_family.prompt_style.stop
|
|
700
|
-
):
|
|
701
|
-
generate_config["stop"] = self.model_family.prompt_style.stop.copy()
|
|
705
|
+
if (not generate_config.get("stop")) and self.model_family.stop is not None:
|
|
706
|
+
generate_config["stop"] = self.model_family.stop.copy()
|
|
702
707
|
if (
|
|
703
708
|
generate_config.get("stop_token_ids", None) is None
|
|
704
|
-
and self.model_family.
|
|
705
|
-
and self.model_family.prompt_style.stop_token_ids
|
|
709
|
+
and self.model_family.stop_token_ids is not None
|
|
706
710
|
):
|
|
707
|
-
generate_config[
|
|
708
|
-
"stop_token_ids"
|
|
709
|
-
] = self.model_family.prompt_style.stop_token_ids.copy()
|
|
711
|
+
generate_config["stop_token_ids"] = self.model_family.stop_token_ids.copy()
|
|
710
712
|
|
|
711
713
|
return generate_config
|
|
712
714
|
|
|
@@ -725,26 +727,23 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
725
727
|
|
|
726
728
|
def chat(
|
|
727
729
|
self,
|
|
728
|
-
|
|
729
|
-
system_prompt: Optional[str] = None,
|
|
730
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
730
|
+
messages: List[Dict],
|
|
731
731
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
732
732
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
733
733
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
734
|
-
full_prompt = self._get_full_prompt(prompt, system_prompt, chat_history, tools)
|
|
735
|
-
|
|
736
|
-
generate_config = self._sanitize_generate_config(generate_config)
|
|
737
|
-
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
738
734
|
model_family = self.model_family.model_family or self.model_family.model_name
|
|
735
|
+
full_context_kwargs = {}
|
|
739
736
|
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
737
|
+
full_context_kwargs["tools"] = tools
|
|
738
|
+
assert self.model_family.chat_template is not None
|
|
739
|
+
full_prompt = self.get_full_context(
|
|
740
|
+
messages,
|
|
741
|
+
self.model_family.chat_template,
|
|
742
|
+
tokenizer=self._tokenizer,
|
|
743
|
+
**full_context_kwargs,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
generate_config = self._sanitize_generate_config(generate_config)
|
|
748
747
|
|
|
749
748
|
stream = generate_config.get("stream", False)
|
|
750
749
|
if stream:
|
|
@@ -755,22 +754,16 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
755
754
|
c = self.generate(full_prompt, generate_config)
|
|
756
755
|
assert not isinstance(c, Iterator)
|
|
757
756
|
if tools:
|
|
758
|
-
return self._tool_calls_completion(
|
|
759
|
-
self.model_family, self.model_uid, c, tools
|
|
760
|
-
)
|
|
757
|
+
return self._tool_calls_completion(self.model_family, self.model_uid, c)
|
|
761
758
|
return self._to_chat_completion(c)
|
|
762
759
|
|
|
763
760
|
def load(self):
|
|
764
761
|
super().load()
|
|
765
762
|
|
|
766
|
-
def _get_full_prompt(self,
|
|
767
|
-
assert self.model_family.
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
prompt_style.system_prompt = system_prompt
|
|
771
|
-
chat_history = chat_history or []
|
|
772
|
-
full_prompt = ChatModelMixin.get_prompt(
|
|
773
|
-
prompt, chat_history, prompt_style, tools=tools
|
|
763
|
+
def _get_full_prompt(self, messages: List[Dict], tools):
|
|
764
|
+
assert self.model_family.chat_template is not None
|
|
765
|
+
full_prompt = self.get_full_context(
|
|
766
|
+
messages, self.model_family.chat_template, tokenizer=self._tokenizer
|
|
774
767
|
)
|
|
775
768
|
return full_prompt
|
|
776
769
|
|
|
@@ -779,9 +772,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
779
772
|
for r in req_list:
|
|
780
773
|
try:
|
|
781
774
|
if not r.stopped and r.is_prefill:
|
|
782
|
-
r.full_prompt = self._get_full_prompt(
|
|
783
|
-
r.prompt, r.system_prompt, r.chat_history, None
|
|
784
|
-
)
|
|
775
|
+
r.full_prompt = self._get_full_prompt(r.prompt, None)
|
|
785
776
|
except Exception as e:
|
|
786
777
|
logger.exception(f"prepare inference error with {e}")
|
|
787
778
|
r.stopped = True
|
|
@@ -790,6 +781,20 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
790
781
|
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
791
782
|
for req in req_list:
|
|
792
783
|
if req.error_msg is None and req.completion:
|
|
784
|
+
# The `generate` function can be called for some chat models.
|
|
785
|
+
# So that we cannot convert completion chunk to chat completion chunk.
|
|
786
|
+
if req.call_ability == "generate":
|
|
787
|
+
results = []
|
|
788
|
+
for c in req.completion:
|
|
789
|
+
if c == "<bos_stream>":
|
|
790
|
+
continue
|
|
791
|
+
elif c == "<eos_stream>":
|
|
792
|
+
break
|
|
793
|
+
else:
|
|
794
|
+
results.append(c)
|
|
795
|
+
req.completion = results
|
|
796
|
+
continue
|
|
797
|
+
|
|
793
798
|
if req.stream:
|
|
794
799
|
results = []
|
|
795
800
|
for i, c in enumerate(req.completion):
|