xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +143 -6
- xinference/client/restful/restful_client.py +144 -5
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +160 -19
- xinference/core/scheduler.py +446 -0
- xinference/core/supervisor.py +99 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/isolation.py +9 -2
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +22 -4
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +38 -2
- xinference/model/llm/llm_family.json +509 -1
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +411 -2
- xinference/model/llm/pytorch/chatglm.py +20 -13
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +141 -6
- xinference/model/llm/pytorch/glm4v.py +268 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +405 -8
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +16 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +3 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
|
@@ -15,8 +15,10 @@
|
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
|
+
from functools import lru_cache
|
|
18
19
|
from typing import Iterable, Iterator, List, Optional, Union
|
|
19
20
|
|
|
21
|
+
from ....core.scheduler import InferenceRequest
|
|
20
22
|
from ....device_utils import (
|
|
21
23
|
get_device_preferred_dtype,
|
|
22
24
|
gpu_count,
|
|
@@ -27,6 +29,7 @@ from ....types import (
|
|
|
27
29
|
ChatCompletionChunk,
|
|
28
30
|
ChatCompletionMessage,
|
|
29
31
|
Completion,
|
|
32
|
+
CompletionChoice,
|
|
30
33
|
CompletionChunk,
|
|
31
34
|
CreateCompletionTorch,
|
|
32
35
|
Embedding,
|
|
@@ -40,6 +43,7 @@ from ...utils import select_device
|
|
|
40
43
|
from ..core import LLM
|
|
41
44
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
42
45
|
from ..utils import ChatModelMixin
|
|
46
|
+
from .utils import get_context_length, get_max_src_len
|
|
43
47
|
|
|
44
48
|
logger = logging.getLogger(__name__)
|
|
45
49
|
|
|
@@ -53,6 +57,11 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
53
57
|
"chatglm2",
|
|
54
58
|
"chatglm2-32k",
|
|
55
59
|
"chatglm2-128k",
|
|
60
|
+
"chatglm3",
|
|
61
|
+
"chatglm3-32k",
|
|
62
|
+
"chatglm3-128k",
|
|
63
|
+
"glm4-chat",
|
|
64
|
+
"glm4-chat-1m",
|
|
56
65
|
"llama-2",
|
|
57
66
|
"llama-2-chat",
|
|
58
67
|
"internlm2-chat",
|
|
@@ -63,6 +72,8 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
63
72
|
"internvl-chat",
|
|
64
73
|
"mini-internvl-chat",
|
|
65
74
|
"cogvlm2",
|
|
75
|
+
"MiniCPM-Llama3-V-2_5",
|
|
76
|
+
"glm-4v",
|
|
66
77
|
]
|
|
67
78
|
|
|
68
79
|
|
|
@@ -96,6 +107,7 @@ class PytorchModel(LLM):
|
|
|
96
107
|
pytorch_model_config.setdefault("gptq_act_order", False)
|
|
97
108
|
pytorch_model_config.setdefault("device", "auto")
|
|
98
109
|
pytorch_model_config.setdefault("trust_remote_code", True)
|
|
110
|
+
pytorch_model_config.setdefault("max_num_seqs", 16)
|
|
99
111
|
return pytorch_model_config
|
|
100
112
|
|
|
101
113
|
def _sanitize_generate_config(
|
|
@@ -356,6 +368,90 @@ class PytorchModel(LLM):
|
|
|
356
368
|
else:
|
|
357
369
|
return generator_wrapper(prompt, generate_config)
|
|
358
370
|
|
|
371
|
+
@lru_cache
|
|
372
|
+
def get_context_len(self):
|
|
373
|
+
return get_context_length(self._model.config)
|
|
374
|
+
|
|
375
|
+
def get_max_num_seqs(self) -> int:
|
|
376
|
+
return self._pytorch_model_config.get("max_num_seqs") # type: ignore
|
|
377
|
+
|
|
378
|
+
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
379
|
+
# check some parameters
|
|
380
|
+
for r in req_list:
|
|
381
|
+
if r.sanitized_generate_config is None:
|
|
382
|
+
r.sanitized_generate_config = self._sanitize_generate_config(
|
|
383
|
+
r.generate_config
|
|
384
|
+
)
|
|
385
|
+
if r.is_prefill:
|
|
386
|
+
# check some generate params
|
|
387
|
+
max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
|
|
388
|
+
if max_src_len < 0:
|
|
389
|
+
r.stopped = True
|
|
390
|
+
r.error_msg = "Max tokens exceeds model's max length"
|
|
391
|
+
continue
|
|
392
|
+
if r.stream_interval <= 0:
|
|
393
|
+
r.stopped = True
|
|
394
|
+
r.error_msg = "`stream_interval` must be greater than 0"
|
|
395
|
+
continue
|
|
396
|
+
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
397
|
+
if stop_str and (
|
|
398
|
+
not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
|
|
399
|
+
):
|
|
400
|
+
r.stopped = True
|
|
401
|
+
r.error_msg = "Invalid `stop` field type"
|
|
402
|
+
continue
|
|
403
|
+
|
|
404
|
+
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
405
|
+
for req in req_list:
|
|
406
|
+
if req.error_msg is None:
|
|
407
|
+
# nothing need handle for non-stream case
|
|
408
|
+
if req.stream:
|
|
409
|
+
results = []
|
|
410
|
+
for i, c in enumerate(req.completion):
|
|
411
|
+
if c == "<bos_stream>":
|
|
412
|
+
chunk = req.completion[i + 1]
|
|
413
|
+
results.append(
|
|
414
|
+
CompletionChunk(
|
|
415
|
+
id=chunk["id"],
|
|
416
|
+
object=chunk["object"],
|
|
417
|
+
created=chunk["created"],
|
|
418
|
+
model=chunk["model"],
|
|
419
|
+
choices=[
|
|
420
|
+
CompletionChoice(
|
|
421
|
+
text="",
|
|
422
|
+
index=0,
|
|
423
|
+
logprobs=None,
|
|
424
|
+
finish_reason=None,
|
|
425
|
+
)
|
|
426
|
+
],
|
|
427
|
+
)
|
|
428
|
+
)
|
|
429
|
+
continue
|
|
430
|
+
elif c == "<eos_stream>":
|
|
431
|
+
break
|
|
432
|
+
else:
|
|
433
|
+
results.append(c)
|
|
434
|
+
|
|
435
|
+
if req.stopped and req.include_usage:
|
|
436
|
+
results.append(req.completion[-1])
|
|
437
|
+
req.completion = results
|
|
438
|
+
|
|
439
|
+
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
440
|
+
from .utils import batch_inference_one_step
|
|
441
|
+
|
|
442
|
+
self.prepare_batch_inference(req_list)
|
|
443
|
+
context_len = self.get_context_len()
|
|
444
|
+
assert isinstance(context_len, int)
|
|
445
|
+
batch_inference_one_step(
|
|
446
|
+
req_list,
|
|
447
|
+
self.model_uid,
|
|
448
|
+
self._model,
|
|
449
|
+
self._tokenizer,
|
|
450
|
+
self._device,
|
|
451
|
+
context_len,
|
|
452
|
+
)
|
|
453
|
+
self.handle_batch_inference_results(req_list)
|
|
454
|
+
|
|
359
455
|
def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
|
|
360
456
|
try:
|
|
361
457
|
import torch
|
|
@@ -497,13 +593,8 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
497
593
|
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
498
594
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
499
595
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
500
|
-
assert self.model_family.prompt_style is not None
|
|
501
|
-
prompt_style = self.model_family.prompt_style.copy()
|
|
502
|
-
if system_prompt:
|
|
503
|
-
prompt_style.system_prompt = system_prompt
|
|
504
|
-
chat_history = chat_history or []
|
|
505
596
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
506
|
-
full_prompt = self.
|
|
597
|
+
full_prompt = self._get_full_prompt(prompt, system_prompt, chat_history, tools)
|
|
507
598
|
|
|
508
599
|
generate_config = self._sanitize_generate_config(generate_config)
|
|
509
600
|
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
@@ -531,3 +622,47 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
531
622
|
self.model_family, self.model_uid, c, tools
|
|
532
623
|
)
|
|
533
624
|
return self._to_chat_completion(c)
|
|
625
|
+
|
|
626
|
+
def load(self):
|
|
627
|
+
super().load()
|
|
628
|
+
|
|
629
|
+
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
630
|
+
assert self.model_family.prompt_style is not None
|
|
631
|
+
prompt_style = self.model_family.prompt_style.copy()
|
|
632
|
+
if system_prompt:
|
|
633
|
+
prompt_style.system_prompt = system_prompt
|
|
634
|
+
chat_history = chat_history or []
|
|
635
|
+
full_prompt = ChatModelMixin.get_prompt(
|
|
636
|
+
prompt, chat_history, prompt_style, tools=tools
|
|
637
|
+
)
|
|
638
|
+
return full_prompt
|
|
639
|
+
|
|
640
|
+
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
641
|
+
super().prepare_batch_inference(req_list)
|
|
642
|
+
for r in req_list:
|
|
643
|
+
r.full_prompt = self._get_full_prompt(
|
|
644
|
+
r.prompt, r.system_prompt, r.chat_history, None
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
648
|
+
for req in req_list:
|
|
649
|
+
if req.stream and req.error_msg is None:
|
|
650
|
+
if req.completion:
|
|
651
|
+
results = []
|
|
652
|
+
for i, c in enumerate(req.completion):
|
|
653
|
+
if c == "<bos_stream>":
|
|
654
|
+
results.append(
|
|
655
|
+
self._get_first_chat_completion_chunk(
|
|
656
|
+
req.completion[i + 1]
|
|
657
|
+
)
|
|
658
|
+
)
|
|
659
|
+
elif c == "<eos_stream>":
|
|
660
|
+
break
|
|
661
|
+
else:
|
|
662
|
+
results.append(self._to_chat_completion_chunk(c))
|
|
663
|
+
|
|
664
|
+
if req.stopped and req.include_usage:
|
|
665
|
+
results.append(
|
|
666
|
+
self._get_final_chat_completion_chunk(req.completion[-1])
|
|
667
|
+
)
|
|
668
|
+
req.completion = results
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from threading import Thread
|
|
21
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
22
|
+
|
|
23
|
+
import requests
|
|
24
|
+
import torch
|
|
25
|
+
from PIL import Image
|
|
26
|
+
|
|
27
|
+
from ....types import (
|
|
28
|
+
ChatCompletion,
|
|
29
|
+
ChatCompletionChunk,
|
|
30
|
+
ChatCompletionMessage,
|
|
31
|
+
Completion,
|
|
32
|
+
CompletionChoice,
|
|
33
|
+
CompletionChunk,
|
|
34
|
+
CompletionUsage,
|
|
35
|
+
)
|
|
36
|
+
from ...utils import select_device
|
|
37
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
38
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Glm4VModel(PytorchChatModel):
|
|
44
|
+
def __init__(self, *args, **kwargs):
|
|
45
|
+
super().__init__(*args, **kwargs)
|
|
46
|
+
self._device = None
|
|
47
|
+
self._tokenizer = None
|
|
48
|
+
self._model = None
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def match(
|
|
52
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
53
|
+
) -> bool:
|
|
54
|
+
family = model_family.model_family or model_family.model_name
|
|
55
|
+
if "glm-4v" in family.lower():
|
|
56
|
+
return True
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
def load(self):
|
|
60
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
61
|
+
|
|
62
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
63
|
+
self._device = select_device(device)
|
|
64
|
+
|
|
65
|
+
kwargs = {"device_map": self._device}
|
|
66
|
+
quantization = self.quantization
|
|
67
|
+
if quantization != "none":
|
|
68
|
+
if self._device == "cuda" and self._is_linux():
|
|
69
|
+
kwargs["device_map"] = "auto"
|
|
70
|
+
self._device = "auto"
|
|
71
|
+
if quantization == "4-bit":
|
|
72
|
+
kwargs["load_in_4bit"] = True
|
|
73
|
+
elif quantization == "8-bit":
|
|
74
|
+
kwargs["load_in_8bit"] = True
|
|
75
|
+
|
|
76
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
77
|
+
self.model_path,
|
|
78
|
+
low_cpu_mem_usage=True,
|
|
79
|
+
trust_remote_code=True,
|
|
80
|
+
torch_dtype=torch.float16,
|
|
81
|
+
**kwargs,
|
|
82
|
+
)
|
|
83
|
+
self._model = model.eval()
|
|
84
|
+
|
|
85
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
86
|
+
self.model_path, trust_remote_code=True
|
|
87
|
+
)
|
|
88
|
+
self._tokenizer = tokenizer
|
|
89
|
+
|
|
90
|
+
def _message_content_to_chat(self, content):
|
|
91
|
+
def _load_image(_url):
|
|
92
|
+
if _url.startswith("data:"):
|
|
93
|
+
logging.info("Parse url by base64 decoder.")
|
|
94
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
95
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
96
|
+
_type, data = _url.split(";")
|
|
97
|
+
_, ext = _type.split("/")
|
|
98
|
+
data = data[len("base64,") :]
|
|
99
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
100
|
+
return Image.open(BytesIO(data)).convert("RGB")
|
|
101
|
+
else:
|
|
102
|
+
try:
|
|
103
|
+
response = requests.get(_url)
|
|
104
|
+
except requests.exceptions.MissingSchema:
|
|
105
|
+
return Image.open(_url).convert("RGB")
|
|
106
|
+
else:
|
|
107
|
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
108
|
+
|
|
109
|
+
if not isinstance(content, str):
|
|
110
|
+
texts = []
|
|
111
|
+
image_urls = []
|
|
112
|
+
for c in content:
|
|
113
|
+
c_type = c.get("type")
|
|
114
|
+
if c_type == "text":
|
|
115
|
+
texts.append(c["text"])
|
|
116
|
+
elif c_type == "image_url":
|
|
117
|
+
image_urls.append(c["image_url"]["url"])
|
|
118
|
+
image_futures = []
|
|
119
|
+
with ThreadPoolExecutor() as executor:
|
|
120
|
+
for image_url in image_urls:
|
|
121
|
+
fut = executor.submit(_load_image, image_url)
|
|
122
|
+
image_futures.append(fut)
|
|
123
|
+
images = [fut.result() for fut in image_futures]
|
|
124
|
+
# images = []
|
|
125
|
+
# for image_url in image_urls:
|
|
126
|
+
# images.append(_load_image(image_url))
|
|
127
|
+
text = " ".join(texts)
|
|
128
|
+
if len(images) == 0:
|
|
129
|
+
return text, []
|
|
130
|
+
elif len(images) == 1:
|
|
131
|
+
return text, images
|
|
132
|
+
else:
|
|
133
|
+
raise RuntimeError("Only one image per message is supported")
|
|
134
|
+
return content, []
|
|
135
|
+
|
|
136
|
+
def chat(
|
|
137
|
+
self,
|
|
138
|
+
prompt: Union[str, List[Dict]],
|
|
139
|
+
system_prompt: Optional[str] = None,
|
|
140
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
141
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
142
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
143
|
+
from transformers import TextIteratorStreamer
|
|
144
|
+
|
|
145
|
+
if not generate_config:
|
|
146
|
+
generate_config = {}
|
|
147
|
+
|
|
148
|
+
stream = generate_config.get("stream", False)
|
|
149
|
+
content, images_chat = self._message_content_to_chat(prompt)
|
|
150
|
+
|
|
151
|
+
msgs = []
|
|
152
|
+
query_to_response: List[Dict] = []
|
|
153
|
+
images_history = []
|
|
154
|
+
for h in chat_history or []:
|
|
155
|
+
role = h["role"]
|
|
156
|
+
content_h, images_tmp = self._message_content_to_chat(h["content"])
|
|
157
|
+
if images_tmp != []:
|
|
158
|
+
images_history = images_tmp
|
|
159
|
+
if len(query_to_response) == 0 and role == "user":
|
|
160
|
+
query_to_response.append({"role": "user", "content": content_h})
|
|
161
|
+
if len(query_to_response) == 1 and role == "assistant":
|
|
162
|
+
query_to_response.append({"role": "assistant", "content": content_h})
|
|
163
|
+
if len(query_to_response) == 2:
|
|
164
|
+
msgs.extend(query_to_response)
|
|
165
|
+
query_to_response = []
|
|
166
|
+
image = None
|
|
167
|
+
if len(images_chat) > 0:
|
|
168
|
+
image = images_chat[0]
|
|
169
|
+
elif len(images_history) > 0:
|
|
170
|
+
image = images_history[0]
|
|
171
|
+
msgs.append({"role": "user", "content": content, "image": image})
|
|
172
|
+
|
|
173
|
+
inputs = self._tokenizer.apply_chat_template(
|
|
174
|
+
msgs,
|
|
175
|
+
add_generation_prompt=True,
|
|
176
|
+
tokenize=True,
|
|
177
|
+
return_tensors="pt",
|
|
178
|
+
return_dict=True,
|
|
179
|
+
) # chat mode
|
|
180
|
+
inputs = inputs.to(self._model.device)
|
|
181
|
+
|
|
182
|
+
generate_kwargs = {
|
|
183
|
+
**inputs,
|
|
184
|
+
"eos_token_id": [151329, 151336, 151338],
|
|
185
|
+
"do_sample": True,
|
|
186
|
+
"max_length": generate_config.get("max_tokens", 2048),
|
|
187
|
+
"temperature": generate_config.get("temperature", 0.7),
|
|
188
|
+
}
|
|
189
|
+
stop_str = "<|endoftext|>"
|
|
190
|
+
|
|
191
|
+
if stream:
|
|
192
|
+
streamer = TextIteratorStreamer(
|
|
193
|
+
tokenizer=self._tokenizer,
|
|
194
|
+
timeout=60,
|
|
195
|
+
skip_prompt=True,
|
|
196
|
+
skip_special_tokens=True,
|
|
197
|
+
)
|
|
198
|
+
generate_kwargs = {
|
|
199
|
+
**generate_kwargs,
|
|
200
|
+
"streamer": streamer,
|
|
201
|
+
}
|
|
202
|
+
t = Thread(target=self._model.generate, kwargs=generate_kwargs)
|
|
203
|
+
t.start()
|
|
204
|
+
|
|
205
|
+
it = self.chat_stream(streamer, stop_str)
|
|
206
|
+
return self._to_chat_completion_chunks(it)
|
|
207
|
+
else:
|
|
208
|
+
with torch.no_grad():
|
|
209
|
+
outputs = self._model.generate(**generate_kwargs)
|
|
210
|
+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
|
|
211
|
+
response = self._tokenizer.decode(outputs[0])
|
|
212
|
+
if response.endswith(stop_str):
|
|
213
|
+
response = response[: -len(stop_str)]
|
|
214
|
+
c = Completion(
|
|
215
|
+
id=str(uuid.uuid1()),
|
|
216
|
+
object="text_completion",
|
|
217
|
+
created=int(time.time()),
|
|
218
|
+
model=self.model_uid,
|
|
219
|
+
choices=[
|
|
220
|
+
CompletionChoice(
|
|
221
|
+
index=0, text=response, finish_reason="stop", logprobs=None
|
|
222
|
+
)
|
|
223
|
+
],
|
|
224
|
+
usage=CompletionUsage(
|
|
225
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
return self._to_chat_completion(c)
|
|
229
|
+
|
|
230
|
+
def chat_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
|
|
231
|
+
completion_id = str(uuid.uuid1())
|
|
232
|
+
for new_text in streamer:
|
|
233
|
+
if not new_text.endswith(stop_str):
|
|
234
|
+
completion_choice = CompletionChoice(
|
|
235
|
+
text=new_text, index=0, logprobs=None, finish_reason=None
|
|
236
|
+
)
|
|
237
|
+
chunk = CompletionChunk(
|
|
238
|
+
id=completion_id,
|
|
239
|
+
object="text_completion",
|
|
240
|
+
created=int(time.time()),
|
|
241
|
+
model=self.model_uid,
|
|
242
|
+
choices=[completion_choice],
|
|
243
|
+
)
|
|
244
|
+
completion_usage = CompletionUsage(
|
|
245
|
+
prompt_tokens=-1,
|
|
246
|
+
completion_tokens=-1,
|
|
247
|
+
total_tokens=-1,
|
|
248
|
+
)
|
|
249
|
+
chunk["usage"] = completion_usage
|
|
250
|
+
yield chunk
|
|
251
|
+
|
|
252
|
+
completion_choice = CompletionChoice(
|
|
253
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
254
|
+
)
|
|
255
|
+
chunk = CompletionChunk(
|
|
256
|
+
id=completion_id,
|
|
257
|
+
object="text_completion",
|
|
258
|
+
created=int(time.time()),
|
|
259
|
+
model=self.model_uid,
|
|
260
|
+
choices=[completion_choice],
|
|
261
|
+
)
|
|
262
|
+
completion_usage = CompletionUsage(
|
|
263
|
+
prompt_tokens=-1,
|
|
264
|
+
completion_tokens=-1,
|
|
265
|
+
total_tokens=-1,
|
|
266
|
+
)
|
|
267
|
+
chunk["usage"] = completion_usage
|
|
268
|
+
yield chunk
|