xinference 0.16.3__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +62 -11
- xinference/client/restful/restful_client.py +8 -2
- xinference/constants.py +1 -0
- xinference/core/model.py +10 -3
- xinference/core/supervisor.py +8 -2
- xinference/core/utils.py +67 -2
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/llm_family.json +176 -4
- xinference/model/llm/llm_family_modelscope.json +211 -0
- xinference/model/llm/mlx/core.py +45 -2
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/METADATA +23 -1
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/RECORD +43 -50
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-11-
|
|
11
|
+
"date": "2024-11-15T17:33:11+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "4c96475b8f90e354aa1b47856fda4db098b62b65",
|
|
15
|
+
"version": "1.0.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -52,10 +52,14 @@ from xoscar.utils import get_next_port
|
|
|
52
52
|
|
|
53
53
|
from .._compat import BaseModel, Field
|
|
54
54
|
from .._version import get_versions
|
|
55
|
-
from ..constants import
|
|
55
|
+
from ..constants import (
|
|
56
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
57
|
+
XINFERENCE_DEFAULT_ENDPOINT_PORT,
|
|
58
|
+
XINFERENCE_DISABLE_METRICS,
|
|
59
|
+
)
|
|
56
60
|
from ..core.event import Event, EventCollectorActor, EventType
|
|
57
61
|
from ..core.supervisor import SupervisorActor
|
|
58
|
-
from ..core.utils import json_dumps
|
|
62
|
+
from ..core.utils import CancelMixin, json_dumps
|
|
59
63
|
from ..types import (
|
|
60
64
|
ChatCompletion,
|
|
61
65
|
Completion,
|
|
@@ -111,6 +115,7 @@ class RerankRequest(BaseModel):
|
|
|
111
115
|
return_documents: Optional[bool] = False
|
|
112
116
|
return_len: Optional[bool] = False
|
|
113
117
|
max_chunks_per_doc: Optional[int] = None
|
|
118
|
+
kwargs: Optional[str] = None
|
|
114
119
|
|
|
115
120
|
|
|
116
121
|
class TextToImageRequest(BaseModel):
|
|
@@ -206,7 +211,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
|
|
|
206
211
|
model_ability: List[str]
|
|
207
212
|
|
|
208
213
|
|
|
209
|
-
class RESTfulAPI:
|
|
214
|
+
class RESTfulAPI(CancelMixin):
|
|
210
215
|
def __init__(
|
|
211
216
|
self,
|
|
212
217
|
supervisor_address: str,
|
|
@@ -1311,11 +1316,6 @@ class RESTfulAPI:
|
|
|
1311
1316
|
payload = await request.json()
|
|
1312
1317
|
body = RerankRequest.parse_obj(payload)
|
|
1313
1318
|
model_uid = body.model
|
|
1314
|
-
kwargs = {
|
|
1315
|
-
key: value
|
|
1316
|
-
for key, value in payload.items()
|
|
1317
|
-
if key not in RerankRequest.__annotations__.keys()
|
|
1318
|
-
}
|
|
1319
1319
|
|
|
1320
1320
|
try:
|
|
1321
1321
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -1329,6 +1329,10 @@ class RESTfulAPI:
|
|
|
1329
1329
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1330
1330
|
|
|
1331
1331
|
try:
|
|
1332
|
+
if body.kwargs is not None:
|
|
1333
|
+
parsed_kwargs = json.loads(body.kwargs)
|
|
1334
|
+
else:
|
|
1335
|
+
parsed_kwargs = {}
|
|
1332
1336
|
scores = await model.rerank(
|
|
1333
1337
|
body.documents,
|
|
1334
1338
|
body.query,
|
|
@@ -1336,7 +1340,7 @@ class RESTfulAPI:
|
|
|
1336
1340
|
max_chunks_per_doc=body.max_chunks_per_doc,
|
|
1337
1341
|
return_documents=body.return_documents,
|
|
1338
1342
|
return_len=body.return_len,
|
|
1339
|
-
**
|
|
1343
|
+
**parsed_kwargs,
|
|
1340
1344
|
)
|
|
1341
1345
|
return Response(scores, media_type="application/json")
|
|
1342
1346
|
except RuntimeError as re:
|
|
@@ -1531,8 +1535,11 @@ class RESTfulAPI:
|
|
|
1531
1535
|
await self._report_error_event(model_uid, str(e))
|
|
1532
1536
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1533
1537
|
|
|
1538
|
+
request_id = None
|
|
1534
1539
|
try:
|
|
1535
1540
|
kwargs = json.loads(body.kwargs) if body.kwargs else {}
|
|
1541
|
+
request_id = kwargs.get("request_id")
|
|
1542
|
+
self._add_running_task(request_id)
|
|
1536
1543
|
image_list = await model.text_to_image(
|
|
1537
1544
|
prompt=body.prompt,
|
|
1538
1545
|
n=body.n,
|
|
@@ -1541,6 +1548,11 @@ class RESTfulAPI:
|
|
|
1541
1548
|
**kwargs,
|
|
1542
1549
|
)
|
|
1543
1550
|
return Response(content=image_list, media_type="application/json")
|
|
1551
|
+
except asyncio.CancelledError:
|
|
1552
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1553
|
+
logger.error(err_str)
|
|
1554
|
+
await self._report_error_event(model_uid, err_str)
|
|
1555
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1544
1556
|
except RuntimeError as re:
|
|
1545
1557
|
logger.error(re, exc_info=True)
|
|
1546
1558
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1686,11 +1698,14 @@ class RESTfulAPI:
|
|
|
1686
1698
|
await self._report_error_event(model_uid, str(e))
|
|
1687
1699
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1688
1700
|
|
|
1701
|
+
request_id = None
|
|
1689
1702
|
try:
|
|
1690
1703
|
if kwargs is not None:
|
|
1691
1704
|
parsed_kwargs = json.loads(kwargs)
|
|
1692
1705
|
else:
|
|
1693
1706
|
parsed_kwargs = {}
|
|
1707
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1708
|
+
self._add_running_task(request_id)
|
|
1694
1709
|
image_list = await model_ref.image_to_image(
|
|
1695
1710
|
image=Image.open(image.file),
|
|
1696
1711
|
prompt=prompt,
|
|
@@ -1701,6 +1716,11 @@ class RESTfulAPI:
|
|
|
1701
1716
|
**parsed_kwargs,
|
|
1702
1717
|
)
|
|
1703
1718
|
return Response(content=image_list, media_type="application/json")
|
|
1719
|
+
except asyncio.CancelledError:
|
|
1720
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1721
|
+
logger.error(err_str)
|
|
1722
|
+
await self._report_error_event(model_uid, err_str)
|
|
1723
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1704
1724
|
except RuntimeError as re:
|
|
1705
1725
|
logger.error(re, exc_info=True)
|
|
1706
1726
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1734,11 +1754,14 @@ class RESTfulAPI:
|
|
|
1734
1754
|
await self._report_error_event(model_uid, str(e))
|
|
1735
1755
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1736
1756
|
|
|
1757
|
+
request_id = None
|
|
1737
1758
|
try:
|
|
1738
1759
|
if kwargs is not None:
|
|
1739
1760
|
parsed_kwargs = json.loads(kwargs)
|
|
1740
1761
|
else:
|
|
1741
1762
|
parsed_kwargs = {}
|
|
1763
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1764
|
+
self._add_running_task(request_id)
|
|
1742
1765
|
im = Image.open(image.file)
|
|
1743
1766
|
mask_im = Image.open(mask_image.file)
|
|
1744
1767
|
if not size:
|
|
@@ -1755,6 +1778,11 @@ class RESTfulAPI:
|
|
|
1755
1778
|
**parsed_kwargs,
|
|
1756
1779
|
)
|
|
1757
1780
|
return Response(content=image_list, media_type="application/json")
|
|
1781
|
+
except asyncio.CancelledError:
|
|
1782
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1783
|
+
logger.error(err_str)
|
|
1784
|
+
await self._report_error_event(model_uid, err_str)
|
|
1785
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1758
1786
|
except RuntimeError as re:
|
|
1759
1787
|
logger.error(re, exc_info=True)
|
|
1760
1788
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1782,17 +1810,25 @@ class RESTfulAPI:
|
|
|
1782
1810
|
await self._report_error_event(model_uid, str(e))
|
|
1783
1811
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1784
1812
|
|
|
1813
|
+
request_id = None
|
|
1785
1814
|
try:
|
|
1786
1815
|
if kwargs is not None:
|
|
1787
1816
|
parsed_kwargs = json.loads(kwargs)
|
|
1788
1817
|
else:
|
|
1789
1818
|
parsed_kwargs = {}
|
|
1819
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1820
|
+
self._add_running_task(request_id)
|
|
1790
1821
|
im = Image.open(image.file)
|
|
1791
1822
|
text = await model_ref.ocr(
|
|
1792
1823
|
image=im,
|
|
1793
1824
|
**parsed_kwargs,
|
|
1794
1825
|
)
|
|
1795
1826
|
return Response(content=text, media_type="text/plain")
|
|
1827
|
+
except asyncio.CancelledError:
|
|
1828
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1829
|
+
logger.error(err_str)
|
|
1830
|
+
await self._report_error_event(model_uid, err_str)
|
|
1831
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1796
1832
|
except RuntimeError as re:
|
|
1797
1833
|
logger.error(re, exc_info=True)
|
|
1798
1834
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -2111,10 +2147,25 @@ class RESTfulAPI:
|
|
|
2111
2147
|
logger.error(e, exc_info=True)
|
|
2112
2148
|
raise HTTPException(status_code=500, detail=str(e))
|
|
2113
2149
|
|
|
2114
|
-
async def abort_request(
|
|
2150
|
+
async def abort_request(
|
|
2151
|
+
self, request: Request, model_uid: str, request_id: str
|
|
2152
|
+
) -> JSONResponse:
|
|
2115
2153
|
try:
|
|
2154
|
+
payload = await request.json()
|
|
2155
|
+
block_duration = payload.get(
|
|
2156
|
+
"block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION
|
|
2157
|
+
)
|
|
2158
|
+
logger.info(
|
|
2159
|
+
"Abort request with model uid: %s, request id: %s, block duration: %s",
|
|
2160
|
+
model_uid,
|
|
2161
|
+
request_id,
|
|
2162
|
+
block_duration,
|
|
2163
|
+
)
|
|
2116
2164
|
supervisor_ref = await self._get_supervisor_ref()
|
|
2117
|
-
res = await supervisor_ref.abort_request(
|
|
2165
|
+
res = await supervisor_ref.abort_request(
|
|
2166
|
+
model_uid, request_id, block_duration
|
|
2167
|
+
)
|
|
2168
|
+
self._cancel_running_task(request_id, block_duration)
|
|
2118
2169
|
return JSONResponse(content=res)
|
|
2119
2170
|
except Exception as e:
|
|
2120
2171
|
logger.error(e, exc_info=True)
|
|
@@ -174,6 +174,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
174
174
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
175
175
|
"return_documents": return_documents,
|
|
176
176
|
"return_len": return_len,
|
|
177
|
+
"kwargs": json.dumps(kwargs),
|
|
177
178
|
}
|
|
178
179
|
request_body.update(kwargs)
|
|
179
180
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
@@ -1357,7 +1358,7 @@ class Client:
|
|
|
1357
1358
|
response_data = response.json()
|
|
1358
1359
|
return response_data
|
|
1359
1360
|
|
|
1360
|
-
def abort_request(self, model_uid: str, request_id: str):
|
|
1361
|
+
def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30):
|
|
1361
1362
|
"""
|
|
1362
1363
|
Abort a request.
|
|
1363
1364
|
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
@@ -1369,13 +1370,18 @@ class Client:
|
|
|
1369
1370
|
Model uid.
|
|
1370
1371
|
request_id: str
|
|
1371
1372
|
Request id.
|
|
1373
|
+
block_duration: int
|
|
1374
|
+
The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
|
|
1375
|
+
prevent it from taking effect if it arrives before the request operation.
|
|
1372
1376
|
Returns
|
|
1373
1377
|
-------
|
|
1374
1378
|
Dict
|
|
1375
1379
|
Return empty dict.
|
|
1376
1380
|
"""
|
|
1377
1381
|
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
|
|
1378
|
-
response = requests.post(
|
|
1382
|
+
response = requests.post(
|
|
1383
|
+
url, headers=self._headers, json={"block_duration": block_duration}
|
|
1384
|
+
)
|
|
1379
1385
|
if response.status_code != 200:
|
|
1380
1386
|
raise RuntimeError(
|
|
1381
1387
|
f"Failed to abort request, detail: {_get_error_string(response)}"
|
xinference/constants.py
CHANGED
xinference/core/model.py
CHANGED
|
@@ -41,6 +41,7 @@ import sse_starlette.sse
|
|
|
41
41
|
import xoscar as xo
|
|
42
42
|
|
|
43
43
|
from ..constants import (
|
|
44
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
44
45
|
XINFERENCE_LAUNCH_MODEL_RETRY,
|
|
45
46
|
XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE,
|
|
46
47
|
)
|
|
@@ -57,7 +58,7 @@ import logging
|
|
|
57
58
|
logger = logging.getLogger(__name__)
|
|
58
59
|
|
|
59
60
|
from ..device_utils import empty_cache
|
|
60
|
-
from .utils import json_dumps, log_async
|
|
61
|
+
from .utils import CancelMixin, json_dumps, log_async
|
|
61
62
|
|
|
62
63
|
try:
|
|
63
64
|
from torch.cuda import OutOfMemoryError
|
|
@@ -136,7 +137,7 @@ def oom_check(fn):
|
|
|
136
137
|
return _wrapper
|
|
137
138
|
|
|
138
139
|
|
|
139
|
-
class ModelActor(xo.StatelessActor):
|
|
140
|
+
class ModelActor(xo.StatelessActor, CancelMixin):
|
|
140
141
|
_replica_model_uid: Optional[str]
|
|
141
142
|
|
|
142
143
|
@classmethod
|
|
@@ -553,6 +554,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
553
554
|
|
|
554
555
|
@oom_check
|
|
555
556
|
async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
|
|
557
|
+
self._add_running_task(kwargs.get("request_id"))
|
|
556
558
|
if self._lock is None:
|
|
557
559
|
if inspect.iscoroutinefunction(fn):
|
|
558
560
|
ret = await fn(*args, **kwargs)
|
|
@@ -761,9 +763,14 @@ class ModelActor(xo.StatelessActor):
|
|
|
761
763
|
prompt_tokens,
|
|
762
764
|
)
|
|
763
765
|
|
|
764
|
-
async def abort_request(
|
|
766
|
+
async def abort_request(
|
|
767
|
+
self,
|
|
768
|
+
request_id: str,
|
|
769
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
770
|
+
) -> str:
|
|
765
771
|
from .utils import AbortRequestMessage
|
|
766
772
|
|
|
773
|
+
self._cancel_running_task(request_id, block_duration)
|
|
767
774
|
if self.allow_batching():
|
|
768
775
|
if self._scheduler_ref is None:
|
|
769
776
|
return AbortRequestMessage.NOT_FOUND.name
|
xinference/core/supervisor.py
CHANGED
|
@@ -35,6 +35,7 @@ from typing import (
|
|
|
35
35
|
import xoscar as xo
|
|
36
36
|
|
|
37
37
|
from ..constants import (
|
|
38
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
38
39
|
XINFERENCE_DISABLE_HEALTH_CHECK,
|
|
39
40
|
XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
|
|
40
41
|
XINFERENCE_HEALTH_CHECK_INTERVAL,
|
|
@@ -1213,7 +1214,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1213
1214
|
return cached_models
|
|
1214
1215
|
|
|
1215
1216
|
@log_async(logger=logger)
|
|
1216
|
-
async def abort_request(
|
|
1217
|
+
async def abort_request(
|
|
1218
|
+
self,
|
|
1219
|
+
model_uid: str,
|
|
1220
|
+
request_id: str,
|
|
1221
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
1222
|
+
) -> Dict:
|
|
1217
1223
|
from .scheduler import AbortRequestMessage
|
|
1218
1224
|
|
|
1219
1225
|
res = {"msg": AbortRequestMessage.NO_OP.name}
|
|
@@ -1228,7 +1234,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1228
1234
|
if worker_ref is None:
|
|
1229
1235
|
continue
|
|
1230
1236
|
model_ref = await worker_ref.get_model(model_uid=rep_mid)
|
|
1231
|
-
result_info = await model_ref.abort_request(request_id)
|
|
1237
|
+
result_info = await model_ref.abort_request(request_id, block_duration)
|
|
1232
1238
|
res["msg"] = result_info
|
|
1233
1239
|
if result_info == AbortRequestMessage.DONE.name:
|
|
1234
1240
|
break
|
xinference/core/utils.py
CHANGED
|
@@ -11,11 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
14
15
|
import logging
|
|
15
16
|
import os
|
|
16
17
|
import random
|
|
17
18
|
import string
|
|
18
19
|
import uuid
|
|
20
|
+
import weakref
|
|
19
21
|
from enum import Enum
|
|
20
22
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
21
23
|
|
|
@@ -23,7 +25,10 @@ import orjson
|
|
|
23
25
|
from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
|
|
24
26
|
|
|
25
27
|
from .._compat import BaseModel
|
|
26
|
-
from ..constants import
|
|
28
|
+
from ..constants import (
|
|
29
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
30
|
+
XINFERENCE_LOG_ARG_MAX_LENGTH,
|
|
31
|
+
)
|
|
27
32
|
|
|
28
33
|
logger = logging.getLogger(__name__)
|
|
29
34
|
|
|
@@ -49,13 +54,20 @@ def log_async(
|
|
|
49
54
|
):
|
|
50
55
|
import time
|
|
51
56
|
from functools import wraps
|
|
57
|
+
from inspect import signature
|
|
52
58
|
|
|
53
59
|
def decorator(func):
|
|
54
60
|
func_name = func.__name__
|
|
61
|
+
sig = signature(func)
|
|
55
62
|
|
|
56
63
|
@wraps(func)
|
|
57
64
|
async def wrapped(*args, **kwargs):
|
|
58
|
-
|
|
65
|
+
try:
|
|
66
|
+
bound_args = sig.bind_partial(*args, **kwargs)
|
|
67
|
+
arguments = bound_args.arguments
|
|
68
|
+
except TypeError:
|
|
69
|
+
arguments = {}
|
|
70
|
+
request_id_str = arguments.get("request_id", "")
|
|
59
71
|
if not request_id_str:
|
|
60
72
|
request_id_str = uuid.uuid1()
|
|
61
73
|
if func_name == "text_to_image":
|
|
@@ -269,3 +281,56 @@ def assign_replica_gpu(
|
|
|
269
281
|
if isinstance(gpu_idx, list) and gpu_idx:
|
|
270
282
|
return gpu_idx[rep_id::replica]
|
|
271
283
|
return gpu_idx
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class CancelMixin:
|
|
287
|
+
_CANCEL_TASK_NAME = "abort_block"
|
|
288
|
+
|
|
289
|
+
def __init__(self):
|
|
290
|
+
self._running_tasks: weakref.WeakValueDictionary[
|
|
291
|
+
str, asyncio.Task
|
|
292
|
+
] = weakref.WeakValueDictionary()
|
|
293
|
+
|
|
294
|
+
def _add_running_task(self, request_id: Optional[str]):
|
|
295
|
+
"""Add current asyncio task to the running task.
|
|
296
|
+
:param request_id: The corresponding request id.
|
|
297
|
+
"""
|
|
298
|
+
if request_id is None:
|
|
299
|
+
return
|
|
300
|
+
running_task = self._running_tasks.get(request_id)
|
|
301
|
+
if running_task is not None:
|
|
302
|
+
if running_task.get_name() == self._CANCEL_TASK_NAME:
|
|
303
|
+
raise Exception(f"The request has been aborted: {request_id}")
|
|
304
|
+
raise Exception(f"Duplicate request id: {request_id}")
|
|
305
|
+
current_task = asyncio.current_task()
|
|
306
|
+
assert current_task is not None
|
|
307
|
+
self._running_tasks[request_id] = current_task
|
|
308
|
+
|
|
309
|
+
def _cancel_running_task(
|
|
310
|
+
self,
|
|
311
|
+
request_id: Optional[str],
|
|
312
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
313
|
+
):
|
|
314
|
+
"""Cancel the running asyncio task.
|
|
315
|
+
:param request_id: The request id to cancel.
|
|
316
|
+
:param block_duration: The duration seconds to ensure the request can't be executed.
|
|
317
|
+
"""
|
|
318
|
+
if request_id is None:
|
|
319
|
+
return
|
|
320
|
+
running_task = self._running_tasks.pop(request_id, None)
|
|
321
|
+
if running_task is not None:
|
|
322
|
+
running_task.cancel()
|
|
323
|
+
|
|
324
|
+
async def block_task():
|
|
325
|
+
"""This task is for blocking the request for a duration."""
|
|
326
|
+
try:
|
|
327
|
+
await asyncio.sleep(block_duration)
|
|
328
|
+
logger.info("Abort block end for request: %s", request_id)
|
|
329
|
+
except asyncio.CancelledError:
|
|
330
|
+
logger.info("Abort block is cancelled for request: %s", request_id)
|
|
331
|
+
|
|
332
|
+
if block_duration > 0:
|
|
333
|
+
logger.info("Abort block start for request: %s", request_id)
|
|
334
|
+
self._running_tasks[request_id] = asyncio.create_task(
|
|
335
|
+
block_task(), name=self._CANCEL_TASK_NAME
|
|
336
|
+
)
|
|
@@ -159,7 +159,7 @@
|
|
|
159
159
|
"model_name": "FishSpeech-1.4",
|
|
160
160
|
"model_family": "FishAudio",
|
|
161
161
|
"model_id": "fishaudio/fish-speech-1.4",
|
|
162
|
-
"model_revision": "
|
|
162
|
+
"model_revision": "069c573759936b35191d3380deb89183c0656f59",
|
|
163
163
|
"model_ability": "text-to-audio",
|
|
164
164
|
"multilingual": true
|
|
165
165
|
}
|
|
@@ -17,9 +17,11 @@ import gc
|
|
|
17
17
|
import inspect
|
|
18
18
|
import itertools
|
|
19
19
|
import logging
|
|
20
|
+
import os
|
|
20
21
|
import re
|
|
21
22
|
import sys
|
|
22
23
|
import warnings
|
|
24
|
+
from glob import glob
|
|
23
25
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
24
26
|
|
|
25
27
|
import PIL.Image
|
|
@@ -194,8 +196,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
|
|
|
194
196
|
if sys.platform != "darwin" and torch_dtype is None:
|
|
195
197
|
# The following params crashes on Mac M2
|
|
196
198
|
self._torch_dtype = self._kwargs["torch_dtype"] = torch.float16
|
|
197
|
-
self._kwargs["
|
|
198
|
-
|
|
199
|
+
self._kwargs["use_safetensors"] = any(
|
|
200
|
+
glob(os.path.join(self._model_path, "*/*.safetensors"))
|
|
201
|
+
)
|
|
199
202
|
if isinstance(torch_dtype, str):
|
|
200
203
|
self._kwargs["torch_dtype"] = getattr(torch, torch_dtype)
|
|
201
204
|
|