xinference 0.16.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +62 -11
- xinference/client/restful/restful_client.py +8 -2
- xinference/conftest.py +0 -8
- xinference/constants.py +2 -0
- xinference/core/model.py +44 -5
- xinference/core/supervisor.py +13 -7
- xinference/core/utils.py +76 -12
- xinference/core/worker.py +5 -4
- xinference/deploy/cmdline.py +5 -0
- xinference/deploy/utils.py +7 -4
- xinference/model/audio/model_spec.json +2 -2
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/core.py +1 -3
- xinference/model/llm/llm_family.json +263 -4
- xinference/model/llm/llm_family_modelscope.json +302 -0
- xinference/model/llm/mlx/core.py +45 -2
- xinference/model/llm/vllm/core.py +2 -1
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/METADATA +26 -3
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/RECORD +49 -56
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-11-
|
|
11
|
+
"date": "2024-11-15T17:33:11+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "4c96475b8f90e354aa1b47856fda4db098b62b65",
|
|
15
|
+
"version": "1.0.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -52,10 +52,14 @@ from xoscar.utils import get_next_port
|
|
|
52
52
|
|
|
53
53
|
from .._compat import BaseModel, Field
|
|
54
54
|
from .._version import get_versions
|
|
55
|
-
from ..constants import
|
|
55
|
+
from ..constants import (
|
|
56
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
57
|
+
XINFERENCE_DEFAULT_ENDPOINT_PORT,
|
|
58
|
+
XINFERENCE_DISABLE_METRICS,
|
|
59
|
+
)
|
|
56
60
|
from ..core.event import Event, EventCollectorActor, EventType
|
|
57
61
|
from ..core.supervisor import SupervisorActor
|
|
58
|
-
from ..core.utils import json_dumps
|
|
62
|
+
from ..core.utils import CancelMixin, json_dumps
|
|
59
63
|
from ..types import (
|
|
60
64
|
ChatCompletion,
|
|
61
65
|
Completion,
|
|
@@ -111,6 +115,7 @@ class RerankRequest(BaseModel):
|
|
|
111
115
|
return_documents: Optional[bool] = False
|
|
112
116
|
return_len: Optional[bool] = False
|
|
113
117
|
max_chunks_per_doc: Optional[int] = None
|
|
118
|
+
kwargs: Optional[str] = None
|
|
114
119
|
|
|
115
120
|
|
|
116
121
|
class TextToImageRequest(BaseModel):
|
|
@@ -206,7 +211,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
|
|
|
206
211
|
model_ability: List[str]
|
|
207
212
|
|
|
208
213
|
|
|
209
|
-
class RESTfulAPI:
|
|
214
|
+
class RESTfulAPI(CancelMixin):
|
|
210
215
|
def __init__(
|
|
211
216
|
self,
|
|
212
217
|
supervisor_address: str,
|
|
@@ -1311,11 +1316,6 @@ class RESTfulAPI:
|
|
|
1311
1316
|
payload = await request.json()
|
|
1312
1317
|
body = RerankRequest.parse_obj(payload)
|
|
1313
1318
|
model_uid = body.model
|
|
1314
|
-
kwargs = {
|
|
1315
|
-
key: value
|
|
1316
|
-
for key, value in payload.items()
|
|
1317
|
-
if key not in RerankRequest.__annotations__.keys()
|
|
1318
|
-
}
|
|
1319
1319
|
|
|
1320
1320
|
try:
|
|
1321
1321
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -1329,6 +1329,10 @@ class RESTfulAPI:
|
|
|
1329
1329
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1330
1330
|
|
|
1331
1331
|
try:
|
|
1332
|
+
if body.kwargs is not None:
|
|
1333
|
+
parsed_kwargs = json.loads(body.kwargs)
|
|
1334
|
+
else:
|
|
1335
|
+
parsed_kwargs = {}
|
|
1332
1336
|
scores = await model.rerank(
|
|
1333
1337
|
body.documents,
|
|
1334
1338
|
body.query,
|
|
@@ -1336,7 +1340,7 @@ class RESTfulAPI:
|
|
|
1336
1340
|
max_chunks_per_doc=body.max_chunks_per_doc,
|
|
1337
1341
|
return_documents=body.return_documents,
|
|
1338
1342
|
return_len=body.return_len,
|
|
1339
|
-
**
|
|
1343
|
+
**parsed_kwargs,
|
|
1340
1344
|
)
|
|
1341
1345
|
return Response(scores, media_type="application/json")
|
|
1342
1346
|
except RuntimeError as re:
|
|
@@ -1531,8 +1535,11 @@ class RESTfulAPI:
|
|
|
1531
1535
|
await self._report_error_event(model_uid, str(e))
|
|
1532
1536
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1533
1537
|
|
|
1538
|
+
request_id = None
|
|
1534
1539
|
try:
|
|
1535
1540
|
kwargs = json.loads(body.kwargs) if body.kwargs else {}
|
|
1541
|
+
request_id = kwargs.get("request_id")
|
|
1542
|
+
self._add_running_task(request_id)
|
|
1536
1543
|
image_list = await model.text_to_image(
|
|
1537
1544
|
prompt=body.prompt,
|
|
1538
1545
|
n=body.n,
|
|
@@ -1541,6 +1548,11 @@ class RESTfulAPI:
|
|
|
1541
1548
|
**kwargs,
|
|
1542
1549
|
)
|
|
1543
1550
|
return Response(content=image_list, media_type="application/json")
|
|
1551
|
+
except asyncio.CancelledError:
|
|
1552
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1553
|
+
logger.error(err_str)
|
|
1554
|
+
await self._report_error_event(model_uid, err_str)
|
|
1555
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1544
1556
|
except RuntimeError as re:
|
|
1545
1557
|
logger.error(re, exc_info=True)
|
|
1546
1558
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1686,11 +1698,14 @@ class RESTfulAPI:
|
|
|
1686
1698
|
await self._report_error_event(model_uid, str(e))
|
|
1687
1699
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1688
1700
|
|
|
1701
|
+
request_id = None
|
|
1689
1702
|
try:
|
|
1690
1703
|
if kwargs is not None:
|
|
1691
1704
|
parsed_kwargs = json.loads(kwargs)
|
|
1692
1705
|
else:
|
|
1693
1706
|
parsed_kwargs = {}
|
|
1707
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1708
|
+
self._add_running_task(request_id)
|
|
1694
1709
|
image_list = await model_ref.image_to_image(
|
|
1695
1710
|
image=Image.open(image.file),
|
|
1696
1711
|
prompt=prompt,
|
|
@@ -1701,6 +1716,11 @@ class RESTfulAPI:
|
|
|
1701
1716
|
**parsed_kwargs,
|
|
1702
1717
|
)
|
|
1703
1718
|
return Response(content=image_list, media_type="application/json")
|
|
1719
|
+
except asyncio.CancelledError:
|
|
1720
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1721
|
+
logger.error(err_str)
|
|
1722
|
+
await self._report_error_event(model_uid, err_str)
|
|
1723
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1704
1724
|
except RuntimeError as re:
|
|
1705
1725
|
logger.error(re, exc_info=True)
|
|
1706
1726
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1734,11 +1754,14 @@ class RESTfulAPI:
|
|
|
1734
1754
|
await self._report_error_event(model_uid, str(e))
|
|
1735
1755
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1736
1756
|
|
|
1757
|
+
request_id = None
|
|
1737
1758
|
try:
|
|
1738
1759
|
if kwargs is not None:
|
|
1739
1760
|
parsed_kwargs = json.loads(kwargs)
|
|
1740
1761
|
else:
|
|
1741
1762
|
parsed_kwargs = {}
|
|
1763
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1764
|
+
self._add_running_task(request_id)
|
|
1742
1765
|
im = Image.open(image.file)
|
|
1743
1766
|
mask_im = Image.open(mask_image.file)
|
|
1744
1767
|
if not size:
|
|
@@ -1755,6 +1778,11 @@ class RESTfulAPI:
|
|
|
1755
1778
|
**parsed_kwargs,
|
|
1756
1779
|
)
|
|
1757
1780
|
return Response(content=image_list, media_type="application/json")
|
|
1781
|
+
except asyncio.CancelledError:
|
|
1782
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1783
|
+
logger.error(err_str)
|
|
1784
|
+
await self._report_error_event(model_uid, err_str)
|
|
1785
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1758
1786
|
except RuntimeError as re:
|
|
1759
1787
|
logger.error(re, exc_info=True)
|
|
1760
1788
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1782,17 +1810,25 @@ class RESTfulAPI:
|
|
|
1782
1810
|
await self._report_error_event(model_uid, str(e))
|
|
1783
1811
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1784
1812
|
|
|
1813
|
+
request_id = None
|
|
1785
1814
|
try:
|
|
1786
1815
|
if kwargs is not None:
|
|
1787
1816
|
parsed_kwargs = json.loads(kwargs)
|
|
1788
1817
|
else:
|
|
1789
1818
|
parsed_kwargs = {}
|
|
1819
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1820
|
+
self._add_running_task(request_id)
|
|
1790
1821
|
im = Image.open(image.file)
|
|
1791
1822
|
text = await model_ref.ocr(
|
|
1792
1823
|
image=im,
|
|
1793
1824
|
**parsed_kwargs,
|
|
1794
1825
|
)
|
|
1795
1826
|
return Response(content=text, media_type="text/plain")
|
|
1827
|
+
except asyncio.CancelledError:
|
|
1828
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1829
|
+
logger.error(err_str)
|
|
1830
|
+
await self._report_error_event(model_uid, err_str)
|
|
1831
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1796
1832
|
except RuntimeError as re:
|
|
1797
1833
|
logger.error(re, exc_info=True)
|
|
1798
1834
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -2111,10 +2147,25 @@ class RESTfulAPI:
|
|
|
2111
2147
|
logger.error(e, exc_info=True)
|
|
2112
2148
|
raise HTTPException(status_code=500, detail=str(e))
|
|
2113
2149
|
|
|
2114
|
-
async def abort_request(
|
|
2150
|
+
async def abort_request(
|
|
2151
|
+
self, request: Request, model_uid: str, request_id: str
|
|
2152
|
+
) -> JSONResponse:
|
|
2115
2153
|
try:
|
|
2154
|
+
payload = await request.json()
|
|
2155
|
+
block_duration = payload.get(
|
|
2156
|
+
"block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION
|
|
2157
|
+
)
|
|
2158
|
+
logger.info(
|
|
2159
|
+
"Abort request with model uid: %s, request id: %s, block duration: %s",
|
|
2160
|
+
model_uid,
|
|
2161
|
+
request_id,
|
|
2162
|
+
block_duration,
|
|
2163
|
+
)
|
|
2116
2164
|
supervisor_ref = await self._get_supervisor_ref()
|
|
2117
|
-
res = await supervisor_ref.abort_request(
|
|
2165
|
+
res = await supervisor_ref.abort_request(
|
|
2166
|
+
model_uid, request_id, block_duration
|
|
2167
|
+
)
|
|
2168
|
+
self._cancel_running_task(request_id, block_duration)
|
|
2118
2169
|
return JSONResponse(content=res)
|
|
2119
2170
|
except Exception as e:
|
|
2120
2171
|
logger.error(e, exc_info=True)
|
|
@@ -174,6 +174,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
174
174
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
175
175
|
"return_documents": return_documents,
|
|
176
176
|
"return_len": return_len,
|
|
177
|
+
"kwargs": json.dumps(kwargs),
|
|
177
178
|
}
|
|
178
179
|
request_body.update(kwargs)
|
|
179
180
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
@@ -1357,7 +1358,7 @@ class Client:
|
|
|
1357
1358
|
response_data = response.json()
|
|
1358
1359
|
return response_data
|
|
1359
1360
|
|
|
1360
|
-
def abort_request(self, model_uid: str, request_id: str):
|
|
1361
|
+
def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30):
|
|
1361
1362
|
"""
|
|
1362
1363
|
Abort a request.
|
|
1363
1364
|
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
@@ -1369,13 +1370,18 @@ class Client:
|
|
|
1369
1370
|
Model uid.
|
|
1370
1371
|
request_id: str
|
|
1371
1372
|
Request id.
|
|
1373
|
+
block_duration: int
|
|
1374
|
+
The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
|
|
1375
|
+
prevent it from taking effect if it arrives before the request operation.
|
|
1372
1376
|
Returns
|
|
1373
1377
|
-------
|
|
1374
1378
|
Dict
|
|
1375
1379
|
Return empty dict.
|
|
1376
1380
|
"""
|
|
1377
1381
|
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
|
|
1378
|
-
response = requests.post(
|
|
1382
|
+
response = requests.post(
|
|
1383
|
+
url, headers=self._headers, json={"block_duration": block_duration}
|
|
1384
|
+
)
|
|
1379
1385
|
if response.status_code != 200:
|
|
1380
1386
|
raise RuntimeError(
|
|
1381
1387
|
f"Failed to abort request, detail: {_get_error_string(response)}"
|
xinference/conftest.py
CHANGED
|
@@ -58,10 +58,6 @@ TEST_LOGGING_CONF = {
|
|
|
58
58
|
"propagate": False,
|
|
59
59
|
}
|
|
60
60
|
},
|
|
61
|
-
"root": {
|
|
62
|
-
"level": "WARN",
|
|
63
|
-
"handlers": ["stream_handler"],
|
|
64
|
-
},
|
|
65
61
|
}
|
|
66
62
|
|
|
67
63
|
TEST_LOG_FILE_PATH = get_log_file(f"test_{get_timestamp_ms()}")
|
|
@@ -102,10 +98,6 @@ TEST_FILE_LOGGING_CONF = {
|
|
|
102
98
|
"propagate": False,
|
|
103
99
|
}
|
|
104
100
|
},
|
|
105
|
-
"root": {
|
|
106
|
-
"level": "WARN",
|
|
107
|
-
"handlers": ["stream_handler", "file_handler"],
|
|
108
|
-
},
|
|
109
101
|
}
|
|
110
102
|
|
|
111
103
|
|
xinference/constants.py
CHANGED
xinference/core/model.py
CHANGED
|
@@ -40,7 +40,11 @@ from typing import (
|
|
|
40
40
|
import sse_starlette.sse
|
|
41
41
|
import xoscar as xo
|
|
42
42
|
|
|
43
|
-
from ..constants import
|
|
43
|
+
from ..constants import (
|
|
44
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
45
|
+
XINFERENCE_LAUNCH_MODEL_RETRY,
|
|
46
|
+
XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE,
|
|
47
|
+
)
|
|
44
48
|
|
|
45
49
|
if TYPE_CHECKING:
|
|
46
50
|
from .progress_tracker import ProgressTrackerActor
|
|
@@ -54,7 +58,7 @@ import logging
|
|
|
54
58
|
logger = logging.getLogger(__name__)
|
|
55
59
|
|
|
56
60
|
from ..device_utils import empty_cache
|
|
57
|
-
from .utils import json_dumps, log_async
|
|
61
|
+
from .utils import CancelMixin, json_dumps, log_async
|
|
58
62
|
|
|
59
63
|
try:
|
|
60
64
|
from torch.cuda import OutOfMemoryError
|
|
@@ -133,7 +137,9 @@ def oom_check(fn):
|
|
|
133
137
|
return _wrapper
|
|
134
138
|
|
|
135
139
|
|
|
136
|
-
class ModelActor(xo.StatelessActor):
|
|
140
|
+
class ModelActor(xo.StatelessActor, CancelMixin):
|
|
141
|
+
_replica_model_uid: Optional[str]
|
|
142
|
+
|
|
137
143
|
@classmethod
|
|
138
144
|
def gen_uid(cls, model: "LLM"):
|
|
139
145
|
return f"{model.__class__}-model-actor"
|
|
@@ -192,6 +198,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
192
198
|
supervisor_address: str,
|
|
193
199
|
worker_address: str,
|
|
194
200
|
model: "LLM",
|
|
201
|
+
replica_model_uid: str,
|
|
195
202
|
model_description: Optional["ModelDescription"] = None,
|
|
196
203
|
request_limits: Optional[int] = None,
|
|
197
204
|
):
|
|
@@ -203,6 +210,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
203
210
|
|
|
204
211
|
self._supervisor_address = supervisor_address
|
|
205
212
|
self._worker_address = worker_address
|
|
213
|
+
self._replica_model_uid = replica_model_uid
|
|
206
214
|
self._model = model
|
|
207
215
|
self._model_description = (
|
|
208
216
|
model_description.to_dict() if model_description else {}
|
|
@@ -257,6 +265,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
257
265
|
uid=FluxBatchSchedulerActor.gen_uid(self.model_uid()),
|
|
258
266
|
)
|
|
259
267
|
|
|
268
|
+
def __repr__(self) -> str:
|
|
269
|
+
return f"ModelActor({self._replica_model_uid})"
|
|
270
|
+
|
|
260
271
|
async def _record_completion_metrics(
|
|
261
272
|
self, duration, completion_tokens, prompt_tokens
|
|
262
273
|
):
|
|
@@ -374,7 +385,28 @@ class ModelActor(xo.StatelessActor):
|
|
|
374
385
|
return condition
|
|
375
386
|
|
|
376
387
|
async def load(self):
|
|
377
|
-
|
|
388
|
+
try:
|
|
389
|
+
# Change process title for model
|
|
390
|
+
import setproctitle
|
|
391
|
+
|
|
392
|
+
setproctitle.setproctitle(f"Model: {self._replica_model_uid}")
|
|
393
|
+
except ImportError:
|
|
394
|
+
pass
|
|
395
|
+
i = 0
|
|
396
|
+
while True:
|
|
397
|
+
i += 1
|
|
398
|
+
try:
|
|
399
|
+
self._model.load()
|
|
400
|
+
break
|
|
401
|
+
except Exception as e:
|
|
402
|
+
if (
|
|
403
|
+
i < XINFERENCE_LAUNCH_MODEL_RETRY
|
|
404
|
+
and str(e).find("busy or unavailable") >= 0
|
|
405
|
+
):
|
|
406
|
+
await asyncio.sleep(5)
|
|
407
|
+
logger.warning("Retry to load model {model_uid}: %d times", i)
|
|
408
|
+
continue
|
|
409
|
+
raise
|
|
378
410
|
if self.allow_batching():
|
|
379
411
|
await self._scheduler_ref.set_model(self._model)
|
|
380
412
|
logger.debug(
|
|
@@ -385,6 +417,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
385
417
|
logger.debug(
|
|
386
418
|
f"Batching enabled for model: {self.model_uid()}, max_num_images: {self._model.get_max_num_images_for_batching()}"
|
|
387
419
|
)
|
|
420
|
+
logger.info(f"{self} loaded")
|
|
388
421
|
|
|
389
422
|
def model_uid(self):
|
|
390
423
|
return (
|
|
@@ -521,6 +554,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
521
554
|
|
|
522
555
|
@oom_check
|
|
523
556
|
async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
|
|
557
|
+
self._add_running_task(kwargs.get("request_id"))
|
|
524
558
|
if self._lock is None:
|
|
525
559
|
if inspect.iscoroutinefunction(fn):
|
|
526
560
|
ret = await fn(*args, **kwargs)
|
|
@@ -729,9 +763,14 @@ class ModelActor(xo.StatelessActor):
|
|
|
729
763
|
prompt_tokens,
|
|
730
764
|
)
|
|
731
765
|
|
|
732
|
-
async def abort_request(
|
|
766
|
+
async def abort_request(
|
|
767
|
+
self,
|
|
768
|
+
request_id: str,
|
|
769
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
770
|
+
) -> str:
|
|
733
771
|
from .utils import AbortRequestMessage
|
|
734
772
|
|
|
773
|
+
self._cancel_running_task(request_id, block_duration)
|
|
735
774
|
if self.allow_batching():
|
|
736
775
|
if self._scheduler_ref is None:
|
|
737
776
|
return AbortRequestMessage.NOT_FOUND.name
|
xinference/core/supervisor.py
CHANGED
|
@@ -35,6 +35,7 @@ from typing import (
|
|
|
35
35
|
import xoscar as xo
|
|
36
36
|
|
|
37
37
|
from ..constants import (
|
|
38
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
38
39
|
XINFERENCE_DISABLE_HEALTH_CHECK,
|
|
39
40
|
XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
|
|
40
41
|
XINFERENCE_HEALTH_CHECK_INTERVAL,
|
|
@@ -970,7 +971,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
970
971
|
raise ValueError(
|
|
971
972
|
f"Model is already in the model list, uid: {_replica_model_uid}"
|
|
972
973
|
)
|
|
973
|
-
replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
|
|
974
|
+
replica_gpu_idx = assign_replica_gpu(_replica_model_uid, replica, gpu_idx)
|
|
974
975
|
nonlocal model_type
|
|
975
976
|
|
|
976
977
|
worker_ref = (
|
|
@@ -1084,7 +1085,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1084
1085
|
dead_models,
|
|
1085
1086
|
)
|
|
1086
1087
|
for replica_model_uid in dead_models:
|
|
1087
|
-
model_uid, _
|
|
1088
|
+
model_uid, _ = parse_replica_model_uid(replica_model_uid)
|
|
1088
1089
|
self._model_uid_to_replica_info.pop(model_uid, None)
|
|
1089
1090
|
self._replica_model_uid_to_worker.pop(
|
|
1090
1091
|
replica_model_uid, None
|
|
@@ -1137,7 +1138,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1137
1138
|
raise ValueError(f"Model not found in the model list, uid: {model_uid}")
|
|
1138
1139
|
|
|
1139
1140
|
replica_model_uid = build_replica_model_uid(
|
|
1140
|
-
model_uid,
|
|
1141
|
+
model_uid, next(replica_info.scheduler)
|
|
1141
1142
|
)
|
|
1142
1143
|
|
|
1143
1144
|
worker_ref = self._replica_model_uid_to_worker.get(replica_model_uid, None)
|
|
@@ -1154,7 +1155,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1154
1155
|
raise ValueError(f"Model not found in the model list, uid: {model_uid}")
|
|
1155
1156
|
# Use rep id 0 to instead of next(replica_info.scheduler) to avoid
|
|
1156
1157
|
# consuming the generator.
|
|
1157
|
-
replica_model_uid = build_replica_model_uid(model_uid,
|
|
1158
|
+
replica_model_uid = build_replica_model_uid(model_uid, 0)
|
|
1158
1159
|
worker_ref = self._replica_model_uid_to_worker.get(replica_model_uid, None)
|
|
1159
1160
|
if worker_ref is None:
|
|
1160
1161
|
raise ValueError(
|
|
@@ -1213,7 +1214,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1213
1214
|
return cached_models
|
|
1214
1215
|
|
|
1215
1216
|
@log_async(logger=logger)
|
|
1216
|
-
async def abort_request(
|
|
1217
|
+
async def abort_request(
|
|
1218
|
+
self,
|
|
1219
|
+
model_uid: str,
|
|
1220
|
+
request_id: str,
|
|
1221
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
1222
|
+
) -> Dict:
|
|
1217
1223
|
from .scheduler import AbortRequestMessage
|
|
1218
1224
|
|
|
1219
1225
|
res = {"msg": AbortRequestMessage.NO_OP.name}
|
|
@@ -1228,7 +1234,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1228
1234
|
if worker_ref is None:
|
|
1229
1235
|
continue
|
|
1230
1236
|
model_ref = await worker_ref.get_model(model_uid=rep_mid)
|
|
1231
|
-
result_info = await model_ref.abort_request(request_id)
|
|
1237
|
+
result_info = await model_ref.abort_request(request_id, block_duration)
|
|
1232
1238
|
res["msg"] = result_info
|
|
1233
1239
|
if result_info == AbortRequestMessage.DONE.name:
|
|
1234
1240
|
break
|
|
@@ -1260,7 +1266,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1260
1266
|
uids_to_remove.append(model_uid)
|
|
1261
1267
|
|
|
1262
1268
|
for replica_model_uid in uids_to_remove:
|
|
1263
|
-
model_uid, _
|
|
1269
|
+
model_uid, _ = parse_replica_model_uid(replica_model_uid)
|
|
1264
1270
|
self._model_uid_to_replica_info.pop(model_uid, None)
|
|
1265
1271
|
self._replica_model_uid_to_worker.pop(replica_model_uid, None)
|
|
1266
1272
|
|
xinference/core/utils.py
CHANGED
|
@@ -11,11 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
14
15
|
import logging
|
|
15
16
|
import os
|
|
16
17
|
import random
|
|
17
18
|
import string
|
|
18
19
|
import uuid
|
|
20
|
+
import weakref
|
|
19
21
|
from enum import Enum
|
|
20
22
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
21
23
|
|
|
@@ -23,7 +25,10 @@ import orjson
|
|
|
23
25
|
from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
|
|
24
26
|
|
|
25
27
|
from .._compat import BaseModel
|
|
26
|
-
from ..constants import
|
|
28
|
+
from ..constants import (
|
|
29
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
30
|
+
XINFERENCE_LOG_ARG_MAX_LENGTH,
|
|
31
|
+
)
|
|
27
32
|
|
|
28
33
|
logger = logging.getLogger(__name__)
|
|
29
34
|
|
|
@@ -49,13 +54,20 @@ def log_async(
|
|
|
49
54
|
):
|
|
50
55
|
import time
|
|
51
56
|
from functools import wraps
|
|
57
|
+
from inspect import signature
|
|
52
58
|
|
|
53
59
|
def decorator(func):
|
|
54
60
|
func_name = func.__name__
|
|
61
|
+
sig = signature(func)
|
|
55
62
|
|
|
56
63
|
@wraps(func)
|
|
57
64
|
async def wrapped(*args, **kwargs):
|
|
58
|
-
|
|
65
|
+
try:
|
|
66
|
+
bound_args = sig.bind_partial(*args, **kwargs)
|
|
67
|
+
arguments = bound_args.arguments
|
|
68
|
+
except TypeError:
|
|
69
|
+
arguments = {}
|
|
70
|
+
request_id_str = arguments.get("request_id", "")
|
|
59
71
|
if not request_id_str:
|
|
60
72
|
request_id_str = uuid.uuid1()
|
|
61
73
|
if func_name == "text_to_image":
|
|
@@ -146,27 +158,26 @@ def iter_replica_model_uid(model_uid: str, replica: int) -> Generator[str, None,
|
|
|
146
158
|
"""
|
|
147
159
|
replica = int(replica)
|
|
148
160
|
for rep_id in range(replica):
|
|
149
|
-
yield f"{model_uid}-{
|
|
161
|
+
yield f"{model_uid}-{rep_id}"
|
|
150
162
|
|
|
151
163
|
|
|
152
|
-
def build_replica_model_uid(model_uid: str,
|
|
164
|
+
def build_replica_model_uid(model_uid: str, rep_id: int) -> str:
|
|
153
165
|
"""
|
|
154
166
|
Build a replica model uid.
|
|
155
167
|
"""
|
|
156
|
-
return f"{model_uid}-{
|
|
168
|
+
return f"{model_uid}-{rep_id}"
|
|
157
169
|
|
|
158
170
|
|
|
159
|
-
def parse_replica_model_uid(replica_model_uid: str) -> Tuple[str, int
|
|
171
|
+
def parse_replica_model_uid(replica_model_uid: str) -> Tuple[str, int]:
|
|
160
172
|
"""
|
|
161
|
-
Parse replica model uid to model uid
|
|
173
|
+
Parse replica model uid to model uid and rep id.
|
|
162
174
|
"""
|
|
163
175
|
parts = replica_model_uid.split("-")
|
|
164
176
|
if len(parts) == 1:
|
|
165
|
-
return replica_model_uid, -1
|
|
177
|
+
return replica_model_uid, -1
|
|
166
178
|
rep_id = int(parts.pop())
|
|
167
|
-
replica = int(parts.pop())
|
|
168
179
|
model_uid = "-".join(parts)
|
|
169
|
-
return model_uid,
|
|
180
|
+
return model_uid, rep_id
|
|
170
181
|
|
|
171
182
|
|
|
172
183
|
def is_valid_model_uid(model_uid: str) -> bool:
|
|
@@ -261,12 +272,65 @@ def get_nvidia_gpu_info() -> Dict:
|
|
|
261
272
|
|
|
262
273
|
|
|
263
274
|
def assign_replica_gpu(
|
|
264
|
-
_replica_model_uid: str, gpu_idx: Union[int, List[int]]
|
|
275
|
+
_replica_model_uid: str, replica: int, gpu_idx: Union[int, List[int]]
|
|
265
276
|
) -> List[int]:
|
|
266
|
-
model_uid,
|
|
277
|
+
model_uid, rep_id = parse_replica_model_uid(_replica_model_uid)
|
|
267
278
|
rep_id, replica = int(rep_id), int(replica)
|
|
268
279
|
if isinstance(gpu_idx, int):
|
|
269
280
|
gpu_idx = [gpu_idx]
|
|
270
281
|
if isinstance(gpu_idx, list) and gpu_idx:
|
|
271
282
|
return gpu_idx[rep_id::replica]
|
|
272
283
|
return gpu_idx
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class CancelMixin:
|
|
287
|
+
_CANCEL_TASK_NAME = "abort_block"
|
|
288
|
+
|
|
289
|
+
def __init__(self):
|
|
290
|
+
self._running_tasks: weakref.WeakValueDictionary[
|
|
291
|
+
str, asyncio.Task
|
|
292
|
+
] = weakref.WeakValueDictionary()
|
|
293
|
+
|
|
294
|
+
def _add_running_task(self, request_id: Optional[str]):
|
|
295
|
+
"""Add current asyncio task to the running task.
|
|
296
|
+
:param request_id: The corresponding request id.
|
|
297
|
+
"""
|
|
298
|
+
if request_id is None:
|
|
299
|
+
return
|
|
300
|
+
running_task = self._running_tasks.get(request_id)
|
|
301
|
+
if running_task is not None:
|
|
302
|
+
if running_task.get_name() == self._CANCEL_TASK_NAME:
|
|
303
|
+
raise Exception(f"The request has been aborted: {request_id}")
|
|
304
|
+
raise Exception(f"Duplicate request id: {request_id}")
|
|
305
|
+
current_task = asyncio.current_task()
|
|
306
|
+
assert current_task is not None
|
|
307
|
+
self._running_tasks[request_id] = current_task
|
|
308
|
+
|
|
309
|
+
def _cancel_running_task(
|
|
310
|
+
self,
|
|
311
|
+
request_id: Optional[str],
|
|
312
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
313
|
+
):
|
|
314
|
+
"""Cancel the running asyncio task.
|
|
315
|
+
:param request_id: The request id to cancel.
|
|
316
|
+
:param block_duration: The duration seconds to ensure the request can't be executed.
|
|
317
|
+
"""
|
|
318
|
+
if request_id is None:
|
|
319
|
+
return
|
|
320
|
+
running_task = self._running_tasks.pop(request_id, None)
|
|
321
|
+
if running_task is not None:
|
|
322
|
+
running_task.cancel()
|
|
323
|
+
|
|
324
|
+
async def block_task():
|
|
325
|
+
"""This task is for blocking the request for a duration."""
|
|
326
|
+
try:
|
|
327
|
+
await asyncio.sleep(block_duration)
|
|
328
|
+
logger.info("Abort block end for request: %s", request_id)
|
|
329
|
+
except asyncio.CancelledError:
|
|
330
|
+
logger.info("Abort block is cancelled for request: %s", request_id)
|
|
331
|
+
|
|
332
|
+
if block_duration > 0:
|
|
333
|
+
logger.info("Abort block start for request: %s", request_id)
|
|
334
|
+
self._running_tasks[request_id] = asyncio.create_task(
|
|
335
|
+
block_task(), name=self._CANCEL_TASK_NAME
|
|
336
|
+
)
|
xinference/core/worker.py
CHANGED
|
@@ -157,7 +157,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
157
157
|
model_uid,
|
|
158
158
|
recover_count - 1,
|
|
159
159
|
)
|
|
160
|
-
event_model_uid, _
|
|
160
|
+
event_model_uid, _ = parse_replica_model_uid(model_uid)
|
|
161
161
|
try:
|
|
162
162
|
if self._event_collector_ref is not None:
|
|
163
163
|
await self._event_collector_ref.report_event(
|
|
@@ -377,7 +377,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
377
377
|
return len(self._model_uid_to_model)
|
|
378
378
|
|
|
379
379
|
async def is_model_vllm_backend(self, model_uid: str) -> bool:
|
|
380
|
-
_model_uid, _
|
|
380
|
+
_model_uid, _ = parse_replica_model_uid(model_uid)
|
|
381
381
|
supervisor_ref = await self.get_supervisor_ref()
|
|
382
382
|
model_ref = await supervisor_ref.get_model(_model_uid)
|
|
383
383
|
return await model_ref.is_vllm_backend()
|
|
@@ -800,7 +800,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
800
800
|
launch_args.update(kwargs)
|
|
801
801
|
|
|
802
802
|
try:
|
|
803
|
-
origin_uid, _
|
|
803
|
+
origin_uid, _ = parse_replica_model_uid(model_uid)
|
|
804
804
|
except Exception as e:
|
|
805
805
|
logger.exception(e)
|
|
806
806
|
raise
|
|
@@ -889,6 +889,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
889
889
|
uid=model_uid,
|
|
890
890
|
supervisor_address=self._supervisor_address,
|
|
891
891
|
worker_address=self.address,
|
|
892
|
+
replica_model_uid=model_uid,
|
|
892
893
|
model=model,
|
|
893
894
|
model_description=model_description,
|
|
894
895
|
request_limits=request_limits,
|
|
@@ -926,7 +927,7 @@ class WorkerActor(xo.StatelessActor):
|
|
|
926
927
|
# Terminate model while its launching is not allow
|
|
927
928
|
if model_uid in self._model_uid_launching_guard:
|
|
928
929
|
raise ValueError(f"{model_uid} is launching")
|
|
929
|
-
origin_uid, _
|
|
930
|
+
origin_uid, _ = parse_replica_model_uid(model_uid)
|
|
930
931
|
try:
|
|
931
932
|
_ = await self.get_supervisor_ref()
|
|
932
933
|
if self._event_collector_ref is not None:
|