xinference 0.12.3__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +56 -8
- xinference/client/restful/restful_client.py +49 -4
- xinference/core/model.py +36 -4
- xinference/core/scheduler.py +2 -0
- xinference/core/supervisor.py +132 -15
- xinference/core/worker.py +239 -53
- xinference/deploy/cmdline.py +5 -0
- xinference/deploy/utils.py +33 -2
- xinference/model/audio/chattts.py +6 -6
- xinference/model/audio/core.py +23 -15
- xinference/model/core.py +12 -3
- xinference/model/embedding/core.py +25 -16
- xinference/model/flexible/__init__.py +40 -0
- xinference/model/flexible/core.py +228 -0
- xinference/model/flexible/launchers/__init__.py +15 -0
- xinference/model/flexible/launchers/transformers_launcher.py +63 -0
- xinference/model/flexible/utils.py +33 -0
- xinference/model/image/core.py +18 -14
- xinference/model/image/custom.py +1 -1
- xinference/model/llm/__init__.py +5 -2
- xinference/model/llm/core.py +3 -2
- xinference/model/llm/ggml/llamacpp.py +1 -10
- xinference/model/llm/llm_family.json +292 -36
- xinference/model/llm/llm_family.py +102 -53
- xinference/model/llm/llm_family_modelscope.json +247 -27
- xinference/model/llm/mlx/__init__.py +13 -0
- xinference/model/llm/mlx/core.py +408 -0
- xinference/model/llm/pytorch/chatglm.py +2 -9
- xinference/model/llm/pytorch/cogvlm2.py +206 -21
- xinference/model/llm/pytorch/core.py +213 -120
- xinference/model/llm/pytorch/glm4v.py +171 -15
- xinference/model/llm/pytorch/qwen_vl.py +168 -7
- xinference/model/llm/pytorch/utils.py +53 -62
- xinference/model/llm/utils.py +28 -7
- xinference/model/rerank/core.py +29 -25
- xinference/thirdparty/deepseek_vl/serve/__init__.py +13 -0
- xinference/thirdparty/deepseek_vl/serve/app_deepseek.py +510 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/__init__.py +13 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/gradio_utils.py +94 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/presets.py +96 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/utils.py +229 -0
- xinference/thirdparty/deepseek_vl/serve/inference.py +170 -0
- xinference/types.py +0 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.95c1d652.js +3 -0
- xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
- {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/METADATA +10 -11
- {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/RECORD +71 -69
- xinference/model/llm/ggml/chatglm.py +0 -457
- xinference/thirdparty/ChatTTS/__init__.py +0 -1
- xinference/thirdparty/ChatTTS/core.py +0 -200
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +0 -125
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
- xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
- xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
- xinference/web/ui/build/static/js/main.77dd47c3.js +0 -3
- xinference/web/ui/build/static/js/main.77dd47c3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0cd591866aa345566e0b63fb51ff2043e163a770af6fdc2f3bad395d046353e2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/37c1476717199863bbba1530e3513a9368f8f73001b75b4a85c2075956308027.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46edc1fe657dfedb2e673148332bb442c6eb98f09f2592c389209e376510afa5.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/82db357f3fd5b32215d747ee593f69ff06c95ad6cde37f71a96c8290aaab64c0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bc6da27195ec4607bb472bf61f97c928ad4966fa64e4c2247661bedb7400abba.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f118f99c22b713c678c1209c4e1dd43fe86e3f6e801a4c0c35d3bbf41fd05fe6.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +0 -1
- /xinference/web/ui/build/static/js/{main.77dd47c3.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/LICENSE +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/WHEEL +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-07-12T17:56:13+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "5e3f254d48383f37d849dd16db564ad9449e5163",
|
|
15
|
+
"version": "0.13.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -133,6 +133,7 @@ class SpeechRequest(BaseModel):
|
|
|
133
133
|
|
|
134
134
|
class RegisterModelRequest(BaseModel):
|
|
135
135
|
model: str
|
|
136
|
+
worker_ip: Optional[str]
|
|
136
137
|
persist: bool
|
|
137
138
|
|
|
138
139
|
|
|
@@ -501,6 +502,16 @@ class RESTfulAPI:
|
|
|
501
502
|
else None
|
|
502
503
|
),
|
|
503
504
|
)
|
|
505
|
+
self._router.add_api_route(
|
|
506
|
+
"/v1/flexible/infers",
|
|
507
|
+
self.create_flexible_infer,
|
|
508
|
+
methods=["POST"],
|
|
509
|
+
dependencies=(
|
|
510
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
511
|
+
if self.is_authenticated()
|
|
512
|
+
else None
|
|
513
|
+
),
|
|
514
|
+
)
|
|
504
515
|
|
|
505
516
|
# for custom models
|
|
506
517
|
self._router.add_api_route(
|
|
@@ -772,6 +783,7 @@ class RESTfulAPI:
|
|
|
772
783
|
peft_model_config = payload.get("peft_model_config", None)
|
|
773
784
|
worker_ip = payload.get("worker_ip", None)
|
|
774
785
|
gpu_idx = payload.get("gpu_idx", None)
|
|
786
|
+
download_hub = payload.get("download_hub", None)
|
|
775
787
|
|
|
776
788
|
exclude_keys = {
|
|
777
789
|
"model_uid",
|
|
@@ -787,6 +799,7 @@ class RESTfulAPI:
|
|
|
787
799
|
"peft_model_config",
|
|
788
800
|
"worker_ip",
|
|
789
801
|
"gpu_idx",
|
|
802
|
+
"download_hub",
|
|
790
803
|
}
|
|
791
804
|
|
|
792
805
|
kwargs = {
|
|
@@ -834,9 +847,9 @@ class RESTfulAPI:
|
|
|
834
847
|
peft_model_config=peft_model_config,
|
|
835
848
|
worker_ip=worker_ip,
|
|
836
849
|
gpu_idx=gpu_idx,
|
|
850
|
+
download_hub=download_hub,
|
|
837
851
|
**kwargs,
|
|
838
852
|
)
|
|
839
|
-
|
|
840
853
|
except ValueError as ve:
|
|
841
854
|
logger.error(str(ve), exc_info=True)
|
|
842
855
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
@@ -1397,6 +1410,40 @@ class RESTfulAPI:
|
|
|
1397
1410
|
await self._report_error_event(model_uid, str(e))
|
|
1398
1411
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1399
1412
|
|
|
1413
|
+
async def create_flexible_infer(self, request: Request) -> Response:
|
|
1414
|
+
payload = await request.json()
|
|
1415
|
+
|
|
1416
|
+
model_uid = payload.get("model")
|
|
1417
|
+
|
|
1418
|
+
exclude = {
|
|
1419
|
+
"model",
|
|
1420
|
+
}
|
|
1421
|
+
kwargs = {key: value for key, value in payload.items() if key not in exclude}
|
|
1422
|
+
|
|
1423
|
+
try:
|
|
1424
|
+
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1425
|
+
except ValueError as ve:
|
|
1426
|
+
logger.error(str(ve), exc_info=True)
|
|
1427
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1428
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1429
|
+
except Exception as e:
|
|
1430
|
+
logger.error(e, exc_info=True)
|
|
1431
|
+
await self._report_error_event(model_uid, str(e))
|
|
1432
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1433
|
+
|
|
1434
|
+
try:
|
|
1435
|
+
result = await model.infer(**kwargs)
|
|
1436
|
+
return Response(result, media_type="application/json")
|
|
1437
|
+
except RuntimeError as re:
|
|
1438
|
+
logger.error(re, exc_info=True)
|
|
1439
|
+
await self._report_error_event(model_uid, str(re))
|
|
1440
|
+
self.handle_request_limit_error(re)
|
|
1441
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1442
|
+
except Exception as e:
|
|
1443
|
+
logger.error(e, exc_info=True)
|
|
1444
|
+
await self._report_error_event(model_uid, str(e))
|
|
1445
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1446
|
+
|
|
1400
1447
|
async def create_chat_completion(self, request: Request) -> Response:
|
|
1401
1448
|
raw_body = await request.json()
|
|
1402
1449
|
body = CreateChatCompletion.parse_obj(raw_body)
|
|
@@ -1477,14 +1524,14 @@ class RESTfulAPI:
|
|
|
1477
1524
|
await self._report_error_event(model_uid, str(e))
|
|
1478
1525
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1479
1526
|
|
|
1480
|
-
from ..model.llm.utils import QWEN_TOOL_CALL_FAMILY
|
|
1527
|
+
from ..model.llm.utils import GLM4_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY
|
|
1481
1528
|
|
|
1482
1529
|
model_family = desc.get("model_family", "")
|
|
1483
|
-
function_call_models =
|
|
1484
|
-
"chatglm3",
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1530
|
+
function_call_models = (
|
|
1531
|
+
["chatglm3", "gorilla-openfunctions-v1"]
|
|
1532
|
+
+ QWEN_TOOL_CALL_FAMILY
|
|
1533
|
+
+ GLM4_TOOL_CALL_FAMILY
|
|
1534
|
+
)
|
|
1488
1535
|
|
|
1489
1536
|
is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
|
|
1490
1537
|
|
|
@@ -1593,11 +1640,12 @@ class RESTfulAPI:
|
|
|
1593
1640
|
async def register_model(self, model_type: str, request: Request) -> JSONResponse:
|
|
1594
1641
|
body = RegisterModelRequest.parse_obj(await request.json())
|
|
1595
1642
|
model = body.model
|
|
1643
|
+
worker_ip = body.worker_ip
|
|
1596
1644
|
persist = body.persist
|
|
1597
1645
|
|
|
1598
1646
|
try:
|
|
1599
1647
|
await (await self._get_supervisor_ref()).register_model(
|
|
1600
|
-
model_type, model, persist
|
|
1648
|
+
model_type, model, persist, worker_ip
|
|
1601
1649
|
)
|
|
1602
1650
|
except ValueError as re:
|
|
1603
1651
|
logger.error(re, exc_info=True)
|
|
@@ -182,8 +182,6 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
182
182
|
f"Failed to rerank documents, detail: {response.json()['detail']}"
|
|
183
183
|
)
|
|
184
184
|
response_data = response.json()
|
|
185
|
-
for r in response_data["results"]:
|
|
186
|
-
r["document"] = documents[r["index"]]
|
|
187
185
|
return response_data
|
|
188
186
|
|
|
189
187
|
|
|
@@ -732,6 +730,41 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
732
730
|
return response.content
|
|
733
731
|
|
|
734
732
|
|
|
733
|
+
class RESTfulFlexibleModelHandle(RESTfulModelHandle):
|
|
734
|
+
def infer(
|
|
735
|
+
self,
|
|
736
|
+
**kwargs,
|
|
737
|
+
):
|
|
738
|
+
"""
|
|
739
|
+
Call flexible model.
|
|
740
|
+
|
|
741
|
+
Parameters
|
|
742
|
+
----------
|
|
743
|
+
|
|
744
|
+
kwargs: dict
|
|
745
|
+
The inference arguments.
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
Returns
|
|
749
|
+
-------
|
|
750
|
+
bytes
|
|
751
|
+
The inference result.
|
|
752
|
+
"""
|
|
753
|
+
url = f"{self._base_url}/v1/flexible/infers"
|
|
754
|
+
params = {
|
|
755
|
+
"model": self._model_uid,
|
|
756
|
+
}
|
|
757
|
+
params.update(kwargs)
|
|
758
|
+
|
|
759
|
+
response = requests.post(url, json=params, headers=self.auth_headers)
|
|
760
|
+
if response.status_code != 200:
|
|
761
|
+
raise RuntimeError(
|
|
762
|
+
f"Failed to predict, detail: {_get_error_string(response)}"
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
return response.content
|
|
766
|
+
|
|
767
|
+
|
|
735
768
|
class Client:
|
|
736
769
|
def __init__(self, base_url, api_key: Optional[str] = None):
|
|
737
770
|
self.base_url = base_url
|
|
@@ -1011,6 +1044,10 @@ class Client:
|
|
|
1011
1044
|
return RESTfulAudioModelHandle(
|
|
1012
1045
|
model_uid, self.base_url, auth_headers=self._headers
|
|
1013
1046
|
)
|
|
1047
|
+
elif desc["model_type"] == "flexible":
|
|
1048
|
+
return RESTfulFlexibleModelHandle(
|
|
1049
|
+
model_uid, self.base_url, auth_headers=self._headers
|
|
1050
|
+
)
|
|
1014
1051
|
else:
|
|
1015
1052
|
raise ValueError(f"Unknown model type:{desc['model_type']}")
|
|
1016
1053
|
|
|
@@ -1064,7 +1101,13 @@ class Client:
|
|
|
1064
1101
|
)
|
|
1065
1102
|
return response.json()
|
|
1066
1103
|
|
|
1067
|
-
def register_model(
|
|
1104
|
+
def register_model(
|
|
1105
|
+
self,
|
|
1106
|
+
model_type: str,
|
|
1107
|
+
model: str,
|
|
1108
|
+
persist: bool,
|
|
1109
|
+
worker_ip: Optional[str] = None,
|
|
1110
|
+
):
|
|
1068
1111
|
"""
|
|
1069
1112
|
Register a custom model.
|
|
1070
1113
|
|
|
@@ -1074,6 +1117,8 @@ class Client:
|
|
|
1074
1117
|
The type of model.
|
|
1075
1118
|
model: str
|
|
1076
1119
|
The model definition. (refer to: https://inference.readthedocs.io/en/latest/models/custom.html)
|
|
1120
|
+
worker_ip: Optional[str]
|
|
1121
|
+
The IP address of the worker on which the model is running.
|
|
1077
1122
|
persist: bool
|
|
1078
1123
|
|
|
1079
1124
|
|
|
@@ -1083,7 +1128,7 @@ class Client:
|
|
|
1083
1128
|
Report failure to register the custom model. Provide details of failure through error message.
|
|
1084
1129
|
"""
|
|
1085
1130
|
url = f"{self.base_url}/v1/model_registrations/{model_type}"
|
|
1086
|
-
request_body = {"model": model, "persist": persist}
|
|
1131
|
+
request_body = {"model": model, "worker_ip": worker_ip, "persist": persist}
|
|
1087
1132
|
response = requests.post(url, json=request_body, headers=self._headers)
|
|
1088
1133
|
if response.status_code != 200:
|
|
1089
1134
|
raise RuntimeError(
|
xinference/core/model.py
CHANGED
|
@@ -65,6 +65,9 @@ except ImportError:
|
|
|
65
65
|
OutOfMemoryError = _OutOfMemoryError
|
|
66
66
|
|
|
67
67
|
|
|
68
|
+
XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = ["qwen-vl-chat", "cogvlm2", "glm-4v"]
|
|
69
|
+
|
|
70
|
+
|
|
68
71
|
def request_limit(fn):
|
|
69
72
|
"""
|
|
70
73
|
Used by ModelActor.
|
|
@@ -268,11 +271,25 @@ class ModelActor(xo.StatelessActor):
|
|
|
268
271
|
|
|
269
272
|
model_ability = self._model_description.get("model_ability", [])
|
|
270
273
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
and isinstance(self._model, PytorchModel)
|
|
274
|
-
and "vision" not in model_ability
|
|
274
|
+
condition = XINFERENCE_TRANSFORMERS_ENABLE_BATCHING and isinstance(
|
|
275
|
+
self._model, PytorchModel
|
|
275
276
|
)
|
|
277
|
+
if condition and "vision" in model_ability:
|
|
278
|
+
if (
|
|
279
|
+
self._model.model_family.model_name
|
|
280
|
+
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
|
|
281
|
+
or self._model.model_family.model_family
|
|
282
|
+
in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
|
|
283
|
+
):
|
|
284
|
+
return True
|
|
285
|
+
else:
|
|
286
|
+
logger.warning(
|
|
287
|
+
f"Currently for multimodal models, "
|
|
288
|
+
f"xinference only supports {', '.join(XINFERENCE_BATCHING_ALLOWED_VISION_MODELS)} for batching. "
|
|
289
|
+
f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
|
|
290
|
+
)
|
|
291
|
+
return False
|
|
292
|
+
return condition
|
|
276
293
|
|
|
277
294
|
async def load(self):
|
|
278
295
|
self._model.load()
|
|
@@ -680,6 +697,21 @@ class ModelActor(xo.StatelessActor):
|
|
|
680
697
|
f"Model {self._model.model_spec} is not for creating image."
|
|
681
698
|
)
|
|
682
699
|
|
|
700
|
+
@log_async(logger=logger)
|
|
701
|
+
@request_limit
|
|
702
|
+
async def infer(
|
|
703
|
+
self,
|
|
704
|
+
**kwargs,
|
|
705
|
+
):
|
|
706
|
+
if hasattr(self._model, "infer"):
|
|
707
|
+
return await self._call_wrapper(
|
|
708
|
+
self._model.infer,
|
|
709
|
+
**kwargs,
|
|
710
|
+
)
|
|
711
|
+
raise AttributeError(
|
|
712
|
+
f"Model {self._model.model_spec} is not for flexible infer."
|
|
713
|
+
)
|
|
714
|
+
|
|
683
715
|
async def record_metrics(self, name, op, kwargs):
|
|
684
716
|
worker_ref = await self._get_worker_ref()
|
|
685
717
|
await worker_ref.record_metrics(name, op, kwargs)
|
xinference/core/scheduler.py
CHANGED
|
@@ -82,6 +82,8 @@ class InferenceRequest:
|
|
|
82
82
|
# Record error message when this request has error.
|
|
83
83
|
# Must set stopped=True when this field is set.
|
|
84
84
|
self.error_msg: Optional[str] = None
|
|
85
|
+
# For compatibility. Record some extra parameters for some special cases.
|
|
86
|
+
self.extra_kwargs = {}
|
|
85
87
|
|
|
86
88
|
# check the integrity of args passed upstream
|
|
87
89
|
self._check_args()
|
xinference/core/supervisor.py
CHANGED
|
@@ -20,7 +20,17 @@ import time
|
|
|
20
20
|
import typing
|
|
21
21
|
from dataclasses import dataclass
|
|
22
22
|
from logging import getLogger
|
|
23
|
-
from typing import
|
|
23
|
+
from typing import (
|
|
24
|
+
TYPE_CHECKING,
|
|
25
|
+
Any,
|
|
26
|
+
Dict,
|
|
27
|
+
Iterator,
|
|
28
|
+
List,
|
|
29
|
+
Literal,
|
|
30
|
+
Optional,
|
|
31
|
+
Tuple,
|
|
32
|
+
Union,
|
|
33
|
+
)
|
|
24
34
|
|
|
25
35
|
import xoscar as xo
|
|
26
36
|
|
|
@@ -50,6 +60,7 @@ from .utils import (
|
|
|
50
60
|
if TYPE_CHECKING:
|
|
51
61
|
from ..model.audio import AudioModelFamilyV1
|
|
52
62
|
from ..model.embedding import EmbeddingModelSpec
|
|
63
|
+
from ..model.flexible import FlexibleModelSpec
|
|
53
64
|
from ..model.image import ImageModelFamilyV1
|
|
54
65
|
from ..model.llm import LLMFamilyV1
|
|
55
66
|
from ..model.rerank import RerankModelSpec
|
|
@@ -153,6 +164,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
153
164
|
register_embedding,
|
|
154
165
|
unregister_embedding,
|
|
155
166
|
)
|
|
167
|
+
from ..model.flexible import (
|
|
168
|
+
FlexibleModelSpec,
|
|
169
|
+
generate_flexible_model_description,
|
|
170
|
+
get_flexible_model_descriptions,
|
|
171
|
+
register_flexible_model,
|
|
172
|
+
unregister_flexible_model,
|
|
173
|
+
)
|
|
156
174
|
from ..model.image import (
|
|
157
175
|
CustomImageModelFamilyV1,
|
|
158
176
|
generate_image_description,
|
|
@@ -206,6 +224,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
206
224
|
unregister_audio,
|
|
207
225
|
generate_audio_description,
|
|
208
226
|
),
|
|
227
|
+
"flexible": (
|
|
228
|
+
FlexibleModelSpec,
|
|
229
|
+
register_flexible_model,
|
|
230
|
+
unregister_flexible_model,
|
|
231
|
+
generate_flexible_model_description,
|
|
232
|
+
),
|
|
209
233
|
}
|
|
210
234
|
|
|
211
235
|
# record model version
|
|
@@ -215,6 +239,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
215
239
|
model_version_infos.update(get_rerank_model_descriptions())
|
|
216
240
|
model_version_infos.update(get_image_model_descriptions())
|
|
217
241
|
model_version_infos.update(get_audio_model_descriptions())
|
|
242
|
+
model_version_infos.update(get_flexible_model_descriptions())
|
|
218
243
|
await self._cache_tracker_ref.record_model_version(
|
|
219
244
|
model_version_infos, self.address
|
|
220
245
|
)
|
|
@@ -459,6 +484,27 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
459
484
|
res["model_instance_count"] = instance_cnt
|
|
460
485
|
return res
|
|
461
486
|
|
|
487
|
+
async def _to_flexible_model_reg(
|
|
488
|
+
self, model_spec: "FlexibleModelSpec", is_builtin: bool
|
|
489
|
+
) -> Dict[str, Any]:
|
|
490
|
+
instance_cnt = await self.get_instance_count(model_spec.model_name)
|
|
491
|
+
version_cnt = await self.get_model_version_count(model_spec.model_name)
|
|
492
|
+
|
|
493
|
+
if self.is_local_deployment():
|
|
494
|
+
res = {
|
|
495
|
+
**model_spec.dict(),
|
|
496
|
+
"cache_status": True,
|
|
497
|
+
"is_builtin": is_builtin,
|
|
498
|
+
}
|
|
499
|
+
else:
|
|
500
|
+
res = {
|
|
501
|
+
**model_spec.dict(),
|
|
502
|
+
"is_builtin": is_builtin,
|
|
503
|
+
}
|
|
504
|
+
res["model_version_count"] = version_cnt
|
|
505
|
+
res["model_instance_count"] = instance_cnt
|
|
506
|
+
return res
|
|
507
|
+
|
|
462
508
|
@log_async(logger=logger)
|
|
463
509
|
async def list_model_registrations(
|
|
464
510
|
self, model_type: str, detailed: bool = False
|
|
@@ -467,10 +513,15 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
467
513
|
assert isinstance(item["model_name"], str)
|
|
468
514
|
return item.get("model_name").lower()
|
|
469
515
|
|
|
516
|
+
ret = []
|
|
517
|
+
if not self.is_local_deployment():
|
|
518
|
+
workers = list(self._worker_address_to_worker.values())
|
|
519
|
+
for worker in workers:
|
|
520
|
+
ret.extend(await worker.list_model_registrations(model_type, detailed))
|
|
521
|
+
|
|
470
522
|
if model_type == "LLM":
|
|
471
523
|
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
|
|
472
524
|
|
|
473
|
-
ret = []
|
|
474
525
|
for family in BUILTIN_LLM_FAMILIES:
|
|
475
526
|
if detailed:
|
|
476
527
|
ret.append(await self._to_llm_reg(family, True))
|
|
@@ -489,7 +540,6 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
489
540
|
from ..model.embedding import BUILTIN_EMBEDDING_MODELS
|
|
490
541
|
from ..model.embedding.custom import get_user_defined_embeddings
|
|
491
542
|
|
|
492
|
-
ret = []
|
|
493
543
|
for model_name, family in BUILTIN_EMBEDDING_MODELS.items():
|
|
494
544
|
if detailed:
|
|
495
545
|
ret.append(
|
|
@@ -514,7 +564,6 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
514
564
|
from ..model.image import BUILTIN_IMAGE_MODELS
|
|
515
565
|
from ..model.image.custom import get_user_defined_images
|
|
516
566
|
|
|
517
|
-
ret = []
|
|
518
567
|
for model_name, family in BUILTIN_IMAGE_MODELS.items():
|
|
519
568
|
if detailed:
|
|
520
569
|
ret.append(await self._to_image_model_reg(family, is_builtin=True))
|
|
@@ -537,7 +586,6 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
537
586
|
from ..model.audio import BUILTIN_AUDIO_MODELS
|
|
538
587
|
from ..model.audio.custom import get_user_defined_audios
|
|
539
588
|
|
|
540
|
-
ret = []
|
|
541
589
|
for model_name, family in BUILTIN_AUDIO_MODELS.items():
|
|
542
590
|
if detailed:
|
|
543
591
|
ret.append(await self._to_audio_model_reg(family, is_builtin=True))
|
|
@@ -560,7 +608,6 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
560
608
|
from ..model.rerank import BUILTIN_RERANK_MODELS
|
|
561
609
|
from ..model.rerank.custom import get_user_defined_reranks
|
|
562
610
|
|
|
563
|
-
ret = []
|
|
564
611
|
for model_name, family in BUILTIN_RERANK_MODELS.items():
|
|
565
612
|
if detailed:
|
|
566
613
|
ret.append(await self._to_rerank_model_reg(family, is_builtin=True))
|
|
@@ -577,13 +624,38 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
577
624
|
{"model_name": model_spec.model_name, "is_builtin": False}
|
|
578
625
|
)
|
|
579
626
|
|
|
627
|
+
ret.sort(key=sort_helper)
|
|
628
|
+
return ret
|
|
629
|
+
elif model_type == "flexible":
|
|
630
|
+
from ..model.flexible import get_flexible_models
|
|
631
|
+
|
|
632
|
+
ret = []
|
|
633
|
+
|
|
634
|
+
for model_spec in get_flexible_models():
|
|
635
|
+
if detailed:
|
|
636
|
+
ret.append(
|
|
637
|
+
await self._to_flexible_model_reg(model_spec, is_builtin=False)
|
|
638
|
+
)
|
|
639
|
+
else:
|
|
640
|
+
ret.append(
|
|
641
|
+
{"model_name": model_spec.model_name, "is_builtin": False}
|
|
642
|
+
)
|
|
643
|
+
|
|
580
644
|
ret.sort(key=sort_helper)
|
|
581
645
|
return ret
|
|
582
646
|
else:
|
|
583
647
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
584
648
|
|
|
585
649
|
@log_sync(logger=logger)
|
|
586
|
-
def get_model_registration(self, model_type: str, model_name: str) -> Any:
|
|
650
|
+
async def get_model_registration(self, model_type: str, model_name: str) -> Any:
|
|
651
|
+
# search in worker first
|
|
652
|
+
if not self.is_local_deployment():
|
|
653
|
+
workers = list(self._worker_address_to_worker.values())
|
|
654
|
+
for worker in workers:
|
|
655
|
+
f = await worker.get_model_registration(model_type, model_name)
|
|
656
|
+
if f is not None:
|
|
657
|
+
return f
|
|
658
|
+
|
|
587
659
|
if model_type == "LLM":
|
|
588
660
|
from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
|
|
589
661
|
|
|
@@ -626,6 +698,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
626
698
|
if f.model_name == model_name:
|
|
627
699
|
return f
|
|
628
700
|
raise ValueError(f"Model {model_name} not found")
|
|
701
|
+
elif model_type == "flexible":
|
|
702
|
+
from ..model.flexible import get_flexible_models
|
|
703
|
+
|
|
704
|
+
for f in get_flexible_models():
|
|
705
|
+
if f.model_name == model_name:
|
|
706
|
+
return f
|
|
707
|
+
raise ValueError(f"Model {model_name} not found")
|
|
629
708
|
else:
|
|
630
709
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
631
710
|
|
|
@@ -635,6 +714,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
635
714
|
|
|
636
715
|
from ..model.llm.llm_family import LLM_ENGINES
|
|
637
716
|
|
|
717
|
+
# search in worker first
|
|
718
|
+
workers = list(self._worker_address_to_worker.values())
|
|
719
|
+
for worker in workers:
|
|
720
|
+
res = await worker.query_engines_by_model_name(model_name)
|
|
721
|
+
if res is not None:
|
|
722
|
+
return res
|
|
723
|
+
|
|
638
724
|
if model_name not in LLM_ENGINES:
|
|
639
725
|
raise ValueError(f"Model {model_name} not found")
|
|
640
726
|
|
|
@@ -648,7 +734,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
648
734
|
return engine_params
|
|
649
735
|
|
|
650
736
|
@log_async(logger=logger)
|
|
651
|
-
async def register_model(
|
|
737
|
+
async def register_model(
|
|
738
|
+
self,
|
|
739
|
+
model_type: str,
|
|
740
|
+
model: str,
|
|
741
|
+
persist: bool,
|
|
742
|
+
worker_ip: Optional[str] = None,
|
|
743
|
+
):
|
|
652
744
|
if model_type in self._custom_register_type_to_cls:
|
|
653
745
|
(
|
|
654
746
|
model_spec_cls,
|
|
@@ -657,10 +749,21 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
657
749
|
generate_fn,
|
|
658
750
|
) = self._custom_register_type_to_cls[model_type]
|
|
659
751
|
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
752
|
+
target_ip_worker_ref = (
|
|
753
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
754
|
+
)
|
|
755
|
+
if (
|
|
756
|
+
worker_ip is not None
|
|
757
|
+
and not self.is_local_deployment()
|
|
758
|
+
and target_ip_worker_ref is None
|
|
759
|
+
):
|
|
760
|
+
raise ValueError(
|
|
761
|
+
f"Worker ip address {worker_ip} is not in the cluster."
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
if target_ip_worker_ref:
|
|
765
|
+
await target_ip_worker_ref.register_model(model_type, model, persist)
|
|
766
|
+
return
|
|
664
767
|
|
|
665
768
|
model_spec = model_spec_cls.parse_raw(model)
|
|
666
769
|
try:
|
|
@@ -668,6 +771,8 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
668
771
|
await self._cache_tracker_ref.record_model_version(
|
|
669
772
|
generate_fn(model_spec), self.address
|
|
670
773
|
)
|
|
774
|
+
except ValueError as e:
|
|
775
|
+
raise e
|
|
671
776
|
except Exception as e:
|
|
672
777
|
unregister_fn(model_spec.model_name, raise_error=False)
|
|
673
778
|
raise e
|
|
@@ -678,13 +783,14 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
678
783
|
async def unregister_model(self, model_type: str, model_name: str):
|
|
679
784
|
if model_type in self._custom_register_type_to_cls:
|
|
680
785
|
_, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type]
|
|
681
|
-
unregister_fn(model_name)
|
|
682
|
-
await self._cache_tracker_ref.unregister_model_version(model_name)
|
|
786
|
+
unregister_fn(model_name, False)
|
|
683
787
|
|
|
684
788
|
if not self.is_local_deployment():
|
|
685
789
|
workers = list(self._worker_address_to_worker.values())
|
|
686
790
|
for worker in workers:
|
|
687
|
-
await worker.unregister_model(model_name)
|
|
791
|
+
await worker.unregister_model(model_type, model_name)
|
|
792
|
+
|
|
793
|
+
await self._cache_tracker_ref.unregister_model_version(model_name)
|
|
688
794
|
else:
|
|
689
795
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
690
796
|
|
|
@@ -752,8 +858,17 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
752
858
|
peft_model_config: Optional[PeftModelConfig] = None,
|
|
753
859
|
worker_ip: Optional[str] = None,
|
|
754
860
|
gpu_idx: Optional[Union[int, List[int]]] = None,
|
|
861
|
+
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
|
|
755
862
|
**kwargs,
|
|
756
863
|
) -> str:
|
|
864
|
+
# search in worker first
|
|
865
|
+
if not self.is_local_deployment():
|
|
866
|
+
workers = list(self._worker_address_to_worker.values())
|
|
867
|
+
for worker in workers:
|
|
868
|
+
res = await worker.get_model_registration(model_type, model_name)
|
|
869
|
+
if res is not None:
|
|
870
|
+
worker_ip = worker.address.split(":")[0]
|
|
871
|
+
|
|
757
872
|
target_ip_worker_ref = (
|
|
758
873
|
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
759
874
|
)
|
|
@@ -806,6 +921,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
806
921
|
)
|
|
807
922
|
replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
|
|
808
923
|
nonlocal model_type
|
|
924
|
+
|
|
809
925
|
worker_ref = (
|
|
810
926
|
target_ip_worker_ref
|
|
811
927
|
if target_ip_worker_ref is not None
|
|
@@ -825,6 +941,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
825
941
|
request_limits=request_limits,
|
|
826
942
|
peft_model_config=peft_model_config,
|
|
827
943
|
gpu_idx=replica_gpu_idx,
|
|
944
|
+
download_hub=download_hub,
|
|
828
945
|
**kwargs,
|
|
829
946
|
)
|
|
830
947
|
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
|