xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (75) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +143 -6
  3. xinference/client/restful/restful_client.py +144 -5
  4. xinference/constants.py +5 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +160 -19
  7. xinference/core/scheduler.py +446 -0
  8. xinference/core/supervisor.py +99 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/isolation.py +9 -2
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/chattts.py +84 -0
  15. xinference/model/audio/core.py +22 -4
  16. xinference/model/audio/custom.py +6 -4
  17. xinference/model/audio/model_spec.json +20 -0
  18. xinference/model/audio/model_spec_modelscope.json +20 -0
  19. xinference/model/llm/__init__.py +38 -2
  20. xinference/model/llm/llm_family.json +509 -1
  21. xinference/model/llm/llm_family.py +86 -1
  22. xinference/model/llm/llm_family_csghub.json +66 -0
  23. xinference/model/llm/llm_family_modelscope.json +411 -2
  24. xinference/model/llm/pytorch/chatglm.py +20 -13
  25. xinference/model/llm/pytorch/cogvlm2.py +76 -17
  26. xinference/model/llm/pytorch/core.py +141 -6
  27. xinference/model/llm/pytorch/glm4v.py +268 -0
  28. xinference/model/llm/pytorch/minicpmv25.py +232 -0
  29. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  30. xinference/model/llm/pytorch/utils.py +405 -8
  31. xinference/model/llm/utils.py +14 -13
  32. xinference/model/llm/vllm/core.py +16 -4
  33. xinference/model/utils.py +8 -2
  34. xinference/thirdparty/ChatTTS/__init__.py +1 -0
  35. xinference/thirdparty/ChatTTS/core.py +200 -0
  36. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  38. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  39. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  40. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  41. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  42. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  43. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  44. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  45. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  46. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  47. xinference/types.py +3 -0
  48. xinference/web/ui/build/asset-manifest.json +6 -6
  49. xinference/web/ui/build/index.html +1 -1
  50. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  51. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  52. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  53. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  59. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
  60. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
  61. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  62. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  63. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  64. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  71. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  72. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  73. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  74. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  75. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-05-31T17:12:13+0800",
11
+ "date": "2024-06-14T17:17:50+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "69c09cd068a530cd2fdcac07e4e81f03d48f04f9",
15
- "version": "0.11.3"
14
+ "full-revisionid": "34a57df449f0890415c424802d3596f3c8758412",
15
+ "version": "0.12.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -122,6 +122,14 @@ class TextToImageRequest(BaseModel):
122
122
  user: Optional[str] = None
123
123
 
124
124
 
125
+ class SpeechRequest(BaseModel):
126
+ model: str
127
+ input: str
128
+ voice: Optional[str]
129
+ response_format: Optional[str] = "mp3"
130
+ speed: Optional[float] = 1.0
131
+
132
+
125
133
  class RegisterModelRequest(BaseModel):
126
134
  model: str
127
135
  persist: bool
@@ -337,6 +345,16 @@ class RESTfulAPI:
337
345
  else None
338
346
  ),
339
347
  )
348
+ self._router.add_api_route(
349
+ "/v1/models/{model_uid}/requests/{request_id}/abort",
350
+ self.abort_request,
351
+ methods=["POST"],
352
+ dependencies=(
353
+ [Security(self._auth_service, scopes=["models:read"])]
354
+ if self.is_authenticated()
355
+ else None
356
+ ),
357
+ )
340
358
  self._router.add_api_route(
341
359
  "/v1/models/instance",
342
360
  self.launch_model_by_version,
@@ -418,6 +436,16 @@ class RESTfulAPI:
418
436
  else None
419
437
  ),
420
438
  )
439
+ self._router.add_api_route(
440
+ "/v1/audio/speech",
441
+ self.create_speech,
442
+ methods=["POST"],
443
+ dependencies=(
444
+ [Security(self._auth_service, scopes=["models:read"])]
445
+ if self.is_authenticated()
446
+ else None
447
+ ),
448
+ )
421
449
  self._router.add_api_route(
422
450
  "/v1/images/generations",
423
451
  self.create_images,
@@ -494,11 +522,31 @@ class RESTfulAPI:
494
522
  ),
495
523
  )
