xinference 0.11.2.post1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +83 -8
- xinference/client/restful/restful_client.py +70 -0
- xinference/constants.py +8 -0
- xinference/core/__init__.py +0 -2
- xinference/core/cache_tracker.py +22 -1
- xinference/core/chat_interface.py +71 -10
- xinference/core/model.py +141 -12
- xinference/core/scheduler.py +428 -0
- xinference/core/supervisor.py +31 -3
- xinference/core/worker.py +8 -3
- xinference/isolation.py +9 -2
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +10 -3
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +1063 -260
- xinference/model/llm/llm_family_modelscope.json +686 -13
- xinference/model/llm/pytorch/baichuan.py +2 -1
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/cogvlm2.py +316 -0
- xinference/model/llm/pytorch/core.py +92 -6
- xinference/model/llm/pytorch/glm4v.py +258 -0
- xinference/model/llm/pytorch/intern_vl.py +5 -10
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/utils.py +386 -2
- xinference/model/llm/vllm/core.py +7 -1
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/types.py +3 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/METADATA +28 -11
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/RECORD +36 -29
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-06-07T15:04:33+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "55c5636f2b6022842d1827eae373c8e5f162a1a3",
|
|
15
|
+
"version": "0.12.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -52,7 +52,7 @@ from xoscar.utils import get_next_port
|
|
|
52
52
|
|
|
53
53
|
from .._compat import BaseModel, Field
|
|
54
54
|
from .._version import get_versions
|
|
55
|
-
from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT
|
|
55
|
+
from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS
|
|
56
56
|
from ..core.event import Event, EventCollectorActor, EventType
|
|
57
57
|
from ..core.supervisor import SupervisorActor
|
|
58
58
|
from ..core.utils import json_dumps
|
|
@@ -122,6 +122,14 @@ class TextToImageRequest(BaseModel):
|
|
|
122
122
|
user: Optional[str] = None
|
|
123
123
|
|
|
124
124
|
|
|
125
|
+
class SpeechRequest(BaseModel):
|
|
126
|
+
model: str
|
|
127
|
+
input: str
|
|
128
|
+
voice: Optional[str]
|
|
129
|
+
response_format: Optional[str] = "mp3"
|
|
130
|
+
speed: Optional[float] = 1.0
|
|
131
|
+
|
|
132
|
+
|
|
125
133
|
class RegisterModelRequest(BaseModel):
|
|
126
134
|
model: str
|
|
127
135
|
persist: bool
|
|
@@ -337,6 +345,16 @@ class RESTfulAPI:
|
|
|
337
345
|
else None
|
|
338
346
|
),
|
|
339
347
|
)
|
|
348
|
+
self._router.add_api_route(
|
|
349
|
+
"/v1/models/{model_uid}/requests/{request_id}/abort",
|
|
350
|
+
self.abort_request,
|
|
351
|
+
methods=["POST"],
|
|
352
|
+
dependencies=(
|
|
353
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
354
|
+
if self.is_authenticated()
|
|
355
|
+
else None
|
|
356
|
+
),
|
|
357
|
+
)
|
|
340
358
|
self._router.add_api_route(
|
|
341
359
|
"/v1/models/instance",
|
|
342
360
|
self.launch_model_by_version,
|
|
@@ -418,6 +436,16 @@ class RESTfulAPI:
|
|
|
418
436
|
else None
|
|
419
437
|
),
|
|
420
438
|
)
|
|
439
|
+
self._router.add_api_route(
|
|
440
|
+
"/v1/audio/speech",
|
|
441
|
+
self.create_speech,
|
|
442
|
+
methods=["POST"],
|
|
443
|
+
dependencies=(
|
|
444
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
445
|
+
if self.is_authenticated()
|
|
446
|
+
else None
|
|
447
|
+
),
|
|
448
|
+
)
|
|
421
449
|
self._router.add_api_route(
|
|
422
450
|
"/v1/images/generations",
|
|
423
451
|
self.create_images,
|
|
@@ -504,13 +532,19 @@ class RESTfulAPI:
|
|
|
504
532
|
),
|
|
505
533
|
)
|
|
506
534
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
535
|
+
if XINFERENCE_DISABLE_METRICS:
|
|
536
|
+
logger.info(
|
|
537
|
+
"Supervisor metrics is disabled due to the environment XINFERENCE_DISABLE_METRICS=1"
|
|
538
|
+
)
|
|
539
|
+
self._app.include_router(self._router)
|
|
540
|
+
else:
|
|
541
|
+
# Clear the global Registry for the MetricsMiddleware, or
|
|
542
|
+
# the MetricsMiddleware will register duplicated metrics if the port
|
|
543
|
+
# conflict (This serve method run more than once).
|
|
544
|
+
REGISTRY.clear()
|
|
545
|
+
self._app.add_middleware(MetricsMiddleware)
|
|
546
|
+
self._app.include_router(self._router)
|
|
547
|
+
self._app.add_route("/metrics", metrics)
|
|
514
548
|
|
|
515
549
|
# Check all the routes returns Response.
|
|
516
550
|
# This is to avoid `jsonable_encoder` performance issue:
|
|
@@ -1173,6 +1207,38 @@ class RESTfulAPI:
|
|
|
1173
1207
|
await self._report_error_event(model_uid, str(e))
|
|
1174
1208
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1175
1209
|
|
|
1210
|
+
async def create_speech(self, request: Request) -> Response:
|
|
1211
|
+
body = SpeechRequest.parse_obj(await request.json())
|
|
1212
|
+
model_uid = body.model
|
|
1213
|
+
try:
|
|
1214
|
+
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1215
|
+
except ValueError as ve:
|
|
1216
|
+
logger.error(str(ve), exc_info=True)
|
|
1217
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1218
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1219
|
+
except Exception as e:
|
|
1220
|
+
logger.error(e, exc_info=True)
|
|
1221
|
+
await self._report_error_event(model_uid, str(e))
|
|
1222
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1223
|
+
|
|
1224
|
+
try:
|
|
1225
|
+
out = await model.speech(
|
|
1226
|
+
input=body.input,
|
|
1227
|
+
voice=body.voice,
|
|
1228
|
+
response_format=body.response_format,
|
|
1229
|
+
speed=body.speed,
|
|
1230
|
+
)
|
|
1231
|
+
return Response(media_type="application/octet-stream", content=out)
|
|
1232
|
+
except RuntimeError as re:
|
|
1233
|
+
logger.error(re, exc_info=True)
|
|
1234
|
+
await self._report_error_event(model_uid, str(re))
|
|
1235
|
+
self.handle_request_limit_error(re)
|
|
1236
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1237
|
+
except Exception as e:
|
|
1238
|
+
logger.error(e, exc_info=True)
|
|
1239
|
+
await self._report_error_event(model_uid, str(e))
|
|
1240
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1241
|
+
|
|
1176
1242
|
async def create_images(self, request: Request) -> Response:
|
|
1177
1243
|
body = TextToImageRequest.parse_obj(await request.json())
|
|
1178
1244
|
model_uid = body.model
|
|
@@ -1512,6 +1578,15 @@ class RESTfulAPI:
|
|
|
1512
1578
|
logger.error(e, exc_info=True)
|
|
1513
1579
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1514
1580
|
|
|
1581
|
+
async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
|
|
1582
|
+
try:
|
|
1583
|
+
supervisor_ref = await self._get_supervisor_ref()
|
|
1584
|
+
res = await supervisor_ref.abort_request(model_uid, request_id)
|
|
1585
|
+
return JSONResponse(content=res)
|
|
1586
|
+
except Exception as e:
|
|
1587
|
+
logger.error(e, exc_info=True)
|
|
1588
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1589
|
+
|
|
1515
1590
|
async def list_vllm_supported_model_families(self) -> JSONResponse:
|
|
1516
1591
|
try:
|
|
1517
1592
|
from ..model.llm.vllm.core import (
|
|
@@ -684,6 +684,49 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
684
684
|
response_data = response.json()
|
|
685
685
|
return response_data
|
|
686
686
|
|
|
687
|
+
def speech(
|
|
688
|
+
self,
|
|
689
|
+
input: str,
|
|
690
|
+
voice: str = "",
|
|
691
|
+
response_format: str = "mp3",
|
|
692
|
+
speed: float = 1.0,
|
|
693
|
+
):
|
|
694
|
+
"""
|
|
695
|
+
Generates audio from the input text.
|
|
696
|
+
|
|
697
|
+
Parameters
|
|
698
|
+
----------
|
|
699
|
+
|
|
700
|
+
input: str
|
|
701
|
+
The text to generate audio for. The maximum length is 4096 characters.
|
|
702
|
+
voice: str
|
|
703
|
+
The voice to use when generating the audio.
|
|
704
|
+
response_format: str
|
|
705
|
+
The format to audio in.
|
|
706
|
+
speed: str
|
|
707
|
+
The speed of the generated audio.
|
|
708
|
+
|
|
709
|
+
Returns
|
|
710
|
+
-------
|
|
711
|
+
bytes
|
|
712
|
+
The generated audio binary.
|
|
713
|
+
"""
|
|
714
|
+
url = f"{self._base_url}/v1/audio/speech"
|
|
715
|
+
params = {
|
|
716
|
+
"model": self._model_uid,
|
|
717
|
+
"input": input,
|
|
718
|
+
"voice": voice,
|
|
719
|
+
"response_format": response_format,
|
|
720
|
+
"speed": speed,
|
|
721
|
+
}
|
|
722
|
+
response = requests.post(url, json=params, headers=self.auth_headers)
|
|
723
|
+
if response.status_code != 200:
|
|
724
|
+
raise RuntimeError(
|
|
725
|
+
f"Failed to speech the text, detail: {_get_error_string(response)}"
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
return response.content
|
|
729
|
+
|
|
687
730
|
|
|
688
731
|
class Client:
|
|
689
732
|
def __init__(self, base_url, api_key: Optional[str] = None):
|
|
@@ -1181,3 +1224,30 @@ class Client:
|
|
|
1181
1224
|
|
|
1182
1225
|
response_data = response.json()
|
|
1183
1226
|
return response_data
|
|
1227
|
+
|
|
1228
|
+
def abort_request(self, model_uid: str, request_id: str):
|
|
1229
|
+
"""
|
|
1230
|
+
Abort a request.
|
|
1231
|
+
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
1232
|
+
Currently, this interface is only supported when batching is enabled for models on transformers backend.
|
|
1233
|
+
|
|
1234
|
+
Parameters
|
|
1235
|
+
----------
|
|
1236
|
+
model_uid: str
|
|
1237
|
+
Model uid.
|
|
1238
|
+
request_id: str
|
|
1239
|
+
Request id.
|
|
1240
|
+
Returns
|
|
1241
|
+
-------
|
|
1242
|
+
Dict
|
|
1243
|
+
Return empty dict.
|
|
1244
|
+
"""
|
|
1245
|
+
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
|
|
1246
|
+
response = requests.post(url, headers=self._headers)
|
|
1247
|
+
if response.status_code != 200:
|
|
1248
|
+
raise RuntimeError(
|
|
1249
|
+
f"Failed to abort request, detail: {_get_error_string(response)}"
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
response_data = response.json()
|
|
1253
|
+
return response_data
|
xinference/constants.py
CHANGED
|
@@ -26,6 +26,8 @@ XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
|
|
|
26
26
|
XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
|
|
27
27
|
XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
|
|
28
28
|
XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
|
|
29
|
+
XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
|
|
30
|
+
XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
def get_xinference_home() -> str:
|
|
@@ -66,3 +68,9 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
|
|
|
66
68
|
)
|
|
67
69
|
XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
|
|
68
70
|
XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG, 0)))
|
|
71
|
+
XINFERENCE_DISABLE_METRICS = bool(
|
|
72
|
+
int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
|
|
73
|
+
)
|
|
74
|
+
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
|
|
75
|
+
int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
|
|
76
|
+
)
|
xinference/core/__init__.py
CHANGED
xinference/core/cache_tracker.py
CHANGED
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import os
|
|
14
15
|
from logging import getLogger
|
|
15
16
|
from typing import Any, Dict, List, Optional
|
|
16
17
|
|
|
@@ -105,9 +106,29 @@ class CacheTrackerActor(xo.Actor):
|
|
|
105
106
|
cached_models = []
|
|
106
107
|
for model_name, model_versions in self._model_name_to_version_info.items():
|
|
107
108
|
for version_info in model_versions:
|
|
108
|
-
|
|
109
|
+
cache_status = version_info.get("cache_status", None)
|
|
110
|
+
if cache_status == True:
|
|
109
111
|
ret = version_info.copy()
|
|
110
112
|
ret["model_name"] = model_name
|
|
113
|
+
|
|
114
|
+
re_dict = version_info.get("model_file_location", None)
|
|
115
|
+
if re_dict is not None and isinstance(re_dict, dict):
|
|
116
|
+
if re_dict:
|
|
117
|
+
actor_ip_address, path = next(iter(re_dict.items()))
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError("The dictionary is empty.")
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError("re_dict must be a non-empty dictionary.")
|
|
122
|
+
|
|
123
|
+
ret["actor_ip_address"] = actor_ip_address
|
|
124
|
+
ret["path"] = path
|
|
125
|
+
if os.path.isdir(path):
|
|
126
|
+
files = os.listdir(path)
|
|
127
|
+
resolved_file = os.path.realpath(os.path.join(path, files[0]))
|
|
128
|
+
if resolved_file:
|
|
129
|
+
ret["real_path"] = os.path.dirname(resolved_file)
|
|
130
|
+
else:
|
|
131
|
+
ret["real_path"] = os.path.realpath(path)
|
|
111
132
|
cached_models.append(ret)
|
|
112
133
|
cached_models = sorted(cached_models, key=lambda x: x["model_name"])
|
|
113
134
|
return cached_models
|
|
@@ -186,8 +186,7 @@ class GradioInterface:
|
|
|
186
186
|
def build_chat_vl_interface(
|
|
187
187
|
self,
|
|
188
188
|
) -> "gr.Blocks":
|
|
189
|
-
def predict(history, bot):
|
|
190
|
-
logger.debug("Predict model: %s, history: %s", self.model_uid, history)
|
|
189
|
+
def predict(history, bot, max_tokens, temperature, stream):
|
|
191
190
|
from ..client import RESTfulClient
|
|
192
191
|
|
|
193
192
|
client = RESTfulClient(self.endpoint)
|
|
@@ -199,10 +198,46 @@ class GradioInterface:
|
|
|
199
198
|
assert prompt["role"] == "user"
|
|
200
199
|
prompt = prompt["content"]
|
|
201
200
|
# multimodal chat does not support stream.
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
201
|
+
if stream:
|
|
202
|
+
response_content = ""
|
|
203
|
+
for chunk in model.chat(
|
|
204
|
+
prompt=prompt,
|
|
205
|
+
chat_history=history[:-1],
|
|
206
|
+
generate_config={
|
|
207
|
+
"max_tokens": max_tokens,
|
|
208
|
+
"temperature": temperature,
|
|
209
|
+
"stream": stream,
|
|
210
|
+
},
|
|
211
|
+
):
|
|
212
|
+
assert isinstance(chunk, dict)
|
|
213
|
+
delta = chunk["choices"][0]["delta"]
|
|
214
|
+
if "content" not in delta:
|
|
215
|
+
continue
|
|
216
|
+
else:
|
|
217
|
+
response_content += delta["content"]
|
|
218
|
+
bot[-1][1] = response_content
|
|
219
|
+
yield history, bot
|
|
220
|
+
history.append(
|
|
221
|
+
{
|
|
222
|
+
"content": response_content,
|
|
223
|
+
"role": "assistant",
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
bot[-1][1] = response_content
|
|
227
|
+
yield history, bot
|
|
228
|
+
else:
|
|
229
|
+
response = model.chat(
|
|
230
|
+
prompt=prompt,
|
|
231
|
+
chat_history=history[:-1],
|
|
232
|
+
generate_config={
|
|
233
|
+
"max_tokens": max_tokens,
|
|
234
|
+
"temperature": temperature,
|
|
235
|
+
"stream": stream,
|
|
236
|
+
},
|
|
237
|
+
)
|
|
238
|
+
history.append(response["choices"][0]["message"])
|
|
239
|
+
bot[-1][1] = history[-1]["content"]
|
|
240
|
+
yield history, bot
|
|
206
241
|
|
|
207
242
|
def add_text(history, bot, text, image):
|
|
208
243
|
logger.debug("Add text, text: %s, image: %s", text, image)
|
|
@@ -217,14 +252,19 @@ class GradioInterface:
|
|
|
217
252
|
"role": "user",
|
|
218
253
|
"content": [
|
|
219
254
|
{"type": "text", "text": text},
|
|
220
|
-
{
|
|
255
|
+
{
|
|
256
|
+
"type": "image_url",
|
|
257
|
+
"image_url": {
|
|
258
|
+
"url": f"data:image/png;base64,{img_b64_str}"
|
|
259
|
+
},
|
|
260
|
+
},
|
|
221
261
|
],
|
|
222
262
|
}
|
|
223
263
|
else:
|
|
224
264
|
display_content = text
|
|
225
265
|
message = {"role": "user", "content": text}
|
|
226
266
|
history = history + [message]
|
|
227
|
-
bot = bot + [
|
|
267
|
+
bot = bot + [[display_content, None]]
|
|
228
268
|
return history, bot, "", None
|
|
229
269
|
|
|
230
270
|
def clear_history():
|
|
@@ -286,6 +326,19 @@ class GradioInterface:
|
|
|
286
326
|
)
|
|
287
327
|
clear_btn = gr.Button(value="Clear")
|
|
288
328
|
|
|
329
|
+
with gr.Accordion("Additional Inputs", open=False):
|
|
330
|
+
max_tokens = gr.Slider(
|
|
331
|
+
minimum=1,
|
|
332
|
+
maximum=self.context_length,
|
|
333
|
+
value=512,
|
|
334
|
+
step=1,
|
|
335
|
+
label="Max Tokens",
|
|
336
|
+
)
|
|
337
|
+
temperature = gr.Slider(
|
|
338
|
+
minimum=0, maximum=2, value=1, step=0.01, label="Temperature"
|
|
339
|
+
)
|
|
340
|
+
stream = gr.Checkbox(label="Stream", value=False)
|
|
341
|
+
|
|
289
342
|
textbox.change(update_button, [textbox], [submit_btn], queue=False)
|
|
290
343
|
|
|
291
344
|
textbox.submit(
|
|
@@ -293,14 +346,22 @@ class GradioInterface:
|
|
|
293
346
|
[state, chatbot, textbox, imagebox],
|
|
294
347
|
[state, chatbot, textbox, imagebox],
|
|
295
348
|
queue=False,
|
|
296
|
-
).then(
|
|
349
|
+
).then(
|
|
350
|
+
predict,
|
|
351
|
+
[state, chatbot, max_tokens, temperature, stream],
|
|
352
|
+
[state, chatbot],
|
|
353
|
+
)
|
|
297
354
|
|
|
298
355
|
submit_btn.click(
|
|
299
356
|
add_text,
|
|
300
357
|
[state, chatbot, textbox, imagebox],
|
|
301
358
|
[state, chatbot, textbox, imagebox],
|
|
302
359
|
queue=False,
|
|
303
|
-
).then(
|
|
360
|
+
).then(
|
|
361
|
+
predict,
|
|
362
|
+
[state, chatbot, max_tokens, temperature, stream],
|
|
363
|
+
[state, chatbot],
|
|
364
|
+
)
|
|
304
365
|
|
|
305
366
|
clear_btn.click(
|
|
306
367
|
clear_history, None, [state, chatbot, textbox, imagebox], queue=False
|
xinference/core/model.py
CHANGED
|
@@ -20,9 +20,14 @@ import os
|
|
|
20
20
|
import time
|
|
21
21
|
import types
|
|
22
22
|
import weakref
|
|
23
|
+
from asyncio.queues import Queue
|
|
24
|
+
from asyncio.tasks import wait_for
|
|
25
|
+
from concurrent.futures import Future as ConcurrentFuture
|
|
23
26
|
from typing import (
|
|
24
27
|
TYPE_CHECKING,
|
|
28
|
+
Any,
|
|
25
29
|
AsyncGenerator,
|
|
30
|
+
AsyncIterator,
|
|
26
31
|
Callable,
|
|
27
32
|
Dict,
|
|
28
33
|
Generator,
|
|
@@ -35,6 +40,8 @@ from typing import (
|
|
|
35
40
|
import sse_starlette.sse
|
|
36
41
|
import xoscar as xo
|
|
37
42
|
|
|
43
|
+
from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
44
|
+
|
|
38
45
|
if TYPE_CHECKING:
|
|
39
46
|
from .worker import WorkerActor
|
|
40
47
|
from ..model.llm.core import LLM
|
|
@@ -125,6 +132,16 @@ class ModelActor(xo.StatelessActor):
|
|
|
125
132
|
from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
|
|
126
133
|
from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
|
|
127
134
|
|
|
135
|
+
if self.allow_batching():
|
|
136
|
+
try:
|
|
137
|
+
assert self._scheduler_ref is not None
|
|
138
|
+
await xo.destroy_actor(self._scheduler_ref)
|
|
139
|
+
del self._scheduler_ref
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
143
|
+
)
|
|
144
|
+
|
|
128
145
|
if (
|
|
129
146
|
isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
|
|
130
147
|
and self._model.model_spec.model_format == "pytorch"
|
|
@@ -181,9 +198,20 @@ class ModelActor(xo.StatelessActor):
|
|
|
181
198
|
}
|
|
182
199
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
183
200
|
|
|
201
|
+
self._scheduler_ref = None
|
|
202
|
+
|
|
184
203
|
async def __post_create__(self):
|
|
185
204
|
self._loop = asyncio.get_running_loop()
|
|
186
205
|
|
|
206
|
+
if self.allow_batching():
|
|
207
|
+
from .scheduler import SchedulerActor
|
|
208
|
+
|
|
209
|
+
self._scheduler_ref = await xo.create_actor(
|
|
210
|
+
SchedulerActor,
|
|
211
|
+
address=self.address,
|
|
212
|
+
uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
|
|
213
|
+
)
|
|
214
|
+
|
|
187
215
|
async def _record_completion_metrics(
|
|
188
216
|
self, duration, completion_tokens, prompt_tokens
|
|
189
217
|
):
|
|
@@ -235,8 +263,22 @@ class ModelActor(xo.StatelessActor):
|
|
|
235
263
|
|
|
236
264
|
return isinstance(self._model, VLLMModel)
|
|
237
265
|
|
|
238
|
-
def
|
|
266
|
+
def allow_batching(self) -> bool:
|
|
267
|
+
from ..model.llm.pytorch.core import PytorchChatModel
|
|
268
|
+
|
|
269
|
+
return (
|
|
270
|
+
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
271
|
+
and isinstance(self._model, PytorchChatModel)
|
|
272
|
+
and self._model.__class__.__name__ == PytorchChatModel.__name__
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
async def load(self):
|
|
239
276
|
self._model.load()
|
|
277
|
+
if self.allow_batching():
|
|
278
|
+
await self._scheduler_ref.set_model(self._model)
|
|
279
|
+
logger.debug(
|
|
280
|
+
f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
|
|
281
|
+
)
|
|
240
282
|
|
|
241
283
|
def model_uid(self):
|
|
242
284
|
return (
|
|
@@ -343,6 +385,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
343
385
|
gen = self._to_json_async_gen(ret)
|
|
344
386
|
self._current_generator = weakref.ref(gen)
|
|
345
387
|
return gen
|
|
388
|
+
if isinstance(ret, bytes):
|
|
389
|
+
return ret
|
|
346
390
|
return await asyncio.to_thread(json_dumps, ret)
|
|
347
391
|
|
|
348
392
|
@log_async(logger=logger)
|
|
@@ -359,6 +403,36 @@ class ModelActor(xo.StatelessActor):
|
|
|
359
403
|
)
|
|
360
404
|
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
|
|
361
405
|
|
|
406
|
+
async def _queue_consumer(
|
|
407
|
+
self, queue: Queue, timeout: Optional[float] = None
|
|
408
|
+
) -> AsyncIterator[Any]:
|
|
409
|
+
from .scheduler import (
|
|
410
|
+
XINFERENCE_STREAMING_ABORT_FLAG,
|
|
411
|
+
XINFERENCE_STREAMING_DONE_FLAG,
|
|
412
|
+
XINFERENCE_STREAMING_ERROR_FLAG,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
while True:
|
|
416
|
+
# TODO: timeout setting
|
|
417
|
+
res = await wait_for(queue.get(), timeout)
|
|
418
|
+
if res == XINFERENCE_STREAMING_DONE_FLAG:
|
|
419
|
+
break
|
|
420
|
+
elif res == XINFERENCE_STREAMING_ABORT_FLAG:
|
|
421
|
+
raise RuntimeError(
|
|
422
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
423
|
+
)
|
|
424
|
+
elif isinstance(res, str) and res.startswith(
|
|
425
|
+
XINFERENCE_STREAMING_ERROR_FLAG
|
|
426
|
+
):
|
|
427
|
+
raise RuntimeError(res[len(XINFERENCE_STREAMING_ERROR_FLAG) :])
|
|
428
|
+
else:
|
|
429
|
+
yield res
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def get_stream_from_args(*args) -> bool:
|
|
433
|
+
assert args[2] is None or isinstance(args[2], dict)
|
|
434
|
+
return False if args[2] is None else args[2].get("stream", False)
|
|
435
|
+
|
|
362
436
|
@log_async(logger=logger)
|
|
363
437
|
@request_limit
|
|
364
438
|
@xo.generator
|
|
@@ -366,17 +440,46 @@ class ModelActor(xo.StatelessActor):
|
|
|
366
440
|
start_time = time.time()
|
|
367
441
|
response = None
|
|
368
442
|
try:
|
|
369
|
-
if
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
self.
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
443
|
+
if self.allow_batching():
|
|
444
|
+
stream = self.get_stream_from_args(*args)
|
|
445
|
+
assert self._scheduler_ref is not None
|
|
446
|
+
if stream:
|
|
447
|
+
assert self._scheduler_ref is not None
|
|
448
|
+
queue: Queue[Any] = Queue()
|
|
449
|
+
ret = self._queue_consumer(queue)
|
|
450
|
+
await self._scheduler_ref.add_request(
|
|
451
|
+
prompt, queue, *args, **kwargs
|
|
452
|
+
)
|
|
453
|
+
gen = self._to_json_async_gen(ret)
|
|
454
|
+
self._current_generator = weakref.ref(gen)
|
|
455
|
+
return gen
|
|
456
|
+
else:
|
|
457
|
+
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
458
|
+
|
|
459
|
+
assert self._loop is not None
|
|
460
|
+
future = ConcurrentFuture()
|
|
461
|
+
await self._scheduler_ref.add_request(
|
|
462
|
+
prompt, future, *args, **kwargs
|
|
463
|
+
)
|
|
464
|
+
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
465
|
+
result = await fut
|
|
466
|
+
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
467
|
+
raise RuntimeError(
|
|
468
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
469
|
+
)
|
|
470
|
+
return await asyncio.to_thread(json_dumps, result)
|
|
471
|
+
else:
|
|
472
|
+
if hasattr(self._model, "chat"):
|
|
473
|
+
response = await self._call_wrapper(
|
|
474
|
+
self._model.chat, prompt, *args, **kwargs
|
|
475
|
+
)
|
|
476
|
+
return response
|
|
477
|
+
if hasattr(self._model, "async_chat"):
|
|
478
|
+
response = await self._call_wrapper(
|
|
479
|
+
self._model.async_chat, prompt, *args, **kwargs
|
|
480
|
+
)
|
|
481
|
+
return response
|
|
482
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
|
|
380
483
|
finally:
|
|
381
484
|
# For the non stream result.
|
|
382
485
|
record = None
|
|
@@ -395,6 +498,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
395
498
|
prompt_tokens,
|
|
396
499
|
)
|
|
397
500
|
|
|
501
|
+
async def abort_request(self, request_id: str) -> str:
|
|
502
|
+
from .scheduler import AbortRequestMessage
|
|
503
|
+
|
|
504
|
+
if self.allow_batching():
|
|
505
|
+
if self._scheduler_ref is None:
|
|
506
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
507
|
+
return await self._scheduler_ref.abort_request(request_id)
|
|
508
|
+
return AbortRequestMessage.NO_OP.name
|
|
509
|
+
|
|
398
510
|
@log_async(logger=logger)
|
|
399
511
|
@request_limit
|
|
400
512
|
async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
|
|
@@ -482,6 +594,23 @@ class ModelActor(xo.StatelessActor):
|
|
|
482
594
|
f"Model {self._model.model_spec} is not for creating translations."
|
|
483
595
|
)
|
|
484
596
|
|
|
597
|
+
@log_async(logger=logger)
|
|
598
|
+
@request_limit
|
|
599
|
+
async def speech(
|
|
600
|
+
self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
|
|
601
|
+
):
|
|
602
|
+
if hasattr(self._model, "speech"):
|
|
603
|
+
return await self._call_wrapper(
|
|
604
|
+
self._model.speech,
|
|
605
|
+
input,
|
|
606
|
+
voice,
|
|
607
|
+
response_format,
|
|
608
|
+
speed,
|
|
609
|
+
)
|
|
610
|
+
raise AttributeError(
|
|
611
|
+
f"Model {self._model.model_spec} is not for creating speech."
|
|
612
|
+
)
|
|
613
|
+
|
|
485
614
|
@log_async(logger=logger)
|
|
486
615
|
@request_limit
|
|
487
616
|
async def text_to_image(
|