xinference 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (85) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +108 -14
  3. xinference/client/restful/restful_client.py +78 -5
  4. xinference/constants.py +1 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/event.py +5 -6
  7. xinference/core/model.py +59 -42
  8. xinference/core/scheduler.py +46 -18
  9. xinference/core/supervisor.py +73 -24
  10. xinference/core/worker.py +68 -2
  11. xinference/deploy/cmdline.py +86 -2
  12. xinference/deploy/test/test_cmdline.py +19 -10
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/core.py +12 -1
  15. xinference/model/audio/custom.py +6 -4
  16. xinference/model/audio/model_spec_modelscope.json +20 -0
  17. xinference/model/llm/__init__.py +34 -2
  18. xinference/model/llm/llm_family.json +8 -2
  19. xinference/model/llm/llm_family.py +86 -1
  20. xinference/model/llm/llm_family_csghub.json +66 -0
  21. xinference/model/llm/llm_family_modelscope.json +8 -2
  22. xinference/model/llm/pytorch/chatglm.py +41 -12
  23. xinference/model/llm/pytorch/core.py +128 -88
  24. xinference/model/llm/pytorch/glm4v.py +24 -3
  25. xinference/model/llm/pytorch/internlm2.py +15 -0
  26. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  27. xinference/model/llm/pytorch/utils.py +69 -189
  28. xinference/model/llm/utils.py +27 -14
  29. xinference/model/llm/vllm/core.py +10 -4
  30. xinference/model/rerank/core.py +35 -6
  31. xinference/model/utils.py +8 -2
  32. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  33. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  34. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  35. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  36. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  38. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  39. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  40. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  41. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  42. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  43. xinference/types.py +28 -0
  44. xinference/web/ui/build/asset-manifest.json +6 -6
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
  47. xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
  49. xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  63. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/METADATA +1 -1
  64. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/RECORD +69 -56
  65. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  66. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  67. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  68. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  71. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
  72. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
  73. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  74. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
  75. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  76. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
  81. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
  82. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/LICENSE +0 -0
  83. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/WHEEL +0 -0
  84. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/entry_points.txt +0 -0
  85. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-06-07T15:04:33+0800",
11
+ "date": "2024-06-21T15:34:17+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "55c5636f2b6022842d1827eae373c8e5f162a1a3",
15
- "version": "0.12.0"
14
+ "full-revisionid": "5cef7c3d4bb0c5208d262fc3ffb7d7083724de1c",
15
+ "version": "0.12.2"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -109,6 +109,7 @@ class RerankRequest(BaseModel):
109
109
  documents: List[str]
110
110
  top_n: Optional[int] = None
111
111
  return_documents: Optional[bool] = False
112
+ return_len: Optional[bool] = False
112
113
  max_chunks_per_doc: Optional[int] = None
113
114
 
114
115
 
@@ -522,11 +523,31 @@ class RESTfulAPI:
522
523
  ),
523
524
  )