496
524
  self._router.add_api_route(
497
- "/v1/cached/list_cached_models",
525
+ "/v1/cache/models",
498
526
  self.list_cached_models,
499
527
  methods=["GET"],
500
528
  dependencies=(
501
- [Security(self._auth_service, scopes=["models:list"])]
529
+ [Security(self._auth_service, scopes=["cache:list"])]
530
+ if self.is_authenticated()
531
+ else None
532
+ ),
533
+ )
534
+ self._router.add_api_route(
535
+ "/v1/cache/models/files",
536
+ self.list_model_files,
537
+ methods=["GET"],
538
+ dependencies=(
539
+ [Security(self._auth_service, scopes=["cache:list"])]
540
+ if self.is_authenticated()
541
+ else None
542
+ ),
543
+ )
544
+ self._router.add_api_route(
545
+ "/v1/cache/models",
546
+ self.confirm_and_remove_model,
547
+ methods=["DELETE"],
548
+ dependencies=(
549
+ [Security(self._auth_service, scopes=["cache:delete"])]
502
550
  if self.is_authenticated()
503
551
  else None
504
552
  ),
@@ -1179,6 +1227,38 @@ class RESTfulAPI:
1179
1227
  await self._report_error_event(model_uid, str(e))
1180
1228
  raise HTTPException(status_code=500, detail=str(e))
1181
1229
 
1230
+ async def create_speech(self, request: Request) -> Response:
1231
+ body = SpeechRequest.parse_obj(await request.json())
1232
+ model_uid = body.model
1233
+ try:
1234
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1235
+ except ValueError as ve:
1236
+ logger.error(str(ve), exc_info=True)
1237
+ await self._report_error_event(model_uid, str(ve))
1238
+ raise HTTPException(status_code=400, detail=str(ve))
1239
+ except Exception as e:
1240
+ logger.error(e, exc_info=True)
1241
+ await self._report_error_event(model_uid, str(e))
1242
+ raise HTTPException(status_code=500, detail=str(e))
1243
+
1244
+ try:
1245
+ out = await model.speech(
1246
+ input=body.input,
1247
+ voice=body.voice,
1248
+ response_format=body.response_format,
1249
+ speed=body.speed,
1250
+ )
1251
+ return Response(media_type="application/octet-stream", content=out)
1252
+ except RuntimeError as re:
1253
+ logger.error(re, exc_info=True)
1254
+ await self._report_error_event(model_uid, str(re))
1255
+ self.handle_request_limit_error(re)
1256
+ raise HTTPException(status_code=400, detail=str(re))
1257
+ except Exception as e:
1258
+ logger.error(e, exc_info=True)
1259
+ await self._report_error_event(model_uid, str(e))
1260
+ raise HTTPException(status_code=500, detail=str(e))
1261
+
1182
1262
  async def create_images(self, request: Request) -> Response:
1183
1263
  body = TextToImageRequest.parse_obj(await request.json())
1184
1264
  model_uid = body.model
@@ -1341,9 +1421,11 @@ class RESTfulAPI:
1341
1421
  model_family = desc.get("model_family", "")
1342
1422
  function_call_models = [
1343
1423
  "chatglm3",
1424
+ "glm4-chat",
1344
1425
  "gorilla-openfunctions-v1",
1345
1426
  "qwen-chat",
1346
1427
  "qwen1.5-chat",
1428
+ "qwen2-instruct",
1347
1429
  ]
1348
1430
 
1349
1431
  is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
@@ -1366,7 +1448,11 @@ class RESTfulAPI:
1366
1448
  )
1367
1449
  if body.tools and body.stream:
1368
1450
  is_vllm = await model.is_vllm_backend()
