xinference 0.16.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (69) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +148 -12
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/model.py +45 -15
  7. xinference/core/supervisor.py +8 -2
  8. xinference/core/utils.py +67 -2
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +21 -4
  11. xinference/model/audio/fish_speech.py +70 -35
  12. xinference/model/audio/model_spec.json +81 -1
  13. xinference/model/audio/whisper_mlx.py +208 -0
  14. xinference/model/embedding/core.py +259 -4
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/embedding/model_spec_modelscope.json +1 -1
  17. xinference/model/image/stable_diffusion/core.py +5 -2
  18. xinference/model/llm/__init__.py +2 -0
  19. xinference/model/llm/llm_family.json +485 -6
  20. xinference/model/llm/llm_family_modelscope.json +519 -0
  21. xinference/model/llm/mlx/core.py +45 -3
  22. xinference/model/llm/sglang/core.py +1 -0
  23. xinference/model/llm/transformers/core.py +1 -0
  24. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  25. xinference/model/llm/utils.py +19 -0
  26. xinference/model/llm/vllm/core.py +84 -2
  27. xinference/model/rerank/core.py +11 -4
  28. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  29. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  34. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  35. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  37. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  38. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  39. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  40. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  41. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  42. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  43. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  44. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  45. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  46. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  47. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  48. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  49. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  50. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  51. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  52. xinference/types.py +2 -1
  53. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/METADATA +30 -6
  54. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/RECORD +58 -63
  55. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/WHEEL +1 -1
  56. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  58. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  63. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  64. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  65. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  67. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/LICENSE +0 -0
  68. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/entry_points.txt +0 -0
  69. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/top_level.txt +0 -0
xinference/_compat.py CHANGED
@@ -60,6 +60,10 @@ from openai.types.chat.chat_completion_stream_options_param import (
60
60
  ChatCompletionStreamOptionsParam,
61
61
  )
62
62
  from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
63
+ from openai.types.shared_params.response_format_json_object import (
64
+ ResponseFormatJSONObject,
65
+ )
66
+ from openai.types.shared_params.response_format_text import ResponseFormatText
63
67
 