524
525
  self._router.add_api_route(
525
- "/v1/cached/list_cached_models",
526
+ "/v1/cache/models",
526
527
  self.list_cached_models,
527
528
  methods=["GET"],
528
529
  dependencies=(
529
- [Security(self._auth_service, scopes=["models:list"])]
530
+ [Security(self._auth_service, scopes=["cache:list"])]
531
+ if self.is_authenticated()
532
+ else None
533
+ ),
534
+ )
535
+ self._router.add_api_route(
536
+ "/v1/cache/models/files",
537
+ self.list_model_files,
538
+ methods=["GET"],
539
+ dependencies=(
540
+ [Security(self._auth_service, scopes=["cache:list"])]
541
+ if self.is_authenticated()
542
+ else None
543
+ ),
544
+ )
545
+ self._router.add_api_route(
546
+ "/v1/cache/models",
547
+ self.confirm_and_remove_model,
548
+ methods=["DELETE"],
549
+ dependencies=(
550
+ [Security(self._auth_service, scopes=["cache:delete"])]
530
551
  if self.is_authenticated()
531
552
  else None
532
553
  ),
@@ -961,7 +982,8 @@ class RESTfulAPI:
961
982
  return JSONResponse(content=self._supervisor_address)
962
983
 
963
984
  async def create_completion(self, request: Request) -> Response:
964
- body = CreateCompletionRequest.parse_obj(await request.json())
985
+ raw_body = await request.json()
986
+ body = CreateCompletionRequest.parse_obj(raw_body)
965
987
  exclude = {
966
988
  "prompt",
967
989
  "model",
@@ -971,6 +993,7 @@ class RESTfulAPI:
971
993
  "logit_bias_type",
972
994
  "user",
973
995
  }
996
+ raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
974
997
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
975
998
 
976
999
  # TODO: Decide if this default value override is necessary #1061
@@ -1000,7 +1023,9 @@ class RESTfulAPI:
1000
1023
  iterator = None
1001
1024
  try:
1002
1025
  try:
1003
- iterator = await model.generate(body.prompt, kwargs)
1026
+ iterator = await model.generate(
1027
+ body.prompt, kwargs, raw_params=raw_kwargs
1028
+ )
1004
1029
  except RuntimeError as re:
1005
1030
  self.handle_request_limit_error(re)
1006
1031
  async for item in iterator:
@@ -1020,7 +1045,7 @@ class RESTfulAPI:
1020
1045
  return EventSourceResponse(stream_results())
1021
1046
  else:
1022
1047
  try:
1023
- data = await model.generate(body.prompt, kwargs)
1048
+ data = await model.generate(body.prompt, kwargs, raw_params=raw_kwargs)
1024
1049
  return Response(data, media_type="application/json")
1025
1050
  except Exception as e:
1026
1051
  logger.error(e, exc_info=True)
@@ -1092,6 +1117,7 @@ class RESTfulAPI:
1092
1117
  top_n=body.top_n,
1093
1118
  max_chunks_per_doc=body.max_chunks_per_doc,
1094
1119
  return_documents=body.return_documents,
1120
+ return_len=body.return_len,
1095
1121
  **kwargs,
1096
1122
  )
1097
1123
  return Response(scores, media_type="application/json")
@@ -1321,7 +1347,8 @@ class RESTfulAPI:
1321
1347
  raise HTTPException(status_code=500, detail=str(e))
1322
1348
 
1323
1349
  async def create_chat_completion(self, request: Request) -> Response:
1324
- body = CreateChatCompletion.parse_obj(await request.json())
1350
+ raw_body = await request.json()
1351
+ body = CreateChatCompletion.parse_obj(raw_body)
1325
1352
  exclude = {
1326
1353
  "prompt",
1327
1354
  "model",
@@ -1331,6 +1358,7 @@ class RESTfulAPI:
1331
1358
  "logit_bias_type",
1332
1359
  "user",
1333
1360
  }
1361
+ raw_kwargs = {k: v for k, v in raw_body.items() if k not in exclude}
1334
1362
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1335
1363
 
1336
1364
  # TODO: Decide if this default value override is necessary #1061
@@ -1401,9 +1429,13 @@ class RESTfulAPI:
1401
1429
  model_family = desc.get("model_family", "")
1402
1430
  function_call_models = [
1403
1431
  "chatglm3",
1432
+ "glm4-chat",
1404
1433
  "gorilla-openfunctions-v1",
1405
1434
  "qwen-chat",
1406
1435
  "qwen1.5-chat",
1436
+ "qwen1.5-moe-chat",
1437
+ "qwen2-instruct",
1438
+ "qwen2-moe-instruct",
1407
1439
  ]
1408
1440
 
1409
1441
  is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
@@ -1426,7 +1458,13 @@ class RESTfulAPI:
1426
1458
  )
1427
1459
  if body.tools and body.stream:
1428
1460
  is_vllm = await model.is_vllm_backend()
