xinference 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +69 -0
- xinference/client/restful/restful_client.py +70 -0
- xinference/constants.py +4 -0
- xinference/core/model.py +141 -12
- xinference/core/scheduler.py +428 -0
- xinference/core/supervisor.py +26 -0
- xinference/isolation.py +9 -2
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +10 -3
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +507 -1
- xinference/model/llm/llm_family_modelscope.json +409 -2
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +91 -6
- xinference/model/llm/pytorch/glm4v.py +258 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/utils.py +386 -2
- xinference/model/llm/vllm/core.py +6 -0
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/types.py +3 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/RECORD +30 -24
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@ import logging
|
|
|
17
17
|
import os
|
|
18
18
|
from typing import Iterable, Iterator, List, Optional, Union
|
|
19
19
|
|
|
20
|
+
from ....core.scheduler import InferenceRequest
|
|
20
21
|
from ....device_utils import (
|
|
21
22
|
get_device_preferred_dtype,
|
|
22
23
|
gpu_count,
|
|
@@ -40,6 +41,7 @@ from ...utils import select_device
|
|
|
40
41
|
from ..core import LLM
|
|
41
42
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
42
43
|
from ..utils import ChatModelMixin
|
|
44
|
+
from .utils import get_context_length, get_max_src_len
|
|
43
45
|
|
|
44
46
|
logger = logging.getLogger(__name__)
|
|
45
47
|
|
|
@@ -53,6 +55,11 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
53
55
|
"chatglm2",
|
|
54
56
|
"chatglm2-32k",
|
|
55
57
|
"chatglm2-128k",
|
|
58
|
+
"chatglm3",
|
|
59
|
+
"chatglm3-32k",
|
|
60
|
+
"chatglm3-128k",
|
|
61
|
+
"glm4-chat",
|
|
62
|
+
"glm4-chat-1m",
|
|
56
63
|
"llama-2",
|
|
57
64
|
"llama-2-chat",
|
|
58
65
|
"internlm2-chat",
|
|
@@ -63,6 +70,8 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
63
70
|
"internvl-chat",
|
|
64
71
|
"mini-internvl-chat",
|
|
65
72
|
"cogvlm2",
|
|
73
|
+
"MiniCPM-Llama3-V-2_5",
|
|
74
|
+
"glm-4v",
|
|
66
75
|
]
|
|
67
76
|
|
|
68
77
|
|
|
@@ -96,6 +105,7 @@ class PytorchModel(LLM):
|
|
|
96
105
|
pytorch_model_config.setdefault("gptq_act_order", False)
|
|
97
106
|
pytorch_model_config.setdefault("device", "auto")
|
|
98
107
|
pytorch_model_config.setdefault("trust_remote_code", True)
|
|
108
|
+
pytorch_model_config.setdefault("max_num_seqs", 16)
|
|
99
109
|
return pytorch_model_config
|
|
100
110
|
|
|
101
111
|
def _sanitize_generate_config(
|
|
@@ -454,6 +464,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
454
464
|
pytorch_model_config,
|
|
455
465
|
peft_model,
|
|
456
466
|
)
|
|
467
|
+
self._context_len = None
|
|
457
468
|
|
|
458
469
|
def _sanitize_generate_config(
|
|
459
470
|
self,
|
|
@@ -497,13 +508,8 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
497
508
|
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
498
509
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
499
510
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
500
|
-
assert self.model_family.prompt_style is not None
|
|
501
|
-
prompt_style = self.model_family.prompt_style.copy()
|
|
502
|
-
if system_prompt:
|
|
503
|
-
prompt_style.system_prompt = system_prompt
|
|
504
|
-
chat_history = chat_history or []
|
|
505
511
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
506
|
-
full_prompt = self.
|
|
512
|
+
full_prompt = self._get_full_prompt(prompt, system_prompt, chat_history, tools)
|
|
507
513
|
|
|
508
514
|
generate_config = self._sanitize_generate_config(generate_config)
|
|
509
515
|
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
@@ -531,3 +537,82 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
531
537
|
self.model_family, self.model_uid, c, tools
|
|
532
538
|
)
|
|
533
539
|
return self._to_chat_completion(c)
|
|
540
|
+
|
|
541
|
+
def load(self):
|
|
542
|
+
super().load()
|
|
543
|
+
self._context_len = get_context_length(self._model.config)
|
|
544
|
+
|
|
545
|
+
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
546
|
+
assert self.model_family.prompt_style is not None
|
|
547
|
+
prompt_style = self.model_family.prompt_style.copy()
|
|
548
|
+
if system_prompt:
|
|
549
|
+
prompt_style.system_prompt = system_prompt
|
|
550
|
+
chat_history = chat_history or []
|
|
551
|
+
full_prompt = ChatModelMixin.get_prompt(
|
|
552
|
+
prompt, chat_history, prompt_style, tools=tools
|
|
553
|
+
)
|
|
554
|
+
return full_prompt
|
|
555
|
+
|
|
556
|
+
def get_max_num_seqs(self) -> int:
|
|
557
|
+
return self._pytorch_model_config.get("max_num_seqs") # type: ignore
|
|
558
|
+
|
|
559
|
+
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
560
|
+
from .utils import batch_inference_one_step
|
|
561
|
+
|
|
562
|
+
for r in req_list:
|
|
563
|
+
if r.sanitized_generate_config is None:
|
|
564
|
+
r.sanitized_generate_config = self._sanitize_generate_config(
|
|
565
|
+
r.generate_config
|
|
566
|
+
)
|
|
567
|
+
if r.is_prefill:
|
|
568
|
+
# check some generate params
|
|
569
|
+
max_src_len = get_max_src_len(self._context_len, r) # type: ignore
|
|
570
|
+
if max_src_len < 0:
|
|
571
|
+
r.stopped = True
|
|
572
|
+
r.error_msg = "Max tokens exceeds model's max length"
|
|
573
|
+
continue
|
|
574
|
+
if r.stream_interval <= 0:
|
|
575
|
+
r.stopped = True
|
|
576
|
+
r.error_msg = "`stream_interval` must be greater than 0"
|
|
577
|
+
continue
|
|
578
|
+
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
579
|
+
if stop_str and (
|
|
580
|
+
not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
|
|
581
|
+
):
|
|
582
|
+
r.stopped = True
|
|
583
|
+
r.error_msg = "Invalid `stop` field type"
|
|
584
|
+
continue
|
|
585
|
+
r.full_prompt = self._get_full_prompt(
|
|
586
|
+
r.prompt, r.system_prompt, r.chat_history, None
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
assert isinstance(self._context_len, int)
|
|
590
|
+
batch_inference_one_step(
|
|
591
|
+
req_list,
|
|
592
|
+
self.model_uid,
|
|
593
|
+
self._model,
|
|
594
|
+
self._tokenizer,
|
|
595
|
+
self._device,
|
|
596
|
+
self._context_len,
|
|
597
|
+
)
|
|
598
|
+
for req in req_list:
|
|
599
|
+
if req.stream and req.error_msg is None:
|
|
600
|
+
if req.completion:
|
|
601
|
+
results = []
|
|
602
|
+
for i, c in enumerate(req.completion):
|
|
603
|
+
if c == "<bos_stream>":
|
|
604
|
+
results.append(
|
|
605
|
+
self._get_first_chat_completion_chunk(
|
|
606
|
+
req.completion[i + 1]
|
|
607
|
+
)
|
|
608
|
+
)
|
|
609
|
+
elif c == "<eos_stream>":
|
|
610
|
+
break
|
|
611
|
+
else:
|
|
612
|
+
results.append(self._to_chat_completion_chunk(c))
|
|
613
|
+
|
|
614
|
+
if req.stopped and req.include_usage:
|
|
615
|
+
results.append(
|
|
616
|
+
self._get_final_chat_completion_chunk(req.completion[-1])
|
|
617
|
+
)
|
|
618
|
+
req.completion = results
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from threading import Thread
|
|
21
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
22
|
+
|
|
23
|
+
import requests
|
|
24
|
+
import torch
|
|
25
|
+
from PIL import Image
|
|
26
|
+
|
|
27
|
+
from ....types import (
|
|
28
|
+
ChatCompletion,
|
|
29
|
+
ChatCompletionChunk,
|
|
30
|
+
ChatCompletionMessage,
|
|
31
|
+
Completion,
|
|
32
|
+
CompletionChoice,
|
|
33
|
+
CompletionChunk,
|
|
34
|
+
CompletionUsage,
|
|
35
|
+
)
|
|
36
|
+
from ...utils import select_device
|
|
37
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
38
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Glm4VModel(PytorchChatModel):
|
|
44
|
+
def __init__(self, *args, **kwargs):
|
|
45
|
+
super().__init__(*args, **kwargs)
|
|
46
|
+
self._device = None
|
|
47
|
+
self._tokenizer = None
|
|
48
|
+
self._model = None
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def match(
|
|
52
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
53
|
+
) -> bool:
|
|
54
|
+
family = model_family.model_family or model_family.model_name
|
|
55
|
+
if "glm-4v" in family.lower():
|
|
56
|
+
return True
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
def load(self, **kwargs):
|
|
60
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
61
|
+
|
|
62
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
63
|
+
self._device = select_device(device)
|
|
64
|
+
self._device = "auto" if self._device == "cuda" else self._device
|
|
65
|
+
|
|
66
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
67
|
+
self.model_path,
|
|
68
|
+
low_cpu_mem_usage=True,
|
|
69
|
+
trust_remote_code=True,
|
|
70
|
+
torch_dtype=torch.float16,
|
|
71
|
+
device_map=self._device,
|
|
72
|
+
)
|
|
73
|
+
self._model = model.eval()
|
|
74
|
+
|
|
75
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
76
|
+
self.model_path, trust_remote_code=True
|
|
77
|
+
)
|
|
78
|
+
self._tokenizer = tokenizer
|
|
79
|
+
|
|
80
|
+
def _message_content_to_chat(self, content):
|
|
81
|
+
def _load_image(_url):
|
|
82
|
+
if _url.startswith("data:"):
|
|
83
|
+
logging.info("Parse url by base64 decoder.")
|
|
84
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
85
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
86
|
+
_type, data = _url.split(";")
|
|
87
|
+
_, ext = _type.split("/")
|
|
88
|
+
data = data[len("base64,") :]
|
|
89
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
90
|
+
return Image.open(BytesIO(data)).convert("RGB")
|
|
91
|
+
else:
|
|
92
|
+
try:
|
|
93
|
+
response = requests.get(_url)
|
|
94
|
+
except requests.exceptions.MissingSchema:
|
|
95
|
+
return Image.open(_url).convert("RGB")
|
|
96
|
+
else:
|
|
97
|
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
98
|
+
|
|
99
|
+
if not isinstance(content, str):
|
|
100
|
+
texts = []
|
|
101
|
+
image_urls = []
|
|
102
|
+
for c in content:
|
|
103
|
+
c_type = c.get("type")
|
|
104
|
+
if c_type == "text":
|
|
105
|
+
texts.append(c["text"])
|
|
106
|
+
elif c_type == "image_url":
|
|
107
|
+
image_urls.append(c["image_url"]["url"])
|
|
108
|
+
image_futures = []
|
|
109
|
+
with ThreadPoolExecutor() as executor:
|
|
110
|
+
for image_url in image_urls:
|
|
111
|
+
fut = executor.submit(_load_image, image_url)
|
|
112
|
+
image_futures.append(fut)
|
|
113
|
+
images = [fut.result() for fut in image_futures]
|
|
114
|
+
# images = []
|
|
115
|
+
# for image_url in image_urls:
|
|
116
|
+
# images.append(_load_image(image_url))
|
|
117
|
+
text = " ".join(texts)
|
|
118
|
+
if len(images) == 0:
|
|
119
|
+
return text, []
|
|
120
|
+
elif len(images) == 1:
|
|
121
|
+
return text, images
|
|
122
|
+
else:
|
|
123
|
+
raise RuntimeError("Only one image per message is supported")
|
|
124
|
+
return content, []
|
|
125
|
+
|
|
126
|
+
def chat(
|
|
127
|
+
self,
|
|
128
|
+
prompt: Union[str, List[Dict]],
|
|
129
|
+
system_prompt: Optional[str] = None,
|
|
130
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
131
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
132
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
133
|
+
from transformers import TextIteratorStreamer
|
|
134
|
+
|
|
135
|
+
if not generate_config:
|
|
136
|
+
generate_config = {}
|
|
137
|
+
|
|
138
|
+
stream = generate_config.get("stream", False)
|
|
139
|
+
content, images_chat = self._message_content_to_chat(prompt)
|
|
140
|
+
|
|
141
|
+
msgs = []
|
|
142
|
+
query_to_response: List[Dict] = []
|
|
143
|
+
images_history = []
|
|
144
|
+
for h in chat_history or []:
|
|
145
|
+
role = h["role"]
|
|
146
|
+
content_h, images_tmp = self._message_content_to_chat(h["content"])
|
|
147
|
+
if images_tmp != []:
|
|
148
|
+
images_history = images_tmp
|
|
149
|
+
if len(query_to_response) == 0 and role == "user":
|
|
150
|
+
query_to_response.append({"role": "user", "content": content_h})
|
|
151
|
+
if len(query_to_response) == 1 and role == "assistant":
|
|
152
|
+
query_to_response.append({"role": "assistant", "content": content_h})
|
|
153
|
+
if len(query_to_response) == 2:
|
|
154
|
+
msgs.extend(query_to_response)
|
|
155
|
+
query_to_response = []
|
|
156
|
+
image = None
|
|
157
|
+
if len(images_chat) > 0:
|
|
158
|
+
image = images_chat[0]
|
|
159
|
+
elif len(images_history) > 0:
|
|
160
|
+
image = images_history[0]
|
|
161
|
+
msgs.append({"role": "user", "content": content, "image": image})
|
|
162
|
+
|
|
163
|
+
inputs = self._tokenizer.apply_chat_template(
|
|
164
|
+
msgs,
|
|
165
|
+
add_generation_prompt=True,
|
|
166
|
+
tokenize=True,
|
|
167
|
+
return_tensors="pt",
|
|
168
|
+
return_dict=True,
|
|
169
|
+
) # chat mode
|
|
170
|
+
inputs = inputs.to(self._model.device)
|
|
171
|
+
|
|
172
|
+
generate_kwargs = {
|
|
173
|
+
**inputs,
|
|
174
|
+
"eos_token_id": [151329, 151336, 151338],
|
|
175
|
+
"do_sample": True,
|
|
176
|
+
"max_length": generate_config.get("max_tokens", 2048),
|
|
177
|
+
"temperature": generate_config.get("temperature", 0.7),
|
|
178
|
+
}
|
|
179
|
+
stop_str = "<|endoftext|>"
|
|
180
|
+
|
|
181
|
+
if stream:
|
|
182
|
+
streamer = TextIteratorStreamer(
|
|
183
|
+
tokenizer=self._tokenizer,
|
|
184
|
+
timeout=60,
|
|
185
|
+
skip_prompt=True,
|
|
186
|
+
skip_special_tokens=True,
|
|
187
|
+
)
|
|
188
|
+
generate_kwargs = {
|
|
189
|
+
**generate_kwargs,
|
|
190
|
+
"streamer": streamer,
|
|
191
|
+
}
|
|
192
|
+
t = Thread(target=self._model.generate, kwargs=generate_kwargs)
|
|
193
|
+
t.start()
|
|
194
|
+
|
|
195
|
+
it = self.chat_stream(streamer, stop_str)
|
|
196
|
+
return self._to_chat_completion_chunks(it)
|
|
197
|
+
else:
|
|
198
|
+
with torch.no_grad():
|
|
199
|
+
outputs = self._model.generate(**generate_kwargs)
|
|
200
|
+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
|
|
201
|
+
response = self._tokenizer.decode(outputs[0])
|
|
202
|
+
if response.endswith(stop_str):
|
|
203
|
+
response = response[: -len(stop_str)]
|
|
204
|
+
c = Completion(
|
|
205
|
+
id=str(uuid.uuid1()),
|
|
206
|
+
object="text_completion",
|
|
207
|
+
created=int(time.time()),
|
|
208
|
+
model=self.model_uid,
|
|
209
|
+
choices=[
|
|
210
|
+
CompletionChoice(
|
|
211
|
+
index=0, text=response, finish_reason="stop", logprobs=None
|
|
212
|
+
)
|
|
213
|
+
],
|
|
214
|
+
usage=CompletionUsage(
|
|
215
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
216
|
+
),
|
|
217
|
+
)
|
|
218
|
+
return self._to_chat_completion(c)
|
|
219
|
+
|
|
220
|
+
def chat_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
|
|
221
|
+
completion_id = str(uuid.uuid1())
|
|
222
|
+
for new_text in streamer:
|
|
223
|
+
if not new_text.endswith(stop_str):
|
|
224
|
+
completion_choice = CompletionChoice(
|
|
225
|
+
text=new_text, index=0, logprobs=None, finish_reason=None
|
|
226
|
+
)
|
|
227
|
+
chunk = CompletionChunk(
|
|
228
|
+
id=completion_id,
|
|
229
|
+
object="text_completion",
|
|
230
|
+
created=int(time.time()),
|
|
231
|
+
model=self.model_uid,
|
|
232
|
+
choices=[completion_choice],
|
|
233
|
+
)
|
|
234
|
+
completion_usage = CompletionUsage(
|
|
235
|
+
prompt_tokens=-1,
|
|
236
|
+
completion_tokens=-1,
|
|
237
|
+
total_tokens=-1,
|
|
238
|
+
)
|
|
239
|
+
chunk["usage"] = completion_usage
|
|
240
|
+
yield chunk
|
|
241
|
+
|
|
242
|
+
completion_choice = CompletionChoice(
|
|
243
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
244
|
+
)
|
|
245
|
+
chunk = CompletionChunk(
|
|
246
|
+
id=completion_id,
|
|
247
|
+
object="text_completion",
|
|
248
|
+
created=int(time.time()),
|
|
249
|
+
model=self.model_uid,
|
|
250
|
+
choices=[completion_choice],
|
|
251
|
+
)
|
|
252
|
+
completion_usage = CompletionUsage(
|
|
253
|
+
prompt_tokens=-1,
|
|
254
|
+
completion_tokens=-1,
|
|
255
|
+
total_tokens=-1,
|
|
256
|
+
)
|
|
257
|
+
chunk["usage"] = completion_usage
|
|
258
|
+
yield chunk
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
import uuid
|
|
19
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
+
from io import BytesIO
|
|
21
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
22
|
+
|
|
23
|
+
import requests
|
|
24
|
+
import torch
|
|
25
|
+
from PIL import Image
|
|
26
|
+
|
|
27
|
+
from ....types import (
|
|
28
|
+
ChatCompletion,
|
|
29
|
+
ChatCompletionChunk,
|
|
30
|
+
ChatCompletionMessage,
|
|
31
|
+
Completion,
|
|
32
|
+
CompletionChoice,
|
|
33
|
+
CompletionChunk,
|
|
34
|
+
CompletionUsage,
|
|
35
|
+
)
|
|
36
|
+
from ...utils import select_device
|
|
37
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
38
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MiniCPMV25Model(PytorchChatModel):
|
|
44
|
+
def __init__(self, *args, **kwargs):
|
|
45
|
+
super().__init__(*args, **kwargs)
|
|
46
|
+
self._device = None
|
|
47
|
+
self._tokenizer = None
|
|
48
|
+
self._model = None
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def match(
|
|
52
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
53
|
+
) -> bool:
|
|
54
|
+
family = model_family.model_family or model_family.model_name
|
|
55
|
+
if "MiniCPM-Llama3-V-2_5".lower() in family.lower():
|
|
56
|
+
return True
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
def load(self, **kwargs):
|
|
60
|
+
from transformers import AutoModel, AutoTokenizer
|
|
61
|
+
from transformers.generation import GenerationConfig
|
|
62
|
+
|
|
63
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
64
|
+
self._device = select_device(device)
|
|
65
|
+
self._device = "auto" if self._device == "cuda" else self._device
|
|
66
|
+
|
|
67
|
+
if "int4" in self.model_path:
|
|
68
|
+
if device == "mps":
|
|
69
|
+
print(
|
|
70
|
+
"Error: running int4 model with bitsandbytes on Mac is not supported right now."
|
|
71
|
+
)
|
|
72
|
+
exit()
|
|
73
|
+
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
|
|
74
|
+
else:
|
|
75
|
+
model = AutoModel.from_pretrained(
|
|
76
|
+
self.model_path,
|
|
77
|
+
trust_remote_code=True,
|
|
78
|
+
torch_dtype=torch.float16,
|
|
79
|
+
device_map=self._device,
|
|
80
|
+
)
|
|
81
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
82
|
+
self.model_path, trust_remote_code=True
|
|
83
|
+
)
|
|
84
|
+
self._model = model.eval()
|
|
85
|
+
self._tokenizer = tokenizer
|
|
86
|
+
|
|
87
|
+
# Specify hyperparameters for generation
|
|
88
|
+
self._model.generation_config = GenerationConfig.from_pretrained(
|
|
89
|
+
self.model_path,
|
|
90
|
+
trust_remote_code=True,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def _message_content_to_chat(self, content):
|
|
94
|
+
def _load_image(_url):
|
|
95
|
+
if _url.startswith("data:"):
|
|
96
|
+
logging.info("Parse url by base64 decoder.")
|
|
97
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
98
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
99
|
+
_type, data = _url.split(";")
|
|
100
|
+
_, ext = _type.split("/")
|
|
101
|
+
data = data[len("base64,") :]
|
|
102
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
103
|
+
return Image.open(BytesIO(data)).convert("RGB")
|
|
104
|
+
else:
|
|
105
|
+
try:
|
|
106
|
+
response = requests.get(_url)
|
|
107
|
+
except requests.exceptions.MissingSchema:
|
|
108
|
+
return Image.open(_url).convert("RGB")
|
|
109
|
+
else:
|
|
110
|
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
111
|
+
|
|
112
|
+
if not isinstance(content, str):
|
|
113
|
+
texts = []
|
|
114
|
+
image_urls = []
|
|
115
|
+
for c in content:
|
|
116
|
+
c_type = c.get("type")
|
|
117
|
+
if c_type == "text":
|
|
118
|
+
texts.append(c["text"])
|
|
119
|
+
elif c_type == "image_url":
|
|
120
|
+
image_urls.append(c["image_url"]["url"])
|
|
121
|
+
image_futures = []
|
|
122
|
+
with ThreadPoolExecutor() as executor:
|
|
123
|
+
for image_url in image_urls:
|
|
124
|
+
fut = executor.submit(_load_image, image_url)
|
|
125
|
+
image_futures.append(fut)
|
|
126
|
+
images = [fut.result() for fut in image_futures]
|
|
127
|
+
text = " ".join(texts)
|
|
128
|
+
if len(images) == 0:
|
|
129
|
+
return text, []
|
|
130
|
+
elif len(images) == 1:
|
|
131
|
+
return text, images
|
|
132
|
+
else:
|
|
133
|
+
raise RuntimeError("Only one image per message is supported")
|
|
134
|
+
return content, []
|
|
135
|
+
|
|
136
|
+
def chat(
|
|
137
|
+
self,
|
|
138
|
+
prompt: Union[str, List[Dict]],
|
|
139
|
+
system_prompt: Optional[str] = None,
|
|
140
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
141
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
142
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
143
|
+
stream = generate_config.get("stream", False) if generate_config else False
|
|
144
|
+
content, images_chat = self._message_content_to_chat(prompt)
|
|
145
|
+
|
|
146
|
+
msgs = []
|
|
147
|
+
query_to_response: List[Dict] = []
|
|
148
|
+
images_history = []
|
|
149
|
+
for h in chat_history or []:
|
|
150
|
+
role = h["role"]
|
|
151
|
+
content_h, images_tmp = self._message_content_to_chat(h["content"])
|
|
152
|
+
if images_tmp != []:
|
|
153
|
+
images_history = images_tmp
|
|
154
|
+
if len(query_to_response) == 0 and role == "user":
|
|
155
|
+
query_to_response.append({"role": "user", "content": content_h})
|
|
156
|
+
if len(query_to_response) == 1 and role == "assistant":
|
|
157
|
+
query_to_response.append({"role": "assistant", "content": content_h})
|
|
158
|
+
if len(query_to_response) == 2:
|
|
159
|
+
msgs.extend(query_to_response)
|
|
160
|
+
query_to_response = []
|
|
161
|
+
image = None
|
|
162
|
+
if len(images_chat) > 0:
|
|
163
|
+
image = images_chat[0]
|
|
164
|
+
elif len(images_history) > 0:
|
|
165
|
+
image = images_history[0]
|
|
166
|
+
msgs.append({"role": "user", "content": content})
|
|
167
|
+
|
|
168
|
+
chat = self._model.chat(
|
|
169
|
+
image=image,
|
|
170
|
+
msgs=json.dumps(msgs, ensure_ascii=True),
|
|
171
|
+
tokenizer=self._tokenizer,
|
|
172
|
+
sampling=True,
|
|
173
|
+
**generate_config
|
|
174
|
+
)
|
|
175
|
+
if stream:
|
|
176
|
+
it = self.chat_stream(chat)
|
|
177
|
+
return self._to_chat_completion_chunks(it)
|
|
178
|
+
else:
|
|
179
|
+
c = Completion(
|
|
180
|
+
id=str(uuid.uuid1()),
|
|
181
|
+
object="text_completion",
|
|
182
|
+
created=int(time.time()),
|
|
183
|
+
model=self.model_uid,
|
|
184
|
+
choices=[
|
|
185
|
+
CompletionChoice(
|
|
186
|
+
index=0, text=chat, finish_reason="stop", logprobs=None
|
|
187
|
+
)
|
|
188
|
+
],
|
|
189
|
+
usage=CompletionUsage(
|
|
190
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
return self._to_chat_completion(c)
|
|
194
|
+
|
|
195
|
+
def chat_stream(self, chat) -> Iterator[CompletionChunk]:
|
|
196
|
+
completion_id = str(uuid.uuid1())
|
|
197
|
+
for new_text in chat:
|
|
198
|
+
completion_choice = CompletionChoice(
|
|
199
|
+
text=new_text, index=0, logprobs=None, finish_reason=None
|
|
200
|
+
)
|
|
201
|
+
chunk = CompletionChunk(
|
|
202
|
+
id=completion_id,
|
|
203
|
+
object="text_completion",
|
|
204
|
+
created=int(time.time()),
|
|
205
|
+
model=self.model_uid,
|
|
206
|
+
choices=[completion_choice],
|
|
207
|
+
)
|
|
208
|
+
completion_usage = CompletionUsage(
|
|
209
|
+
prompt_tokens=-1,
|
|
210
|
+
completion_tokens=-1,
|
|
211
|
+
total_tokens=-1,
|
|
212
|
+
)
|
|
213
|
+
chunk["usage"] = completion_usage
|
|
214
|
+
yield chunk
|
|
215
|
+
|
|
216
|
+
completion_choice = CompletionChoice(
|
|
217
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
218
|
+
)
|
|
219
|
+
chunk = CompletionChunk(
|
|
220
|
+
id=completion_id,
|
|
221
|
+
object="text_completion",
|
|
222
|
+
created=int(time.time()),
|
|
223
|
+
model=self.model_uid,
|
|
224
|
+
choices=[completion_choice],
|
|
225
|
+
)
|
|
226
|
+
completion_usage = CompletionUsage(
|
|
227
|
+
prompt_tokens=-1,
|
|
228
|
+
completion_tokens=-1,
|
|
229
|
+
total_tokens=-1,
|
|
230
|
+
)
|
|
231
|
+
chunk["usage"] = completion_usage
|
|
232
|
+
yield chunk
|