xinference 0.16.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +22 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +148 -12
- xinference/client/restful/restful_client.py +47 -2
- xinference/constants.py +1 -0
- xinference/core/model.py +45 -15
- xinference/core/supervisor.py +8 -2
- xinference/core/utils.py +67 -2
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +21 -4
- xinference/model/audio/fish_speech.py +70 -35
- xinference/model/audio/model_spec.json +81 -1
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +259 -4
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/embedding/model_spec_modelscope.json +1 -1
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +485 -6
- xinference/model/llm/llm_family_modelscope.json +519 -0
- xinference/model/llm/mlx/core.py +45 -3
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/utils.py +19 -0
- xinference/model/llm/vllm/core.py +84 -2
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- xinference/types.py +2 -1
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/METADATA +30 -6
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/RECORD +58 -63
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/top_level.txt +0 -0
xinference/_compat.py
CHANGED
|
@@ -60,6 +60,10 @@ from openai.types.chat.chat_completion_stream_options_param import (
|
|
|
60
60
|
ChatCompletionStreamOptionsParam,
|
|
61
61
|
)
|
|
62
62
|
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
63
|
+
from openai.types.shared_params.response_format_json_object import (
|
|
64
|
+
ResponseFormatJSONObject,
|
|
65
|
+
)
|
|
66
|
+
from openai.types.shared_params.response_format_text import ResponseFormatText
|
|
63
67
|
|
|
64
68
|
OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict(
|
|
65
69
|
ChatCompletionStreamOptionsParam
|
|
@@ -70,6 +74,23 @@ OpenAIChatCompletionNamedToolChoiceParam = create_model_from_typeddict(
|
|
|
70
74
|
)
|
|
71
75
|
|
|
72
76
|
|
|
77
|
+
class JSONSchema(BaseModel):
|
|
78
|
+
name: str
|
|
79
|
+
description: Optional[str] = None
|
|
80
|
+
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
|
81
|
+
strict: Optional[bool] = None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class ResponseFormatJSONSchema(BaseModel):
|
|
85
|
+
json_schema: JSONSchema
|
|
86
|
+
type: Literal["json_schema"]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
ResponseFormat = Union[
|
|
90
|
+
ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
|
|
73
94
|
class CreateChatCompletionOpenAI(BaseModel):
|
|
74
95
|
"""
|
|
75
96
|
Comes from source code: https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py
|
|
@@ -84,8 +105,7 @@ class CreateChatCompletionOpenAI(BaseModel):
|
|
|
84
105
|
n: Optional[int]
|
|
85
106
|
parallel_tool_calls: Optional[bool]
|
|
86
107
|
presence_penalty: Optional[float]
|
|
87
|
-
|
|
88
|
-
# response_format: ResponseFormat
|
|
108
|
+
response_format: Optional[ResponseFormat]
|
|
89
109
|
seed: Optional[int]
|
|
90
110
|
service_tier: Optional[Literal["auto", "default"]]
|
|
91
111
|
stop: Union[Optional[str], List[str]]
|
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-11-
|
|
11
|
+
"date": "2024-11-29T16:57:04+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "eb8ddd431f5c5fcb2216e25e0d43745f8455d9b9",
|
|
15
|
+
"version": "1.0.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -52,10 +52,14 @@ from xoscar.utils import get_next_port
|
|
|
52
52
|
|
|
53
53
|
from .._compat import BaseModel, Field
|
|
54
54
|
from .._version import get_versions
|
|
55
|
-
from ..constants import
|
|
55
|
+
from ..constants import (
|
|
56
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
57
|
+
XINFERENCE_DEFAULT_ENDPOINT_PORT,
|
|
58
|
+
XINFERENCE_DISABLE_METRICS,
|
|
59
|
+
)
|
|
56
60
|
from ..core.event import Event, EventCollectorActor, EventType
|
|
57
61
|
from ..core.supervisor import SupervisorActor
|
|
58
|
-
from ..core.utils import json_dumps
|
|
62
|
+
from ..core.utils import CancelMixin, json_dumps
|
|
59
63
|
from ..types import (
|
|
60
64
|
ChatCompletion,
|
|
61
65
|
Completion,
|
|
@@ -111,6 +115,7 @@ class RerankRequest(BaseModel):
|
|
|
111
115
|
return_documents: Optional[bool] = False
|
|
112
116
|
return_len: Optional[bool] = False
|
|
113
117
|
max_chunks_per_doc: Optional[int] = None
|
|
118
|
+
kwargs: Optional[str] = None
|
|
114
119
|
|
|
115
120
|
|
|
116
121
|
class TextToImageRequest(BaseModel):
|
|
@@ -206,7 +211,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
|
|
|
206
211
|
model_ability: List[str]
|
|
207
212
|
|
|
208
213
|
|
|
209
|
-
class RESTfulAPI:
|
|
214
|
+
class RESTfulAPI(CancelMixin):
|
|
210
215
|
def __init__(
|
|
211
216
|
self,
|
|
212
217
|
supervisor_address: str,
|
|
@@ -484,6 +489,16 @@ class RESTfulAPI:
|
|
|
484
489
|
else None
|
|
485
490
|
),
|
|
486
491
|
)
|
|
492
|
+
self._router.add_api_route(
|
|
493
|
+
"/v1/convert_ids_to_tokens",
|
|
494
|
+
self.convert_ids_to_tokens,
|
|
495
|
+
methods=["POST"],
|
|
496
|
+
dependencies=(
|
|
497
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
498
|
+
if self.is_authenticated()
|
|
499
|
+
else None
|
|
500
|
+
),
|
|
501
|
+
)
|
|
487
502
|
self._router.add_api_route(
|
|
488
503
|
"/v1/rerank",
|
|
489
504
|
self.rerank,
|
|
@@ -1214,6 +1229,9 @@ class RESTfulAPI:
|
|
|
1214
1229
|
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
|
|
1215
1230
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
1216
1231
|
|
|
1232
|
+
# guided_decoding params
|
|
1233
|
+
kwargs.update(self.extract_guided_params(raw_body=raw_body))
|
|
1234
|
+
|
|
1217
1235
|
# TODO: Decide if this default value override is necessary #1061
|
|
1218
1236
|
if body.max_tokens is None:
|
|
1219
1237
|
kwargs["max_tokens"] = max_tokens_field.default
|
|
@@ -1259,6 +1277,8 @@ class RESTfulAPI:
|
|
|
1259
1277
|
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
|
|
1260
1278
|
yield dict(data=json.dumps({"error": str(ex)}))
|
|
1261
1279
|
return
|
|
1280
|
+
finally:
|
|
1281
|
+
await model.decrease_serve_count()
|
|
1262
1282
|
|
|
1263
1283
|
return EventSourceResponse(stream_results())
|
|
1264
1284
|
else:
|
|
@@ -1307,15 +1327,45 @@ class RESTfulAPI:
|
|
|
1307
1327
|
await self._report_error_event(model_uid, str(e))
|
|
1308
1328
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1309
1329
|
|
|
1330
|
+
async def convert_ids_to_tokens(self, request: Request) -> Response:
|
|
1331
|
+
payload = await request.json()
|
|
1332
|
+
body = CreateEmbeddingRequest.parse_obj(payload)
|
|
1333
|
+
model_uid = body.model
|
|
1334
|
+
exclude = {
|
|
1335
|
+
"model",
|
|
1336
|
+
"input",
|
|
1337
|
+
"user",
|
|
1338
|
+
}
|
|
1339
|
+
kwargs = {key: value for key, value in payload.items() if key not in exclude}
|
|
1340
|
+
|
|
1341
|
+
try:
|
|
1342
|
+
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1343
|
+
except ValueError as ve:
|
|
1344
|
+
logger.error(str(ve), exc_info=True)
|
|
1345
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1346
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1347
|
+
except Exception as e:
|
|
1348
|
+
logger.error(e, exc_info=True)
|
|
1349
|
+
await self._report_error_event(model_uid, str(e))
|
|
1350
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1351
|
+
|
|
1352
|
+
try:
|
|
1353
|
+
decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs)
|
|
1354
|
+
return Response(decoded_texts, media_type="application/json")
|
|
1355
|
+
except RuntimeError as re:
|
|
1356
|
+
logger.error(re, exc_info=True)
|
|
1357
|
+
await self._report_error_event(model_uid, str(re))
|
|
1358
|
+
self.handle_request_limit_error(re)
|
|
1359
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1360
|
+
except Exception as e:
|
|
1361
|
+
logger.error(e, exc_info=True)
|
|
1362
|
+
await self._report_error_event(model_uid, str(e))
|
|
1363
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1364
|
+
|
|
1310
1365
|
async def rerank(self, request: Request) -> Response:
|
|
1311
1366
|
payload = await request.json()
|
|
1312
1367
|
body = RerankRequest.parse_obj(payload)
|
|
1313
1368
|
model_uid = body.model
|
|
1314
|
-
kwargs = {
|
|
1315
|
-
key: value
|
|
1316
|
-
for key, value in payload.items()
|
|
1317
|
-
if key not in RerankRequest.__annotations__.keys()
|
|
1318
|
-
}
|
|
1319
1369
|
|
|
1320
1370
|
try:
|
|
1321
1371
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -1329,6 +1379,10 @@ class RESTfulAPI:
|
|
|
1329
1379
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1330
1380
|
|
|
1331
1381
|
try:
|
|
1382
|
+
if body.kwargs is not None:
|
|
1383
|
+
parsed_kwargs = json.loads(body.kwargs)
|
|
1384
|
+
else:
|
|
1385
|
+
parsed_kwargs = {}
|
|
1332
1386
|
scores = await model.rerank(
|
|
1333
1387
|
body.documents,
|
|
1334
1388
|
body.query,
|
|
@@ -1336,7 +1390,7 @@ class RESTfulAPI:
|
|
|
1336
1390
|
max_chunks_per_doc=body.max_chunks_per_doc,
|
|
1337
1391
|
return_documents=body.return_documents,
|
|
1338
1392
|
return_len=body.return_len,
|
|
1339
|
-
**
|
|
1393
|
+
**parsed_kwargs,
|
|
1340
1394
|
)
|
|
1341
1395
|
return Response(scores, media_type="application/json")
|
|
1342
1396
|
except RuntimeError as re:
|
|
@@ -1491,8 +1545,16 @@ class RESTfulAPI:
|
|
|
1491
1545
|
**parsed_kwargs,
|
|
1492
1546
|
)
|
|
1493
1547
|
if body.stream:
|
|
1548
|
+
|
|
1549
|
+
async def stream_results():
|
|
1550
|
+
try:
|
|
1551
|
+
async for item in out:
|
|
1552
|
+
yield item
|
|
1553
|
+
finally:
|
|
1554
|
+
await model.decrease_serve_count()
|
|
1555
|
+
|
|
1494
1556
|
return EventSourceResponse(
|
|
1495
|
-
media_type="application/octet-stream", content=
|
|
1557
|
+
media_type="application/octet-stream", content=stream_results()
|
|
1496
1558
|
)
|
|
1497
1559
|
else:
|
|
1498
1560
|
return Response(media_type="application/octet-stream", content=out)
|
|
@@ -1531,8 +1593,11 @@ class RESTfulAPI:
|
|
|
1531
1593
|
await self._report_error_event(model_uid, str(e))
|
|
1532
1594
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1533
1595
|
|
|
1596
|
+
request_id = None
|
|
1534
1597
|
try:
|
|
1535
1598
|
kwargs = json.loads(body.kwargs) if body.kwargs else {}
|
|
1599
|
+
request_id = kwargs.get("request_id")
|
|
1600
|
+
self._add_running_task(request_id)
|
|
1536
1601
|
image_list = await model.text_to_image(
|
|
1537
1602
|
prompt=body.prompt,
|
|
1538
1603
|
n=body.n,
|
|
@@ -1541,6 +1606,11 @@ class RESTfulAPI:
|
|
|
1541
1606
|
**kwargs,
|
|
1542
1607
|
)
|
|
1543
1608
|
return Response(content=image_list, media_type="application/json")
|
|
1609
|
+
except asyncio.CancelledError:
|
|
1610
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1611
|
+
logger.error(err_str)
|
|
1612
|
+
await self._report_error_event(model_uid, err_str)
|
|
1613
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1544
1614
|
except RuntimeError as re:
|
|
1545
1615
|
logger.error(re, exc_info=True)
|
|
1546
1616
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1686,11 +1756,14 @@ class RESTfulAPI:
|
|
|
1686
1756
|
await self._report_error_event(model_uid, str(e))
|
|
1687
1757
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1688
1758
|
|
|
1759
|
+
request_id = None
|
|
1689
1760
|
try:
|
|
1690
1761
|
if kwargs is not None:
|
|
1691
1762
|
parsed_kwargs = json.loads(kwargs)
|
|
1692
1763
|
else:
|
|
1693
1764
|
parsed_kwargs = {}
|
|
1765
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1766
|
+
self._add_running_task(request_id)
|
|
1694
1767
|
image_list = await model_ref.image_to_image(
|
|
1695
1768
|
image=Image.open(image.file),
|
|
1696
1769
|
prompt=prompt,
|
|
@@ -1701,6 +1774,11 @@ class RESTfulAPI:
|
|
|
1701
1774
|
**parsed_kwargs,
|
|
1702
1775
|
)
|
|
1703
1776
|
return Response(content=image_list, media_type="application/json")
|
|
1777
|
+
except asyncio.CancelledError:
|
|
1778
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1779
|
+
logger.error(err_str)
|
|
1780
|
+
await self._report_error_event(model_uid, err_str)
|
|
1781
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1704
1782
|
except RuntimeError as re:
|
|
1705
1783
|
logger.error(re, exc_info=True)
|
|
1706
1784
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1734,11 +1812,14 @@ class RESTfulAPI:
|
|
|
1734
1812
|
await self._report_error_event(model_uid, str(e))
|
|
1735
1813
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1736
1814
|
|
|
1815
|
+
request_id = None
|
|
1737
1816
|
try:
|
|
1738
1817
|
if kwargs is not None:
|
|
1739
1818
|
parsed_kwargs = json.loads(kwargs)
|
|
1740
1819
|
else:
|
|
1741
1820
|
parsed_kwargs = {}
|
|
1821
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1822
|
+
self._add_running_task(request_id)
|
|
1742
1823
|
im = Image.open(image.file)
|
|
1743
1824
|
mask_im = Image.open(mask_image.file)
|
|
1744
1825
|
if not size:
|
|
@@ -1755,6 +1836,11 @@ class RESTfulAPI:
|
|
|
1755
1836
|
**parsed_kwargs,
|
|
1756
1837
|
)
|
|
1757
1838
|
return Response(content=image_list, media_type="application/json")
|
|
1839
|
+
except asyncio.CancelledError:
|
|
1840
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1841
|
+
logger.error(err_str)
|
|
1842
|
+
await self._report_error_event(model_uid, err_str)
|
|
1843
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1758
1844
|
except RuntimeError as re:
|
|
1759
1845
|
logger.error(re, exc_info=True)
|
|
1760
1846
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1782,17 +1868,25 @@ class RESTfulAPI:
|
|
|
1782
1868
|
await self._report_error_event(model_uid, str(e))
|
|
1783
1869
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1784
1870
|
|
|
1871
|
+
request_id = None
|
|
1785
1872
|
try:
|
|
1786
1873
|
if kwargs is not None:
|
|
1787
1874
|
parsed_kwargs = json.loads(kwargs)
|
|
1788
1875
|
else:
|
|
1789
1876
|
parsed_kwargs = {}
|
|
1877
|
+
request_id = parsed_kwargs.get("request_id")
|
|
1878
|
+
self._add_running_task(request_id)
|
|
1790
1879
|
im = Image.open(image.file)
|
|
1791
1880
|
text = await model_ref.ocr(
|
|
1792
1881
|
image=im,
|
|
1793
1882
|
**parsed_kwargs,
|
|
1794
1883
|
)
|
|
1795
1884
|
return Response(content=text, media_type="text/plain")
|
|
1885
|
+
except asyncio.CancelledError:
|
|
1886
|
+
err_str = f"The request has been cancelled: {request_id}"
|
|
1887
|
+
logger.error(err_str)
|
|
1888
|
+
await self._report_error_event(model_uid, err_str)
|
|
1889
|
+
raise HTTPException(status_code=409, detail=err_str)
|
|
1796
1890
|
except RuntimeError as re:
|
|
1797
1891
|
logger.error(re, exc_info=True)
|
|
1798
1892
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1880,9 +1974,13 @@ class RESTfulAPI:
|
|
|
1880
1974
|
"logit_bias_type",
|
|
1881
1975
|
"user",
|
|
1882
1976
|
}
|
|
1977
|
+
|
|
1883
1978
|
raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
|
|
1884
1979
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
1885
1980
|
|
|
1981
|
+
# guided_decoding params
|
|
1982
|
+
kwargs.update(self.extract_guided_params(raw_body=raw_body))
|
|
1983
|
+
|
|
1886
1984
|
# TODO: Decide if this default value override is necessary #1061
|
|
1887
1985
|
if body.max_tokens is None:
|
|
1888
1986
|
kwargs["max_tokens"] = max_tokens_field.default
|
|
@@ -1991,6 +2089,8 @@ class RESTfulAPI:
|
|
|
1991
2089
|
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
|
|
1992
2090
|
yield dict(data=json.dumps({"error": str(ex)}))
|
|
1993
2091
|
return
|
|
2092
|
+
finally:
|
|
2093
|
+
await model.decrease_serve_count()
|
|
1994
2094
|
|
|
1995
2095
|
return EventSourceResponse(stream_results())
|
|
1996
2096
|
else:
|
|
@@ -2111,10 +2211,25 @@ class RESTfulAPI:
|
|
|
2111
2211
|
logger.error(e, exc_info=True)
|
|
2112
2212
|
raise HTTPException(status_code=500, detail=str(e))
|
|
2113
2213
|
|
|
2114
|
-
async def abort_request(
|
|
2214
|
+
async def abort_request(
|
|
2215
|
+
self, request: Request, model_uid: str, request_id: str
|
|
2216
|
+
) -> JSONResponse:
|
|
2115
2217
|
try:
|
|
2218
|
+
payload = await request.json()
|
|
2219
|
+
block_duration = payload.get(
|
|
2220
|
+
"block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION
|
|
2221
|
+
)
|
|
2222
|
+
logger.info(
|
|
2223
|
+
"Abort request with model uid: %s, request id: %s, block duration: %s",
|
|
2224
|
+
model_uid,
|
|
2225
|
+
request_id,
|
|
2226
|
+
block_duration,
|
|
2227
|
+
)
|
|
2116
2228
|
supervisor_ref = await self._get_supervisor_ref()
|
|
2117
|
-
res = await supervisor_ref.abort_request(
|
|
2229
|
+
res = await supervisor_ref.abort_request(
|
|
2230
|
+
model_uid, request_id, block_duration
|
|
2231
|
+
)
|
|
2232
|
+
self._cancel_running_task(request_id, block_duration)
|
|
2118
2233
|
return JSONResponse(content=res)
|
|
2119
2234
|
except Exception as e:
|
|
2120
2235
|
logger.error(e, exc_info=True)
|
|
@@ -2228,6 +2343,27 @@ class RESTfulAPI:
|
|
|
2228
2343
|
logger.error(e, exc_info=True)
|
|
2229
2344
|
raise HTTPException(status_code=500, detail=str(e))
|
|
2230
2345
|
|
|
2346
|
+
@staticmethod
|
|
2347
|
+
def extract_guided_params(raw_body: dict) -> dict:
|
|
2348
|
+
kwargs = {}
|
|
2349
|
+
if raw_body.get("guided_json") is not None:
|
|
2350
|
+
kwargs["guided_json"] = raw_body.get("guided_json")
|
|
2351
|
+
if raw_body.get("guided_regex") is not None:
|
|
2352
|
+
kwargs["guided_regex"] = raw_body.get("guided_regex")
|
|
2353
|
+
if raw_body.get("guided_choice") is not None:
|
|
2354
|
+
kwargs["guided_choice"] = raw_body.get("guided_choice")
|
|
2355
|
+
if raw_body.get("guided_grammar") is not None:
|
|
2356
|
+
kwargs["guided_grammar"] = raw_body.get("guided_grammar")
|
|
2357
|
+
if raw_body.get("guided_json_object") is not None:
|
|
2358
|
+
kwargs["guided_json_object"] = raw_body.get("guided_json_object")
|
|
2359
|
+
if raw_body.get("guided_decoding_backend") is not None:
|
|
2360
|
+
kwargs["guided_decoding_backend"] = raw_body.get("guided_decoding_backend")
|
|
2361
|
+
if raw_body.get("guided_whitespace_pattern") is not None:
|
|
2362
|
+
kwargs["guided_whitespace_pattern"] = raw_body.get(
|
|
2363
|
+
"guided_whitespace_pattern"
|
|
2364
|
+
)
|
|
2365
|
+
return kwargs
|
|
2366
|
+
|
|
2231
2367
|
|
|
2232
2368
|
def run(
|
|
2233
2369
|
supervisor_address: str,
|
|
@@ -126,6 +126,43 @@ class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
|
|
|
126
126
|
response_data = response.json()
|
|
127
127
|
return response_data
|
|
128
128
|
|
|
129
|
+
def convert_ids_to_tokens(
|
|
130
|
+
self, input: Union[List, List[List]], **kwargs
|
|
131
|
+
) -> List[str]:
|
|
132
|
+
"""
|
|
133
|
+
Convert token IDs to human readable tokens via RESTful APIs.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
input: Union[List, List[List]]
|
|
138
|
+
Input token IDs to convert, can be a single list of token IDs or a list of token ID lists.
|
|
139
|
+
To convert multiple sequences in a single request, pass a list of token ID lists.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
list
|
|
144
|
+
A list of decoded tokens in human readable format.
|
|
145
|
+
|
|
146
|
+
Raises
|
|
147
|
+
------
|
|
148
|
+
RuntimeError
|
|
149
|
+
Report the failure of token conversion and provide the error message.
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
url = f"{self._base_url}/v1/convert_ids_to_tokens"
|
|
153
|
+
request_body = {
|
|
154
|
+
"model": self._model_uid,
|
|
155
|
+
"input": input,
|
|
156
|
+
}
|
|
157
|
+
request_body.update(kwargs)
|
|
158
|
+
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
159
|
+
if response.status_code != 200:
|
|
160
|
+
raise RuntimeError(
|
|
161
|
+
f"Failed to decode token ids, detail: {_get_error_string(response)}"
|
|
162
|
+
)
|
|
163
|
+
response_data = response.json()
|
|
164
|
+
return response_data
|
|
165
|
+
|
|
129
166
|
|
|
130
167
|
class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
131
168
|
def rerank(
|
|
@@ -174,6 +211,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
174
211
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
175
212
|
"return_documents": return_documents,
|
|
176
213
|
"return_len": return_len,
|
|
214
|
+
"kwargs": json.dumps(kwargs),
|
|
177
215
|
}
|
|
178
216
|
request_body.update(kwargs)
|
|
179
217
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
@@ -703,6 +741,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
703
741
|
The speed of the generated audio.
|
|
704
742
|
stream: bool
|
|
705
743
|
Use stream or not.
|
|
744
|
+
prompt_speech: bytes
|
|
745
|
+
The audio bytes to be provided to the model.
|
|
706
746
|
|
|
707
747
|
Returns
|
|
708
748
|
-------
|
|
@@ -1357,7 +1397,7 @@ class Client:
|
|
|
1357
1397
|
response_data = response.json()
|
|
1358
1398
|
return response_data
|
|
1359
1399
|
|
|
1360
|
-
def abort_request(self, model_uid: str, request_id: str):
|
|
1400
|
+
def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30):
|
|
1361
1401
|
"""
|
|
1362
1402
|
Abort a request.
|
|
1363
1403
|
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
@@ -1369,13 +1409,18 @@ class Client:
|
|
|
1369
1409
|
Model uid.
|
|
1370
1410
|
request_id: str
|
|
1371
1411
|
Request id.
|
|
1412
|
+
block_duration: int
|
|
1413
|
+
The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
|
|
1414
|
+
prevent it from taking effect if it arrives before the request operation.
|
|
1372
1415
|
Returns
|
|
1373
1416
|
-------
|
|
1374
1417
|
Dict
|
|
1375
1418
|
Return empty dict.
|
|
1376
1419
|
"""
|
|
1377
1420
|
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
|
|
1378
|
-
response = requests.post(
|
|
1421
|
+
response = requests.post(
|
|
1422
|
+
url, headers=self._headers, json={"block_duration": block_duration}
|
|
1423
|
+
)
|
|
1379
1424
|
if response.status_code != 200:
|
|
1380
1425
|
raise RuntimeError(
|
|
1381
1426
|
f"Failed to abort request, detail: {_get_error_string(response)}"
|
xinference/constants.py
CHANGED
xinference/core/model.py
CHANGED
|
@@ -41,6 +41,7 @@ import sse_starlette.sse
|
|
|
41
41
|
import xoscar as xo
|
|
42
42
|
|
|
43
43
|
from ..constants import (
|
|
44
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
44
45
|
XINFERENCE_LAUNCH_MODEL_RETRY,
|
|
45
46
|
XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE,
|
|
46
47
|
)
|
|
@@ -57,7 +58,7 @@ import logging
|
|
|
57
58
|
logger = logging.getLogger(__name__)
|
|
58
59
|
|
|
59
60
|
from ..device_utils import empty_cache
|
|
60
|
-
from .utils import json_dumps, log_async
|
|
61
|
+
from .utils import CancelMixin, json_dumps, log_async
|
|
61
62
|
|
|
62
63
|
try:
|
|
63
64
|
from torch.cuda import OutOfMemoryError
|
|
@@ -90,21 +91,26 @@ def request_limit(fn):
|
|
|
90
91
|
logger.debug(
|
|
91
92
|
f"Request {fn.__name__}, current serve request count: {self._serve_count}, request limit: {self._request_limits} for the model {self.model_uid()}"
|
|
92
93
|
)
|
|
93
|
-
if self.
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
94
|
+
if 1 + self._serve_count <= self._request_limits:
|
|
95
|
+
self._serve_count += 1
|
|
96
|
+
else:
|
|
97
|
+
raise RuntimeError(
|
|
98
|
+
f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
|
|
99
|
+
)
|
|
100
|
+
ret = None
|
|
100
101
|
try:
|
|
101
102
|
ret = await fn(self, *args, **kwargs)
|
|
102
103
|
finally:
|
|
103
|
-
if
|
|
104
|
+
if ret is not None and (
|
|
105
|
+
inspect.isasyncgen(ret) or inspect.isgenerator(ret)
|
|
106
|
+
):
|
|
107
|
+
# stream case, let client call model_ref to decrease self._serve_count
|
|
108
|
+
pass
|
|
109
|
+
else:
|
|
104
110
|
self._serve_count -= 1
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
111
|
+
logger.debug(
|
|
112
|
+
f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
|
|
113
|
+
)
|
|
108
114
|
return ret
|
|
109
115
|
|
|
110
116
|
return wrapped_func
|
|
@@ -136,7 +142,7 @@ def oom_check(fn):
|
|
|
136
142
|
return _wrapper
|
|
137
143
|
|
|
138
144
|
|
|
139
|
-
class ModelActor(xo.StatelessActor):
|
|
145
|
+
class ModelActor(xo.StatelessActor, CancelMixin):
|
|
140
146
|
_replica_model_uid: Optional[str]
|
|
141
147
|
|
|
142
148
|
@classmethod
|
|
@@ -214,7 +220,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
214
220
|
self._model_description = (
|
|
215
221
|
model_description.to_dict() if model_description else {}
|
|
216
222
|
)
|
|
217
|
-
self._request_limits =
|
|
223
|
+
self._request_limits = (
|
|
224
|
+
float("inf") if request_limits is None else request_limits
|
|
225
|
+
)
|
|
218
226
|
self._pending_requests: asyncio.Queue = asyncio.Queue()
|
|
219
227
|
self._handle_pending_requests_task = None
|
|
220
228
|
self._lock = (
|
|
@@ -267,6 +275,9 @@ class ModelActor(xo.StatelessActor):
|
|
|
267
275
|
def __repr__(self) -> str:
|
|
268
276
|
return f"ModelActor({self._replica_model_uid})"
|
|
269
277
|
|
|
278
|
+
def decrease_serve_count(self):
|
|
279
|
+
self._serve_count -= 1
|
|
280
|
+
|
|
270
281
|
async def _record_completion_metrics(
|
|
271
282
|
self, duration, completion_tokens, prompt_tokens
|
|
272
283
|
):
|
|
@@ -553,6 +564,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
553
564
|
|
|
554
565
|
@oom_check
|
|
555
566
|
async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
|
|
567
|
+
self._add_running_task(kwargs.get("request_id"))
|
|
556
568
|
if self._lock is None:
|
|
557
569
|
if inspect.iscoroutinefunction(fn):
|
|
558
570
|
ret = await fn(*args, **kwargs)
|
|
@@ -761,9 +773,14 @@ class ModelActor(xo.StatelessActor):
|
|
|
761
773
|
prompt_tokens,
|
|
762
774
|
)
|
|
763
775
|
|
|
764
|
-
async def abort_request(
|
|
776
|
+
async def abort_request(
|
|
777
|
+
self,
|
|
778
|
+
request_id: str,
|
|
779
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
780
|
+
) -> str:
|
|
765
781
|
from .utils import AbortRequestMessage
|
|
766
782
|
|
|
783
|
+
self._cancel_running_task(request_id, block_duration)
|
|
767
784
|
if self.allow_batching():
|
|
768
785
|
if self._scheduler_ref is None:
|
|
769
786
|
return AbortRequestMessage.NOT_FOUND.name
|
|
@@ -787,6 +804,19 @@ class ModelActor(xo.StatelessActor):
|
|
|
787
804
|
f"Model {self._model.model_spec} is not for creating embedding."
|
|
788
805
|
)
|
|
789
806
|
|
|
807
|
+
@request_limit
|
|
808
|
+
@log_async(logger=logger)
|
|
809
|
+
async def convert_ids_to_tokens(
|
|
810
|
+
self, input: Union[List, List[List]], *args, **kwargs
|
|
811
|
+
):
|
|
812
|
+
kwargs.pop("request_id", None)
|
|
813
|
+
if hasattr(self._model, "convert_ids_to_tokens"):
|
|
814
|
+
return await self._call_wrapper_json(
|
|
815
|
+
self._model.convert_ids_to_tokens, input, *args, **kwargs
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
raise AttributeError(f"Model {self._model.model_spec} can convert token id.")
|
|
819
|
+
|
|
790
820
|
@request_limit
|
|
791
821
|
@log_async(logger=logger)
|
|
792
822
|
async def rerank(
|
xinference/core/supervisor.py
CHANGED
|
@@ -35,6 +35,7 @@ from typing import (
|
|
|
35
35
|
import xoscar as xo
|
|
36
36
|
|
|
37
37
|
from ..constants import (
|
|
38
|
+
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
38
39
|
XINFERENCE_DISABLE_HEALTH_CHECK,
|
|
39
40
|
XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
|
|
40
41
|
XINFERENCE_HEALTH_CHECK_INTERVAL,
|
|
@@ -1213,7 +1214,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1213
1214
|
return cached_models
|
|
1214
1215
|
|
|
1215
1216
|
@log_async(logger=logger)
|
|
1216
|
-
async def abort_request(
|
|
1217
|
+
async def abort_request(
|
|
1218
|
+
self,
|
|
1219
|
+
model_uid: str,
|
|
1220
|
+
request_id: str,
|
|
1221
|
+
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
|
|
1222
|
+
) -> Dict:
|
|
1217
1223
|
from .scheduler import AbortRequestMessage
|
|
1218
1224
|
|
|
1219
1225
|
res = {"msg": AbortRequestMessage.NO_OP.name}
|
|
@@ -1228,7 +1234,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1228
1234
|
if worker_ref is None:
|
|
1229
1235
|
continue
|
|
1230
1236
|
model_ref = await worker_ref.get_model(model_uid=rep_mid)
|
|
1231
|
-
result_info = await model_ref.abort_request(request_id)
|
|
1237
|
+
result_info = await model_ref.abort_request(request_id, block_duration)
|
|
1232
1238
|
res["msg"] = result_info
|
|
1233
1239
|
if result_info == AbortRequestMessage.DONE.name:
|
|
1234
1240
|
break
|