xinference 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +74 -6
- xinference/client/restful/restful_client.py +74 -5
- xinference/constants.py +1 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +54 -42
- xinference/core/scheduler.py +34 -16
- xinference/core/supervisor.py +73 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/core.py +12 -1
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +34 -2
- xinference/model/llm/llm_family.json +2 -0
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +2 -0
- xinference/model/llm/pytorch/chatglm.py +18 -12
- xinference/model/llm/pytorch/core.py +92 -42
- xinference/model/llm/pytorch/glm4v.py +13 -3
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +27 -14
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +10 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/METADATA +1 -1
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/RECORD +57 -45
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-06-
|
|
11
|
+
"date": "2024-06-14T17:17:50+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.12.
|
|
14
|
+
"full-revisionid": "34a57df449f0890415c424802d3596f3c8758412",
|
|
15
|
+
"version": "0.12.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -522,11 +522,31 @@ class RESTfulAPI:
|
|
|
522
522
|
),
|
|
523
523
|
)
|
|
524
524
|
self._router.add_api_route(
|
|
525
|
-
"/v1/
|
|
525
|
+
"/v1/cache/models",
|
|
526
526
|
self.list_cached_models,
|
|
527
527
|
methods=["GET"],
|
|
528
528
|
dependencies=(
|
|
529
|
-
[Security(self._auth_service, scopes=["
|
|
529
|
+
[Security(self._auth_service, scopes=["cache:list"])]
|
|
530
|
+
if self.is_authenticated()
|
|
531
|
+
else None
|
|
532
|
+
),
|
|
533
|
+
)
|
|
534
|
+
self._router.add_api_route(
|
|
535
|
+
"/v1/cache/models/files",
|
|
536
|
+
self.list_model_files,
|
|
537
|
+
methods=["GET"],
|
|
538
|
+
dependencies=(
|
|
539
|
+
[Security(self._auth_service, scopes=["cache:list"])]
|
|
540
|
+
if self.is_authenticated()
|
|
541
|
+
else None
|
|
542
|
+
),
|
|
543
|
+
)
|
|
544
|
+
self._router.add_api_route(
|
|
545
|
+
"/v1/cache/models",
|
|
546
|
+
self.confirm_and_remove_model,
|
|
547
|
+
methods=["DELETE"],
|
|
548
|
+
dependencies=(
|
|
549
|
+
[Security(self._auth_service, scopes=["cache:delete"])]
|
|
530
550
|
if self.is_authenticated()
|
|
531
551
|
else None
|
|
532
552
|
),
|
|
@@ -1401,9 +1421,11 @@ class RESTfulAPI:
|
|
|
1401
1421
|
model_family = desc.get("model_family", "")
|
|
1402
1422
|
function_call_models = [
|
|
1403
1423
|
"chatglm3",
|
|
1424
|
+
"glm4-chat",
|
|
1404
1425
|
"gorilla-openfunctions-v1",
|
|
1405
1426
|
"qwen-chat",
|
|
1406
1427
|
"qwen1.5-chat",
|
|
1428
|
+
"qwen2-instruct",
|
|
1407
1429
|
]
|
|
1408
1430
|
|
|
1409
1431
|
is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
|
|
@@ -1426,7 +1448,11 @@ class RESTfulAPI:
|
|
|
1426
1448
|
)
|
|
1427
1449
|
if body.tools and body.stream:
|
|
1428
1450
|
is_vllm = await model.is_vllm_backend()
|
|
1429
|
-
if not is_vllm or model_family not in [
|
|
1451
|
+
if not is_vllm or model_family not in [
|
|
1452
|
+
"qwen-chat",
|
|
1453
|
+
"qwen1.5-chat",
|
|
1454
|
+
"qwen2-instruct",
|
|
1455
|
+
]:
|
|
1430
1456
|
raise HTTPException(
|
|
1431
1457
|
status_code=400,
|
|
1432
1458
|
detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
|
|
@@ -1555,10 +1581,17 @@ class RESTfulAPI:
|
|
|
1555
1581
|
logger.error(e, exc_info=True)
|
|
1556
1582
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1557
1583
|
|
|
1558
|
-
async def list_cached_models(
|
|
1584
|
+
async def list_cached_models(
|
|
1585
|
+
self, model_name: str = Query(None), worker_ip: str = Query(None)
|
|
1586
|
+
) -> JSONResponse:
|
|
1559
1587
|
try:
|
|
1560
|
-
data = await (await self._get_supervisor_ref()).list_cached_models(
|
|
1561
|
-
|
|
1588
|
+
data = await (await self._get_supervisor_ref()).list_cached_models(
|
|
1589
|
+
model_name, worker_ip
|
|
1590
|
+
)
|
|
1591
|
+
resp = {
|
|
1592
|
+
"list": data,
|
|
1593
|
+
}
|
|
1594
|
+
return JSONResponse(content=resp)
|
|
1562
1595
|
except ValueError as re:
|
|
1563
1596
|
logger.error(re, exc_info=True)
|
|
1564
1597
|
raise HTTPException(status_code=400, detail=str(re))
|
|
@@ -1623,6 +1656,41 @@ class RESTfulAPI:
|
|
|
1623
1656
|
logger.error(e, exc_info=True)
|
|
1624
1657
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1625
1658
|
|
|
1659
|
+
async def list_model_files(
|
|
1660
|
+
self, model_version: str = Query(None), worker_ip: str = Query(None)
|
|
1661
|
+
) -> JSONResponse:
|
|
1662
|
+
try:
|
|
1663
|
+
data = await (await self._get_supervisor_ref()).list_deletable_models(
|
|
1664
|
+
model_version, worker_ip
|
|
1665
|
+
)
|
|
1666
|
+
response = {
|
|
1667
|
+
"model_version": model_version,
|
|
1668
|
+
"worker_ip": worker_ip,
|
|
1669
|
+
"paths": data,
|
|
1670
|
+
}
|
|
1671
|
+
return JSONResponse(content=response)
|
|
1672
|
+
except ValueError as re:
|
|
1673
|
+
logger.error(re, exc_info=True)
|
|
1674
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1675
|
+
except Exception as e:
|
|
1676
|
+
logger.error(e, exc_info=True)
|
|
1677
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1678
|
+
|
|
1679
|
+
async def confirm_and_remove_model(
|
|
1680
|
+
self, model_version: str = Query(None), worker_ip: str = Query(None)
|
|
1681
|
+
) -> JSONResponse:
|
|
1682
|
+
try:
|
|
1683
|
+
res = await (await self._get_supervisor_ref()).confirm_and_remove_model(
|
|
1684
|
+
model_version=model_version, worker_ip=worker_ip
|
|
1685
|
+
)
|
|
1686
|
+
return JSONResponse(content={"result": res})
|
|
1687
|
+
except ValueError as re:
|
|
1688
|
+
logger.error(re, exc_info=True)
|
|
1689
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1690
|
+
except Exception as e:
|
|
1691
|
+
logger.error(e, exc_info=True)
|
|
1692
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1693
|
+
|
|
1626
1694
|
|
|
1627
1695
|
def run(
|
|
1628
1696
|
supervisor_address: str,
|
|
@@ -1145,13 +1145,17 @@ class Client:
|
|
|
1145
1145
|
response_data = response.json()
|
|
1146
1146
|
return response_data
|
|
1147
1147
|
|
|
1148
|
-
def list_cached_models(
|
|
1148
|
+
def list_cached_models(
|
|
1149
|
+
self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
|
|
1150
|
+
) -> List[Dict[Any, Any]]:
|
|
1149
1151
|
"""
|
|
1150
1152
|
Get a list of cached models.
|
|
1151
|
-
|
|
1152
1153
|
Parameters
|
|
1153
1154
|
----------
|
|
1154
|
-
|
|
1155
|
+
model_name: Optional[str]
|
|
1156
|
+
The name of model.
|
|
1157
|
+
worker_ip: Optional[str]
|
|
1158
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1155
1159
|
|
|
1156
1160
|
Returns
|
|
1157
1161
|
-------
|
|
@@ -1164,16 +1168,81 @@ class Client:
|
|
|
1164
1168
|
Raised when the request fails, including the reason for the failure.
|
|
1165
1169
|
"""
|
|
1166
1170
|
|
|
1167
|
-
url = f"{self.base_url}/v1/
|
|
1168
|
-
|
|
1171
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1172
|
+
params = {
|
|
1173
|
+
"model_name": model_name,
|
|
1174
|
+
"worker_ip": worker_ip,
|
|
1175
|
+
}
|
|
1176
|
+
response = requests.get(url, headers=self._headers, params=params)
|
|
1169
1177
|
if response.status_code != 200:
|
|
1170
1178
|
raise RuntimeError(
|
|
1171
1179
|
f"Failed to list cached model, detail: {_get_error_string(response)}"
|
|
1172
1180
|
)
|
|
1173
1181
|
|
|
1182
|
+
response_data = response.json()
|
|
1183
|
+
response_data = response_data.get("list")
|
|
1184
|
+
return response_data
|
|
1185
|
+
|
|
1186
|
+
def list_deletable_models(
|
|
1187
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1188
|
+
) -> Dict[str, Any]:
|
|
1189
|
+
"""
|
|
1190
|
+
Get the cached models with the model path cached on the server.
|
|
1191
|
+
Parameters
|
|
1192
|
+
----------
|
|
1193
|
+
model_version: str
|
|
1194
|
+
The version of the model.
|
|
1195
|
+
worker_ip: Optional[str]
|
|
1196
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1197
|
+
Returns
|
|
1198
|
+
-------
|
|
1199
|
+
Dict[str, Dict[str,str]]]
|
|
1200
|
+
Dictionary with keys "model_name" and values model_file_location.
|
|
1201
|
+
"""
|
|
1202
|
+
url = f"{self.base_url}/v1/cache/models/files"
|
|
1203
|
+
params = {
|
|
1204
|
+
"model_version": model_version,
|
|
1205
|
+
"worker_ip": worker_ip,
|
|
1206
|
+
}
|
|
1207
|
+
response = requests.get(url, headers=self._headers, params=params)
|
|
1208
|
+
if response.status_code != 200:
|
|
1209
|
+
raise RuntimeError(
|
|
1210
|
+
f"Failed to get paths by model name, detail: {_get_error_string(response)}"
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1174
1213
|
response_data = response.json()
|
|
1175
1214
|
return response_data
|
|
1176
1215
|
|
|
1216
|
+
def confirm_and_remove_model(
|
|
1217
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1218
|
+
) -> bool:
|
|
1219
|
+
"""
|
|
1220
|
+
Remove the cached models with the model name cached on the server.
|
|
1221
|
+
Parameters
|
|
1222
|
+
----------
|
|
1223
|
+
model_version: str
|
|
1224
|
+
The version of the model.
|
|
1225
|
+
worker_ip: Optional[str]
|
|
1226
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
1227
|
+
Returns
|
|
1228
|
+
-------
|
|
1229
|
+
str
|
|
1230
|
+
The response of the server.
|
|
1231
|
+
"""
|
|
1232
|
+
url = f"{self.base_url}/v1/cache/models"
|
|
1233
|
+
params = {
|
|
1234
|
+
"model_version": model_version,
|
|
1235
|
+
"worker_ip": worker_ip,
|
|
1236
|
+
}
|
|
1237
|
+
response = requests.delete(url, headers=self._headers, params=params)
|
|
1238
|
+
if response.status_code != 200:
|
|
1239
|
+
raise RuntimeError(
|
|
1240
|
+
f"Failed to remove cached models, detail: {_get_error_string(response)}"
|
|
1241
|
+
)
|
|
1242
|
+
|
|
1243
|
+
response_data = response.json()
|
|
1244
|
+
return response_data.get("result", False)
|
|
1245
|
+
|
|
1177
1246
|
def get_model_registration(
|
|
1178
1247
|
self, model_type: str, model_name: str
|
|
1179
1248
|
) -> Dict[str, Any]:
|
xinference/constants.py
CHANGED
|
@@ -17,6 +17,7 @@ from pathlib import Path
|
|
|
17
17
|
|
|
18
18
|
XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
|
|
19
19
|
XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
|
|
20
|
+
XINFERENCE_ENV_CSG_TOKEN = "XINFERENCE_CSG_TOKEN"
|
|
20
21
|
XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
|
|
21
22
|
XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
|
|
22
23
|
"XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
|
xinference/core/cache_tracker.py
CHANGED
|
@@ -11,7 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import os
|
|
15
14
|
from logging import getLogger
|
|
16
15
|
from typing import Any, Dict, List, Optional
|
|
17
16
|
|
|
@@ -102,33 +101,54 @@ class CacheTrackerActor(xo.Actor):
|
|
|
102
101
|
def get_model_version_count(self, model_name: str) -> int:
|
|
103
102
|
return len(self.get_model_versions(model_name))
|
|
104
103
|
|
|
105
|
-
def list_cached_models(
|
|
104
|
+
def list_cached_models(
|
|
105
|
+
self, worker_ip: str, model_name: Optional[str] = None
|
|
106
|
+
) -> List[Dict[Any, Any]]:
|
|
106
107
|
cached_models = []
|
|
107
|
-
for
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
108
|
+
for name, versions in self._model_name_to_version_info.items():
|
|
109
|
+
# only return assigned cached model if model_name is not none
|
|
110
|
+
# else return all cached model
|
|
111
|
+
if model_name and model_name != name:
|
|
112
|
+
continue
|
|
113
|
+
for version_info in versions:
|
|
114
|
+
cache_status = version_info.get("cache_status", False)
|
|
115
|
+
# search cached model
|
|
116
|
+
if cache_status:
|
|
117
|
+
res = version_info.copy()
|
|
118
|
+
res["model_name"] = name
|
|
119
|
+
paths = res.get("model_file_location", {})
|
|
120
|
+
# only return assigned worker's device path
|
|
121
|
+
if worker_ip in paths.keys():
|
|
122
|
+
res["model_file_location"] = paths[worker_ip]
|
|
123
|
+
cached_models.append(res)
|
|
124
|
+
return cached_models
|
|
113
125
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
126
|
+
def list_deletable_models(self, model_version: str, worker_ip: str) -> str:
|
|
127
|
+
model_file_location = ""
|
|
128
|
+
for model, model_versions in self._model_name_to_version_info.items():
|
|
129
|
+
for version_info in model_versions:
|
|
130
|
+
# search assign model version
|
|
131
|
+
if model_version == version_info.get("model_version", None):
|
|
132
|
+
# check if exist
|
|
133
|
+
if version_info.get("cache_status", False):
|
|
134
|
+
paths = version_info.get("model_file_location", {})
|
|
135
|
+
# only return assigned worker's device path
|
|
136
|
+
if worker_ip in paths.keys():
|
|
137
|
+
model_file_location = paths[worker_ip]
|
|
138
|
+
return model_file_location
|
|
122
139
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
140
|
+
def confirm_and_remove_model(self, model_version: str, worker_ip: str):
|
|
141
|
+
# find remove path
|
|
142
|
+
rm_path = self.list_deletable_models(model_version, worker_ip)
|
|
143
|
+
# search _model_name_to_version_info if exist this path, and delete
|
|
144
|
+
for model, model_versions in self._model_name_to_version_info.items():
|
|
145
|
+
for version_info in model_versions:
|
|
146
|
+
# check if exist
|
|
147
|
+
if version_info.get("cache_status", False):
|
|
148
|
+
paths = version_info.get("model_file_location", {})
|
|
149
|
+
# only delete assigned worker's device path
|
|
150
|
+
if worker_ip in paths.keys() and rm_path == paths[worker_ip]:
|
|
151
|
+
del paths[worker_ip]
|
|
152
|
+
# if path is empty, update cache status
|
|
153
|
+
if not paths:
|
|
154
|
+
version_info["cache_status"] = False
|
xinference/core/model.py
CHANGED
|
@@ -264,12 +264,13 @@ class ModelActor(xo.StatelessActor):
|
|
|
264
264
|
return isinstance(self._model, VLLMModel)
|
|
265
265
|
|
|
266
266
|
def allow_batching(self) -> bool:
|
|
267
|
-
from ..model.llm.pytorch.core import PytorchChatModel
|
|
267
|
+
from ..model.llm.pytorch.core import PytorchChatModel, PytorchModel
|
|
268
268
|
|
|
269
269
|
return (
|
|
270
270
|
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
271
|
-
and isinstance(self._model,
|
|
272
|
-
and self._model.__class__.__name__
|
|
271
|
+
and isinstance(self._model, PytorchModel)
|
|
272
|
+
and self._model.__class__.__name__
|
|
273
|
+
in (PytorchChatModel.__name__, PytorchModel.__name__)
|
|
273
274
|
)
|
|
274
275
|
|
|
275
276
|
async def load(self):
|
|
@@ -393,18 +394,24 @@ class ModelActor(xo.StatelessActor):
|
|
|
393
394
|
@request_limit
|
|
394
395
|
@xo.generator
|
|
395
396
|
async def generate(self, prompt: str, *args, **kwargs):
|
|
396
|
-
if
|
|
397
|
-
return await self.
|
|
398
|
-
|
|
399
|
-
)
|
|
400
|
-
if hasattr(self._model, "async_generate"):
|
|
401
|
-
return await self._call_wrapper(
|
|
402
|
-
self._model.async_generate, prompt, *args, **kwargs
|
|
397
|
+
if self.allow_batching():
|
|
398
|
+
return await self.handle_batching_request(
|
|
399
|
+
prompt, "generate", *args, **kwargs
|
|
403
400
|
)
|
|
404
|
-
|
|
401
|
+
else:
|
|
402
|
+
if hasattr(self._model, "generate"):
|
|
403
|
+
return await self._call_wrapper(
|
|
404
|
+
self._model.generate, prompt, *args, **kwargs
|
|
405
|
+
)
|
|
406
|
+
if hasattr(self._model, "async_generate"):
|
|
407
|
+
return await self._call_wrapper(
|
|
408
|
+
self._model.async_generate, prompt, *args, **kwargs
|
|
409
|
+
)
|
|
410
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
|
|
405
411
|
|
|
412
|
+
@staticmethod
|
|
406
413
|
async def _queue_consumer(
|
|
407
|
-
|
|
414
|
+
queue: Queue, timeout: Optional[float] = None
|
|
408
415
|
) -> AsyncIterator[Any]:
|
|
409
416
|
from .scheduler import (
|
|
410
417
|
XINFERENCE_STREAMING_ABORT_FLAG,
|
|
@@ -429,9 +436,38 @@ class ModelActor(xo.StatelessActor):
|
|
|
429
436
|
yield res
|
|
430
437
|
|
|
431
438
|
@staticmethod
|
|
432
|
-
def
|
|
433
|
-
|
|
434
|
-
|
|
439
|
+
def _get_stream_from_args(ability: str, *args) -> bool:
|
|
440
|
+
if ability == "chat":
|
|
441
|
+
assert args[2] is None or isinstance(args[2], dict)
|
|
442
|
+
return False if args[2] is None else args[2].get("stream", False)
|
|
443
|
+
else:
|
|
444
|
+
assert args[0] is None or isinstance(args[0], dict)
|
|
445
|
+
return False if args[0] is None else args[0].get("stream", False)
|
|
446
|
+
|
|
447
|
+
async def handle_batching_request(self, prompt: str, ability: str, *args, **kwargs):
|
|
448
|
+
stream = self._get_stream_from_args(ability, *args)
|
|
449
|
+
assert self._scheduler_ref is not None
|
|
450
|
+
if stream:
|
|
451
|
+
assert self._scheduler_ref is not None
|
|
452
|
+
queue: Queue[Any] = Queue()
|
|
453
|
+
ret = self._queue_consumer(queue)
|
|
454
|
+
await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
|
|
455
|
+
gen = self._to_json_async_gen(ret)
|
|
456
|
+
self._current_generator = weakref.ref(gen)
|
|
457
|
+
return gen
|
|
458
|
+
else:
|
|
459
|
+
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
460
|
+
|
|
461
|
+
assert self._loop is not None
|
|
462
|
+
future = ConcurrentFuture()
|
|
463
|
+
await self._scheduler_ref.add_request(prompt, future, *args, **kwargs)
|
|
464
|
+
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
465
|
+
result = await fut
|
|
466
|
+
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
467
|
+
raise RuntimeError(
|
|
468
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
469
|
+
)
|
|
470
|
+
return await asyncio.to_thread(json_dumps, result)
|
|
435
471
|
|
|
436
472
|
@log_async(logger=logger)
|
|
437
473
|
@request_limit
|
|
@@ -441,33 +477,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
441
477
|
response = None
|
|
442
478
|
try:
|
|
443
479
|
if self.allow_batching():
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
assert self._scheduler_ref is not None
|
|
448
|
-
queue: Queue[Any] = Queue()
|
|
449
|
-
ret = self._queue_consumer(queue)
|
|
450
|
-
await self._scheduler_ref.add_request(
|
|
451
|
-
prompt, queue, *args, **kwargs
|
|
452
|
-
)
|
|
453
|
-
gen = self._to_json_async_gen(ret)
|
|
454
|
-
self._current_generator = weakref.ref(gen)
|
|
455
|
-
return gen
|
|
456
|
-
else:
|
|
457
|
-
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
458
|
-
|
|
459
|
-
assert self._loop is not None
|
|
460
|
-
future = ConcurrentFuture()
|
|
461
|
-
await self._scheduler_ref.add_request(
|
|
462
|
-
prompt, future, *args, **kwargs
|
|
463
|
-
)
|
|
464
|
-
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
465
|
-
result = await fut
|
|
466
|
-
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
467
|
-
raise RuntimeError(
|
|
468
|
-
f"This request has been cancelled by another `abort_request` request."
|
|
469
|
-
)
|
|
470
|
-
return await asyncio.to_thread(json_dumps, result)
|
|
480
|
+
return await self.handle_batching_request(
|
|
481
|
+
prompt, "chat", *args, **kwargs
|
|
482
|
+
)
|
|
471
483
|
else:
|
|
472
484
|
if hasattr(self._model, "chat"):
|
|
473
485
|
response = await self._call_wrapper(
|
xinference/core/scheduler.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
import asyncio
|
|
16
16
|
import functools
|
|
17
17
|
import logging
|
|
18
|
+
import uuid
|
|
18
19
|
from collections import deque
|
|
19
20
|
from enum import Enum
|
|
20
21
|
from typing import List, Optional, Set
|
|
@@ -50,9 +51,9 @@ class InferenceRequest:
|
|
|
50
51
|
self._new_tokens = []
|
|
51
52
|
# kv_cache used in decode phase
|
|
52
53
|
self._kv_cache = None
|
|
53
|
-
# use passed args from
|
|
54
|
+
# use passed args from upstream interface
|
|
54
55
|
self._inference_args = args
|
|
55
|
-
# use passed kwargs from
|
|
56
|
+
# use passed kwargs from upstream interface, basically not used for now
|
|
56
57
|
self._inference_kwargs = kwargs
|
|
57
58
|
# should this request be stopped
|
|
58
59
|
self._stopped = False
|
|
@@ -63,6 +64,8 @@ class InferenceRequest:
|
|
|
63
64
|
self._aborted = False
|
|
64
65
|
# sanitized generate config
|
|
65
66
|
self._sanitized_generate_config = None
|
|
67
|
+
# Chunk id for results. In stream mode, all the chunk ids should be same.
|
|
68
|
+
self._stream_chunk_id = str(uuid.uuid4())
|
|
66
69
|
# Use in stream mode
|
|
67
70
|
self.last_output_length = 0
|
|
68
71
|
# inference results,
|
|
@@ -81,19 +84,26 @@ class InferenceRequest:
|
|
|
81
84
|
self._check_args()
|
|
82
85
|
|
|
83
86
|
def _check_args(self):
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
self._inference_args[0]
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
self._inference_args[1]
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
self._inference_args[2]
|
|
96
|
-
|
|
87
|
+
# chat
|
|
88
|
+
if len(self._inference_args) == 3:
|
|
89
|
+
# system prompt
|
|
90
|
+
assert self._inference_args[0] is None or isinstance(
|
|
91
|
+
self._inference_args[0], str
|
|
92
|
+
)
|
|
93
|
+
# chat history
|
|
94
|
+
assert self._inference_args[1] is None or isinstance(
|
|
95
|
+
self._inference_args[1], list
|
|
96
|
+
)
|
|
97
|
+
# generate config
|
|
98
|
+
assert self._inference_args[2] is None or isinstance(
|
|
99
|
+
self._inference_args[2], dict
|
|
100
|
+
)
|
|
101
|
+
else: # generate
|
|
102
|
+
assert len(self._inference_args) == 1
|
|
103
|
+
# generate config
|
|
104
|
+
assert self._inference_args[0] is None or isinstance(
|
|
105
|
+
self._inference_args[0], dict
|
|
106
|
+
)
|
|
97
107
|
|
|
98
108
|
@property
|
|
99
109
|
def prompt(self):
|
|
@@ -148,7 +158,11 @@ class InferenceRequest:
|
|
|
148
158
|
|
|
149
159
|
@property
|
|
150
160
|
def generate_config(self):
|
|
151
|
-
return
|
|
161
|
+
return (
|
|
162
|
+
self._inference_args[2]
|
|
163
|
+
if len(self._inference_args) == 3
|
|
164
|
+
else self._inference_args[0]
|
|
165
|
+
)
|
|
152
166
|
|
|
153
167
|
@property
|
|
154
168
|
def sanitized_generate_config(self):
|
|
@@ -174,6 +188,10 @@ class InferenceRequest:
|
|
|
174
188
|
def finish_reason(self, value: Optional[str]):
|
|
175
189
|
self._finish_reason = value
|
|
176
190
|
|
|
191
|
+
@property
|
|
192
|
+
def chunk_id(self):
|
|
193
|
+
return self._stream_chunk_id
|
|
194
|
+
|
|
177
195
|
@property
|
|
178
196
|
def stream(self) -> bool:
|
|
179
197
|
return (
|
xinference/core/supervisor.py
CHANGED
|
@@ -982,32 +982,31 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
982
982
|
)
|
|
983
983
|
|
|
984
984
|
@log_async(logger=logger)
|
|
985
|
-
async def list_cached_models(
|
|
985
|
+
async def list_cached_models(
|
|
986
|
+
self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
|
|
987
|
+
) -> List[Dict[str, Any]]:
|
|
988
|
+
target_ip_worker_ref = (
|
|
989
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
990
|
+
)
|
|
991
|
+
if (
|
|
992
|
+
worker_ip is not None
|
|
993
|
+
and not self.is_local_deployment()
|
|
994
|
+
and target_ip_worker_ref is None
|
|
995
|
+
):
|
|
996
|
+
raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
|
|
997
|
+
|
|
998
|
+
# search assigned worker and return
|
|
999
|
+
if target_ip_worker_ref:
|
|
1000
|
+
cached_models = await target_ip_worker_ref.list_cached_models(model_name)
|
|
1001
|
+
cached_models = sorted(cached_models, key=lambda x: x["model_name"])
|
|
1002
|
+
return cached_models
|
|
1003
|
+
|
|
1004
|
+
# search all worker
|
|
986
1005
|
cached_models = []
|
|
987
1006
|
for worker in self._worker_address_to_worker.values():
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
model_format = model_version.get("model_format", None)
|
|
992
|
-
model_size_in_billions = model_version.get(
|
|
993
|
-
"model_size_in_billions", None
|
|
994
|
-
)
|
|
995
|
-
quantizations = model_version.get("quantization", None)
|
|
996
|
-
actor_ip_address = model_version.get("actor_ip_address", None)
|
|
997
|
-
path = model_version.get("path", None)
|
|
998
|
-
real_path = model_version.get("real_path", None)
|
|
999
|
-
|
|
1000
|
-
cache_entry = {
|
|
1001
|
-
"model_name": model_name,
|
|
1002
|
-
"model_format": model_format,
|
|
1003
|
-
"model_size_in_billions": model_size_in_billions,
|
|
1004
|
-
"quantizations": quantizations,
|
|
1005
|
-
"path": path,
|
|
1006
|
-
"Actor IP Address": actor_ip_address,
|
|
1007
|
-
"real_path": real_path,
|
|
1008
|
-
}
|
|
1009
|
-
|
|
1010
|
-
cached_models.append(cache_entry)
|
|
1007
|
+
res = await worker.list_cached_models(model_name)
|
|
1008
|
+
cached_models.extend(res)
|
|
1009
|
+
cached_models = sorted(cached_models, key=lambda x: x["model_name"])
|
|
1011
1010
|
return cached_models
|
|
1012
1011
|
|
|
1013
1012
|
@log_async(logger=logger)
|
|
@@ -1083,6 +1082,56 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1083
1082
|
worker_status.update_time = time.time()
|
|
1084
1083
|
worker_status.status = status
|
|
1085
1084
|
|
|
1085
|
+
async def list_deletable_models(
|
|
1086
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1087
|
+
) -> List[str]:
|
|
1088
|
+
target_ip_worker_ref = (
|
|
1089
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
1090
|
+
)
|
|
1091
|
+
if (
|
|
1092
|
+
worker_ip is not None
|
|
1093
|
+
and not self.is_local_deployment()
|
|
1094
|
+
and target_ip_worker_ref is None
|
|
1095
|
+
):
|
|
1096
|
+
raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
|
|
1097
|
+
|
|
1098
|
+
ret = []
|
|
1099
|
+
if target_ip_worker_ref:
|
|
1100
|
+
ret = await target_ip_worker_ref.list_deletable_models(
|
|
1101
|
+
model_version=model_version,
|
|
1102
|
+
)
|
|
1103
|
+
return ret
|
|
1104
|
+
|
|
1105
|
+
for worker in self._worker_address_to_worker.values():
|
|
1106
|
+
path = await worker.list_deletable_models(model_version=model_version)
|
|
1107
|
+
ret.extend(path)
|
|
1108
|
+
return ret
|
|
1109
|
+
|
|
1110
|
+
async def confirm_and_remove_model(
|
|
1111
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1112
|
+
) -> bool:
|
|
1113
|
+
target_ip_worker_ref = (
|
|
1114
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
1115
|
+
)
|
|
1116
|
+
if (
|
|
1117
|
+
worker_ip is not None
|
|
1118
|
+
and not self.is_local_deployment()
|
|
1119
|
+
and target_ip_worker_ref is None
|
|
1120
|
+
):
|
|
1121
|
+
raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
|
|
1122
|
+
|
|
1123
|
+
if target_ip_worker_ref:
|
|
1124
|
+
ret = await target_ip_worker_ref.confirm_and_remove_model(
|
|
1125
|
+
model_version=model_version,
|
|
1126
|
+
)
|
|
1127
|
+
return ret
|
|
1128
|
+
ret = True
|
|
1129
|
+
for worker in self._worker_address_to_worker.values():
|
|
1130
|
+
ret = ret and await worker.confirm_and_remove_model(
|
|
1131
|
+
model_version=model_version,
|
|
1132
|
+
)
|
|
1133
|
+
return ret
|
|
1134
|
+
|
|
1086
1135
|
@staticmethod
|
|
1087
1136
|
def record_metrics(name, op, kwargs):
|
|
1088
1137
|
record_metrics(name, op, kwargs)
|