xinference 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +108 -14
- xinference/client/restful/restful_client.py +78 -5
- xinference/constants.py +1 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/event.py +5 -6
- xinference/core/model.py +59 -42
- xinference/core/scheduler.py +46 -18
- xinference/core/supervisor.py +73 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/core.py +12 -1
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +34 -2
- xinference/model/llm/llm_family.json +8 -2
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +8 -2
- xinference/model/llm/pytorch/chatglm.py +41 -12
- xinference/model/llm/pytorch/core.py +128 -88
- xinference/model/llm/pytorch/glm4v.py +24 -3
- xinference/model/llm/pytorch/internlm2.py +15 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +69 -189
- xinference/model/llm/utils.py +27 -14
- xinference/model/llm/vllm/core.py +10 -4
- xinference/model/rerank/core.py +35 -6
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +28 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
- xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/METADATA +1 -1
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/RECORD +69 -56
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/LICENSE +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/WHEEL +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-06-
|
|
11
|
+
"date": "2024-06-21T15:34:17+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.12.
|
|
14
|
+
"full-revisionid": "5cef7c3d4bb0c5208d262fc3ffb7d7083724de1c",
|
|
15
|
+
"version": "0.12.2"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -109,6 +109,7 @@ class RerankRequest(BaseModel):
|
|
|
109
109
|
documents: List[str]
|
|
110
110
|
top_n: Optional[int] = None
|
|
111
111
|
return_documents: Optional[bool] = False
|
|
112
|
+
return_len: Optional[bool] = False
|
|
112
113
|
max_chunks_per_doc: Optional[int] = None
|
|
113
114
|
|
|
114
115
|
|
|
@@ -522,11 +523,31 @@ class RESTfulAPI:
|
|
|
522
523
|
),
|
|
523
524
|
)
|
|
524
525
|
self._router.add_api_route(
|
|
525
|
-
"/v1/
|
|
526
|
+
"/v1/cache/models",
|
|
526
527
|
self.list_cached_models,
|
|
527
528
|
methods=["GET"],
|
|
528
529
|
dependencies=(
|
|
529
|
-
[Security(self._auth_service, scopes=["
|
|
530
|
+
[Security(self._auth_service, scopes=["cache:list"])]
|
|
531
|
+
if self.is_authenticated()
|
|
532
|
+
else None
|
|
533
|
+
),
|
|
534
|
+
)
|
|
535
|
+
self._router.add_api_route(
|
|
536
|
+
"/v1/cache/models/files",
|
|
537
|
+
self.list_model_files,
|
|
538
|
+
methods=["GET"],
|
|
539
|
+
dependencies=(
|
|
540
|
+
[Security(self._auth_service, scopes=["cache:list"])]
|
|
541
|
+
if self.is_authenticated()
|
|
542
|
+
else None
|
|
543
|
+
),
|
|
544
|
+
)
|
|
545
|
+
self._router.add_api_route(
|
|
546
|
+
"/v1/cache/models",
|
|
547
|
+
self.confirm_and_remove_model,
|
|
548
|
+
methods=["DELETE"],
|
|
549
|
+
dependencies=(
|
|
550
|
+
[Security(self._auth_service, scopes=["cache:delete"])]
|
|
530
551
|
if self.is_authenticated()
|
|
531
552
|
else None
|
|
532
553
|
),
|
|
@@ -961,7 +982,8 @@ class RESTfulAPI:
|
|
|
961
982
|
return JSONResponse(content=self._supervisor_address)
|
|
962
983
|
|
|
963
984
|
async def create_completion(self, request: Request) -> Response:
|
|
964
|
-
|
|
985
|
+
raw_body = await request.json()
|
|
986
|
+
body = CreateCompletionRequest.parse_obj(raw_body)
|
|
965
987
|
exclude = {
|
|
966
988
|
"prompt",
|
|
967
989
|
"model",
|
|
@@ -971,6 +993,7 @@ class RESTfulAPI:
|
|
|
971
993
|
"logit_bias_type",
|
|
972
994
|
"user",
|
|
973
995
|
}
|
|
996
|
+
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
|
|
974
997
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
975
998
|
|
|
976
999
|
# TODO: Decide if this default value override is necessary #1061
|
|
@@ -1000,7 +1023,9 @@ class RESTfulAPI:
|
|
|
1000
1023
|
iterator = None
|
|
1001
1024
|
try:
|
|
1002
1025
|
try:
|
|
1003
|
-
iterator = await model.generate(
|
|
1026
|
+
iterator = await model.generate(
|
|
1027
|
+
body.prompt, kwargs, raw_params=raw_kwargs
|
|
1028
|
+
)
|
|
1004
1029
|
except RuntimeError as re:
|
|
1005
1030
|
self.handle_request_limit_error(re)
|
|
1006
1031
|
async for item in iterator:
|
|
@@ -1020,7 +1045,7 @@ class RESTfulAPI:
|
|
|
1020
1045
|
return EventSourceResponse(stream_results())
|
|
1021
1046
|
else:
|
|
1022
1047
|
try:
|
|
1023
|
-
data = await model.generate(body.prompt, kwargs)
|
|
1048
|
+
data = await model.generate(body.prompt, kwargs, raw_params=raw_kwargs)
|
|
1024
1049
|
return Response(data, media_type="application/json")
|
|
1025
1050
|
except Exception as e:
|
|
1026
1051
|
logger.error(e, exc_info=True)
|
|
@@ -1092,6 +1117,7 @@ class RESTfulAPI:
|
|
|
1092
1117
|
top_n=body.top_n,
|
|
1093
1118
|
max_chunks_per_doc=body.max_chunks_per_doc,
|
|
1094
1119
|
return_documents=body.return_documents,
|
|
1120
|
+
return_len=body.return_len,
|
|
1095
1121
|
**kwargs,
|
|
1096
1122
|
)
|
|
1097
1123
|
return Response(scores, media_type="application/json")
|
|
@@ -1321,7 +1347,8 @@ class RESTfulAPI:
|
|
|
1321
1347
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1322
1348
|
|
|
1323
1349
|
async def create_chat_completion(self, request: Request) -> Response:
|
|
1324
|
-
|
|
1350
|
+
raw_body = await request.json()
|
|
1351
|
+
body = CreateChatCompletion.parse_obj(raw_body)
|
|
1325
1352
|
exclude = {
|
|
1326
1353
|
"prompt",
|
|
1327
1354
|
"model",
|
|
@@ -1331,6 +1358,7 @@ class RESTfulAPI:
|
|
|
1331
1358
|
"logit_bias_type",
|
|
1332
1359
|
"user",
|
|
1333
1360
|
}
|
|
1361
|
+
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
|
|
1334
1362
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
1335
1363
|
|
|
1336
1364
|
# TODO: Decide if this default value override is necessary #1061
|
|
@@ -1401,9 +1429,13 @@ class RESTfulAPI:
|
|
|
1401
1429
|
model_family = desc.get("model_family", "")
|
|
1402
1430
|
function_call_models = [
|
|
1403
1431
|
"chatglm3",
|
|
1432
|
+
"glm4-chat",
|
|
1404
1433
|
"gorilla-openfunctions-v1",
|
|
1405
1434
|
"qwen-chat",
|
|
1406
1435
|
"qwen1.5-chat",
|
|
1436
|
+
"qwen1.5-moe-chat",
|
|
1437
|
+
"qwen2-instruct",
|
|
1438
|
+
"qwen2-moe-instruct",
|
|
1407
1439
|
]
|
|
1408
1440
|
|
|
1409
1441
|
is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
|
|
@@ -1426,7 +1458,13 @@ class RESTfulAPI:
|
|
|
1426
1458
|
)
|
|
1427
1459
|
if body.tools and body.stream:
|
|
1428
1460
|
is_vllm = await model.is_vllm_backend()
|
|
1429
|
-
if not is_vllm or model_family not in [
|
|
1461
|
+
if not is_vllm or model_family not in [
|
|
1462
|
+
"qwen-chat",
|
|
1463
|
+
"qwen1.5-chat",
|
|
1464
|
+
"qwen1.5-moe-chat",
|
|
1465
|
+
"qwen2-instruct",
|
|
1466
|
+
"qwen2-moe-instruct",
|
|
1467
|
+
]:
|
|
1430
1468
|
raise HTTPException(
|
|
1431
1469
|
status_code=400,
|
|
1432
1470
|
detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
|
|
@@ -1439,10 +1477,16 @@ class RESTfulAPI:
|
|
|
1439
1477
|
try:
|
|
1440
1478
|
try:
|
|
1441
1479
|
if is_qwen:
|
|
1442
|
-
iterator = await model.chat(
|
|
1480
|
+
iterator = await model.chat(
|
|
1481
|
+
prompt, chat_history, kwargs, raw_params=raw_kwargs
|
|
1482
|
+
)
|
|
1443
1483
|
else:
|
|
1444
1484
|
iterator = await model.chat(
|
|
1445
|
-
prompt,
|
|
1485
|
+
prompt,
|
|
1486
|
+
system_prompt,
|
|
1487
|
+
chat_history,
|
|
1488
|
+
kwargs,
|
|
1489
|
+
raw_params=raw_kwargs,
|
|
1446
1490
|
)
|
|
1447
1491
|
except RuntimeError as re:
|
|
1448
1492
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1472,9 +1516,17 @@ class RESTfulAPI:
|
|
|
1472
1516
|
else:
|
|
1473
1517
|
try:
|
|
1474
1518
|
if is_qwen:
|
|
1475
|
-
data = await model.chat(
|
|
1519
|
+
data = await model.chat(
|
|
1520
|
+
prompt, chat_history, kwargs, raw_params=raw_kwargs
|
|
1521
|
+
)
|
|
1476
1522
|
else:
|
|
1477
|
-
data = await model.chat(
|
|
1523
|
+
data = await model.chat(
|
|
1524
|
+
prompt,
|
|
1525
|
+
system_prompt,
|
|
1526
|
+
chat_history,
|
|
1527
|
+
kwargs,
|
|
1528
|
+
raw_params=raw_kwargs,
|
|
1529
|
+
)
|
|
1478
1530
|
return Response(content=data, media_type="application/json")
|
|
1479
1531
|
except Exception as e:
|
|
1480
1532
|
logger.error(e, exc_info=True)
|
|
@@ -1555,10 +1607,17 @@ class RESTfulAPI:
|
|
|
1555
1607
|
logger.error(e, exc_info=True)
|
|
1556
1608
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1557
1609
|
|
|
1558
|
-
async def list_cached_models(
|
|
1610
|
+
async def list_cached_models(
|
|
1611
|
+
self, model_name: str = Query(None), worker_ip: str = Query(None)
|
|
1612
|
+
) -> JSONResponse:
|
|
1559
1613
|
try:
|
|
1560
|
-
data = await (await self._get_supervisor_ref()).list_cached_models(
|
|
1561
|
-
|
|
1614
|
+
data = await (await self._get_supervisor_ref()).list_cached_models(
|
|
1615
|
+
model_name, worker_ip
|
|
1616
|
+
)
|
|
1617
|
+
resp = {
|
|
1618
|
+
"list": data,
|
|
1619
|
+
}
|
|
1620
|
+
return JSONResponse(content=resp)
|
|
1562
1621
|
except ValueError as re:
|
|
1563
1622
|
logger.error(re, exc_info=True)
|
|
1564
1623
|
raise HTTPException(status_code=400, detail=str(re))
|
|
@@ -1623,6 +1682,41 @@ class RESTfulAPI:
|
|
|
1623
1682
|
logger.error(e, exc_info=True)
|
|
1624
1683
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1625
1684
|
|
|
1685
|
+
async def list_model_files(
|
|
1686
|
+
self, model_version: str = Query(None), worker_ip: str = Query(None)
|
|
1687
|
+
) -> JSONResponse:
|
|
1688
|
+
try:
|
|
1689
|
+
data = await (await self._get_supervisor_ref()).list_deletable_models(
|
|
1690
|
+
model_version, worker_ip
|
|
1691
|
+
)
|
|
1692
|
+
response = {
|
|
1693
|
+
"model_version": model_version,
|
|
1694
|
+
"worker_ip": worker_ip,
|
|
1695
|
+
"paths": data,
|
|
1696
|
+
}
|
|
1697
|
+
return JSONResponse(content=response)
|
|
1698
|
+
except ValueError as re:
|
|
1699
|
+
logger.error(re, exc_info=True)
|
|
1700
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1701
|
+
except Exception as e:
|
|
1702
|
+
logger.error(e, exc_info=True)
|
|
1703
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1704
|
+
|
|
1705
|
+
async def confirm_and_remove_model(
|
|
1706
|
+
self, model_version: str = Query(None), worker_ip: str = Query(None)
|
|
1707
|
+
) -> JSONResponse:
|
|
1708
|
+
try:
|
|
1709
|
+
res = await (await self._get_supervisor_ref()).confirm_and_remove_model(
|
|
1710
|
+
model_version=model_version, worker_ip=worker_ip
|
|
1711
|
+
)
|
|
1712
|
+
return JSONResponse(content={"result": res})
|
|
1713
|
+
except ValueError as re:
|
|
1714
|
+
logger.error(re, exc_info=True)
|
|
1715
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1716
|
+
except Exception as e:
|
|
1717
|
+
logger.error(e, exc_info=True)
|
|
1718
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1719
|
+
|
|
1626
1720
|
|
|
1627
1721
|
def run(
|
|
1628
1722
|
supervisor_address: str,
|
|
@@ -135,6 +135,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
135
135
|
top_n: Optional[int] = None,
|
|
136
136
|
max_chunks_per_doc: Optional[int] = None,
|
|
137
137
|
return_documents: Optional[bool] = None,
|
|
138
|
+
return_len: Optional[bool] = None,
|
|
138
139
|
**kwargs,
|
|
139
140
|
):
|
|
140
141
|
"""
|
|
@@ -152,6 +153,8 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
152
153
|
The maximum number of chunks derived from a document
|
|
153
154
|
return_documents: bool
|
|
154
155
|
if return documents
|
|
156
|
+
return_len: bool
|
|
157
|
+
if return tokens len
|
|
155
158
|
Returns
|
|
156
159
|
-------
|
|
157
160
|
Scores
|
|
@@ -170,6 +173,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
170
173
|
"top_n": top_n,
|
|
171
174
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
172
175
|
"return_documents": return_documents,
|
|
176
|
+
"return_len": return_len,
|
|
173
177
|
}
|
|
174
178
|
request_body.update(kwargs)
|
|
175
179
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
@@ -1145,13 +1149,17 @@ class Client:
|
|
|
1145
1149
|
response_data = response.json()
|
|
1146
1150
|
return response_data
|
|
1147
1151
|
|
|
1148
|
-
def list_cached_models(
|
|
1152
|
+
def list_cached_models(
|
|
1153
|
+
self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
|
|
1154
|
+
) -> List[Dict[Any, Any]]:
|
|
1149
1155
|
"""
|
|
1150
1156
|
Get a list of cached models.
|
|
1151
|
-
|
|
1152
1157
|
Parameters
|
|
1153
1158
|
----------
|
|
1154
|
-
|
|
1159
|
+
model_name: Optional[str]
|
|
1160
|
+
The name of model.
|
|
1161
|
+
worker_ip: Optional[str]
|
|
1162
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1155
1163
|
|
|
1156
1164
|
Returns
|
|
1157
1165
|
-------
|
|
@@ -1164,16 +1172,81 @@ class Client:
|
|
|
1164
1172
|
Raised when the request fails, including the reason for the failure.
|
|
1165
1173
|
"""
|
|
1166
1174
|
|
|
1167
|
-
url = f"{self.base_url}/v1/
|
|
1168
|
-
|
|
1175
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1176
|
+
params = {
|
|
1177
|
+
"model_name": model_name,
|
|
1178
|
+
"worker_ip": worker_ip,
|
|
1179
|
+
}
|
|
1180
|
+
response = requests.get(url, headers=self._headers, params=params)
|
|
1169
1181
|
if response.status_code != 200:
|
|
1170
1182
|
raise RuntimeError(
|
|
1171
1183
|
f"Failed to list cached model, detail: {_get_error_string(response)}"
|
|
1172
1184
|
)
|
|
1173
1185
|
|
|
1186
|
+
response_data = response.json()
|
|
1187
|
+
response_data = response_data.get("list")
|
|
1188
|
+
return response_data
|
|
1189
|
+
|
|
1190
|
+
def list_deletable_models(
|
|
1191
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1192
|
+
) -> Dict[str, Any]:
|
|
1193
|
+
"""
|
|
1194
|
+
Get the cached models with the model path cached on the server.
|
|
1195
|
+
Parameters
|
|
1196
|
+
----------
|
|
1197
|
+
model_version: str
|
|
1198
|
+
The version of the model.
|
|
1199
|
+
worker_ip: Optional[str]
|
|
1200
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1201
|
+
Returns
|
|
1202
|
+
-------
|
|
1203
|
+
Dict[str, Dict[str,str]]]
|
|
1204
|
+
Dictionary with keys "model_name" and values model_file_location.
|
|
1205
|
+
"""
|
|
1206
|
+
url = f"{self.base_url}/v1/cache/models/files"
|
|
1207
|
+
params = {
|
|
1208
|
+
"model_version": model_version,
|
|
1209
|
+
"worker_ip": worker_ip,
|
|
1210
|
+
}
|
|
1211
|
+
response = requests.get(url, headers=self._headers, params=params)
|
|
1212
|
+
if response.status_code != 200:
|
|
1213
|
+
raise RuntimeError(
|
|
1214
|
+
f"Failed to get paths by model name, detail: {_get_error_string(response)}"
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1174
1217
|
response_data = response.json()
|
|
1175
1218
|
return response_data
|
|
1176
1219
|
|
|
1220
|
+
def confirm_and_remove_model(
|
|
1221
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1222
|
+
) -> bool:
|
|
1223
|
+
"""
|
|
1224
|
+
Remove the cached models with the model name cached on the server.
|
|
1225
|
+
Parameters
|
|
1226
|
+
----------
|
|
1227
|
+
model_version: str
|
|
1228
|
+
The version of the model.
|
|
1229
|
+
worker_ip: Optional[str]
|
|
1230
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1231
|
+
Returns
|
|
1232
|
+
-------
|
|
1233
|
+
str
|
|
1234
|
+
The response of the server.
|
|
1235
|
+
"""
|
|
1236
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1237
|
+
params = {
|
|
1238
|
+
"model_version": model_version,
|
|
1239
|
+
"worker_ip": worker_ip,
|
|
1240
|
+
}
|
|
1241
|
+
response = requests.delete(url, headers=self._headers, params=params)
|
|
1242
|
+
if response.status_code != 200:
|
|
1243
|
+
raise RuntimeError(
|
|
1244
|
+
f"Failed to remove cached models, detail: {_get_error_string(response)}"
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
response_data = response.json()
|
|
1248
|
+
return response_data.get("result", False)
|
|
1249
|
+
|
|
1177
1250
|
def get_model_registration(
|
|
1178
1251
|
self, model_type: str, model_name: str
|
|
1179
1252
|
) -> Dict[str, Any]:
|
xinference/constants.py
CHANGED
|
@@ -17,6 +17,7 @@ from pathlib import Path
|
|
|
17
17
|
|
|
18
18
|
XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
|
|
19
19
|
XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
|
|
20
|
+
XINFERENCE_ENV_CSG_TOKEN = "XINFERENCE_CSG_TOKEN"
|
|
20
21
|
XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
|
|
21
22
|
XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
|
|
22
23
|
"XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
|
xinference/core/cache_tracker.py
CHANGED
|
@@ -11,7 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import os
|
|
15
14
|
from logging import getLogger
|
|
16
15
|
from typing import Any, Dict, List, Optional
|
|
17
16
|
|
|
@@ -102,33 +101,54 @@ class CacheTrackerActor(xo.Actor):
|
|
|
102
101
|
def get_model_version_count(self, model_name: str) -> int:
|
|
103
102
|
return len(self.get_model_versions(model_name))
|
|
104
103
|
|
|
105
|
-
def list_cached_models(
|
|
104
|
+
def list_cached_models(
|
|
105
|
+
self, worker_ip: str, model_name: Optional[str] = None
|
|
106
|
+
) -> List[Dict[Any, Any]]:
|
|
106
107
|
cached_models = []
|
|
107
|
-
for
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
108
|
+
for name, versions in self._model_name_to_version_info.items():
|
|
109
|
+
# only return assigned cached model if model_name is not none
|
|
110
|
+
# else return all cached model
|
|
111
|
+
if model_name and model_name != name:
|
|
112
|
+
continue
|
|
113
|
+
for version_info in versions:
|
|
114
|
+
cache_status = version_info.get("cache_status", False)
|
|
115
|
+
# search cached model
|
|
116
|
+
if cache_status:
|
|
117
|
+
res = version_info.copy()
|
|
118
|
+
res["model_name"] = name
|
|
119
|
+
paths = res.get("model_file_location", {})
|
|
120
|
+
# only return assigned worker's device path
|
|
121
|
+
if worker_ip in paths.keys():
|
|
122
|
+
res["model_file_location"] = paths[worker_ip]
|
|
123
|
+
cached_models.append(res)
|
|
124
|
+
return cached_models
|
|
113
125
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
126
|
+
def list_deletable_models(self, model_version: str, worker_ip: str) -> str:
|
|
127
|
+
model_file_location = ""
|
|
128
|
+
for model, model_versions in self._model_name_to_version_info.items():
|
|
129
|
+
for version_info in model_versions:
|
|
130
|
+
# search assign model version
|
|
131
|
+
if model_version == version_info.get("model_version", None):
|
|
132
|
+
# check if exist
|
|
133
|
+
if version_info.get("cache_status", False):
|
|
134
|
+
paths = version_info.get("model_file_location", {})
|
|
135
|
+
# only return assigned worker's device path
|
|
136
|
+
if worker_ip in paths.keys():
|
|
137
|
+
model_file_location = paths[worker_ip]
|
|
138
|
+
return model_file_location
|
|
122
139
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
140
|
+
def confirm_and_remove_model(self, model_version: str, worker_ip: str):
|
|
141
|
+
# find remove path
|
|
142
|
+
rm_path = self.list_deletable_models(model_version, worker_ip)
|
|
143
|
+
# search _model_name_to_version_info if exist this path, and delete
|
|
144
|
+
for model, model_versions in self._model_name_to_version_info.items():
|
|
145
|
+
for version_info in model_versions:
|
|
146
|
+
# check if exist
|
|
147
|
+
if version_info.get("cache_status", False):
|
|
148
|
+
paths = version_info.get("model_file_location", {})
|
|
149
|
+
# only delete assigned worker's device path
|
|
150
|
+
if worker_ip in paths.keys() and rm_path == paths[worker_ip]:
|
|
151
|
+
del paths[worker_ip]
|
|
152
|
+
# if path is empty, update cache status
|
|
153
|
+
if not paths:
|
|
154
|
+
version_info["cache_status"] = False
|
xinference/core/event.py
CHANGED
|
@@ -12,8 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import
|
|
16
|
-
from collections import defaultdict
|
|
15
|
+
from collections import defaultdict, deque
|
|
17
16
|
from enum import Enum
|
|
18
17
|
from typing import Dict, List, TypedDict
|
|
19
18
|
|
|
@@ -37,8 +36,8 @@ class Event(TypedDict):
|
|
|
37
36
|
class EventCollectorActor(xo.StatelessActor):
|
|
38
37
|
def __init__(self):
|
|
39
38
|
super().__init__()
|
|
40
|
-
self._model_uid_to_events: Dict[str,
|
|
41
|
-
lambda:
|
|
39
|
+
self._model_uid_to_events: Dict[str, deque] = defaultdict( # type: ignore
|
|
40
|
+
lambda: deque(maxlen=MAX_EVENT_COUNT_PER_MODEL)
|
|
42
41
|
)
|
|
43
42
|
|
|
44
43
|
@classmethod
|
|
@@ -50,7 +49,7 @@ class EventCollectorActor(xo.StatelessActor):
|
|
|
50
49
|
if event_queue is None:
|
|
51
50
|
return []
|
|
52
51
|
else:
|
|
53
|
-
return [dict(e, event_type=e["event_type"].name) for e in event_queue
|
|
52
|
+
return [dict(e, event_type=e["event_type"].name) for e in iter(event_queue)]
|
|
54
53
|
|
|
55
54
|
def report_event(self, model_uid: str, event: Event):
|
|
56
|
-
self._model_uid_to_events[model_uid].
|
|
55
|
+
self._model_uid_to_events[model_uid].append(event)
|
xinference/core/model.py
CHANGED
|
@@ -264,12 +264,14 @@ class ModelActor(xo.StatelessActor):
|
|
|
264
264
|
return isinstance(self._model, VLLMModel)
|
|
265
265
|
|
|
266
266
|
def allow_batching(self) -> bool:
|
|
267
|
-
from ..model.llm.pytorch.core import
|
|
267
|
+
from ..model.llm.pytorch.core import PytorchModel
|
|
268
|
+
|
|
269
|
+
model_ability = self._model_description.get("model_ability", [])
|
|
268
270
|
|
|
269
271
|
return (
|
|
270
272
|
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
271
|
-
and isinstance(self._model,
|
|
272
|
-
and
|
|
273
|
+
and isinstance(self._model, PytorchModel)
|
|
274
|
+
and "vision" not in model_ability
|
|
273
275
|
)
|
|
274
276
|
|
|
275
277
|
async def load(self):
|
|
@@ -393,18 +395,25 @@ class ModelActor(xo.StatelessActor):
|
|
|
393
395
|
@request_limit
|
|
394
396
|
@xo.generator
|
|
395
397
|
async def generate(self, prompt: str, *args, **kwargs):
|
|
396
|
-
if
|
|
397
|
-
return await self.
|
|
398
|
-
|
|
399
|
-
)
|
|
400
|
-
if hasattr(self._model, "async_generate"):
|
|
401
|
-
return await self._call_wrapper(
|
|
402
|
-
self._model.async_generate, prompt, *args, **kwargs
|
|
398
|
+
if self.allow_batching():
|
|
399
|
+
return await self.handle_batching_request(
|
|
400
|
+
prompt, "generate", *args, **kwargs
|
|
403
401
|
)
|
|
404
|
-
|
|
402
|
+
else:
|
|
403
|
+
kwargs.pop("raw_params", None)
|
|
404
|
+
if hasattr(self._model, "generate"):
|
|
405
|
+
return await self._call_wrapper(
|
|
406
|
+
self._model.generate, prompt, *args, **kwargs
|
|
407
|
+
)
|
|
408
|
+
if hasattr(self._model, "async_generate"):
|
|
409
|
+
return await self._call_wrapper(
|
|
410
|
+
self._model.async_generate, prompt, *args, **kwargs
|
|
411
|
+
)
|
|
412
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
|
|
405
413
|
|
|
414
|
+
@staticmethod
|
|
406
415
|
async def _queue_consumer(
|
|
407
|
-
|
|
416
|
+
queue: Queue, timeout: Optional[float] = None
|
|
408
417
|
) -> AsyncIterator[Any]:
|
|
409
418
|
from .scheduler import (
|
|
410
419
|
XINFERENCE_STREAMING_ABORT_FLAG,
|
|
@@ -429,9 +438,38 @@ class ModelActor(xo.StatelessActor):
|
|
|
429
438
|
yield res
|
|
430
439
|
|
|
431
440
|
@staticmethod
|
|
432
|
-
def
|
|
433
|
-
|
|
434
|
-
|
|
441
|
+
def _get_stream_from_args(ability: str, *args) -> bool:
|
|
442
|
+
if ability == "chat":
|
|
443
|
+
assert args[2] is None or isinstance(args[2], dict)
|
|
444
|
+
return False if args[2] is None else args[2].get("stream", False)
|
|
445
|
+
else:
|
|
446
|
+
assert args[0] is None or isinstance(args[0], dict)
|
|
447
|
+
return False if args[0] is None else args[0].get("stream", False)
|
|
448
|
+
|
|
449
|
+
async def handle_batching_request(self, prompt: str, ability: str, *args, **kwargs):
|
|
450
|
+
stream = self._get_stream_from_args(ability, *args)
|
|
451
|
+
assert self._scheduler_ref is not None
|
|
452
|
+
if stream:
|
|
453
|
+
assert self._scheduler_ref is not None
|
|
454
|
+
queue: Queue[Any] = Queue()
|
|
455
|
+
ret = self._queue_consumer(queue)
|
|
456
|
+
await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
|
|
457
|
+
gen = self._to_json_async_gen(ret)
|
|
458
|
+
self._current_generator = weakref.ref(gen)
|
|
459
|
+
return gen
|
|
460
|
+
else:
|
|
461
|
+
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
462
|
+
|
|
463
|
+
assert self._loop is not None
|
|
464
|
+
future = ConcurrentFuture()
|
|
465
|
+
await self._scheduler_ref.add_request(prompt, future, *args, **kwargs)
|
|
466
|
+
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
467
|
+
result = await fut
|
|
468
|
+
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
469
|
+
raise RuntimeError(
|
|
470
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
471
|
+
)
|
|
472
|
+
return await asyncio.to_thread(json_dumps, result)
|
|
435
473
|
|
|
436
474
|
@log_async(logger=logger)
|
|
437
475
|
@request_limit
|
|
@@ -441,34 +479,11 @@ class ModelActor(xo.StatelessActor):
|
|
|
441
479
|
response = None
|
|
442
480
|
try:
|
|
443
481
|
if self.allow_batching():
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
assert self._scheduler_ref is not None
|
|
448
|
-
queue: Queue[Any] = Queue()
|
|
449
|
-
ret = self._queue_consumer(queue)
|
|
450
|
-
await self._scheduler_ref.add_request(
|
|
451
|
-
prompt, queue, *args, **kwargs
|
|
452
|
-
)
|
|
453
|
-
gen = self._to_json_async_gen(ret)
|
|
454
|
-
self._current_generator = weakref.ref(gen)
|
|
455
|
-
return gen
|
|
456
|
-
else:
|
|
457
|
-
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
458
|
-
|
|
459
|
-
assert self._loop is not None
|
|
460
|
-
future = ConcurrentFuture()
|
|
461
|
-
await self._scheduler_ref.add_request(
|
|
462
|
-
prompt, future, *args, **kwargs
|
|
463
|
-
)
|
|
464
|
-
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
465
|
-
result = await fut
|
|
466
|
-
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
467
|
-
raise RuntimeError(
|
|
468
|
-
f"This request has been cancelled by another `abort_request` request."
|
|
469
|
-
)
|
|
470
|
-
return await asyncio.to_thread(json_dumps, result)
|
|
482
|
+
return await self.handle_batching_request(
|
|
483
|
+
prompt, "chat", *args, **kwargs
|
|
484
|
+
)
|
|
471
485
|
else:
|
|
486
|
+
kwargs.pop("raw_params", None)
|
|
472
487
|
if hasattr(self._model, "chat"):
|
|
473
488
|
response = await self._call_wrapper(
|
|
474
489
|
self._model.chat, prompt, *args, **kwargs
|
|
@@ -528,6 +543,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
528
543
|
top_n: Optional[int],
|
|
529
544
|
max_chunks_per_doc: Optional[int],
|
|
530
545
|
return_documents: Optional[bool],
|
|
546
|
+
return_len: Optional[bool],
|
|
531
547
|
*args,
|
|
532
548
|
**kwargs,
|
|
533
549
|
):
|
|
@@ -539,6 +555,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
539
555
|
top_n,
|
|
540
556
|
max_chunks_per_doc,
|
|
541
557
|
return_documents,
|
|
558
|
+
return_len,
|
|
542
559
|
*args,
|
|
543
560
|
**kwargs,
|
|
544
561
|
)
|