xinference 0.14.0__py3-none-any.whl → 0.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +62 -1
- xinference/client/handlers.py +0 -3
- xinference/client/restful/restful_client.py +51 -134
- xinference/constants.py +1 -0
- xinference/core/chat_interface.py +1 -4
- xinference/core/image_interface.py +33 -5
- xinference/core/model.py +28 -2
- xinference/core/supervisor.py +37 -0
- xinference/core/worker.py +130 -84
- xinference/deploy/cmdline.py +1 -4
- xinference/model/audio/core.py +11 -3
- xinference/model/audio/funasr.py +114 -0
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +21 -0
- xinference/model/audio/whisper.py +1 -1
- xinference/model/core.py +12 -0
- xinference/model/embedding/core.py +6 -6
- xinference/model/image/core.py +3 -4
- xinference/model/image/model_spec.json +41 -13
- xinference/model/image/model_spec_modelscope.json +30 -10
- xinference/model/image/stable_diffusion/core.py +53 -2
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +83 -1
- xinference/model/llm/llm_family_modelscope.json +85 -1
- xinference/model/llm/pytorch/core.py +1 -0
- xinference/model/llm/pytorch/minicpmv26.py +247 -0
- xinference/model/llm/sglang/core.py +72 -34
- xinference/model/llm/vllm/core.py +38 -0
- xinference/model/video/__init__.py +62 -0
- xinference/model/video/core.py +178 -0
- xinference/model/video/diffusers.py +180 -0
- xinference/model/video/model_spec.json +11 -0
- xinference/model/video/model_spec_modelscope.json +12 -0
- xinference/types.py +10 -24
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.af906659.js → main.17ca0398.js} +3 -3
- xinference/web/ui/build/static/js/main.17ca0398.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +1 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/METADATA +128 -122
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/RECORD +49 -42
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/js/main.af906659.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +0 -1
- /xinference/web/ui/build/static/js/{main.af906659.js.LICENSE.txt → main.17ca0398.js.LICENSE.txt} +0 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/LICENSE +0 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.0.dist-info → xinference-0.14.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import base64
|
|
15
|
+
import json
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
import uuid
|
|
19
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
20
|
+
from io import BytesIO
|
|
21
|
+
from typing import Dict, Iterator, List, Optional, Union
|
|
22
|
+
|
|
23
|
+
import requests
|
|
24
|
+
import torch
|
|
25
|
+
from PIL import Image
|
|
26
|
+
|
|
27
|
+
from ....types import (
|
|
28
|
+
ChatCompletion,
|
|
29
|
+
ChatCompletionChunk,
|
|
30
|
+
ChatCompletionMessage,
|
|
31
|
+
Completion,
|
|
32
|
+
CompletionChoice,
|
|
33
|
+
CompletionChunk,
|
|
34
|
+
CompletionUsage,
|
|
35
|
+
)
|
|
36
|
+
from ...utils import select_device
|
|
37
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
38
|
+
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MiniCPMV26Model(PytorchChatModel):
|
|
44
|
+
def __init__(self, *args, **kwargs):
|
|
45
|
+
super().__init__(*args, **kwargs)
|
|
46
|
+
self._device = None
|
|
47
|
+
self._tokenizer = None
|
|
48
|
+
self._model = None
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def match(
|
|
52
|
+
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
|
|
53
|
+
) -> bool:
|
|
54
|
+
family = model_family.model_family or model_family.model_name
|
|
55
|
+
if "MiniCPM-V-2.6".lower() in family.lower():
|
|
56
|
+
return True
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
def _get_model_class(self):
|
|
60
|
+
from transformers import AutoModel
|
|
61
|
+
|
|
62
|
+
return AutoModel
|
|
63
|
+
|
|
64
|
+
def load(self, **kwargs):
|
|
65
|
+
from transformers import AutoModel, AutoTokenizer
|
|
66
|
+
from transformers.generation import GenerationConfig
|
|
67
|
+
|
|
68
|
+
device = self._pytorch_model_config.get("device", "auto")
|
|
69
|
+
self._device = select_device(device)
|
|
70
|
+
self._device = (
|
|
71
|
+
"auto"
|
|
72
|
+
if self._device == "cuda" and self.quantization is None
|
|
73
|
+
else self._device
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if "int4" in self.model_path and device == "mps":
|
|
77
|
+
logger.error(
|
|
78
|
+
"Error: running int4 model with bitsandbytes on Mac is not supported right now."
|
|
79
|
+
)
|
|
80
|
+
exit()
|
|
81
|
+
|
|
82
|
+
if self._check_tensorizer_integrity():
|
|
83
|
+
self._model, self._tokenizer = self._load_tensorizer()
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
if "int4" in self.model_path:
|
|
87
|
+
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
|
|
88
|
+
else:
|
|
89
|
+
model = AutoModel.from_pretrained(
|
|
90
|
+
self.model_path,
|
|
91
|
+
trust_remote_code=True,
|
|
92
|
+
torch_dtype=torch.float16,
|
|
93
|
+
device_map=self._device,
|
|
94
|
+
)
|
|
95
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
96
|
+
self.model_path, trust_remote_code=True
|
|
97
|
+
)
|
|
98
|
+
self._model = model.eval()
|
|
99
|
+
self._tokenizer = tokenizer
|
|
100
|
+
|
|
101
|
+
# Specify hyperparameters for generation
|
|
102
|
+
self._model.generation_config = GenerationConfig.from_pretrained(
|
|
103
|
+
self.model_path,
|
|
104
|
+
trust_remote_code=True,
|
|
105
|
+
)
|
|
106
|
+
self._save_tensorizer()
|
|
107
|
+
|
|
108
|
+
def _message_content_to_chat(self, content):
|
|
109
|
+
def _load_image(_url):
|
|
110
|
+
if _url.startswith("data:"):
|
|
111
|
+
logging.info("Parse url by base64 decoder.")
|
|
112
|
+
# https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
|
|
113
|
+
# e.g. f"data:image/jpeg;base64,{base64_image}"
|
|
114
|
+
_type, data = _url.split(";")
|
|
115
|
+
_, ext = _type.split("/")
|
|
116
|
+
data = data[len("base64,") :]
|
|
117
|
+
data = base64.b64decode(data.encode("utf-8"))
|
|
118
|
+
return Image.open(BytesIO(data)).convert("RGB")
|
|
119
|
+
else:
|
|
120
|
+
try:
|
|
121
|
+
response = requests.get(_url)
|
|
122
|
+
except requests.exceptions.MissingSchema:
|
|
123
|
+
return Image.open(_url).convert("RGB")
|
|
124
|
+
else:
|
|
125
|
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
|
126
|
+
|
|
127
|
+
if not isinstance(content, str):
|
|
128
|
+
texts = []
|
|
129
|
+
image_urls = []
|
|
130
|
+
for c in content:
|
|
131
|
+
c_type = c.get("type")
|
|
132
|
+
if c_type == "text":
|
|
133
|
+
texts.append(c["text"])
|
|
134
|
+
elif c_type == "image_url":
|
|
135
|
+
image_urls.append(c["image_url"]["url"])
|
|
136
|
+
image_futures = []
|
|
137
|
+
with ThreadPoolExecutor() as executor:
|
|
138
|
+
for image_url in image_urls:
|
|
139
|
+
fut = executor.submit(_load_image, image_url)
|
|
140
|
+
image_futures.append(fut)
|
|
141
|
+
images = [fut.result() for fut in image_futures]
|
|
142
|
+
text = " ".join(texts)
|
|
143
|
+
if len(images) == 0:
|
|
144
|
+
return text, []
|
|
145
|
+
elif len(images) == 1:
|
|
146
|
+
return text, images
|
|
147
|
+
else:
|
|
148
|
+
raise RuntimeError("Only one image per message is supported")
|
|
149
|
+
return content, []
|
|
150
|
+
|
|
151
|
+
def chat(
|
|
152
|
+
self,
|
|
153
|
+
prompt: Union[str, List[Dict]],
|
|
154
|
+
system_prompt: Optional[str] = None,
|
|
155
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
156
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
157
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
158
|
+
stream = generate_config.get("stream", False) if generate_config else False
|
|
159
|
+
content, images_chat = self._message_content_to_chat(prompt)
|
|
160
|
+
|
|
161
|
+
msgs = []
|
|
162
|
+
query_to_response: List[Dict] = []
|
|
163
|
+
images_history = []
|
|
164
|
+
for h in chat_history or []:
|
|
165
|
+
role = h["role"]
|
|
166
|
+
content_h, images_tmp = self._message_content_to_chat(h["content"])
|
|
167
|
+
if images_tmp != []:
|
|
168
|
+
images_history = images_tmp
|
|
169
|
+
if len(query_to_response) == 0 and role == "user":
|
|
170
|
+
query_to_response.append({"role": "user", "content": content_h})
|
|
171
|
+
if len(query_to_response) == 1 and role == "assistant":
|
|
172
|
+
query_to_response.append({"role": "assistant", "content": content_h})
|
|
173
|
+
if len(query_to_response) == 2:
|
|
174
|
+
msgs.extend(query_to_response)
|
|
175
|
+
query_to_response = []
|
|
176
|
+
image = None
|
|
177
|
+
if len(images_chat) > 0:
|
|
178
|
+
image = images_chat[0]
|
|
179
|
+
elif len(images_history) > 0:
|
|
180
|
+
image = images_history[0]
|
|
181
|
+
msgs.append({"role": "user", "content": content})
|
|
182
|
+
|
|
183
|
+
chat = self._model.chat(
|
|
184
|
+
image=image,
|
|
185
|
+
msgs=json.dumps(msgs, ensure_ascii=True),
|
|
186
|
+
tokenizer=self._tokenizer,
|
|
187
|
+
sampling=True,
|
|
188
|
+
**generate_config
|
|
189
|
+
)
|
|
190
|
+
if stream:
|
|
191
|
+
it = self.chat_stream(chat)
|
|
192
|
+
return self._to_chat_completion_chunks(it)
|
|
193
|
+
else:
|
|
194
|
+
c = Completion(
|
|
195
|
+
id=str(uuid.uuid1()),
|
|
196
|
+
object="text_completion",
|
|
197
|
+
created=int(time.time()),
|
|
198
|
+
model=self.model_uid,
|
|
199
|
+
choices=[
|
|
200
|
+
CompletionChoice(
|
|
201
|
+
index=0, text=chat, finish_reason="stop", logprobs=None
|
|
202
|
+
)
|
|
203
|
+
],
|
|
204
|
+
usage=CompletionUsage(
|
|
205
|
+
prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
|
|
206
|
+
),
|
|
207
|
+
)
|
|
208
|
+
return self._to_chat_completion(c)
|
|
209
|
+
|
|
210
|
+
def chat_stream(self, chat) -> Iterator[CompletionChunk]:
|
|
211
|
+
completion_id = str(uuid.uuid1())
|
|
212
|
+
for new_text in chat:
|
|
213
|
+
completion_choice = CompletionChoice(
|
|
214
|
+
text=new_text, index=0, logprobs=None, finish_reason=None
|
|
215
|
+
)
|
|
216
|
+
chunk = CompletionChunk(
|
|
217
|
+
id=completion_id,
|
|
218
|
+
object="text_completion",
|
|
219
|
+
created=int(time.time()),
|
|
220
|
+
model=self.model_uid,
|
|
221
|
+
choices=[completion_choice],
|
|
222
|
+
)
|
|
223
|
+
completion_usage = CompletionUsage(
|
|
224
|
+
prompt_tokens=-1,
|
|
225
|
+
completion_tokens=-1,
|
|
226
|
+
total_tokens=-1,
|
|
227
|
+
)
|
|
228
|
+
chunk["usage"] = completion_usage
|
|
229
|
+
yield chunk
|
|
230
|
+
|
|
231
|
+
completion_choice = CompletionChoice(
|
|
232
|
+
text="", index=0, logprobs=None, finish_reason="stop"
|
|
233
|
+
)
|
|
234
|
+
chunk = CompletionChunk(
|
|
235
|
+
id=completion_id,
|
|
236
|
+
object="text_completion",
|
|
237
|
+
created=int(time.time()),
|
|
238
|
+
model=self.model_uid,
|
|
239
|
+
choices=[completion_choice],
|
|
240
|
+
)
|
|
241
|
+
completion_usage = CompletionUsage(
|
|
242
|
+
prompt_tokens=-1,
|
|
243
|
+
completion_tokens=-1,
|
|
244
|
+
total_tokens=-1,
|
|
245
|
+
)
|
|
246
|
+
chunk["usage"] = completion_usage
|
|
247
|
+
yield chunk
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import json
|
|
15
16
|
import logging
|
|
16
17
|
import time
|
|
17
18
|
import uuid
|
|
@@ -122,6 +123,10 @@ class SGLANGModel(LLM):
|
|
|
122
123
|
**self._model_config,
|
|
123
124
|
)
|
|
124
125
|
|
|
126
|
+
def stop(self):
|
|
127
|
+
logger.info("Stopping SGLang engine")
|
|
128
|
+
self._engine.shutdown()
|
|
129
|
+
|
|
125
130
|
def _sanitize_model_config(
|
|
126
131
|
self, model_config: Optional[SGLANGModelConfig]
|
|
127
132
|
) -> SGLANGModelConfig:
|
|
@@ -132,18 +137,20 @@ class SGLANGModel(LLM):
|
|
|
132
137
|
model_config.setdefault("tokenizer_mode", "auto")
|
|
133
138
|
model_config.setdefault("trust_remote_code", True)
|
|
134
139
|
model_config.setdefault("tp_size", cuda_count)
|
|
135
|
-
# See https://github.com/sgl-project/sglang/blob/
|
|
136
|
-
mem_fraction_static = model_config.
|
|
140
|
+
# See https://github.com/sgl-project/sglang/blob/00023d622a6d484e67ef4a0e444f708b8fc861c8/python/sglang/srt/server_args.py#L100-L109
|
|
141
|
+
mem_fraction_static = model_config.get("mem_fraction_static")
|
|
137
142
|
if mem_fraction_static is None:
|
|
138
143
|
tp_size = model_config.get("tp_size", cuda_count)
|
|
139
|
-
if tp_size >=
|
|
140
|
-
model_config["mem_fraction_static"] = 0.
|
|
144
|
+
if tp_size >= 16:
|
|
145
|
+
model_config["mem_fraction_static"] = 0.79
|
|
146
|
+
elif tp_size >= 8:
|
|
147
|
+
model_config["mem_fraction_static"] = 0.83
|
|
141
148
|
elif tp_size >= 4:
|
|
142
|
-
model_config["mem_fraction_static"] = 0.82
|
|
143
|
-
elif tp_size >= 2:
|
|
144
149
|
model_config["mem_fraction_static"] = 0.85
|
|
150
|
+
elif tp_size >= 2:
|
|
151
|
+
model_config["mem_fraction_static"] = 0.87
|
|
145
152
|
else:
|
|
146
|
-
model_config["mem_fraction_static"] = 0.
|
|
153
|
+
model_config["mem_fraction_static"] = 0.88
|
|
147
154
|
model_config.setdefault("log_level", "info")
|
|
148
155
|
model_config.setdefault("attention_reduce_in_fp32", False)
|
|
149
156
|
|
|
@@ -249,28 +256,64 @@ class SGLANGModel(LLM):
|
|
|
249
256
|
usage=usage,
|
|
250
257
|
)
|
|
251
258
|
|
|
259
|
+
@classmethod
|
|
260
|
+
def _filter_sampling_params(cls, sampling_params: dict):
|
|
261
|
+
if not sampling_params.get("lora_name"):
|
|
262
|
+
sampling_params.pop("lora_name", None)
|
|
263
|
+
return sampling_params
|
|
264
|
+
|
|
265
|
+
async def _stream_generate(self, prompt: str, **sampling_params):
|
|
266
|
+
import aiohttp
|
|
267
|
+
|
|
268
|
+
sampling_params = self._filter_sampling_params(sampling_params)
|
|
269
|
+
json_data = {
|
|
270
|
+
"text": prompt,
|
|
271
|
+
"sampling_params": sampling_params,
|
|
272
|
+
"stream": True,
|
|
273
|
+
}
|
|
274
|
+
pos = 0
|
|
275
|
+
|
|
276
|
+
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
|
277
|
+
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
278
|
+
async with session.post(
|
|
279
|
+
self._engine.generate_url, json=json_data # type: ignore
|
|
280
|
+
) as response:
|
|
281
|
+
async for chunk, _ in response.content.iter_chunks():
|
|
282
|
+
chunk = chunk.decode("utf-8")
|
|
283
|
+
if chunk and chunk.startswith("data:"):
|
|
284
|
+
stop = "data: [DONE]\n\n"
|
|
285
|
+
need_stop = False
|
|
286
|
+
if chunk.endswith(stop):
|
|
287
|
+
chunk = chunk[: -len(stop)]
|
|
288
|
+
need_stop = True
|
|
289
|
+
if chunk:
|
|
290
|
+
data = json.loads(chunk[5:].strip("\n"))
|
|
291
|
+
cur = data["text"][pos:]
|
|
292
|
+
if cur:
|
|
293
|
+
yield data["meta_info"], cur
|
|
294
|
+
pos += len(cur)
|
|
295
|
+
if need_stop:
|
|
296
|
+
break
|
|
297
|
+
|
|
298
|
+
async def _non_stream_generate(self, prompt: str, **sampling_params) -> dict:
|
|
299
|
+
import aiohttp
|
|
300
|
+
|
|
301
|
+
sampling_params = self._filter_sampling_params(sampling_params)
|
|
302
|
+
json_data = {
|
|
303
|
+
"text": prompt,
|
|
304
|
+
"sampling_params": sampling_params,
|
|
305
|
+
}
|
|
306
|
+
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
307
|
+
async with session.post(
|
|
308
|
+
self._engine.generate_url, json=json_data # type: ignore
|
|
309
|
+
) as response:
|
|
310
|
+
return await response.json()
|
|
311
|
+
|
|
252
312
|
async def async_generate(
|
|
253
313
|
self,
|
|
254
314
|
prompt: str,
|
|
255
315
|
generate_config: Optional[SGLANGGenerateConfig] = None,
|
|
256
316
|
) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
|
|
257
|
-
try:
|
|
258
|
-
import sglang as sgl
|
|
259
|
-
from sglang import assistant, gen, user
|
|
260
|
-
except ImportError:
|
|
261
|
-
error_message = "Failed to import module 'sglang'"
|
|
262
|
-
installation_guide = [
|
|
263
|
-
"Please make sure 'sglang' is installed. ",
|
|
264
|
-
"You can install it by `pip install sglang[all]`\n",
|
|
265
|
-
]
|
|
266
|
-
|
|
267
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
268
|
-
|
|
269
|
-
@sgl.function
|
|
270
|
-
def pipeline(s, question):
|
|
271
|
-
s += user(question)
|
|
272
|
-
s += assistant(gen("answer"))
|
|
273
|
-
|
|
274
317
|
sanitized_generate_config = self._sanitize_generate_config(generate_config)
|
|
275
318
|
logger.debug(
|
|
276
319
|
"Enter generate, prompt: %s, generate config: %s", prompt, generate_config
|
|
@@ -285,25 +328,20 @@ class SGLANGModel(LLM):
|
|
|
285
328
|
)
|
|
286
329
|
|
|
287
330
|
request_id = str(uuid.uuid1())
|
|
288
|
-
state = pipeline.run(
|
|
289
|
-
question=prompt,
|
|
290
|
-
backend=self._engine,
|
|
291
|
-
stream=stream,
|
|
292
|
-
**sanitized_generate_config,
|
|
293
|
-
)
|
|
294
331
|
if not stream:
|
|
332
|
+
state = await self._non_stream_generate(prompt, **sanitized_generate_config)
|
|
295
333
|
return self._convert_state_to_completion(
|
|
296
334
|
request_id,
|
|
297
335
|
model=self.model_uid,
|
|
298
|
-
output_text=state["
|
|
299
|
-
meta_info=state
|
|
336
|
+
output_text=state["text"],
|
|
337
|
+
meta_info=state["meta_info"],
|
|
300
338
|
)
|
|
301
339
|
else:
|
|
302
340
|
|
|
303
341
|
async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
|
|
304
342
|
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
305
|
-
async for
|
|
306
|
-
|
|
343
|
+
async for meta_info, out in self._stream_generate(
|
|
344
|
+
prompt, **sanitized_generate_config
|
|
307
345
|
):
|
|
308
346
|
chunk = self._convert_state_to_completion_chunk(
|
|
309
347
|
request_id, self.model_uid, output_text=out
|
|
@@ -12,9 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import asyncio
|
|
15
16
|
import json
|
|
16
17
|
import logging
|
|
17
18
|
import multiprocessing
|
|
19
|
+
import os
|
|
18
20
|
import time
|
|
19
21
|
import uuid
|
|
20
22
|
from typing import (
|
|
@@ -240,6 +242,42 @@ class VLLMModel(LLM):
|
|
|
240
242
|
)
|
|
241
243
|
self._engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
242
244
|
|
|
245
|
+
self._check_health_task = None
|
|
246
|
+
if hasattr(self._engine, "check_health"):
|
|
247
|
+
# vLLM introduced `check_health` since v0.4.1
|
|
248
|
+
self._check_health_task = asyncio.create_task(self._check_healthy())
|
|
249
|
+
|
|
250
|
+
def stop(self):
|
|
251
|
+
# though the vLLM engine will shutdown when deleted,
|
|
252
|
+
# but some issue e.g. GH#1682 reported
|
|
253
|
+
# when deleting, the engine exists still
|
|
254
|
+
logger.info("Stopping vLLM engine")
|
|
255
|
+
if self._check_health_task:
|
|
256
|
+
self._check_health_task.cancel()
|
|
257
|
+
if model_executor := getattr(self._engine.engine, "model_executor", None):
|
|
258
|
+
model_executor.shutdown()
|
|
259
|
+
self._engine = None
|
|
260
|
+
|
|
261
|
+
async def _check_healthy(self, interval: int = 30):
|
|
262
|
+
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
|
263
|
+
|
|
264
|
+
logger.debug("Begin to check health of vLLM")
|
|
265
|
+
|
|
266
|
+
while self._engine is not None:
|
|
267
|
+
try:
|
|
268
|
+
await self._engine.check_health()
|
|
269
|
+
except (AsyncEngineDeadError, RuntimeError):
|
|
270
|
+
logger.info("Detecting vLLM is not health, prepare to quit the process")
|
|
271
|
+
try:
|
|
272
|
+
self.stop()
|
|
273
|
+
except:
|
|
274
|
+
# ignore error when stop
|
|
275
|
+
pass
|
|
276
|
+
# Just kill the process and let xinference auto-recover the model
|
|
277
|
+
os._exit(1)
|
|
278
|
+
else:
|
|
279
|
+
await asyncio.sleep(interval)
|
|
280
|
+
|
|
243
281
|
def _sanitize_model_config(
|
|
244
282
|
self, model_config: Optional[VLLMModelConfig]
|
|
245
283
|
) -> VLLMModelConfig:
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import codecs
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
from itertools import chain
|
|
19
|
+
|
|
20
|
+
from .core import (
|
|
21
|
+
BUILTIN_VIDEO_MODELS,
|
|
22
|
+
MODEL_NAME_TO_REVISION,
|
|
23
|
+
MODELSCOPE_VIDEO_MODELS,
|
|
24
|
+
VIDEO_MODEL_DESCRIPTIONS,
|
|
25
|
+
VideoModelFamilyV1,
|
|
26
|
+
generate_video_description,
|
|
27
|
+
get_cache_status,
|
|
28
|
+
get_video_model_descriptions,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
32
|
+
_model_spec_modelscope_json = os.path.join(
|
|
33
|
+
os.path.dirname(__file__), "model_spec_modelscope.json"
|
|
34
|
+
)
|
|
35
|
+
BUILTIN_VIDEO_MODELS.update(
|
|
36
|
+
dict(
|
|
37
|
+
(spec["model_name"], VideoModelFamilyV1(**spec))
|
|
38
|
+
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
for model_name, model_spec in BUILTIN_VIDEO_MODELS.items():
|
|
42
|
+
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
43
|
+
|
|
44
|
+
MODELSCOPE_VIDEO_MODELS.update(
|
|
45
|
+
dict(
|
|
46
|
+
(spec["model_name"], VideoModelFamilyV1(**spec))
|
|
47
|
+
for spec in json.load(
|
|
48
|
+
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
for model_name, model_spec in MODELSCOPE_VIDEO_MODELS.items():
|
|
53
|
+
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
54
|
+
|
|
55
|
+
# register model description
|
|
56
|
+
for model_name, model_spec in chain(
|
|
57
|
+
MODELSCOPE_VIDEO_MODELS.items(), BUILTIN_VIDEO_MODELS.items()
|
|
58
|
+
):
|
|
59
|
+
VIDEO_MODEL_DESCRIPTIONS.update(generate_video_description(model_spec))
|
|
60
|
+
|
|
61
|
+
del _model_spec_json
|
|
62
|
+
del _model_spec_modelscope_json
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
from collections import defaultdict
|
|
17
|
+
from typing import Dict, List, Literal, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
|
+
from ..core import CacheableModelSpec, ModelDescription
|
|
21
|
+
from ..utils import valid_model_revision
|
|
22
|
+
from .diffusers import DiffUsersVideoModel
|
|
23
|
+
|
|
24
|
+
MAX_ATTEMPTS = 3
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
29
|
+
VIDEO_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
|
|
30
|
+
BUILTIN_VIDEO_MODELS: Dict[str, "VideoModelFamilyV1"] = {}
|
|
31
|
+
MODELSCOPE_VIDEO_MODELS: Dict[str, "VideoModelFamilyV1"] = {}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_video_model_descriptions():
|
|
35
|
+
import copy
|
|
36
|
+
|
|
37
|
+
return copy.deepcopy(VIDEO_MODEL_DESCRIPTIONS)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class VideoModelFamilyV1(CacheableModelSpec):
|
|
41
|
+
model_family: str
|
|
42
|
+
model_name: str
|
|
43
|
+
model_id: str
|
|
44
|
+
model_revision: str
|
|
45
|
+
model_hub: str = "huggingface"
|
|
46
|
+
model_ability: Optional[List[str]]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class VideoModelDescription(ModelDescription):
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
address: Optional[str],
|
|
53
|
+
devices: Optional[List[str]],
|
|
54
|
+
model_spec: VideoModelFamilyV1,
|
|
55
|
+
model_path: Optional[str] = None,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(address, devices, model_path=model_path)
|
|
58
|
+
self._model_spec = model_spec
|
|
59
|
+
|
|
60
|
+
def to_dict(self):
|
|
61
|
+
return {
|
|
62
|
+
"model_type": "video",
|
|
63
|
+
"address": self.address,
|
|
64
|
+
"accelerators": self.devices,
|
|
65
|
+
"model_name": self._model_spec.model_name,
|
|
66
|
+
"model_family": self._model_spec.model_family,
|
|
67
|
+
"model_revision": self._model_spec.model_revision,
|
|
68
|
+
"model_ability": self._model_spec.model_ability,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def to_version_info(self):
|
|
72
|
+
if self._model_path is None:
|
|
73
|
+
is_cached = get_cache_status(self._model_spec)
|
|
74
|
+
file_location = get_cache_dir(self._model_spec)
|
|
75
|
+
else:
|
|
76
|
+
is_cached = True
|
|
77
|
+
file_location = self._model_path
|
|
78
|
+
|
|
79
|
+
return [
|
|
80
|
+
{
|
|
81
|
+
"model_version": self._model_spec.model_name,
|
|
82
|
+
"model_file_location": file_location,
|
|
83
|
+
"cache_status": is_cached,
|
|
84
|
+
}
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def generate_video_description(
|
|
89
|
+
video_model: VideoModelFamilyV1,
|
|
90
|
+
) -> Dict[str, List[Dict]]:
|
|
91
|
+
res = defaultdict(list)
|
|
92
|
+
res[video_model.model_name].extend(
|
|
93
|
+
VideoModelDescription(None, None, video_model).to_version_info()
|
|
94
|
+
)
|
|
95
|
+
return res
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def match_diffusion(
|
|
99
|
+
model_name: str,
|
|
100
|
+
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
101
|
+
) -> VideoModelFamilyV1:
|
|
102
|
+
from ..utils import download_from_modelscope
|
|
103
|
+
from . import BUILTIN_VIDEO_MODELS, MODELSCOPE_VIDEO_MODELS
|
|
104
|
+
|
|
105
|
+
if download_hub == "modelscope" and model_name in MODELSCOPE_VIDEO_MODELS:
|
|
106
|
+
logger.debug(f"Video model {model_name} found in ModelScope.")
|
|
107
|
+
return MODELSCOPE_VIDEO_MODELS[model_name]
|
|
108
|
+
elif download_hub == "huggingface" and model_name in BUILTIN_VIDEO_MODELS:
|
|
109
|
+
logger.debug(f"Video model {model_name} found in Huggingface.")
|
|
110
|
+
return BUILTIN_VIDEO_MODELS[model_name]
|
|
111
|
+
elif download_from_modelscope() and model_name in MODELSCOPE_VIDEO_MODELS:
|
|
112
|
+
logger.debug(f"Video model {model_name} found in ModelScope.")
|
|
113
|
+
return MODELSCOPE_VIDEO_MODELS[model_name]
|
|
114
|
+
elif model_name in BUILTIN_VIDEO_MODELS:
|
|
115
|
+
logger.debug(f"Video model {model_name} found in Huggingface.")
|
|
116
|
+
return BUILTIN_VIDEO_MODELS[model_name]
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Video model {model_name} not found, available"
|
|
120
|
+
f"model list: {BUILTIN_VIDEO_MODELS.keys()}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def cache(model_spec: VideoModelFamilyV1):
|
|
125
|
+
from ..utils import cache
|
|
126
|
+
|
|
127
|
+
return cache(model_spec, VideoModelDescription)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_cache_dir(model_spec: VideoModelFamilyV1):
|
|
131
|
+
return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_cache_status(
|
|
135
|
+
model_spec: VideoModelFamilyV1,
|
|
136
|
+
) -> bool:
|
|
137
|
+
cache_dir = get_cache_dir(model_spec)
|
|
138
|
+
meta_path = os.path.join(cache_dir, "__valid_download")
|
|
139
|
+
|
|
140
|
+
model_name = model_spec.model_name
|
|
141
|
+
if model_name in BUILTIN_VIDEO_MODELS and model_name in MODELSCOPE_VIDEO_MODELS:
|
|
142
|
+
hf_spec = BUILTIN_VIDEO_MODELS[model_name]
|
|
143
|
+
ms_spec = MODELSCOPE_VIDEO_MODELS[model_name]
|
|
144
|
+
|
|
145
|
+
return any(
|
|
146
|
+
[
|
|
147
|
+
valid_model_revision(meta_path, hf_spec.model_revision),
|
|
148
|
+
valid_model_revision(meta_path, ms_spec.model_revision),
|
|
149
|
+
]
|
|
150
|
+
)
|
|
151
|
+
else: # Usually for UT
|
|
152
|
+
return valid_model_revision(meta_path, model_spec.model_revision)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def create_video_model_instance(
|
|
156
|
+
subpool_addr: str,
|
|
157
|
+
devices: List[str],
|
|
158
|
+
model_uid: str,
|
|
159
|
+
model_name: str,
|
|
160
|
+
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
161
|
+
model_path: Optional[str] = None,
|
|
162
|
+
**kwargs,
|
|
163
|
+
) -> Tuple[DiffUsersVideoModel, VideoModelDescription]:
|
|
164
|
+
model_spec = match_diffusion(model_name, download_hub)
|
|
165
|
+
if not model_path:
|
|
166
|
+
model_path = cache(model_spec)
|
|
167
|
+
assert model_path is not None
|
|
168
|
+
|
|
169
|
+
model = DiffUsersVideoModel(
|
|
170
|
+
model_uid,
|
|
171
|
+
model_path,
|
|
172
|
+
model_spec,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
model_description = VideoModelDescription(
|
|
176
|
+
subpool_addr, devices, model_spec, model_path=model_path
|
|
177
|
+
)
|
|
178
|
+
return model, model_description
|