1369
- if not is_vllm or model_family not in ["qwen-chat", "qwen1.5-chat"]:
1451
+ if not is_vllm or model_family not in [
1452
+ "qwen-chat",
1453
+ "qwen1.5-chat",
1454
+ "qwen2-instruct",
1455
+ ]:
1370
1456
  raise HTTPException(
1371
1457
  status_code=400,
1372
1458
  detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
@@ -1495,10 +1581,17 @@ class RESTfulAPI:
1495
1581
  logger.error(e, exc_info=True)
1496
1582
  raise HTTPException(status_code=500, detail=str(e))
1497
1583
 
1498
- async def list_cached_models(self) -> JSONResponse:
1584
+ async def list_cached_models(
1585
+ self, model_name: str = Query(None), worker_ip: str = Query(None)
1586
+ ) -> JSONResponse:
1499
1587
  try:
1500
- data = await (await self._get_supervisor_ref()).list_cached_models()
1501
- return JSONResponse(content=data)
1588
+ data = await (await self._get_supervisor_ref()).list_cached_models(
1589
+ model_name, worker_ip
1590
+ )
1591
+ resp = {
1592
+ "list": data,
1593
+ }
1594
+ return JSONResponse(content=resp)
1502
1595
  except ValueError as re:
1503
1596
  logger.error(re, exc_info=True)
1504
1597
  raise HTTPException(status_code=400, detail=str(re))
@@ -1518,6 +1611,15 @@ class RESTfulAPI:
1518
1611
  logger.error(e, exc_info=True)
1519
1612
  raise HTTPException(status_code=500, detail=str(e))
1520
1613
 
1614
+ async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
1615
+ try:
1616
+ supervisor_ref = await self._get_supervisor_ref()
1617
+ res = await supervisor_ref.abort_request(model_uid, request_id)
1618
+ return JSONResponse(content=res)
1619
+ except Exception as e:
1620
+ logger.error(e, exc_info=True)
1621
+ raise HTTPException(status_code=500, detail=str(e))
1622
+
1521
1623
  async def list_vllm_supported_model_families(self) -> JSONResponse:
1522
1624
  try:
1523
1625
  from ..model.llm.vllm.core import (
@@ -1554,6 +1656,41 @@ class RESTfulAPI:
1554
1656
  logger.error(e, exc_info=True)
1555
1657
  raise HTTPException(status_code=500, detail=str(e))
1556
1658
 
1659
+ async def list_model_files(
1660
+ self, model_version: str = Query(None), worker_ip: str = Query(None)
1661
+ ) -> JSONResponse:
1662
+ try:
1663
+ data = await (await self._get_supervisor_ref()).list_deletable_models(
1664
+ model_version, worker_ip
1665
+ )
1666
+ response = {
1667
+ "model_version": model_version,
1668
+ "worker_ip": worker_ip,
1669
+ "paths": data,
1670
+ }
1671
+ return JSONResponse(content=response)
1672
+ except ValueError as re:
1673
+ logger.error(re, exc_info=True)
1674
+ raise HTTPException(status_code=400, detail=str(re))
1675
+ except Exception as e:
1676
+ logger.error(e, exc_info=True)
1677
+ raise HTTPException(status_code=500, detail=str(e))
1678
+
1679
+ async def confirm_and_remove_model(
1680
+ self, model_version: str = Query(None), worker_ip: str = Query(None)
1681
+ ) -> JSONResponse:
1682
+ try:
1683
+ res = await (await self._get_supervisor_ref()).confirm_and_remove_model(
1684
+ model_version=model_version, worker_ip=worker_ip
1685
+ )
1686
+ return JSONResponse(content={"result": res})
1687
+ except ValueError as re:
1688
+ logger.error(re, exc_info=True)
1689
+ raise HTTPException(status_code=400, detail=str(re))
1690
+ except Exception as e:
1691
+ logger.error(e, exc_info=True)
1692
+ raise HTTPException(status_code=500, detail=str(e))
1693
+
1557
1694
 
1558
1695
  def run(
1559
1696
  supervisor_address: str,
@@ -684,6 +684,49 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
684
684
  response_data = response.json()
685
685
  return response_data
686
686
 
687
+ def speech(
688
+ self,
689
+ input: str,
690
+ voice: str = "",
691
+ response_format: str = "mp3",
692
+ speed: float = 1.0,
693
+ ):
694
+ """
695
+ Generates audio from the input text.
696
+
697
+ Parameters
698
+ ----------
699
+
700
+ input: str
701
+ The text to generate audio for. The maximum length is 4096 characters.
702
+ voice: str
703
+ The voice to use when generating the audio.
704
+ response_format: str
705
+ The format to audio in.
706
+ speed: str
707
+ The speed of the generated audio.
708
+
709
+ Returns
710
+ -------
711
+ bytes
712
+ The generated audio binary.
713
+ """
714
+ url = f"{self._base_url}/v1/audio/speech"
715
+ params = {
716
+ "model": self._model_uid,
717
+ "input": input,
718
+ "voice": voice,
719
+ "response_format": response_format,
720
+ "speed": speed,
721
+ }
722
+ response = requests.post(url, json=params, headers=self.auth_headers)
723
+ if response.status_code != 200:
724
+ raise RuntimeError(
725
+ f"Failed to speech the text, detail: {_get_error_string(response)}"
726
+ )
727
+
728
+ return response.content
729
+
687
730
 
688
731
  class Client:
689
732
  def __init__(self, base_url, api_key: Optional[str] = None):
@@ -1102,13 +1145,17 @@ class Client:
1102
1145
  response_data = response.json()
1103
1146
  return response_data
1104
1147
 
1105
- def list_cached_models(self) -> List[Dict[Any, Any]]:
1148
+ def list_cached_models(
1149
+ self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
1150
+ ) -> List[Dict[Any, Any]]:
1106
1151
  """
1107
1152
  Get a list of cached models.
1108
-
1109
1153
  Parameters
1110
1154
  ----------
1111
- None
1155
+ model_name: Optional[str]
1156
+ The name of model.
1157
+ worker_ip: Optional[str]
1158
+ Specify the worker ip where the model is located in a distributed scenario.
1112
1159
 
1113
1160
  Returns
1114
1161
  -------
@@ -1121,16 +1168,81 @@ class Client:
1121
1168
  Raised when the request fails, including the reason for the failure.
1122
1169
  """
1123
1170
 
1124
- url = f"{self.base_url}/v1/cached/list_cached_models"
1125
- response = requests.get(url, headers=self._headers)
1171
+ url = f"{self.base_url}/v1/cache/models"
1172
+ params = {
1173
+ "model_name": model_name,
1174
+ "worker_ip": worker_ip,
1175
+ }
1176
+ response = requests.get(url, headers=self._headers, params=params)
1126
1177
  if response.status_code != 200:
1127
1178
  raise RuntimeError(
1128
1179
  f"Failed to list cached model, detail: {_get_error_string(response)}"
1129
1180
  )
1130
1181
 
1182
+ response_data = response.json()
1183
+ response_data = response_data.get("list")
1184
+ return response_data
1185
+
1186
+ def list_deletable_models(
1187
+ self, model_version: str, worker_ip: Optional[str] = None
1188
+ ) -> Dict[str, Any]:
1189
+ """
1190
+ Get the cached models with the model path cached on the server.
1191
+ Parameters
1192
+ ----------
1193
+ model_version: str
1194
+ The version of the model.
1195
+ worker_ip: Optional[str]
1196
+ Specify the worker ip where the model is located in a distributed scenario.
1197
+ Returns
1198
+ -------
1199
+ Dict[str, Dict[str,str]]]
1200
+ Dictionary with keys "model_name" and values model_file_location.
1201
+ """
1202
+ url = f"{self.base_url}/v1/cache/models/files"
1203
+ params = {
1204
+ "model_version": model_version,
1205
+ "worker_ip": worker_ip,
1206
+ }
1207
+ response = requests.get(url, headers=self._headers, params=params)
1208
+ if response.status_code != 200:
1209
+ raise RuntimeError(
1210
+ f"Failed to get paths by model name, detail: {_get_error_string(response)}"
1211
+ )
1212
+
1131
1213
  response_data = response.json()
1132
1214
  return response_data
1133
1215
 
1216
+ def confirm_and_remove_model(
1217
+ self, model_version: str, worker_ip: Optional[str] = None
1218
+ ) -> bool:
1219
+ """
1220
+ Remove the cached models with the model name cached on the server.
1221
+ Parameters
1222
+ ----------
1223
+ model_version: str
1224
+ The version of the model.
1225
+ worker_ip: Optional[str]
1226
+ Specify the worker ip where the model is located in a distributed scenario.
1227
+ Returns
1228
+ -------
1229
+ str
1230
+ The response of the server.
1231
+ """
1232
+ url = f"{self.base_url}/v1/cache/models"
1233
+ params = {
1234
+ "model_version": model_version,
1235
+ "worker_ip": worker_ip,
1236
+ }
1237
+ response = requests.delete(url, headers=self._headers, params=params)
1238
+ if response.status_code != 200:
1239
+ raise RuntimeError(
1240
+ f"Failed to remove cached models, detail: {_get_error_string(response)}"
1241
+ )
1242
+
1243
+ response_data = response.json()
1244
+ return response_data.get("result", False)
1245
+
1134
1246
  def get_model_registration(
1135
1247
  self, model_type: str, model_name: str
1136
1248
  ) -> Dict[str, Any]:
@@ -1181,3 +1293,30 @@ class Client:
1181
1293
 
1182
1294
  response_data = response.json()
1183
1295
  return response_data
1296
+
1297
+ def abort_request(self, model_uid: str, request_id: str):
1298
+ """
1299
+ Abort a request.
1300
+ Abort a submitted request. If the request is finished or not found, this method will be a no-op.
1301
+ Currently, this interface is only supported when batching is enabled for models on transformers backend.
1302
+
1303
+ Parameters
1304
+ ----------
1305
+ model_uid: str
1306
+ Model uid.
1307
+ request_id: str
1308
+ Request id.
1309
+ Returns
1310
+ -------
1311
+ Dict
1312
+ Return empty dict.
1313
+ """
1314
+ url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
1315
+ response = requests.post(url, headers=self._headers)
1316
+ if response.status_code != 200:
1317
+ raise RuntimeError(
1318
+ f"Failed to abort request, detail: {_get_error_string(response)}"
1319
+ )
1320
+
1321
+ response_data = response.json()
1322
+ return response_data
xinference/constants.py CHANGED
@@ -17,6 +17,7 @@ from pathlib import Path
17
17
 
18
18
  XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
19
19
  XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
20
+ XINFERENCE_ENV_CSG_TOKEN = "XINFERENCE_CSG_TOKEN"
20
21
  XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
21
22
  XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
22
23
  "XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
@@ -27,6 +28,7 @@ XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
27
28
  XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
28
29
  XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
29
30
  XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
31
+ XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
30
32
 
31
33
 
32
34
  def get_xinference_home() -> str:
@@ -70,3 +72,6 @@ XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG,
70
72
  XINFERENCE_DISABLE_METRICS = bool(
71
73
  int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
72
74
  )
75
+ XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
76
+ int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
77
+ )
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import os
15
14
  from logging import getLogger
16
15
  from typing import Any, Dict, List, Optional
17
16
 
@@ -102,33 +101,54 @@ class CacheTrackerActor(xo.Actor):
102
101
  def get_model_version_count(self, model_name: str) -> int:
103
102
  return len(self.get_model_versions(model_name))
104
103
 
105
- def list_cached_models(self) -> List[Dict[Any, Any]]:
104
+ def list_cached_models(
105
+ self, worker_ip: str, model_name: Optional[str] = None
106
+ ) -> List[Dict[Any, Any]]:
106
107
  cached_models = []
107
- for model_name, model_versions in self._model_name_to_version_info.items():
108
- for version_info in model_versions:
109
- cache_status = version_info.get("cache_status", None)
110
- if cache_status == True:
111
- ret = version_info.copy()
112
- ret["model_name"] = model_name
108
+ for name, versions in self._model_name_to_version_info.items():
109
+ # only return assigned cached model if model_name is not none
110
+ # else return all cached model
111
+ if model_name and model_name != name:
112
+ continue
113
+ for version_info in versions:
114
+ cache_status = version_info.get("cache_status", False)
115
+ # search cached model
116
+ if cache_status:
117
+ res = version_info.copy()
118
+ res["model_name"] = name
119
+ paths = res.get("model_file_location", {})
120
+ # only return assigned worker's device path
121
+ if worker_ip in paths.keys():
122
+ res["model_file_location"] = paths[worker_ip]
123
+ cached_models.append(res)
124
+ return cached_models
113
125
 
114
- re_dict = version_info.get("model_file_location", None)
115
- if re_dict is not None and isinstance(re_dict, dict):
116
- if re_dict:
117
- actor_ip_address, path = next(iter(re_dict.items()))
118
- else:
119
- raise ValueError("The dictionary is empty.")
120
- else:
121
- raise ValueError("re_dict must be a non-empty dictionary.")
126
+ def list_deletable_models(self, model_version: str, worker_ip: str) -> str:
127
+ model_file_location = ""
128
+ for model, model_versions in self._model_name_to_version_info.items():
129
+ for version_info in model_versions:
130
+ # search assign model version
131
+ if model_version == version_info.get("model_version", None):
132
+ # check if exist
133
+ if version_info.get("cache_status", False):
134
+ paths = version_info.get("model_file_location", {})
135
+ # only return assigned worker's device path
136
+ if worker_ip in paths.keys():
137
+ model_file_location = paths[worker_ip]
138
+ return model_file_location
122
139
 
123
- ret["actor_ip_address"] = actor_ip_address
124
- ret["path"] = path
125
- if os.path.isdir(path):
126
- files = os.listdir(path)
127
- resolved_file = os.path.realpath(os.path.join(path, files[0]))
128
- if resolved_file:
129
- ret["real_path"] = os.path.dirname(resolved_file)
130
- else:
131
- ret["real_path"] = os.path.realpath(path)
132
- cached_models.append(ret)
133
- cached_models = sorted(cached_models, key=lambda x: x["model_name"])
134
- return cached_models
140
+ def confirm_and_remove_model(self, model_version: str, worker_ip: str):
141
+ # find remove path
142
+ rm_path = self.list_deletable_models(model_version, worker_ip)
143
+ # search _model_name_to_version_info if exist this path, and delete
144
+ for model, model_versions in self._model_name_to_version_info.items():
145
+ for version_info in model_versions:
146
+ # check if exist
147
+ if version_info.get("cache_status", False):
148
+ paths = version_info.get("model_file_location", {})
149
+ # only delete assigned worker's device path
150
+ if worker_ip in paths.keys() and rm_path == paths[worker_ip]:
151
+ del paths[worker_ip]
152
+ # if path is empty, update cache status
153
+ if not paths:
154
+ version_info["cache_status"] = False