xinference 0.11.2.post1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +83 -8
- xinference/client/restful/restful_client.py +70 -0
- xinference/constants.py +8 -0
- xinference/core/__init__.py +0 -2
- xinference/core/cache_tracker.py +22 -1
- xinference/core/chat_interface.py +71 -10
- xinference/core/model.py +141 -12
- xinference/core/scheduler.py +428 -0
- xinference/core/supervisor.py +31 -3
- xinference/core/worker.py +8 -3
- xinference/isolation.py +9 -2
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +10 -3
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +1063 -260
- xinference/model/llm/llm_family_modelscope.json +686 -13
- xinference/model/llm/pytorch/baichuan.py +2 -1
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/cogvlm2.py +316 -0
- xinference/model/llm/pytorch/core.py +92 -6
- xinference/model/llm/pytorch/glm4v.py +258 -0
- xinference/model/llm/pytorch/intern_vl.py +5 -10
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/utils.py +386 -2
- xinference/model/llm/vllm/core.py +7 -1
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/types.py +3 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/METADATA +28 -11
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/RECORD +36 -29
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
|
@@ -73,7 +73,8 @@ class BaichuanPytorchChatModel(PytorchChatModel):
|
|
|
73
73
|
) -> bool:
|
|
74
74
|
if llm_spec.model_format != "pytorch":
|
|
75
75
|
return False
|
|
76
|
-
|
|
76
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
77
|
+
if model_family not in ["baichuan-chat", "baichuan-2-chat"]:
|
|
77
78
|
return False
|
|
78
79
|
if "chat" not in llm_family.model_ability:
|
|
79
80
|
return False
|
|
@@ -82,7 +82,8 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
82
82
|
) -> bool:
|
|
83
83
|
if llm_spec.model_format != "pytorch":
|
|
84
84
|
return False
|
|
85
|
-
|
|
85
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
86
|
+
if "chatglm" not in model_family and "glm4" not in model_family:
|
|
86
87
|
return False
|
|
87
88
|
if "chat" not in llm_family.model_ability:
|
|
88
89
|
return False
|
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
|
21
|
+
|
|
22
|
+
import requests
|
|
23
|
+
import torch
|
|
24
|
+
from PIL import Image
|
|
25
|
+
|
|
26
|
+
from ....model.utils import select_device
|
|
27
|
+
from ....types import (
|
|
28
|
+
ChatCompletion,
|
|
29
|
+
ChatCompletionChunk,
|
|
30
|
+
ChatCompletionMessage,
|
|
31
|
+
Completion,
|
|
32
|
+
CompletionChoice,
|
|
33
|
+
CompletionChunk,
|
|
34
|
+
CompletionUsage,
|
|
35
|
+
)
|
|
36
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
37
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
42
|
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CogVLM2Model(PytorchChatModel):
|
|
46
|
+
def __init__(self, *args, **kwargs):
|
|
47
|
+
super().__init__(*args, **kwargs)
|
|
48
|
+
self._torch_type = None
|
|
49
|
+
self._device = None
|
|
50
|
+
self._tokenizer = None
|
|
51
|
+
self._model = None
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def match(
|
|
55
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
56
|
+
) -> bool:
|
|
57
|
+
family = model_family.model_family or model_family.model_name
|
|
58
|
+
if "cogvlm" in family.lower():
|
|
59
|
+
return True
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
def load(self, **kwargs):
|
|
63
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
64
|
+
from transformers.generation import GenerationConfig
|
|
65
|
+
|
|
66
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
67
|
+
self._device = select_device(device)
|
|
68
|
+
self._torch_type = (
|
|
69
|
+
torch.bfloat16
|
|
70
|
+
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
|
71
|
+
else torch.float16
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
75
|
+
self.model_path,
|
|
76
|
+
trust_remote_code=True,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
80
|
+
self.model_path,
|
|
81
|
+
torch_dtype=self._torch_type,
|
|
82
|
+
trust_remote_code=True,
|
|
83
|
+
low_cpu_mem_usage=True,
|
|
84
|
+
device_map="auto",
|
|
85
|
+
).eval()
|
|
86
|
+
|
|
87
|
+
# Specify hyperparameters for generation
|
|
88
|
+
self._model.generation_config = GenerationConfig.from_pretrained(
|
|
89
|
+
self.model_path,
|
|
90
|
+
trust_remote_code=True,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def _message_content_to_cogvlm2(self, content):
|
|
94
|
+
def _load_image(_url):
|
|
95
|
+
if _url.startswith("data:"):
|
|
96
|
+
logging.info("Parse url by base64 decoder.")
|
|
97
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
98
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
99
|
+
_type, data = _url.split(";")
|
|
100
|
+
_, ext = _type.split("/")
|
|
101
|
+
data = data[len("base64,") :]
|
|
102
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
103
|
+
return Image.open(BytesIO(data)).convert("RGB")
|
|
104
|
+
else:
|
|
105
|
+
try:
|
|
106
|
+
response = requests.get(_url)
|
|
107
|
+
except requests.exceptions.MissingSchema:
|
|
108
|
+
return Image.open(_url).convert("RGB")
|
|
109
|
+
else:
|
|
110
|
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
111
|
+
|
|
112
|
+
if not isinstance(content, str):
|
|
113
|
+
texts = []
|
|
114
|
+
image_urls = []
|
|
115
|
+
for c in content:
|
|
116
|
+
c_type = c.get("type")
|
|
117
|
+
if c_type == "text":
|
|
118
|
+
texts.append(c["text"])
|
|
119
|
+
elif c_type == "image_url":
|
|
120
|
+
image_urls.append(c["image_url"]["url"])
|
|
121
|
+
image_futures = []
|
|
122
|
+
with ThreadPoolExecutor() as executor:
|
|
123
|
+
for image_url in image_urls:
|
|
124
|
+
fut = executor.submit(_load_image, image_url)
|
|
125
|
+
image_futures.append(fut)
|
|
126
|
+
images = [fut.result() for fut in image_futures]
|
|
127
|
+
text = " ".join(texts)
|
|
128
|
+
if len(images) == 0:
|
|
129
|
+
return text, None
|
|
130
|
+
elif len(images) == 1:
|
|
131
|
+
return text, images
|
|
132
|
+
else:
|
|
133
|
+
raise RuntimeError(
|
|
134
|
+
"Only one image per message is supported by CogVLM2."
|
|
135
|
+
)
|
|
136
|
+
return content, None
|
|
137
|
+
|
|
138
|
+
def _history_content_to_cogvlm2(
|
|
139
|
+
self, system_prompt: str, chat_history: List[ChatCompletionMessage]
|
|
140
|
+
):
|
|
141
|
+
def _image_to_piexl_values(image):
|
|
142
|
+
if image.startswith("data:"):
|
|
143
|
+
logging.info("Parse url by base64 decoder.")
|
|
144
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
145
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
146
|
+
_type, data = image.split(";")
|
|
147
|
+
_, ext = _type.split("/")
|
|
148
|
+
data = data[len("base64,") :]
|
|
149
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
150
|
+
return Image.open(BytesIO(data)).convert("RGB")
|
|
151
|
+
else:
|
|
152
|
+
try:
|
|
153
|
+
response = requests.get(image)
|
|
154
|
+
except requests.exceptions.MissingSchema:
|
|
155
|
+
return Image.open(image).convert("RGB")
|
|
156
|
+
else:
|
|
157
|
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
158
|
+
|
|
159
|
+
query = system_prompt
|
|
160
|
+
history: List[Tuple] = []
|
|
161
|
+
pixel_values = None
|
|
162
|
+
for i in range(0, len(chat_history), 2):
|
|
163
|
+
user = chat_history[i]["content"]
|
|
164
|
+
if isinstance(user, List):
|
|
165
|
+
for content in user:
|
|
166
|
+
c_type = content.get("type")
|
|
167
|
+
if c_type == "text":
|
|
168
|
+
user = content["text"]
|
|
169
|
+
elif c_type == "image_url" and not pixel_values:
|
|
170
|
+
pixel_values = _image_to_piexl_values(
|
|
171
|
+
content["image_url"]["url"]
|
|
172
|
+
)
|
|
173
|
+
assistant = chat_history[i + 1]["content"]
|
|
174
|
+
query = query + f" USER: {user} ASSISTANT:"
|
|
175
|
+
history.append((query, assistant))
|
|
176
|
+
query = query + f" {assistant}"
|
|
177
|
+
return query, history, [pixel_values]
|
|
178
|
+
|
|
179
|
+
def chat(
|
|
180
|
+
self,
|
|
181
|
+
prompt: Union[str, List[Dict]],
|
|
182
|
+
system_prompt: Optional[str] = None,
|
|
183
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
184
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
185
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
186
|
+
system_prompt = system_prompt if system_prompt else ""
|
|
187
|
+
stream = generate_config.get("stream", False) if generate_config else False
|
|
188
|
+
|
|
189
|
+
sanitized_config = {
|
|
190
|
+
"pad_token_id": 128002,
|
|
191
|
+
"max_new_tokens": generate_config.get("max_tokens", 512)
|
|
192
|
+
if generate_config
|
|
193
|
+
else 512,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
content, image = self._message_content_to_cogvlm2(prompt)
|
|
197
|
+
|
|
198
|
+
history = []
|
|
199
|
+
query = ""
|
|
200
|
+
history_image = None
|
|
201
|
+
if chat_history:
|
|
202
|
+
query, history, history_image = self._history_content_to_cogvlm2(
|
|
203
|
+
system_prompt, chat_history
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if image and history_image:
|
|
207
|
+
history = []
|
|
208
|
+
query = system_prompt + f" USER: {content} ASSISTANT:"
|
|
209
|
+
else:
|
|
210
|
+
image = image if image else history_image
|
|
211
|
+
query = query + f" USER: {content} ASSISTANT:"
|
|
212
|
+
|
|
213
|
+
input_by_model = self._model.build_conversation_input_ids(
|
|
214
|
+
self._tokenizer,
|
|
215
|
+
query=query,
|
|
216
|
+
history=history,
|
|
217
|
+
images=image,
|
|
218
|
+
template_version="chat",
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
inputs = {
|
|
222
|
+
"input_ids": input_by_model["input_ids"].unsqueeze(0).to(self._device),
|
|
223
|
+
"token_type_ids": input_by_model["token_type_ids"]
|
|
224
|
+
.unsqueeze(0)
|
|
225
|
+
.to(self._device),
|
|
226
|
+
"attention_mask": input_by_model["attention_mask"]
|
|
227
|
+
.unsqueeze(0)
|
|
228
|
+
.to(self._device),
|
|
229
|
+
"images": [
|
|
230
|
+
[input_by_model["images"][0].to(self._device).to(self._torch_type)]
|
|
231
|
+
]
|
|
232
|
+
if image is not None
|
|
233
|
+
else None,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
if stream:
|
|
237
|
+
it = self._streaming_chat_response(inputs, sanitized_config)
|
|
238
|
+
return self._to_chat_completion_chunks(it)
|
|
239
|
+
else:
|
|
240
|
+
with torch.no_grad():
|
|
241
|
+
outputs = self._model.generate(**inputs, **sanitized_config)
|
|
242
|
+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
|
|
243
|
+
response = self._tokenizer.decode(outputs[0])
|
|
244
|
+
response = response.split("<|end_of_text|>")[0]
|
|
245
|
+
|
|
246
|
+
chunk = Completion(
|
|
247
|
+
id=str(uuid.uuid1()),
|
|
248
|
+
object="text_completion",
|
|
249
|
+
created=int(time.time()),
|
|
250
|
+
model=self.model_uid,
|
|
251
|
+
choices=[
|
|
252
|
+
CompletionChoice(
|
|
253
|
+
index=0, text=response, finish_reason="stop", logprobs=None
|
|
254
|
+
)
|
|
255
|
+
],
|
|
256
|
+
usage=CompletionUsage(
|
|
257
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
258
|
+
),
|
|
259
|
+
)
|
|
260
|
+
return self._to_chat_completion(chunk)
|
|
261
|
+
|
|
262
|
+
def _streaming_chat_response(
|
|
263
|
+
self, inputs: Dict, config: Dict
|
|
264
|
+
) -> Iterator[CompletionChunk]:
|
|
265
|
+
from threading import Thread
|
|
266
|
+
|
|
267
|
+
from transformers import TextIteratorStreamer
|
|
268
|
+
|
|
269
|
+
streamer = TextIteratorStreamer(
|
|
270
|
+
self._tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
271
|
+
)
|
|
272
|
+
generation_kwargs = {
|
|
273
|
+
"input_ids": inputs["input_ids"],
|
|
274
|
+
"attention_mask": inputs["attention_mask"],
|
|
275
|
+
"token_type_ids": inputs["token_type_ids"],
|
|
276
|
+
"images": inputs["images"],
|
|
277
|
+
"max_new_tokens": config["max_new_tokens"],
|
|
278
|
+
"pad_token_id": config["pad_token_id"],
|
|
279
|
+
"streamer": streamer,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
|
|
283
|
+
thread.start()
|
|
284
|
+
|
|
285
|
+
completion_id = str(uuid.uuid1())
|
|
286
|
+
for new_text in streamer:
|
|
287
|
+
chunk = CompletionChunk(
|
|
288
|
+
id=completion_id,
|
|
289
|
+
object="text_completion",
|
|
290
|
+
created=int(time.time()),
|
|
291
|
+
model=self.model_uid,
|
|
292
|
+
choices=[
|
|
293
|
+
CompletionChoice(
|
|
294
|
+
index=0, text=new_text, finish_reason=None, logprobs=None
|
|
295
|
+
)
|
|
296
|
+
],
|
|
297
|
+
usage=CompletionUsage(
|
|
298
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
299
|
+
),
|
|
300
|
+
)
|
|
301
|
+
yield chunk
|
|
302
|
+
|
|
303
|
+
completion_choice = CompletionChoice(
|
|
304
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
305
|
+
)
|
|
306
|
+
chunk = CompletionChunk(
|
|
307
|
+
id=completion_id,
|
|
308
|
+
object="text_completion",
|
|
309
|
+
created=int(time.time()),
|
|
310
|
+
model=self.model_uid,
|
|
311
|
+
choices=[completion_choice],
|
|
312
|
+
usage=CompletionUsage(
|
|
313
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
314
|
+
),
|
|
315
|
+
)
|
|
316
|
+
yield chunk
|
|
@@ -17,6 +17,7 @@ import logging
|
|
|
17
17
|
import os
|
|
18
18
|
from typing import Iterable, Iterator, List, Optional, Union
|
|
19
19
|
|
|
20
|
+
from ....core.scheduler import InferenceRequest
|
|
20
21
|
from ....device_utils import (
|
|
21
22
|
get_device_preferred_dtype,
|
|
22
23
|
gpu_count,
|
|
@@ -40,6 +41,7 @@ from ...utils import select_device
|
|
|
40
41
|
from ..core import LLM
|
|
41
42
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
42
43
|
from ..utils import ChatModelMixin
|
|
44
|
+
from .utils import get_context_length, get_max_src_len
|
|
43
45
|
|
|
44
46
|
logger = logging.getLogger(__name__)
|
|
45
47
|
|
|
@@ -53,6 +55,11 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
53
55
|
"chatglm2",
|
|
54
56
|
"chatglm2-32k",
|
|
55
57
|
"chatglm2-128k",
|
|
58
|
+
"chatglm3",
|
|
59
|
+
"chatglm3-32k",
|
|
60
|
+
"chatglm3-128k",
|
|
61
|
+
"glm4-chat",
|
|
62
|
+
"glm4-chat-1m",
|
|
56
63
|
"llama-2",
|
|
57
64
|
"llama-2-chat",
|
|
58
65
|
"internlm2-chat",
|
|
@@ -62,6 +69,9 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
|
|
|
62
69
|
"deepseek-vl-chat",
|
|
63
70
|
"internvl-chat",
|
|
64
71
|
"mini-internvl-chat",
|
|
72
|
+
"cogvlm2",
|
|
73
|
+
"MiniCPM-Llama3-V-2_5",
|
|
74
|
+
"glm-4v",
|
|
65
75
|
]
|
|
66
76
|
|
|
67
77
|
|
|
@@ -95,6 +105,7 @@ class PytorchModel(LLM):
|
|
|
95
105
|
pytorch_model_config.setdefault("gptq_act_order", False)
|
|
96
106
|
pytorch_model_config.setdefault("device", "auto")
|
|
97
107
|
pytorch_model_config.setdefault("trust_remote_code", True)
|
|
108
|
+
pytorch_model_config.setdefault("max_num_seqs", 16)
|
|
98
109
|
return pytorch_model_config
|
|
99
110
|
|
|
100
111
|
def _sanitize_generate_config(
|
|
@@ -453,6 +464,7 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
453
464
|
pytorch_model_config,
|
|
454
465
|
peft_model,
|
|
455
466
|
)
|
|
467
|
+
self._context_len = None
|
|
456
468
|
|
|
457
469
|
def _sanitize_generate_config(
|
|
458
470
|
self,
|
|
@@ -496,13 +508,8 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
496
508
|
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
497
509
|
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
498
510
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
499
|
-
assert self.model_family.prompt_style is not None
|
|
500
|
-
prompt_style = self.model_family.prompt_style.copy()
|
|
501
|
-
if system_prompt:
|
|
502
|
-
prompt_style.system_prompt = system_prompt
|
|
503
|
-
chat_history = chat_history or []
|
|
504
511
|
tools = generate_config.pop("tools", []) if generate_config else None
|
|
505
|
-
full_prompt = self.
|
|
512
|
+
full_prompt = self._get_full_prompt(prompt, system_prompt, chat_history, tools)
|
|
506
513
|
|
|
507
514
|
generate_config = self._sanitize_generate_config(generate_config)
|
|
508
515
|
# TODO(codingl2k1): qwen hacky to set stop for function call.
|
|
@@ -530,3 +537,82 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
530
537
|
self.model_family, self.model_uid, c, tools
|
|
531
538
|
)
|
|
532
539
|
return self._to_chat_completion(c)
|
|
540
|
+
|
|
541
|
+
def load(self):
|
|
542
|
+
super().load()
|
|
543
|
+
self._context_len = get_context_length(self._model.config)
|
|
544
|
+
|
|
545
|
+
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
546
|
+
assert self.model_family.prompt_style is not None
|
|
547
|
+
prompt_style = self.model_family.prompt_style.copy()
|
|
548
|
+
if system_prompt:
|
|
549
|
+
prompt_style.system_prompt = system_prompt
|
|
550
|
+
chat_history = chat_history or []
|
|
551
|
+
full_prompt = ChatModelMixin.get_prompt(
|
|
552
|
+
prompt, chat_history, prompt_style, tools=tools
|
|
553
|
+
)
|
|
554
|
+
return full_prompt
|
|
555
|
+
|
|
556
|
+
def get_max_num_seqs(self) -> int:
|
|
557
|
+
return self._pytorch_model_config.get("max_num_seqs") # type: ignore
|
|
558
|
+
|
|
559
|
+
def batch_inference(self, req_list: List[InferenceRequest]):
|
|
560
|
+
from .utils import batch_inference_one_step
|
|
561
|
+
|
|
562
|
+
for r in req_list:
|
|
563
|
+
if r.sanitized_generate_config is None:
|
|
564
|
+
r.sanitized_generate_config = self._sanitize_generate_config(
|
|
565
|
+
r.generate_config
|
|
566
|
+
)
|
|
567
|
+
if r.is_prefill:
|
|
568
|
+
# check some generate params
|
|
569
|
+
max_src_len = get_max_src_len(self._context_len, r) # type: ignore
|
|
570
|
+
if max_src_len < 0:
|
|
571
|
+
r.stopped = True
|
|
572
|
+
r.error_msg = "Max tokens exceeds model's max length"
|
|
573
|
+
continue
|
|
574
|
+
if r.stream_interval <= 0:
|
|
575
|
+
r.stopped = True
|
|
576
|
+
r.error_msg = "`stream_interval` must be greater than 0"
|
|
577
|
+
continue
|
|
578
|
+
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
579
|
+
if stop_str and (
|
|
580
|
+
not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
|
|
581
|
+
):
|
|
582
|
+
r.stopped = True
|
|
583
|
+
r.error_msg = "Invalid `stop` field type"
|
|
584
|
+
continue
|
|
585
|
+
r.full_prompt = self._get_full_prompt(
|
|
586
|
+
r.prompt, r.system_prompt, r.chat_history, None
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
assert isinstance(self._context_len, int)
|
|
590
|
+
batch_inference_one_step(
|
|
591
|
+
req_list,
|
|
592
|
+
self.model_uid,
|
|
593
|
+
self._model,
|
|
594
|
+
self._tokenizer,
|
|
595
|
+
self._device,
|
|
596
|
+
self._context_len,
|
|
597
|
+
)
|
|
598
|
+
for req in req_list:
|
|
599
|
+
if req.stream and req.error_msg is None:
|
|
600
|
+
if req.completion:
|
|
601
|
+
results = []
|
|
602
|
+
for i, c in enumerate(req.completion):
|
|
603
|
+
if c == "<bos_stream>":
|
|
604
|
+
results.append(
|
|
605
|
+
self._get_first_chat_completion_chunk(
|
|
606
|
+
req.completion[i + 1]
|
|
607
|
+
)
|
|
608
|
+
)
|
|
609
|
+
elif c == "<eos_stream>":
|
|
610
|
+
break
|
|
611
|
+
else:
|
|
612
|
+
results.append(self._to_chat_completion_chunk(c))
|
|
613
|
+
|
|
614
|
+
if req.stopped and req.include_usage:
|
|
615
|
+
results.append(
|
|
616
|
+
self._get_final_chat_completion_chunk(req.completion[-1])
|
|
617
|
+
)
|
|
618
|
+
req.completion = results
|