xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +143 -6
- xinference/client/restful/restful_client.py +144 -5
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +160 -19
- xinference/core/scheduler.py +446 -0
- xinference/core/supervisor.py +99 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/isolation.py +9 -2
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +22 -4
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +38 -2
- xinference/model/llm/llm_family.json +509 -1
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +411 -2
- xinference/model/llm/pytorch/chatglm.py +20 -13
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +141 -6
- xinference/model/llm/pytorch/glm4v.py +268 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +405 -8
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +16 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +3 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-06-14T17:17:50+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "34a57df449f0890415c424802d3596f3c8758412",
|
|
15
|
+
"version": "0.12.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -122,6 +122,14 @@ class TextToImageRequest(BaseModel):
|
|
|
122
122
|
user: Optional[str] = None
|
|
123
123
|
|
|
124
124
|
|
|
125
|
+
class SpeechRequest(BaseModel):
|
|
126
|
+
model: str
|
|
127
|
+
input: str
|
|
128
|
+
voice: Optional[str]
|
|
129
|
+
response_format: Optional[str] = "mp3"
|
|
130
|
+
speed: Optional[float] = 1.0
|
|
131
|
+
|
|
132
|
+
|
|
125
133
|
class RegisterModelRequest(BaseModel):
|
|
126
134
|
model: str
|
|
127
135
|
persist: bool
|
|
@@ -337,6 +345,16 @@ class RESTfulAPI:
|
|
|
337
345
|
else None
|
|
338
346
|
),
|
|
339
347
|
)
|
|
348
|
+
self._router.add_api_route(
|
|
349
|
+
"/v1/models/{model_uid}/requests/{request_id}/abort",
|
|
350
|
+
self.abort_request,
|
|
351
|
+
methods=["POST"],
|
|
352
|
+
dependencies=(
|
|
353
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
354
|
+
if self.is_authenticated()
|
|
355
|
+
else None
|
|
356
|
+
),
|
|
357
|
+
)
|
|
340
358
|
self._router.add_api_route(
|
|
341
359
|
"/v1/models/instance",
|
|
342
360
|
self.launch_model_by_version,
|
|
@@ -418,6 +436,16 @@ class RESTfulAPI:
|
|
|
418
436
|
else None
|
|
419
437
|
),
|
|
420
438
|
)
|
|
439
|
+
self._router.add_api_route(
|
|
440
|
+
"/v1/audio/speech",
|
|
441
|
+
self.create_speech,
|
|
442
|
+
methods=["POST"],
|
|
443
|
+
dependencies=(
|
|
444
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
445
|
+
if self.is_authenticated()
|
|
446
|
+
else None
|
|
447
|
+
),
|
|
448
|
+
)
|
|
421
449
|
self._router.add_api_route(
|
|
422
450
|
"/v1/images/generations",
|
|
423
451
|
self.create_images,
|
|
@@ -494,11 +522,31 @@ class RESTfulAPI:
|
|
|
494
522
|
),
|
|
495
523
|
)
|
|
496
524
|
self._router.add_api_route(
|
|
497
|
-
"/v1/
|
|
525
|
+
"/v1/cache/models",
|
|
498
526
|
self.list_cached_models,
|
|
499
527
|
methods=["GET"],
|
|
500
528
|
dependencies=(
|
|
501
|
-
[Security(self._auth_service, scopes=["
|
|
529
|
+
[Security(self._auth_service, scopes=["cache:list"])]
|
|
530
|
+
if self.is_authenticated()
|
|
531
|
+
else None
|
|
532
|
+
),
|
|
533
|
+
)
|
|
534
|
+
self._router.add_api_route(
|
|
535
|
+
"/v1/cache/models/files",
|
|
536
|
+
self.list_model_files,
|
|
537
|
+
methods=["GET"],
|
|
538
|
+
dependencies=(
|
|
539
|
+
[Security(self._auth_service, scopes=["cache:list"])]
|
|
540
|
+
if self.is_authenticated()
|
|
541
|
+
else None
|
|
542
|
+
),
|
|
543
|
+
)
|
|
544
|
+
self._router.add_api_route(
|
|
545
|
+
"/v1/cache/models",
|
|
546
|
+
self.confirm_and_remove_model,
|
|
547
|
+
methods=["DELETE"],
|
|
548
|
+
dependencies=(
|
|
549
|
+
[Security(self._auth_service, scopes=["cache:delete"])]
|
|
502
550
|
if self.is_authenticated()
|
|
503
551
|
else None
|
|
504
552
|
),
|
|
@@ -1179,6 +1227,38 @@ class RESTfulAPI:
|
|
|
1179
1227
|
await self._report_error_event(model_uid, str(e))
|
|
1180
1228
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1181
1229
|
|
|
1230
|
+
async def create_speech(self, request: Request) -> Response:
|
|
1231
|
+
body = SpeechRequest.parse_obj(await request.json())
|
|
1232
|
+
model_uid = body.model
|
|
1233
|
+
try:
|
|
1234
|
+
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1235
|
+
except ValueError as ve:
|
|
1236
|
+
logger.error(str(ve), exc_info=True)
|
|
1237
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1238
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1239
|
+
except Exception as e:
|
|
1240
|
+
logger.error(e, exc_info=True)
|
|
1241
|
+
await self._report_error_event(model_uid, str(e))
|
|
1242
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1243
|
+
|
|
1244
|
+
try:
|
|
1245
|
+
out = await model.speech(
|
|
1246
|
+
input=body.input,
|
|
1247
|
+
voice=body.voice,
|
|
1248
|
+
response_format=body.response_format,
|
|
1249
|
+
speed=body.speed,
|
|
1250
|
+
)
|
|
1251
|
+
return Response(media_type="application/octet-stream", content=out)
|
|
1252
|
+
except RuntimeError as re:
|
|
1253
|
+
logger.error(re, exc_info=True)
|
|
1254
|
+
await self._report_error_event(model_uid, str(re))
|
|
1255
|
+
self.handle_request_limit_error(re)
|
|
1256
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1257
|
+
except Exception as e:
|
|
1258
|
+
logger.error(e, exc_info=True)
|
|
1259
|
+
await self._report_error_event(model_uid, str(e))
|
|
1260
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1261
|
+
|
|
1182
1262
|
async def create_images(self, request: Request) -> Response:
|
|
1183
1263
|
body = TextToImageRequest.parse_obj(await request.json())
|
|
1184
1264
|
model_uid = body.model
|
|
@@ -1341,9 +1421,11 @@ class RESTfulAPI:
|
|
|
1341
1421
|
model_family = desc.get("model_family", "")
|
|
1342
1422
|
function_call_models = [
|
|
1343
1423
|
"chatglm3",
|
|
1424
|
+
"glm4-chat",
|
|
1344
1425
|
"gorilla-openfunctions-v1",
|
|
1345
1426
|
"qwen-chat",
|
|
1346
1427
|
"qwen1.5-chat",
|
|
1428
|
+
"qwen2-instruct",
|
|
1347
1429
|
]
|
|
1348
1430
|
|
|
1349
1431
|
is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
|
|
@@ -1366,7 +1448,11 @@ class RESTfulAPI:
|
|
|
1366
1448
|
)
|
|
1367
1449
|
if body.tools and body.stream:
|
|
1368
1450
|
is_vllm = await model.is_vllm_backend()
|
|
1369
|
-
if not is_vllm or model_family not in [
|
|
1451
|
+
if not is_vllm or model_family not in [
|
|
1452
|
+
"qwen-chat",
|
|
1453
|
+
"qwen1.5-chat",
|
|
1454
|
+
"qwen2-instruct",
|
|
1455
|
+
]:
|
|
1370
1456
|
raise HTTPException(
|
|
1371
1457
|
status_code=400,
|
|
1372
1458
|
detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
|
|
@@ -1495,10 +1581,17 @@ class RESTfulAPI:
|
|
|
1495
1581
|
logger.error(e, exc_info=True)
|
|
1496
1582
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1497
1583
|
|
|
1498
|
-
async def list_cached_models(
|
|
1584
|
+
async def list_cached_models(
|
|
1585
|
+
self, model_name: str = Query(None), worker_ip: str = Query(None)
|
|
1586
|
+
) -> JSONResponse:
|
|
1499
1587
|
try:
|
|
1500
|
-
data = await (await self._get_supervisor_ref()).list_cached_models(
|
|
1501
|
-
|
|
1588
|
+
data = await (await self._get_supervisor_ref()).list_cached_models(
|
|
1589
|
+
model_name, worker_ip
|
|
1590
|
+
)
|
|
1591
|
+
resp = {
|
|
1592
|
+
"list": data,
|
|
1593
|
+
}
|
|
1594
|
+
return JSONResponse(content=resp)
|
|
1502
1595
|
except ValueError as re:
|
|
1503
1596
|
logger.error(re, exc_info=True)
|
|
1504
1597
|
raise HTTPException(status_code=400, detail=str(re))
|
|
@@ -1518,6 +1611,15 @@ class RESTfulAPI:
|
|
|
1518
1611
|
logger.error(e, exc_info=True)
|
|
1519
1612
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1520
1613
|
|
|
1614
|
+
async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
|
|
1615
|
+
try:
|
|
1616
|
+
supervisor_ref = await self._get_supervisor_ref()
|
|
1617
|
+
res = await supervisor_ref.abort_request(model_uid, request_id)
|
|
1618
|
+
return JSONResponse(content=res)
|
|
1619
|
+
except Exception as e:
|
|
1620
|
+
logger.error(e, exc_info=True)
|
|
1621
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1622
|
+
|
|
1521
1623
|
async def list_vllm_supported_model_families(self) -> JSONResponse:
|
|
1522
1624
|
try:
|
|
1523
1625
|
from ..model.llm.vllm.core import (
|
|
@@ -1554,6 +1656,41 @@ class RESTfulAPI:
|
|
|
1554
1656
|
logger.error(e, exc_info=True)
|
|
1555
1657
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1556
1658
|
|
|
1659
|
+
async def list_model_files(
|
|
1660
|
+
self, model_version: str = Query(None), worker_ip: str = Query(None)
|
|
1661
|
+
) -> JSONResponse:
|
|
1662
|
+
try:
|
|
1663
|
+
data = await (await self._get_supervisor_ref()).list_deletable_models(
|
|
1664
|
+
model_version, worker_ip
|
|
1665
|
+
)
|
|
1666
|
+
response = {
|
|
1667
|
+
"model_version": model_version,
|
|
1668
|
+
"worker_ip": worker_ip,
|
|
1669
|
+
"paths": data,
|
|
1670
|
+
}
|
|
1671
|
+
return JSONResponse(content=response)
|
|
1672
|
+
except ValueError as re:
|
|
1673
|
+
logger.error(re, exc_info=True)
|
|
1674
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1675
|
+
except Exception as e:
|
|
1676
|
+
logger.error(e, exc_info=True)
|
|
1677
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1678
|
+
|
|
1679
|
+
async def confirm_and_remove_model(
|
|
1680
|
+
self, model_version: str = Query(None), worker_ip: str = Query(None)
|
|
1681
|
+
) -> JSONResponse:
|
|
1682
|
+
try:
|
|
1683
|
+
res = await (await self._get_supervisor_ref()).confirm_and_remove_model(
|
|
1684
|
+
model_version=model_version, worker_ip=worker_ip
|
|
1685
|
+
)
|
|
1686
|
+
return JSONResponse(content={"result": res})
|
|
1687
|
+
except ValueError as re:
|
|
1688
|
+
logger.error(re, exc_info=True)
|
|
1689
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1690
|
+
except Exception as e:
|
|
1691
|
+
logger.error(e, exc_info=True)
|
|
1692
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1693
|
+
|
|
1557
1694
|
|
|
1558
1695
|
def run(
|
|
1559
1696
|
supervisor_address: str,
|
|
@@ -684,6 +684,49 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
684
684
|
response_data = response.json()
|
|
685
685
|
return response_data
|
|
686
686
|
|
|
687
|
+
def speech(
|
|
688
|
+
self,
|
|
689
|
+
input: str,
|
|
690
|
+
voice: str = "",
|
|
691
|
+
response_format: str = "mp3",
|
|
692
|
+
speed: float = 1.0,
|
|
693
|
+
):
|
|
694
|
+
"""
|
|
695
|
+
Generates audio from the input text.
|
|
696
|
+
|
|
697
|
+
Parameters
|
|
698
|
+
----------
|
|
699
|
+
|
|
700
|
+
input: str
|
|
701
|
+
The text to generate audio for. The maximum length is 4096 characters.
|
|
702
|
+
voice: str
|
|
703
|
+
The voice to use when generating the audio.
|
|
704
|
+
response_format: str
|
|
705
|
+
The format to audio in.
|
|
706
|
+
speed: str
|
|
707
|
+
The speed of the generated audio.
|
|
708
|
+
|
|
709
|
+
Returns
|
|
710
|
+
-------
|
|
711
|
+
bytes
|
|
712
|
+
The generated audio binary.
|
|
713
|
+
"""
|
|
714
|
+
url = f"{self._base_url}/v1/audio/speech"
|
|
715
|
+
params = {
|
|
716
|
+
"model": self._model_uid,
|
|
717
|
+
"input": input,
|
|
718
|
+
"voice": voice,
|
|
719
|
+
"response_format": response_format,
|
|
720
|
+
"speed": speed,
|
|
721
|
+
}
|
|
722
|
+
response = requests.post(url, json=params, headers=self.auth_headers)
|
|
723
|
+
if response.status_code != 200:
|
|
724
|
+
raise RuntimeError(
|
|
725
|
+
f"Failed to speech the text, detail: {_get_error_string(response)}"
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
return response.content
|
|
729
|
+
|
|
687
730
|
|
|
688
731
|
class Client:
|
|
689
732
|
def __init__(self, base_url, api_key: Optional[str] = None):
|
|
@@ -1102,13 +1145,17 @@ class Client:
|
|
|
1102
1145
|
response_data = response.json()
|
|
1103
1146
|
return response_data
|
|
1104
1147
|
|
|
1105
|
-
def list_cached_models(
|
|
1148
|
+
def list_cached_models(
|
|
1149
|
+
self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
|
|
1150
|
+
) -> List[Dict[Any, Any]]:
|
|
1106
1151
|
"""
|
|
1107
1152
|
Get a list of cached models.
|
|
1108
|
-
|
|
1109
1153
|
Parameters
|
|
1110
1154
|
----------
|
|
1111
|
-
|
|
1155
|
+
model_name: Optional[str]
|
|
1156
|
+
The name of model.
|
|
1157
|
+
worker_ip: Optional[str]
|
|
1158
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1112
1159
|
|
|
1113
1160
|
Returns
|
|
1114
1161
|
-------
|
|
@@ -1121,16 +1168,81 @@ class Client:
|
|
|
1121
1168
|
Raised when the request fails, including the reason for the failure.
|
|
1122
1169
|
"""
|
|
1123
1170
|
|
|
1124
|
-
url = f"{self.base_url}/v1/
|
|
1125
|
-
|
|
1171
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1172
|
+
params = {
|
|
1173
|
+
"model_name": model_name,
|
|
1174
|
+
"worker_ip": worker_ip,
|
|
1175
|
+
}
|
|
1176
|
+
response = requests.get(url, headers=self._headers, params=params)
|
|
1126
1177
|
if response.status_code != 200:
|
|
1127
1178
|
raise RuntimeError(
|
|
1128
1179
|
f"Failed to list cached model, detail: {_get_error_string(response)}"
|
|
1129
1180
|
)
|
|
1130
1181
|
|
|
1182
|
+
response_data = response.json()
|
|
1183
|
+
response_data = response_data.get("list")
|
|
1184
|
+
return response_data
|
|
1185
|
+
|
|
1186
|
+
def list_deletable_models(
|
|
1187
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1188
|
+
) -> Dict[str, Any]:
|
|
1189
|
+
"""
|
|
1190
|
+
Get the cached models with the model path cached on the server.
|
|
1191
|
+
Parameters
|
|
1192
|
+
----------
|
|
1193
|
+
model_version: str
|
|
1194
|
+
The version of the model.
|
|
1195
|
+
worker_ip: Optional[str]
|
|
1196
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1197
|
+
Returns
|
|
1198
|
+
-------
|
|
1199
|
+
Dict[str, Dict[str,str]]]
|
|
1200
|
+
Dictionary with keys "model_name" and values model_file_location.
|
|
1201
|
+
"""
|
|
1202
|
+
url = f"{self.base_url}/v1/cache/models/files"
|
|
1203
|
+
params = {
|
|
1204
|
+
"model_version": model_version,
|
|
1205
|
+
"worker_ip": worker_ip,
|
|
1206
|
+
}
|
|
1207
|
+
response = requests.get(url, headers=self._headers, params=params)
|
|
1208
|
+
if response.status_code != 200:
|
|
1209
|
+
raise RuntimeError(
|
|
1210
|
+
f"Failed to get paths by model name, detail: {_get_error_string(response)}"
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1131
1213
|
response_data = response.json()
|
|
1132
1214
|
return response_data
|
|
1133
1215
|
|
|
1216
|
+
def confirm_and_remove_model(
|
|
1217
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1218
|
+
) -> bool:
|
|
1219
|
+
"""
|
|
1220
|
+
Remove the cached models with the model name cached on the server.
|
|
1221
|
+
Parameters
|
|
1222
|
+
----------
|
|
1223
|
+
model_version: str
|
|
1224
|
+
The version of the model.
|
|
1225
|
+
worker_ip: Optional[str]
|
|
1226
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1227
|
+
Returns
|
|
1228
|
+
-------
|
|
1229
|
+
str
|
|
1230
|
+
The response of the server.
|
|
1231
|
+
"""
|
|
1232
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1233
|
+
params = {
|
|
1234
|
+
"model_version": model_version,
|
|
1235
|
+
"worker_ip": worker_ip,
|
|
1236
|
+
}
|
|
1237
|
+
response = requests.delete(url, headers=self._headers, params=params)
|
|
1238
|
+
if response.status_code != 200:
|
|
1239
|
+
raise RuntimeError(
|
|
1240
|
+
f"Failed to remove cached models, detail: {_get_error_string(response)}"
|
|
1241
|
+
)
|
|
1242
|
+
|
|
1243
|
+
response_data = response.json()
|
|
1244
|
+
return response_data.get("result", False)
|
|
1245
|
+
|
|
1134
1246
|
def get_model_registration(
|
|
1135
1247
|
self, model_type: str, model_name: str
|
|
1136
1248
|
) -> Dict[str, Any]:
|
|
@@ -1181,3 +1293,30 @@ class Client:
|
|
|
1181
1293
|
|
|
1182
1294
|
response_data = response.json()
|
|
1183
1295
|
return response_data
|
|
1296
|
+
|
|
1297
|
+
def abort_request(self, model_uid: str, request_id: str):
|
|
1298
|
+
"""
|
|
1299
|
+
Abort a request.
|
|
1300
|
+
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
1301
|
+
Currently, this interface is only supported when batching is enabled for models on transformers backend.
|
|
1302
|
+
|
|
1303
|
+
Parameters
|
|
1304
|
+
----------
|
|
1305
|
+
model_uid: str
|
|
1306
|
+
Model uid.
|
|
1307
|
+
request_id: str
|
|
1308
|
+
Request id.
|
|
1309
|
+
Returns
|
|
1310
|
+
-------
|
|
1311
|
+
Dict
|
|
1312
|
+
Return empty dict.
|
|
1313
|
+
"""
|
|
1314
|
+
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
|
|
1315
|
+
response = requests.post(url, headers=self._headers)
|
|
1316
|
+
if response.status_code != 200:
|
|
1317
|
+
raise RuntimeError(
|
|
1318
|
+
f"Failed to abort request, detail: {_get_error_string(response)}"
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
response_data = response.json()
|
|
1322
|
+
return response_data
|
xinference/constants.py
CHANGED
|
@@ -17,6 +17,7 @@ from pathlib import Path
|
|
|
17
17
|
|
|
18
18
|
XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
|
|
19
19
|
XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
|
|
20
|
+
XINFERENCE_ENV_CSG_TOKEN = "XINFERENCE_CSG_TOKEN"
|
|
20
21
|
XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
|
|
21
22
|
XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
|
|
22
23
|
"XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
|
|
@@ -27,6 +28,7 @@ XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
|
|
|
27
28
|
XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
|
|
28
29
|
XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
|
|
29
30
|
XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
|
|
31
|
+
XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
def get_xinference_home() -> str:
|
|
@@ -70,3 +72,6 @@ XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG,
|
|
|
70
72
|
XINFERENCE_DISABLE_METRICS = bool(
|
|
71
73
|
int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
|
|
72
74
|
)
|
|
75
|
+
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
|
|
76
|
+
int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
|
|
77
|
+
)
|
xinference/core/cache_tracker.py
CHANGED
|
@@ -11,7 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import os
|
|
15
14
|
from logging import getLogger
|
|
16
15
|
from typing import Any, Dict, List, Optional
|
|
17
16
|
|
|
@@ -102,33 +101,54 @@ class CacheTrackerActor(xo.Actor):
|
|
|
102
101
|
def get_model_version_count(self, model_name: str) -> int:
|
|
103
102
|
return len(self.get_model_versions(model_name))
|
|
104
103
|
|
|
105
|
-
def list_cached_models(
|
|
104
|
+
def list_cached_models(
|
|
105
|
+
self, worker_ip: str, model_name: Optional[str] = None
|
|
106
|
+
) -> List[Dict[Any, Any]]:
|
|
106
107
|
cached_models = []
|
|
107
|
-
for
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
108
|
+
for name, versions in self._model_name_to_version_info.items():
|
|
109
|
+
# only return assigned cached model if model_name is not none
|
|
110
|
+
# else return all cached model
|
|
111
|
+
if model_name and model_name != name:
|
|
112
|
+
continue
|
|
113
|
+
for version_info in versions:
|
|
114
|
+
cache_status = version_info.get("cache_status", False)
|
|
115
|
+
# search cached model
|
|
116
|
+
if cache_status:
|
|
117
|
+
res = version_info.copy()
|
|
118
|
+
res["model_name"] = name
|
|
119
|
+
paths = res.get("model_file_location", {})
|
|
120
|
+
# only return assigned worker's device path
|
|
121
|
+
if worker_ip in paths.keys():
|
|
122
|
+
res["model_file_location"] = paths[worker_ip]
|
|
123
|
+
cached_models.append(res)
|
|
124
|
+
return cached_models
|
|
113
125
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
126
|
+
def list_deletable_models(self, model_version: str, worker_ip: str) -> str:
|
|
127
|
+
model_file_location = ""
|
|
128
|
+
for model, model_versions in self._model_name_to_version_info.items():
|
|
129
|
+
for version_info in model_versions:
|
|
130
|
+
# search assign model version
|
|
131
|
+
if model_version == version_info.get("model_version", None):
|
|
132
|
+
# check if exist
|
|
133
|
+
if version_info.get("cache_status", False):
|
|
134
|
+
paths = version_info.get("model_file_location", {})
|
|
135
|
+
# only return assigned worker's device path
|
|
136
|
+
if worker_ip in paths.keys():
|
|
137
|
+
model_file_location = paths[worker_ip]
|
|
138
|
+
return model_file_location
|
|
122
139
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
140
|
+
def confirm_and_remove_model(self, model_version: str, worker_ip: str):
|
|
141
|
+
# find remove path
|
|
142
|
+
rm_path = self.list_deletable_models(model_version, worker_ip)
|
|
143
|
+
# search _model_name_to_version_info if exist this path, and delete
|
|
144
|
+
for model, model_versions in self._model_name_to_version_info.items():
|
|
145
|
+
for version_info in model_versions:
|
|
146
|
+
# check if exist
|
|
147
|
+
if version_info.get("cache_status", False):
|
|
148
|
+
paths = version_info.get("model_file_location", {})
|
|
149
|
+
# only delete assigned worker's device path
|
|
150
|
+
if worker_ip in paths.keys() and rm_path == paths[worker_ip]:
|
|
151
|
+
del paths[worker_ip]
|
|
152
|
+
# if path is empty, update cache status
|
|
153
|
+
if not paths:
|
|
154
|
+
version_info["cache_status"] = False
|