64
68
  OpenAIChatCompletionStreamOptionsParam = create_model_from_typeddict(
65
69
  ChatCompletionStreamOptionsParam
@@ -70,6 +74,23 @@ OpenAIChatCompletionNamedToolChoiceParam = create_model_from_typeddict(
70
74
  )
71
75
 
72
76
 
77
+ class JSONSchema(BaseModel):
78
+ name: str
79
+ description: Optional[str] = None
80
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
81
+ strict: Optional[bool] = None
82
+
83
+
84
+ class ResponseFormatJSONSchema(BaseModel):
85
+ json_schema: JSONSchema
86
+ type: Literal["json_schema"]
87
+
88
+
89
+ ResponseFormat = Union[
90
+ ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema
91
+ ]
92
+
93
+
73
94
  class CreateChatCompletionOpenAI(BaseModel):
74
95
  """
75
96
  Comes from source code: https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py
@@ -84,8 +105,7 @@ class CreateChatCompletionOpenAI(BaseModel):
84
105
  n: Optional[int]
85
106
  parallel_tool_calls: Optional[bool]
86
107
  presence_penalty: Optional[float]
87
- # we do not support this
88
- # response_format: ResponseFormat
108
+ response_format: Optional[ResponseFormat]
89
109
  seed: Optional[int]
90
110
  service_tier: Optional[Literal["auto", "default"]]
91
111
  stop: Union[Optional[str], List[str]]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-07T16:55:36+0800",
11
+ "date": "2024-11-29T16:57:04+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "85ab86bf1c0967e45fbec995534cd5a0c9a9c439",
15
- "version": "0.16.3"
14
+ "full-revisionid": "eb8ddd431f5c5fcb2216e25e0d43745f8455d9b9",
15
+ "version": "1.0.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -52,10 +52,14 @@ from xoscar.utils import get_next_port
52
52
 
53
53
  from .._compat import BaseModel, Field
54
54
  from .._version import get_versions
55
- from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS
55
+ from ..constants import (
56
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
57
+ XINFERENCE_DEFAULT_ENDPOINT_PORT,
58
+ XINFERENCE_DISABLE_METRICS,
59
+ )
56
60
  from ..core.event import Event, EventCollectorActor, EventType
57
61
  from ..core.supervisor import SupervisorActor
58
- from ..core.utils import json_dumps
62
+ from ..core.utils import CancelMixin, json_dumps
59
63
  from ..types import (
60
64
  ChatCompletion,
61
65
  Completion,
@@ -111,6 +115,7 @@ class RerankRequest(BaseModel):
111
115
  return_documents: Optional[bool] = False
112
116
  return_len: Optional[bool] = False
113
117
  max_chunks_per_doc: Optional[int] = None
118
+ kwargs: Optional[str] = None
114
119
 
115
120
 
116
121
  class TextToImageRequest(BaseModel):
@@ -206,7 +211,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
206
211
  model_ability: List[str]
207
212
 
208
213
 
209
- class RESTfulAPI:
214
+ class RESTfulAPI(CancelMixin):
210
215
  def __init__(
211
216
  self,
212
217
  supervisor_address: str,
@@ -484,6 +489,16 @@ class RESTfulAPI:
484
489
  else None
485
490
  ),
486
491
  )
492
+ self._router.add_api_route(
493
+ "/v1/convert_ids_to_tokens",
494
+ self.convert_ids_to_tokens,
495
+ methods=["POST"],
496
+ dependencies=(
497
+ [Security(self._auth_service, scopes=["models:read"])]
498
+ if self.is_authenticated()
499
+ else None
500
+ ),
501
+ )
487
502
  self._router.add_api_route(
488
503
  "/v1/rerank",
489
504
  self.rerank,
@@ -1214,6 +1229,9 @@ class RESTfulAPI:
1214
1229
  raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1215
1230
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1216
1231
 
1232
+ # guided_decoding params
1233
+ kwargs.update(self.extract_guided_params(raw_body=raw_body))
1234
+
1217
1235
  # TODO: Decide if this default value override is necessary #1061
1218
1236
  if body.max_tokens is None:
1219
1237
  kwargs["max_tokens"] = max_tokens_field.default
@@ -1259,6 +1277,8 @@ class RESTfulAPI:
1259
1277
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
1260
1278
  yield dict(data=json.dumps({"error": str(ex)}))
1261
1279
  return
1280
+ finally:
1281
+ await model.decrease_serve_count()
1262
1282
 
1263
1283
  return EventSourceResponse(stream_results())
1264
1284
  else:
@@ -1307,15 +1327,45 @@ class RESTfulAPI:
1307
1327
  await self._report_error_event(model_uid, str(e))
1308
1328
  raise HTTPException(status_code=500, detail=str(e))
1309
1329
 
1330
+ async def convert_ids_to_tokens(self, request: Request) -> Response:
1331
+ payload = await request.json()
1332
+ body = CreateEmbeddingRequest.parse_obj(payload)
1333
+ model_uid = body.model
1334
+ exclude = {
1335
+ "model",
1336
+ "input",
1337
+ "user",
1338
+ }
1339
+ kwargs = {key: value for key, value in payload.items() if key not in exclude}
1340
+
1341
+ try:
1342
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1343
+ except ValueError as ve:
1344
+ logger.error(str(ve), exc_info=True)
1345
+ await self._report_error_event(model_uid, str(ve))
1346
+ raise HTTPException(status_code=400, detail=str(ve))
1347
+ except Exception as e:
1348
+ logger.error(e, exc_info=True)
1349
+ await self._report_error_event(model_uid, str(e))
1350
+ raise HTTPException(status_code=500, detail=str(e))
1351
+
1352
+ try:
1353
+ decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs)
1354
+ return Response(decoded_texts, media_type="application/json")
1355
+ except RuntimeError as re:
1356
+ logger.error(re, exc_info=True)
1357
+ await self._report_error_event(model_uid, str(re))
1358
+ self.handle_request_limit_error(re)
1359
+ raise HTTPException(status_code=400, detail=str(re))
1360
+ except Exception as e:
1361
+ logger.error(e, exc_info=True)
1362
+ await self._report_error_event(model_uid, str(e))
1363
+ raise HTTPException(status_code=500, detail=str(e))
1364
+
1310
1365
  async def rerank(self, request: Request) -> Response:
1311
1366
  payload = await request.json()
1312
1367
  body = RerankRequest.parse_obj(payload)
1313
1368
  model_uid = body.model
1314
- kwargs = {
1315
- key: value
1316
- for key, value in payload.items()
1317
- if key not in RerankRequest.__annotations__.keys()
1318
- }
1319
1369
 
1320
1370
  try:
1321
1371
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1329,6 +1379,10 @@ class RESTfulAPI:
1329
1379
  raise HTTPException(status_code=500, detail=str(e))
1330
1380
 
1331
1381
  try:
1382
+ if body.kwargs is not None:
1383
+ parsed_kwargs = json.loads(body.kwargs)
1384
+ else:
1385
+ parsed_kwargs = {}
1332
1386
  scores = await model.rerank(
1333
1387
  body.documents,
1334
1388
  body.query,
@@ -1336,7 +1390,7 @@ class RESTfulAPI:
1336
1390
  max_chunks_per_doc=body.max_chunks_per_doc,
1337
1391
  return_documents=body.return_documents,
1338
1392
  return_len=body.return_len,
1339
- **kwargs,
1393
+ **parsed_kwargs,
1340
1394
  )
1341
1395
  return Response(scores, media_type="application/json")
1342
1396
  except RuntimeError as re:
@@ -1491,8 +1545,16 @@ class RESTfulAPI:
1491
1545
  **parsed_kwargs,
1492
1546
  )
1493
1547
  if body.stream:
1548
+
1549
+ async def stream_results():
1550
+ try:
1551
+ async for item in out:
1552
+ yield item
1553
+ finally:
1554
+ await model.decrease_serve_count()
1555
+
1494
1556
  return EventSourceResponse(
1495
- media_type="application/octet-stream", content=out
1557
+ media_type="application/octet-stream", content=stream_results()
1496
1558
  )
1497
1559
  else:
1498
1560
  return Response(media_type="application/octet-stream", content=out)
@@ -1531,8 +1593,11 @@ class RESTfulAPI:
1531
1593
  await self._report_error_event(model_uid, str(e))
1532
1594
  raise HTTPException(status_code=500, detail=str(e))
1533
1595
 
1596
+ request_id = None
1534
1597
  try:
1535
1598
  kwargs = json.loads(body.kwargs) if body.kwargs else {}
1599
+ request_id = kwargs.get("request_id")
1600
+ self._add_running_task(request_id)
1536
1601
  image_list = await model.text_to_image(
1537
1602
  prompt=body.prompt,
1538
1603
  n=body.n,
@@ -1541,6 +1606,11 @@ class RESTfulAPI:
1541
1606
  **kwargs,
1542
1607
  )
1543
1608
  return Response(content=image_list, media_type="application/json")
1609
+ except asyncio.CancelledError:
1610
+ err_str = f"The request has been cancelled: {request_id}"
1611
+ logger.error(err_str)
1612
+ await self._report_error_event(model_uid, err_str)
1613
+ raise HTTPException(status_code=409, detail=err_str)
1544
1614
  except RuntimeError as re:
1545
1615
  logger.error(re, exc_info=True)
1546
1616
  await self._report_error_event(model_uid, str(re))
@@ -1686,11 +1756,14 @@ class RESTfulAPI:
1686
1756
  await self._report_error_event(model_uid, str(e))
1687
1757
  raise HTTPException(status_code=500, detail=str(e))
1688
1758
 
1759
+ request_id = None
1689
1760
  try:
1690
1761
  if kwargs is not None:
1691
1762
  parsed_kwargs = json.loads(kwargs)
1692
1763
  else:
1693
1764
  parsed_kwargs = {}
1765
+ request_id = parsed_kwargs.get("request_id")
1766
+ self._add_running_task(request_id)
1694
1767
  image_list = await model_ref.image_to_image(
1695
1768
  image=Image.open(image.file),
1696
1769
  prompt=prompt,
@@ -1701,6 +1774,11 @@ class RESTfulAPI:
1701
1774
  **parsed_kwargs,
1702
1775
  )
1703
1776
  return Response(content=image_list, media_type="application/json")
1777
+ except asyncio.CancelledError:
1778
+ err_str = f"The request has been cancelled: {request_id}"
1779
+ logger.error(err_str)
1780
+ await self._report_error_event(model_uid, err_str)
1781
+ raise HTTPException(status_code=409, detail=err_str)
1704
1782
  except RuntimeError as re:
1705
1783
  logger.error(re, exc_info=True)
1706
1784
  await self._report_error_event(model_uid, str(re))
@@ -1734,11 +1812,14 @@ class RESTfulAPI:
1734
1812
  await self._report_error_event(model_uid, str(e))
1735
1813
  raise HTTPException(status_code=500, detail=str(e))
1736
1814
 
1815
+ request_id = None
1737
1816
  try:
1738
1817
  if kwargs is not None:
1739
1818
  parsed_kwargs = json.loads(kwargs)
1740
1819
  else:
1741
1820
  parsed_kwargs = {}
1821
+ request_id = parsed_kwargs.get("request_id")
1822
+ self._add_running_task(request_id)
1742
1823
  im = Image.open(image.file)
1743
1824
  mask_im = Image.open(mask_image.file)
1744
1825
  if not size:
@@ -1755,6 +1836,11 @@ class RESTfulAPI:
1755
1836
  **parsed_kwargs,
1756
1837
  )
1757
1838
  return Response(content=image_list, media_type="application/json")
1839
+ except asyncio.CancelledError:
1840
+ err_str = f"The request has been cancelled: {request_id}"
1841
+ logger.error(err_str)
1842
+ await self._report_error_event(model_uid, err_str)
1843
+ raise HTTPException(status_code=409, detail=err_str)
1758
1844
  except RuntimeError as re:
1759
1845
  logger.error(re, exc_info=True)
1760
1846
  await self._report_error_event(model_uid, str(re))
@@ -1782,17 +1868,25 @@ class RESTfulAPI:
1782
1868
  await self._report_error_event(model_uid, str(e))
1783
1869
  raise HTTPException(status_code=500, detail=str(e))
1784
1870
 
1871
+ request_id = None
1785
1872
  try:
1786
1873
  if kwargs is not None:
1787
1874
  parsed_kwargs = json.loads(kwargs)
1788
1875
  else:
1789
1876
  parsed_kwargs = {}
1877
+ request_id = parsed_kwargs.get("request_id")
1878
+ self._add_running_task(request_id)
1790
1879
  im = Image.open(image.file)
1791
1880
  text = await model_ref.ocr(
1792
1881
  image=im,
1793
1882
  **parsed_kwargs,
1794
1883
  )
1795
1884
  return Response(content=text, media_type="text/plain")
1885
+ except asyncio.CancelledError:
1886
+ err_str = f"The request has been cancelled: {request_id}"
1887
+ logger.error(err_str)
1888
+ await self._report_error_event(model_uid, err_str)
1889
+ raise HTTPException(status_code=409, detail=err_str)
1796
1890
  except RuntimeError as re:
1797
1891
  logger.error(re, exc_info=True)
1798
1892
  await self._report_error_event(model_uid, str(re))
@@ -1880,9 +1974,13 @@ class RESTfulAPI:
1880
1974
  "logit_bias_type",
1881
1975
  "user",
1882
1976
  }
1977
+
1883
1978
  raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1884
1979
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1885
1980
 
1981
+ # guided_decoding params
1982
+ kwargs.update(self.extract_guided_params(raw_body=raw_body))
1983
+
1886
1984
  # TODO: Decide if this default value override is necessary #1061
1887
1985
  if body.max_tokens is None:
1888
1986
  kwargs["max_tokens"] = max_tokens_field.default
@@ -1991,6 +2089,8 @@ class RESTfulAPI:
1991
2089
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
1992
2090
  yield dict(data=json.dumps({"error": str(ex)}))
1993
2091
  return
2092
+ finally:
2093
+ await model.decrease_serve_count()
1994
2094
 
1995
2095
  return EventSourceResponse(stream_results())
1996
2096
  else:
@@ -2111,10 +2211,25 @@ class RESTfulAPI:
2111
2211
  logger.error(e, exc_info=True)
2112
2212
  raise HTTPException(status_code=500, detail=str(e))
2113
2213
 
2114
- async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
2214
+ async def abort_request(
2215
+ self, request: Request, model_uid: str, request_id: str
2216
+ ) -> JSONResponse:
2115
2217
  try:
2218
+ payload = await request.json()
2219
+ block_duration = payload.get(
2220
+ "block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION
2221
+ )
2222
+ logger.info(
2223
+ "Abort request with model uid: %s, request id: %s, block duration: %s",
2224
+ model_uid,
2225
+ request_id,
2226
+ block_duration,
2227
+ )
2116
2228
  supervisor_ref = await self._get_supervisor_ref()
2117
- res = await supervisor_ref.abort_request(model_uid, request_id)
2229
+ res = await supervisor_ref.abort_request(
2230
+ model_uid, request_id, block_duration
2231
+ )
2232
+ self._cancel_running_task(request_id, block_duration)
2118
2233
  return JSONResponse(content=res)
2119
2234
  except Exception as e:
2120
2235
  logger.error(e, exc_info=True)
@@ -2228,6 +2343,27 @@ class RESTfulAPI:
2228
2343
  logger.error(e, exc_info=True)
2229
2344
  raise HTTPException(status_code=500, detail=str(e))
2230
2345
 
2346
+ @staticmethod
2347
+ def extract_guided_params(raw_body: dict) -> dict:
2348
+ kwargs = {}
2349
+ if raw_body.get("guided_json") is not None:
2350
+ kwargs["guided_json"] = raw_body.get("guided_json")
2351
+ if raw_body.get("guided_regex") is not None:
2352
+ kwargs["guided_regex"] = raw_body.get("guided_regex")
2353
+ if raw_body.get("guided_choice") is not None:
2354
+ kwargs["guided_choice"] = raw_body.get("guided_choice")
2355
+ if raw_body.get("guided_grammar") is not None:
2356
+ kwargs["guided_grammar"] = raw_body.get("guided_grammar")
2357
+ if raw_body.get("guided_json_object") is not None:
2358
+ kwargs["guided_json_object"] = raw_body.get("guided_json_object")
2359
+ if raw_body.get("guided_decoding_backend") is not None:
2360
+ kwargs["guided_decoding_backend"] = raw_body.get("guided_decoding_backend")
2361
+ if raw_body.get("guided_whitespace_pattern") is not None:
2362
+ kwargs["guided_whitespace_pattern"] = raw_body.get(
2363
+ "guided_whitespace_pattern"
2364
+ )
2365
+ return kwargs
2366
+
2231
2367
 
2232
2368
  def run(
2233
2369
  supervisor_address: str,
@@ -126,6 +126,43 @@ class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
126
126
  response_data = response.json()
127
127
  return response_data
128
128
 
129
+ def convert_ids_to_tokens(
130
+ self, input: Union[List, List[List]], **kwargs
131
+ ) -> List[str]:
132
+ """
133
+ Convert token IDs to human readable tokens via RESTful APIs.
134
+
135
+ Parameters
136
+ ----------
137
+ input: Union[List, List[List]]
138
+ Input token IDs to convert, can be a single list of token IDs or a list of token ID lists.
139
+ To convert multiple sequences in a single request, pass a list of token ID lists.
140
+
141
+ Returns
142
+ -------
143
+ list
144
+ A list of decoded tokens in human readable format.
145
+
146
+ Raises
147
+ ------
148
+ RuntimeError
149
+ Report the failure of token conversion and provide the error message.
150
+
151
+ """
152
+ url = f"{self._base_url}/v1/convert_ids_to_tokens"
153
+ request_body = {
154
+ "model": self._model_uid,
155
+ "input": input,
156
+ }
157
+ request_body.update(kwargs)
158
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
159
+ if response.status_code != 200:
160
+ raise RuntimeError(
161
+ f"Failed to decode token ids, detail: {_get_error_string(response)}"
162
+ )
163
+ response_data = response.json()
164
+ return response_data
165
+
129
166
 
130
167
  class RESTfulRerankModelHandle(RESTfulModelHandle):
131
168
  def rerank(
@@ -174,6 +211,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
174
211
  "max_chunks_per_doc": max_chunks_per_doc,
175
212
  "return_documents": return_documents,
176
213
  "return_len": return_len,
214
+ "kwargs": json.dumps(kwargs),
177
215
  }
178
216
  request_body.update(kwargs)
179
217
  response = requests.post(url, json=request_body, headers=self.auth_headers)
@@ -703,6 +741,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
703
741
  The speed of the generated audio.
704
742
  stream: bool
705
743
  Use stream or not.
744
+ prompt_speech: bytes
745
+ The audio bytes to be provided to the model.
706
746
 
707
747
  Returns
708
748
  -------
@@ -1357,7 +1397,7 @@ class Client:
1357
1397
  response_data = response.json()
1358
1398
  return response_data
1359
1399
 
1360
- def abort_request(self, model_uid: str, request_id: str):
1400
+ def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30):
1361
1401
  """
1362
1402
  Abort a request.
1363
1403
  Abort a submitted request. If the request is finished or not found, this method will be a no-op.
@@ -1369,13 +1409,18 @@ class Client:
1369
1409
  Model uid.
1370
1410
  request_id: str
1371
1411
  Request id.
1412
+ block_duration: int
1413
+ The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
1414
+ prevent it from taking effect if it arrives before the request operation.
1372
1415
  Returns
1373
1416
  -------
1374
1417
  Dict
1375
1418
  Return empty dict.
1376
1419
  """
1377
1420
  url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
1378
- response = requests.post(url, headers=self._headers)
1421
+ response = requests.post(
1422
+ url, headers=self._headers, json={"block_duration": block_duration}
1423
+ )
1379
1424
  if response.status_code != 200:
1380
1425
  raise RuntimeError(
1381
1426
  f"Failed to abort request, detail: {_get_error_string(response)}"
xinference/constants.py CHANGED
@@ -88,3 +88,4 @@ XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get(
88
88
  XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
89
89
  )
90
90
  XINFERENCE_LAUNCH_MODEL_RETRY = 3
91
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION = 30
xinference/core/model.py CHANGED
@@ -41,6 +41,7 @@ import sse_starlette.sse
41
41
  import xoscar as xo
42
42
 
43
43
  from ..constants import (
44
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
44
45
  XINFERENCE_LAUNCH_MODEL_RETRY,
45
46
  XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE,
46
47
  )
@@ -57,7 +58,7 @@ import logging
57
58
  logger = logging.getLogger(__name__)
58
59
 
59
60
  from ..device_utils import empty_cache
60
- from .utils import json_dumps, log_async
61
+ from .utils import CancelMixin, json_dumps, log_async
61
62
 
62
63
  try:
63
64
  from torch.cuda import OutOfMemoryError
@@ -90,21 +91,26 @@ def request_limit(fn):
90
91
  logger.debug(
91
92
  f"Request {fn.__name__}, current serve request count: {self._serve_count}, request limit: {self._request_limits} for the model {self.model_uid()}"
92
93
  )
93
- if self._request_limits is not None:
94
- if 1 + self._serve_count <= self._request_limits:
95
- self._serve_count += 1
96
- else:
97
- raise RuntimeError(
98
- f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
99
- )
94
+ if 1 + self._serve_count <= self._request_limits:
95
+ self._serve_count += 1
96
+ else:
97
+ raise RuntimeError(
98
+ f"Rate limit reached for the model. Request limit {self._request_limits} for the model: {self.model_uid()}"
99
+ )
100
+ ret = None
100
101
  try:
101
102
  ret = await fn(self, *args, **kwargs)
102
103
  finally:
103
- if self._request_limits is not None:
104
+ if ret is not None and (
105
+ inspect.isasyncgen(ret) or inspect.isgenerator(ret)
106
+ ):
107
+ # stream case, let client call model_ref to decrease self._serve_count
108
+ pass
109
+ else:
104
110
  self._serve_count -= 1
105
- logger.debug(
106
- f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
107
- )
111
+ logger.debug(
112
+ f"After request {fn.__name__}, current serve request count: {self._serve_count} for the model {self.model_uid()}"
113
+ )
108
114
  return ret
109
115
 
110
116
  return wrapped_func
@@ -136,7 +142,7 @@ def oom_check(fn):
136
142
  return _wrapper
137
143
 
138
144
 
139
- class ModelActor(xo.StatelessActor):
145
+ class ModelActor(xo.StatelessActor, CancelMixin):
140
146
  _replica_model_uid: Optional[str]
141
147
 
142
148
  @classmethod
@@ -214,7 +220,9 @@ class ModelActor(xo.StatelessActor):
214
220
  self._model_description = (
215
221
  model_description.to_dict() if model_description else {}
216
222
  )
217
- self._request_limits = request_limits
223
+ self._request_limits = (
224
+ float("inf") if request_limits is None else request_limits
225
+ )
218
226
  self._pending_requests: asyncio.Queue = asyncio.Queue()
219
227
  self._handle_pending_requests_task = None
220
228
  self._lock = (
@@ -267,6 +275,9 @@ class ModelActor(xo.StatelessActor):
267
275
  def __repr__(self) -> str:
268
276
  return f"ModelActor({self._replica_model_uid})"
269
277
 
278
+ def decrease_serve_count(self):
279
+ self._serve_count -= 1
280
+
270
281
  async def _record_completion_metrics(
271
282
  self, duration, completion_tokens, prompt_tokens
272
283
  ):
@@ -553,6 +564,7 @@ class ModelActor(xo.StatelessActor):
553
564
 
554
565
  @oom_check
555
566
  async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
567
+ self._add_running_task(kwargs.get("request_id"))
556
568
  if self._lock is None:
557
569
  if inspect.iscoroutinefunction(fn):
558
570
  ret = await fn(*args, **kwargs)
@@ -761,9 +773,14 @@ class ModelActor(xo.StatelessActor):
761
773
  prompt_tokens,
762
774
  )
763
775
 
764
- async def abort_request(self, request_id: str) -> str:
776
+ async def abort_request(
777
+ self,
778
+ request_id: str,
779
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
780
+ ) -> str:
765
781
  from .utils import AbortRequestMessage
766
782
 
783
+ self._cancel_running_task(request_id, block_duration)
767
784
  if self.allow_batching():
768
785
  if self._scheduler_ref is None:
769
786
  return AbortRequestMessage.NOT_FOUND.name
@@ -787,6 +804,19 @@ class ModelActor(xo.StatelessActor):
787
804
  f"Model {self._model.model_spec} is not for creating embedding."
788
805
  )
789
806
 
807
+ @request_limit
808
+ @log_async(logger=logger)
809
+ async def convert_ids_to_tokens(
810
+ self, input: Union[List, List[List]], *args, **kwargs
811
+ ):
812
+ kwargs.pop("request_id", None)
813
+ if hasattr(self._model, "convert_ids_to_tokens"):
814
+ return await self._call_wrapper_json(
815
+ self._model.convert_ids_to_tokens, input, *args, **kwargs
816
+ )
817
+
818
+ raise AttributeError(f"Model {self._model.model_spec} can convert token id.")
819
+
790
820
  @request_limit
791
821
  @log_async(logger=logger)
792
822
  async def rerank(
@@ -35,6 +35,7 @@ from typing import (
35
35
  import xoscar as xo
36
36
 
37
37
  from ..constants import (
38
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
38
39
  XINFERENCE_DISABLE_HEALTH_CHECK,
39
40
  XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
40
41
  XINFERENCE_HEALTH_CHECK_INTERVAL,
@@ -1213,7 +1214,12 @@ class SupervisorActor(xo.StatelessActor):
1213
1214
  return cached_models
1214
1215
 
1215
1216
  @log_async(logger=logger)
1216
- async def abort_request(self, model_uid: str, request_id: str) -> Dict:
1217
+ async def abort_request(
1218
+ self,
1219
+ model_uid: str,
1220
+ request_id: str,
1221
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
1222
+ ) -> Dict:
1217
1223
  from .scheduler import AbortRequestMessage
1218
1224
 
1219
1225
  res = {"msg": AbortRequestMessage.NO_OP.name}
@@ -1228,7 +1234,7 @@ class SupervisorActor(xo.StatelessActor):
1228
1234
  if worker_ref is None:
1229
1235
  continue
1230
1236
  model_ref = await worker_ref.get_model(model_uid=rep_mid)
1231
- result_info = await model_ref.abort_request(request_id)
1237
+ result_info = await model_ref.abort_request(request_id, block_duration)
1232
1238
  res["msg"] = result_info
1233
1239
  if result_info == AbortRequestMessage.DONE.name:
1234
1240
  break