1429
- if not is_vllm or model_family not in ["qwen-chat", "qwen1.5-chat"]:
1461
+ if not is_vllm or model_family not in [
1462
+ "qwen-chat",
1463
+ "qwen1.5-chat",
1464
+ "qwen1.5-moe-chat",
1465
+ "qwen2-instruct",
1466
+ "qwen2-moe-instruct",
1467
+ ]:
1430
1468
  raise HTTPException(
1431
1469
  status_code=400,
1432
1470
  detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
@@ -1439,10 +1477,16 @@ class RESTfulAPI:
1439
1477
  try:
1440
1478
  try:
1441
1479
  if is_qwen:
1442
- iterator = await model.chat(prompt, chat_history, kwargs)
1480
+ iterator = await model.chat(
1481
+ prompt, chat_history, kwargs, raw_params=raw_kwargs
1482
+ )
1443
1483
  else:
1444
1484
  iterator = await model.chat(
1445
- prompt, system_prompt, chat_history, kwargs
1485
+ prompt,
1486
+ system_prompt,
1487
+ chat_history,
1488
+ kwargs,
1489
+ raw_params=raw_kwargs,
1446
1490
  )
1447
1491
  except RuntimeError as re:
1448
1492
  await self._report_error_event(model_uid, str(re))
@@ -1472,9 +1516,17 @@ class RESTfulAPI:
1472
1516
  else:
1473
1517
  try:
1474
1518
  if is_qwen:
1475
- data = await model.chat(prompt, chat_history, kwargs)
1519
+ data = await model.chat(
1520
+ prompt, chat_history, kwargs, raw_params=raw_kwargs
1521
+ )
1476
1522
  else:
1477
- data = await model.chat(prompt, system_prompt, chat_history, kwargs)
1523
+ data = await model.chat(
1524
+ prompt,
1525
+ system_prompt,
1526
+ chat_history,
1527
+ kwargs,
1528
+ raw_params=raw_kwargs,
1529
+ )
1478
1530
  return Response(content=data, media_type="application/json")
1479
1531
  except Exception as e:
1480
1532
  logger.error(e, exc_info=True)
@@ -1555,10 +1607,17 @@ class RESTfulAPI:
1555
1607
  logger.error(e, exc_info=True)
1556
1608
  raise HTTPException(status_code=500, detail=str(e))
1557
1609
 
1558
- async def list_cached_models(self) -> JSONResponse:
1610
+ async def list_cached_models(
1611
+ self, model_name: str = Query(None), worker_ip: str = Query(None)
1612
+ ) -> JSONResponse:
1559
1613
  try:
1560
- data = await (await self._get_supervisor_ref()).list_cached_models()
1561
- return JSONResponse(content=data)
1614
+ data = await (await self._get_supervisor_ref()).list_cached_models(
1615
+ model_name, worker_ip
1616
+ )
1617
+ resp = {
1618
+ "list": data,
1619
+ }
1620
+ return JSONResponse(content=resp)
1562
1621
  except ValueError as re:
1563
1622
  logger.error(re, exc_info=True)
1564
1623
  raise HTTPException(status_code=400, detail=str(re))
@@ -1623,6 +1682,41 @@ class RESTfulAPI:
1623
1682
  logger.error(e, exc_info=True)
1624
1683
  raise HTTPException(status_code=500, detail=str(e))
1625
1684
 
1685
+ async def list_model_files(
1686
+ self, model_version: str = Query(None), worker_ip: str = Query(None)
1687
+ ) -> JSONResponse:
1688
+ try:
1689
+ data = await (await self._get_supervisor_ref()).list_deletable_models(
1690
+ model_version, worker_ip
1691
+ )
1692
+ response = {
1693
+ "model_version": model_version,
1694
+ "worker_ip": worker_ip,
1695
+ "paths": data,
1696
+ }
1697
+ return JSONResponse(content=response)
1698
+ except ValueError as re:
1699
+ logger.error(re, exc_info=True)
1700
+ raise HTTPException(status_code=400, detail=str(re))
1701
+ except Exception as e:
1702
+ logger.error(e, exc_info=True)
1703
+ raise HTTPException(status_code=500, detail=str(e))
1704
+
1705
+ async def confirm_and_remove_model(
1706
+ self, model_version: str = Query(None), worker_ip: str = Query(None)
1707
+ ) -> JSONResponse:
1708
+ try:
1709
+ res = await (await self._get_supervisor_ref()).confirm_and_remove_model(
1710
+ model_version=model_version, worker_ip=worker_ip
1711
+ )
1712
+ return JSONResponse(content={"result": res})
1713
+ except ValueError as re:
1714
+ logger.error(re, exc_info=True)
1715
+ raise HTTPException(status_code=400, detail=str(re))
1716
+ except Exception as e:
1717
+ logger.error(e, exc_info=True)
1718
+ raise HTTPException(status_code=500, detail=str(e))
1719
+
1626
1720
 
1627
1721
  def run(
1628
1722
  supervisor_address: str,
@@ -135,6 +135,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
135
135
  top_n: Optional[int] = None,
136
136
  max_chunks_per_doc: Optional[int] = None,
137
137
  return_documents: Optional[bool] = None,
138
+ return_len: Optional[bool] = None,
138
139
  **kwargs,
139
140
  ):
140
141
  """
@@ -152,6 +153,8 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
152
153
  The maximum number of chunks derived from a document
153
154
  return_documents: bool
154
155
  if return documents
156
+ return_len: bool
157
+ if return tokens len
155
158
  Returns
156
159
  -------
157
160
  Scores
@@ -170,6 +173,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
170
173
  "top_n": top_n,
171
174
  "max_chunks_per_doc": max_chunks_per_doc,
172
175
  "return_documents": return_documents,
176
+ "return_len": return_len,
173
177
  }
174
178
  request_body.update(kwargs)
175
179
  response = requests.post(url, json=request_body, headers=self.auth_headers)
@@ -1145,13 +1149,17 @@ class Client:
1145
1149
  response_data = response.json()
1146
1150
  return response_data
1147
1151
 
1148
- def list_cached_models(self) -> List[Dict[Any, Any]]:
1152
+ def list_cached_models(
1153
+ self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
1154
+ ) -> List[Dict[Any, Any]]:
1149
1155
  """
1150
1156
  Get a list of cached models.
1151
-
1152
1157
  Parameters
1153
1158
  ----------
1154
- None
1159
+ model_name: Optional[str]
1160
+ The name of model.
1161
+ worker_ip: Optional[str]
1162
+ Specify the worker ip where the model is located in a distributed scenario.
1155
1163
 
1156
1164
  Returns
1157
1165
  -------
@@ -1164,16 +1172,81 @@ class Client:
1164
1172
  Raised when the request fails, including the reason for the failure.
1165
1173
  """
1166
1174
 
1167
- url = f"{self.base_url}/v1/cached/list_cached_models"
1168
- response = requests.get(url, headers=self._headers)
1175
+ url = f"{self.base_url}/v1/cache/models"
1176
+ params = {
1177
+ "model_name": model_name,
1178
+ "worker_ip": worker_ip,
1179
+ }
1180
+ response = requests.get(url, headers=self._headers, params=params)
1169
1181
  if response.status_code != 200:
1170
1182
  raise RuntimeError(
1171
1183
  f"Failed to list cached model, detail: {_get_error_string(response)}"
1172
1184
  )
1173
1185
 
1186
+ response_data = response.json()
1187
+ response_data = response_data.get("list")
1188
+ return response_data
1189
+
1190
+ def list_deletable_models(
1191
+ self, model_version: str, worker_ip: Optional[str] = None
1192
+ ) -> Dict[str, Any]:
1193
+ """
1194
+ Get the cached models with the model path cached on the server.
1195
+ Parameters
1196
+ ----------
1197
+ model_version: str
1198
+ The version of the model.
1199
+ worker_ip: Optional[str]
1200
+ Specify the worker ip where the model is located in a distributed scenario.
1201
+ Returns
1202
+ -------
1203
+ Dict[str, Dict[str,str]]]
1204
+ Dictionary with keys "model_name" and values model_file_location.
1205
+ """
1206
+ url = f"{self.base_url}/v1/cache/models/files"
1207
+ params = {
1208
+ "model_version": model_version,
1209
+ "worker_ip": worker_ip,
1210
+ }
1211
+ response = requests.get(url, headers=self._headers, params=params)
1212
+ if response.status_code != 200:
1213
+ raise RuntimeError(
1214
+ f"Failed to get paths by model name, detail: {_get_error_string(response)}"
1215
+ )
1216
+
1174
1217
  response_data = response.json()
1175
1218
  return response_data
1176
1219
 
1220
+ def confirm_and_remove_model(
1221
+ self, model_version: str, worker_ip: Optional[str] = None
1222
+ ) -> bool:
1223
+ """
1224
+ Remove the cached models with the model name cached on the server.
1225
+ Parameters
1226
+ ----------
1227
+ model_version: str
1228
+ The version of the model.
1229
+ worker_ip: Optional[str]
1230
+ Specify the worker ip where the model is located in a distributed scenario.
1231
+ Returns
1232
+ -------
1233
+ str
1234
+ The response of the server.
1235
+ """
1236
+ url = f"{self.base_url}/v1/cache/models"
1237
+ params = {
1238
+ "model_version": model_version,
1239
+ "worker_ip": worker_ip,
1240
+ }
1241
+ response = requests.delete(url, headers=self._headers, params=params)
1242
+ if response.status_code != 200:
1243
+ raise RuntimeError(
1244
+ f"Failed to remove cached models, detail: {_get_error_string(response)}"
1245
+ )
1246
+
1247
+ response_data = response.json()
1248
+ return response_data.get("result", False)
1249
+
1177
1250
  def get_model_registration(
1178
1251
  self, model_type: str, model_name: str
1179
1252
  ) -> Dict[str, Any]:
xinference/constants.py CHANGED
@@ -17,6 +17,7 @@ from pathlib import Path
17
17
 
18
18
  XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
19
19
  XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
20
+ XINFERENCE_ENV_CSG_TOKEN = "XINFERENCE_CSG_TOKEN"
20
21
  XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
21
22
  XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
22
23
  "XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import os
15
14
  from logging import getLogger
16
15
  from typing import Any, Dict, List, Optional
17
16
 
@@ -102,33 +101,54 @@ class CacheTrackerActor(xo.Actor):
102
101
  def get_model_version_count(self, model_name: str) -> int:
103
102
  return len(self.get_model_versions(model_name))
104
103
 
105
- def list_cached_models(self) -> List[Dict[Any, Any]]:
104
+ def list_cached_models(
105
+ self, worker_ip: str, model_name: Optional[str] = None
106
+ ) -> List[Dict[Any, Any]]:
106
107
  cached_models = []
107
- for model_name, model_versions in self._model_name_to_version_info.items():
108
- for version_info in model_versions:
109
- cache_status = version_info.get("cache_status", None)
110
- if cache_status == True:
111
- ret = version_info.copy()
112
- ret["model_name"] = model_name
108
+ for name, versions in self._model_name_to_version_info.items():
109
+ # only return assigned cached model if model_name is not none
110
+ # else return all cached model
111
+ if model_name and model_name != name:
112
+ continue
113
+ for version_info in versions:
114
+ cache_status = version_info.get("cache_status", False)
115
+ # search cached model
116
+ if cache_status:
117
+ res = version_info.copy()
118
+ res["model_name"] = name
119
+ paths = res.get("model_file_location", {})
120
+ # only return assigned worker's device path
121
+ if worker_ip in paths.keys():
122
+ res["model_file_location"] = paths[worker_ip]
123
+ cached_models.append(res)
124
+ return cached_models
113
125
 
114
- re_dict = version_info.get("model_file_location", None)
115
- if re_dict is not None and isinstance(re_dict, dict):
116
- if re_dict:
117
- actor_ip_address, path = next(iter(re_dict.items()))
118
- else:
119
- raise ValueError("The dictionary is empty.")
120
- else:
121
- raise ValueError("re_dict must be a non-empty dictionary.")
126
+ def list_deletable_models(self, model_version: str, worker_ip: str) -> str:
127
+ model_file_location = ""
128
+ for model, model_versions in self._model_name_to_version_info.items():
129
+ for version_info in model_versions:
130
+ # search assign model version
131
+ if model_version == version_info.get("model_version", None):
132
+ # check if exist
133
+ if version_info.get("cache_status", False):
134
+ paths = version_info.get("model_file_location", {})
135
+ # only return assigned worker's device path
136
+ if worker_ip in paths.keys():
137
+ model_file_location = paths[worker_ip]
138
+ return model_file_location
122
139
 
123
- ret["actor_ip_address"] = actor_ip_address
124
- ret["path"] = path
125
- if os.path.isdir(path):
126
- files = os.listdir(path)
127
- resolved_file = os.path.realpath(os.path.join(path, files[0]))
128
- if resolved_file:
129
- ret["real_path"] = os.path.dirname(resolved_file)
130
- else:
131
- ret["real_path"] = os.path.realpath(path)
132
- cached_models.append(ret)
133
- cached_models = sorted(cached_models, key=lambda x: x["model_name"])
134
- return cached_models
140
+ def confirm_and_remove_model(self, model_version: str, worker_ip: str):
141
+ # find remove path
142
+ rm_path = self.list_deletable_models(model_version, worker_ip)
143
+ # search _model_name_to_version_info if exist this path, and delete
144
+ for model, model_versions in self._model_name_to_version_info.items():
145
+ for version_info in model_versions:
146
+ # check if exist
147
+ if version_info.get("cache_status", False):
148
+ paths = version_info.get("model_file_location", {})
149
+ # only delete assigned worker's device path
150
+ if worker_ip in paths.keys() and rm_path == paths[worker_ip]:
151
+ del paths[worker_ip]
152
+ # if path is empty, update cache status
153
+ if not paths:
154
+ version_info["cache_status"] = False
xinference/core/event.py CHANGED
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import queue
16
- from collections import defaultdict
15
+ from collections import defaultdict, deque
17
16
  from enum import Enum
18
17
  from typing import Dict, List, TypedDict
19
18
 
@@ -37,8 +36,8 @@ class Event(TypedDict):
37
36
  class EventCollectorActor(xo.StatelessActor):
38
37
  def __init__(self):
39
38
  super().__init__()
40
- self._model_uid_to_events: Dict[str, queue.Queue] = defaultdict( # type: ignore
41
- lambda: queue.Queue(maxsize=MAX_EVENT_COUNT_PER_MODEL)
39
+ self._model_uid_to_events: Dict[str, deque] = defaultdict( # type: ignore
40
+ lambda: deque(maxlen=MAX_EVENT_COUNT_PER_MODEL)
42
41
  )
43
42
 
44
43
  @classmethod
@@ -50,7 +49,7 @@ class EventCollectorActor(xo.StatelessActor):
50
49
  if event_queue is None:
51
50
  return []
52
51
  else:
53
- return [dict(e, event_type=e["event_type"].name) for e in event_queue.queue]
52
+ return [dict(e, event_type=e["event_type"].name) for e in iter(event_queue)]
54
53
 
55
54
  def report_event(self, model_uid: str, event: Event):
56
- self._model_uid_to_events[model_uid].put(event)
55
+ self._model_uid_to_events[model_uid].append(event)
xinference/core/model.py CHANGED
@@ -264,12 +264,14 @@ class ModelActor(xo.StatelessActor):
264
264
  return isinstance(self._model, VLLMModel)
265
265
 
266
266
  def allow_batching(self) -> bool:
267
- from ..model.llm.pytorch.core import PytorchChatModel
267
+ from ..model.llm.pytorch.core import PytorchModel
268
+
269
+ model_ability = self._model_description.get("model_ability", [])
268
270
 
269
271
  return (
270
272
  XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
271
- and isinstance(self._model, PytorchChatModel)
272
- and self._model.__class__.__name__ == PytorchChatModel.__name__
273
+ and isinstance(self._model, PytorchModel)
274
+ and "vision" not in model_ability
273
275
  )
274
276
 
275
277
  async def load(self):
@@ -393,18 +395,25 @@ class ModelActor(xo.StatelessActor):
393
395
  @request_limit
394
396
  @xo.generator
395
397
  async def generate(self, prompt: str, *args, **kwargs):
396
- if hasattr(self._model, "generate"):
397
- return await self._call_wrapper(
398
- self._model.generate, prompt, *args, **kwargs
399
- )
400
- if hasattr(self._model, "async_generate"):
401
- return await self._call_wrapper(
402
- self._model.async_generate, prompt, *args, **kwargs
398
+ if self.allow_batching():
399
+ return await self.handle_batching_request(
400
+ prompt, "generate", *args, **kwargs
403
401
  )
404
- raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
402
+ else:
403
+ kwargs.pop("raw_params", None)
404
+ if hasattr(self._model, "generate"):
405
+ return await self._call_wrapper(
406
+ self._model.generate, prompt, *args, **kwargs
407
+ )
408
+ if hasattr(self._model, "async_generate"):
409
+ return await self._call_wrapper(
410
+ self._model.async_generate, prompt, *args, **kwargs
411
+ )
412
+ raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
405
413
 
414
+ @staticmethod
406
415
  async def _queue_consumer(
407
- self, queue: Queue, timeout: Optional[float] = None
416
+ queue: Queue, timeout: Optional[float] = None
408
417
  ) -> AsyncIterator[Any]:
409
418
  from .scheduler import (
410
419
  XINFERENCE_STREAMING_ABORT_FLAG,
@@ -429,9 +438,38 @@ class ModelActor(xo.StatelessActor):
429
438
  yield res
430
439
 
431
440
  @staticmethod
432
- def get_stream_from_args(*args) -> bool:
433
- assert args[2] is None or isinstance(args[2], dict)
434
- return False if args[2] is None else args[2].get("stream", False)
441
+ def _get_stream_from_args(ability: str, *args) -> bool:
442
+ if ability == "chat":
443
+ assert args[2] is None or isinstance(args[2], dict)
444
+ return False if args[2] is None else args[2].get("stream", False)
445
+ else:
446
+ assert args[0] is None or isinstance(args[0], dict)
447
+ return False if args[0] is None else args[0].get("stream", False)
448
+
449
+ async def handle_batching_request(self, prompt: str, ability: str, *args, **kwargs):
450
+ stream = self._get_stream_from_args(ability, *args)
451
+ assert self._scheduler_ref is not None
452
+ if stream:
453
+ assert self._scheduler_ref is not None
454
+ queue: Queue[Any] = Queue()
455
+ ret = self._queue_consumer(queue)
456
+ await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
457
+ gen = self._to_json_async_gen(ret)
458
+ self._current_generator = weakref.ref(gen)
459
+ return gen
460
+ else:
461
+ from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
462
+
463
+ assert self._loop is not None
464
+ future = ConcurrentFuture()
465
+ await self._scheduler_ref.add_request(prompt, future, *args, **kwargs)
466
+ fut = asyncio.wrap_future(future, loop=self._loop)
467
+ result = await fut
468
+ if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
469
+ raise RuntimeError(
470
+ f"This request has been cancelled by another `abort_request` request."
471
+ )
472
+ return await asyncio.to_thread(json_dumps, result)
435
473
 
436
474
  @log_async(logger=logger)
437
475
  @request_limit
@@ -441,34 +479,11 @@ class ModelActor(xo.StatelessActor):
441
479
  response = None
442
480
  try:
443
481
  if self.allow_batching():
444
- stream = self.get_stream_from_args(*args)
445
- assert self._scheduler_ref is not None
446
- if stream:
447
- assert self._scheduler_ref is not None
448
- queue: Queue[Any] = Queue()
449
- ret = self._queue_consumer(queue)
450
- await self._scheduler_ref.add_request(
451
- prompt, queue, *args, **kwargs
452
- )
453
- gen = self._to_json_async_gen(ret)
454
- self._current_generator = weakref.ref(gen)
455
- return gen
456
- else:
457
- from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
458
-
459
- assert self._loop is not None
460
- future = ConcurrentFuture()
461
- await self._scheduler_ref.add_request(
462
- prompt, future, *args, **kwargs
463
- )
464
- fut = asyncio.wrap_future(future, loop=self._loop)
465
- result = await fut
466
- if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
467
- raise RuntimeError(
468
- f"This request has been cancelled by another `abort_request` request."
469
- )
470
- return await asyncio.to_thread(json_dumps, result)
482
+ return await self.handle_batching_request(
483
+ prompt, "chat", *args, **kwargs
484
+ )
471
485
  else:
486
+ kwargs.pop("raw_params", None)
472
487
  if hasattr(self._model, "chat"):
473
488
  response = await self._call_wrapper(
474
489
  self._model.chat, prompt, *args, **kwargs
@@ -528,6 +543,7 @@ class ModelActor(xo.StatelessActor):
528
543
  top_n: Optional[int],
529
544
  max_chunks_per_doc: Optional[int],
530
545
  return_documents: Optional[bool],
546
+ return_len: Optional[bool],
531
547
  *args,
532
548
  **kwargs,
533
549
  ):
@@ -539,6 +555,7 @@ class ModelActor(xo.StatelessActor):
539
555
  top_n,
540
556
  max_chunks_per_doc,
541
557
  return_documents,
558
+ return_len,
542
559
  *args,
543
560
  **kwargs,
544
561
  )