xinference 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (67) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +74 -6
  3. xinference/client/restful/restful_client.py +74 -5
  4. xinference/constants.py +1 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +54 -42
  7. xinference/core/scheduler.py +34 -16
  8. xinference/core/supervisor.py +73 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/model/audio/__init__.py +14 -1
  13. xinference/model/audio/core.py +12 -1
  14. xinference/model/audio/custom.py +6 -4
  15. xinference/model/audio/model_spec_modelscope.json +20 -0
  16. xinference/model/llm/__init__.py +34 -2
  17. xinference/model/llm/llm_family.json +2 -0
  18. xinference/model/llm/llm_family.py +86 -1
  19. xinference/model/llm/llm_family_csghub.json +66 -0
  20. xinference/model/llm/llm_family_modelscope.json +2 -0
  21. xinference/model/llm/pytorch/chatglm.py +18 -12
  22. xinference/model/llm/pytorch/core.py +92 -42
  23. xinference/model/llm/pytorch/glm4v.py +13 -3
  24. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  25. xinference/model/llm/pytorch/utils.py +27 -14
  26. xinference/model/llm/utils.py +14 -13
  27. xinference/model/llm/vllm/core.py +10 -4
  28. xinference/model/utils.py +8 -2
  29. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  30. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  31. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  32. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  33. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  34. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  35. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  36. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  38. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  39. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  40. xinference/web/ui/build/asset-manifest.json +6 -6
  41. xinference/web/ui/build/index.html +1 -1
  42. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  43. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  44. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  45. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  51. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/METADATA +1 -1
  52. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/RECORD +57 -45
  53. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  54. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  55. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  56. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  63. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  64. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  65. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  66. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  67. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-06-07T15:04:33+0800",
11
+ "date": "2024-06-14T17:17:50+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "55c5636f2b6022842d1827eae373c8e5f162a1a3",
15
- "version": "0.12.0"
14
+ "full-revisionid": "34a57df449f0890415c424802d3596f3c8758412",
15
+ "version": "0.12.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -522,11 +522,31 @@ class RESTfulAPI:
522
522
  ),
523
523
  )
524
524
  self._router.add_api_route(
525
- "/v1/cached/list_cached_models",
525
+ "/v1/cache/models",
526
526
  self.list_cached_models,
527
527
  methods=["GET"],
528
528
  dependencies=(
529
- [Security(self._auth_service, scopes=["models:list"])]
529
+ [Security(self._auth_service, scopes=["cache:list"])]
530
+ if self.is_authenticated()
531
+ else None
532
+ ),
533
+ )
534
+ self._router.add_api_route(
535
+ "/v1/cache/models/files",
536
+ self.list_model_files,
537
+ methods=["GET"],
538
+ dependencies=(
539
+ [Security(self._auth_service, scopes=["cache:list"])]
540
+ if self.is_authenticated()
541
+ else None
542
+ ),
543
+ )
544
+ self._router.add_api_route(
545
+ "/v1/cache/models",
546
+ self.confirm_and_remove_model,
547
+ methods=["DELETE"],
548
+ dependencies=(
549
+ [Security(self._auth_service, scopes=["cache:delete"])]
530
550
  if self.is_authenticated()
531
551
  else None
532
552
  ),
@@ -1401,9 +1421,11 @@ class RESTfulAPI:
1401
1421
  model_family = desc.get("model_family", "")
1402
1422
  function_call_models = [
1403
1423
  "chatglm3",
1424
+ "glm4-chat",
1404
1425
  "gorilla-openfunctions-v1",
1405
1426
  "qwen-chat",
1406
1427
  "qwen1.5-chat",
1428
+ "qwen2-instruct",
1407
1429
  ]
1408
1430
 
1409
1431
  is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
@@ -1426,7 +1448,11 @@ class RESTfulAPI:
1426
1448
  )
1427
1449
  if body.tools and body.stream:
1428
1450
  is_vllm = await model.is_vllm_backend()
1429
- if not is_vllm or model_family not in ["qwen-chat", "qwen1.5-chat"]:
1451
+ if not is_vllm or model_family not in [
1452
+ "qwen-chat",
1453
+ "qwen1.5-chat",
1454
+ "qwen2-instruct",
1455
+ ]:
1430
1456
  raise HTTPException(
1431
1457
  status_code=400,
1432
1458
  detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
@@ -1555,10 +1581,17 @@ class RESTfulAPI:
1555
1581
  logger.error(e, exc_info=True)
1556
1582
  raise HTTPException(status_code=500, detail=str(e))
1557
1583
 
1558
- async def list_cached_models(self) -> JSONResponse:
1584
+ async def list_cached_models(
1585
+ self, model_name: str = Query(None), worker_ip: str = Query(None)
1586
+ ) -> JSONResponse:
1559
1587
  try:
1560
- data = await (await self._get_supervisor_ref()).list_cached_models()
1561
- return JSONResponse(content=data)
1588
+ data = await (await self._get_supervisor_ref()).list_cached_models(
1589
+ model_name, worker_ip
1590
+ )
1591
+ resp = {
1592
+ "list": data,
1593
+ }
1594
+ return JSONResponse(content=resp)
1562
1595
  except ValueError as re:
1563
1596
  logger.error(re, exc_info=True)
1564
1597
  raise HTTPException(status_code=400, detail=str(re))
@@ -1623,6 +1656,41 @@ class RESTfulAPI:
1623
1656
  logger.error(e, exc_info=True)
1624
1657
  raise HTTPException(status_code=500, detail=str(e))
1625
1658
 
1659
+ async def list_model_files(
1660
+ self, model_version: str = Query(None), worker_ip: str = Query(None)
1661
+ ) -> JSONResponse:
1662
+ try:
1663
+ data = await (await self._get_supervisor_ref()).list_deletable_models(
1664
+ model_version, worker_ip
1665
+ )
1666
+ response = {
1667
+ "model_version": model_version,
1668
+ "worker_ip": worker_ip,
1669
+ "paths": data,
1670
+ }
1671
+ return JSONResponse(content=response)
1672
+ except ValueError as re:
1673
+ logger.error(re, exc_info=True)
1674
+ raise HTTPException(status_code=400, detail=str(re))
1675
+ except Exception as e:
1676
+ logger.error(e, exc_info=True)
1677
+ raise HTTPException(status_code=500, detail=str(e))
1678
+
1679
+ async def confirm_and_remove_model(
1680
+ self, model_version: str = Query(None), worker_ip: str = Query(None)
1681
+ ) -> JSONResponse:
1682
+ try:
1683
+ res = await (await self._get_supervisor_ref()).confirm_and_remove_model(
1684
+ model_version=model_version, worker_ip=worker_ip
1685
+ )
1686
+ return JSONResponse(content={"result": res})
1687
+ except ValueError as re:
1688
+ logger.error(re, exc_info=True)
1689
+ raise HTTPException(status_code=400, detail=str(re))
1690
+ except Exception as e:
1691
+ logger.error(e, exc_info=True)
1692
+ raise HTTPException(status_code=500, detail=str(e))
1693
+
1626
1694
 
1627
1695
  def run(
1628
1696
  supervisor_address: str,
@@ -1145,13 +1145,17 @@ class Client:
1145
1145
  response_data = response.json()
1146
1146
  return response_data
1147
1147
 
1148
- def list_cached_models(self) -> List[Dict[Any, Any]]:
1148
+ def list_cached_models(
1149
+ self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
1150
+ ) -> List[Dict[Any, Any]]:
1149
1151
  """
1150
1152
  Get a list of cached models.
1151
-
1152
1153
  Parameters
1153
1154
  ----------
1154
- None
1155
+ model_name: Optional[str]
1156
+ The name of model.
1157
+ worker_ip: Optional[str]
1158
+ Specify the worker ip where the model is located in a distributed scenario.
1155
1159
 
1156
1160
  Returns
1157
1161
  -------
@@ -1164,16 +1168,81 @@ class Client:
1164
1168
  Raised when the request fails, including the reason for the failure.
1165
1169
  """
1166
1170
 
1167
- url = f"{self.base_url}/v1/cached/list_cached_models"
1168
- response = requests.get(url, headers=self._headers)
1171
+ url = f"{self.base_url}/v1/cache/models"
1172
+ params = {
1173
+ "model_name": model_name,
1174
+ "worker_ip": worker_ip,
1175
+ }
1176
+ response = requests.get(url, headers=self._headers, params=params)
1169
1177
  if response.status_code != 200:
1170
1178
  raise RuntimeError(
1171
1179
  f"Failed to list cached model, detail: {_get_error_string(response)}"
1172
1180
  )
1173
1181
 
1182
+ response_data = response.json()
1183
+ response_data = response_data.get("list")
1184
+ return response_data
1185
+
1186
+ def list_deletable_models(
1187
+ self, model_version: str, worker_ip: Optional[str] = None
1188
+ ) -> Dict[str, Any]:
1189
+ """
1190
+ Get the cached models with the model path cached on the server.
1191
+ Parameters
1192
+ ----------
1193
+ model_version: str
1194
+ The version of the model.
1195
+ worker_ip: Optional[str]
1196
+ Specify the worker ip where the model is located in a distributed scenario.
1197
+ Returns
1198
+ -------
1199
+ Dict[str, Dict[str,str]]]
1200
+ Dictionary with keys "model_name" and values model_file_location.
1201
+ """
1202
+ url = f"{self.base_url}/v1/cache/models/files"
1203
+ params = {
1204
+ "model_version": model_version,
1205
+ "worker_ip": worker_ip,
1206
+ }
1207
+ response = requests.get(url, headers=self._headers, params=params)
1208
+ if response.status_code != 200:
1209
+ raise RuntimeError(
1210
+ f"Failed to get paths by model name, detail: {_get_error_string(response)}"
1211
+ )
1212
+
1174
1213
  response_data = response.json()
1175
1214
  return response_data
1176
1215
 
1216
+ def confirm_and_remove_model(
1217
+ self, model_version: str, worker_ip: Optional[str] = None
1218
+ ) -> bool:
1219
+ """
1220
+ Remove the cached models with the model name cached on the server.
1221
+ Parameters
1222
+ ----------
1223
+ model_version: str
1224
+ The version of the model.
1225
+ worker_ip: Optional[str]
1226
+ Specify the worker ip where the model is located in a distributed scenario.
1227
+ Returns
1228
+ -------
1229
+ str
1230
+ The response of the server.
1231
+ """
1232
+ url = f"{self.base_url}/v1/cache/models"
1233
+ params = {
1234
+ "model_version": model_version,
1235
+ "worker_ip": worker_ip,
1236
+ }
1237
+ response = requests.delete(url, headers=self._headers, params=params)
1238
+ if response.status_code != 200:
1239
+ raise RuntimeError(
1240
+ f"Failed to remove cached models, detail: {_get_error_string(response)}"
1241
+ )
1242
+
1243
+ response_data = response.json()
1244
+ return response_data.get("result", False)
1245
+
1177
1246
  def get_model_registration(
1178
1247
  self, model_type: str, model_name: str
1179
1248
  ) -> Dict[str, Any]:
xinference/constants.py CHANGED
@@ -17,6 +17,7 @@ from pathlib import Path
17
17
 
18
18
  XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
19
19
  XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
20
+ XINFERENCE_ENV_CSG_TOKEN = "XINFERENCE_CSG_TOKEN"
20
21
  XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
21
22
  XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
22
23
  "XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import os
15
14
  from logging import getLogger
16
15
  from typing import Any, Dict, List, Optional
17
16
 
@@ -102,33 +101,54 @@ class CacheTrackerActor(xo.Actor):
102
101
  def get_model_version_count(self, model_name: str) -> int:
103
102
  return len(self.get_model_versions(model_name))
104
103
 
105
- def list_cached_models(self) -> List[Dict[Any, Any]]:
104
+ def list_cached_models(
105
+ self, worker_ip: str, model_name: Optional[str] = None
106
+ ) -> List[Dict[Any, Any]]:
106
107
  cached_models = []
107
- for model_name, model_versions in self._model_name_to_version_info.items():
108
- for version_info in model_versions:
109
- cache_status = version_info.get("cache_status", None)
110
- if cache_status == True:
111
- ret = version_info.copy()
112
- ret["model_name"] = model_name
108
+ for name, versions in self._model_name_to_version_info.items():
109
+ # only return assigned cached model if model_name is not none
110
+ # else return all cached model
111
+ if model_name and model_name != name:
112
+ continue
113
+ for version_info in versions:
114
+ cache_status = version_info.get("cache_status", False)
115
+ # search cached model
116
+ if cache_status:
117
+ res = version_info.copy()
118
+ res["model_name"] = name
119
+ paths = res.get("model_file_location", {})
120
+ # only return assigned worker's device path
121
+ if worker_ip in paths.keys():
122
+ res["model_file_location"] = paths[worker_ip]
123
+ cached_models.append(res)
124
+ return cached_models
113
125
 
114
- re_dict = version_info.get("model_file_location", None)
115
- if re_dict is not None and isinstance(re_dict, dict):
116
- if re_dict:
117
- actor_ip_address, path = next(iter(re_dict.items()))
118
- else:
119
- raise ValueError("The dictionary is empty.")
120
- else:
121
- raise ValueError("re_dict must be a non-empty dictionary.")
126
+ def list_deletable_models(self, model_version: str, worker_ip: str) -> str:
127
+ model_file_location = ""
128
+ for model, model_versions in self._model_name_to_version_info.items():
129
+ for version_info in model_versions:
130
+ # search assign model version
131
+ if model_version == version_info.get("model_version", None):
132
+ # check if exist
133
+ if version_info.get("cache_status", False):
134
+ paths = version_info.get("model_file_location", {})
135
+ # only return assigned worker's device path
136
+ if worker_ip in paths.keys():
137
+ model_file_location = paths[worker_ip]
138
+ return model_file_location
122
139
 
123
- ret["actor_ip_address"] = actor_ip_address
124
- ret["path"] = path
125
- if os.path.isdir(path):
126
- files = os.listdir(path)
127
- resolved_file = os.path.realpath(os.path.join(path, files[0]))
128
- if resolved_file:
129
- ret["real_path"] = os.path.dirname(resolved_file)
130
- else:
131
- ret["real_path"] = os.path.realpath(path)
132
- cached_models.append(ret)
133
- cached_models = sorted(cached_models, key=lambda x: x["model_name"])
134
- return cached_models
140
+ def confirm_and_remove_model(self, model_version: str, worker_ip: str):
141
+ # find remove path
142
+ rm_path = self.list_deletable_models(model_version, worker_ip)
143
+ # search _model_name_to_version_info if exist this path, and delete
144
+ for model, model_versions in self._model_name_to_version_info.items():
145
+ for version_info in model_versions:
146
+ # check if exist
147
+ if version_info.get("cache_status", False):
148
+ paths = version_info.get("model_file_location", {})
149
+ # only delete assigned worker's device path
150
+ if worker_ip in paths.keys() and rm_path == paths[worker_ip]:
151
+ del paths[worker_ip]
152
+ # if path is empty, update cache status
153
+ if not paths:
154
+ version_info["cache_status"] = False
xinference/core/model.py CHANGED
@@ -264,12 +264,13 @@ class ModelActor(xo.StatelessActor):
264
264
  return isinstance(self._model, VLLMModel)
265
265
 
266
266
  def allow_batching(self) -> bool:
267
- from ..model.llm.pytorch.core import PytorchChatModel
267
+ from ..model.llm.pytorch.core import PytorchChatModel, PytorchModel
268
268
 
269
269
  return (
270
270
  XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
271
- and isinstance(self._model, PytorchChatModel)
272
- and self._model.__class__.__name__ == PytorchChatModel.__name__
271
+ and isinstance(self._model, PytorchModel)
272
+ and self._model.__class__.__name__
273
+ in (PytorchChatModel.__name__, PytorchModel.__name__)
273
274
  )
274
275
 
275
276
  async def load(self):
@@ -393,18 +394,24 @@ class ModelActor(xo.StatelessActor):
393
394
  @request_limit
394
395
  @xo.generator
395
396
  async def generate(self, prompt: str, *args, **kwargs):
396
- if hasattr(self._model, "generate"):
397
- return await self._call_wrapper(
398
- self._model.generate, prompt, *args, **kwargs
399
- )
400
- if hasattr(self._model, "async_generate"):
401
- return await self._call_wrapper(
402
- self._model.async_generate, prompt, *args, **kwargs
397
+ if self.allow_batching():
398
+ return await self.handle_batching_request(
399
+ prompt, "generate", *args, **kwargs
403
400
  )
404
- raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
401
+ else:
402
+ if hasattr(self._model, "generate"):
403
+ return await self._call_wrapper(
404
+ self._model.generate, prompt, *args, **kwargs
405
+ )
406
+ if hasattr(self._model, "async_generate"):
407
+ return await self._call_wrapper(
408
+ self._model.async_generate, prompt, *args, **kwargs
409
+ )
410
+ raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
405
411
 
412
+ @staticmethod
406
413
  async def _queue_consumer(
407
- self, queue: Queue, timeout: Optional[float] = None
414
+ queue: Queue, timeout: Optional[float] = None
408
415
  ) -> AsyncIterator[Any]:
409
416
  from .scheduler import (
410
417
  XINFERENCE_STREAMING_ABORT_FLAG,
@@ -429,9 +436,38 @@ class ModelActor(xo.StatelessActor):
429
436
  yield res
430
437
 
431
438
  @staticmethod
432
- def get_stream_from_args(*args) -> bool:
433
- assert args[2] is None or isinstance(args[2], dict)
434
- return False if args[2] is None else args[2].get("stream", False)
439
+ def _get_stream_from_args(ability: str, *args) -> bool:
440
+ if ability == "chat":
441
+ assert args[2] is None or isinstance(args[2], dict)
442
+ return False if args[2] is None else args[2].get("stream", False)
443
+ else:
444
+ assert args[0] is None or isinstance(args[0], dict)
445
+ return False if args[0] is None else args[0].get("stream", False)
446
+
447
+ async def handle_batching_request(self, prompt: str, ability: str, *args, **kwargs):
448
+ stream = self._get_stream_from_args(ability, *args)
449
+ assert self._scheduler_ref is not None
450
+ if stream:
451
+ assert self._scheduler_ref is not None
452
+ queue: Queue[Any] = Queue()
453
+ ret = self._queue_consumer(queue)
454
+ await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
455
+ gen = self._to_json_async_gen(ret)
456
+ self._current_generator = weakref.ref(gen)
457
+ return gen
458
+ else:
459
+ from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
460
+
461
+ assert self._loop is not None
462
+ future = ConcurrentFuture()
463
+ await self._scheduler_ref.add_request(prompt, future, *args, **kwargs)
464
+ fut = asyncio.wrap_future(future, loop=self._loop)
465
+ result = await fut
466
+ if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
467
+ raise RuntimeError(
468
+ f"This request has been cancelled by another `abort_request` request."
469
+ )
470
+ return await asyncio.to_thread(json_dumps, result)
435
471
 
436
472
  @log_async(logger=logger)
437
473
  @request_limit
@@ -441,33 +477,9 @@ class ModelActor(xo.StatelessActor):
441
477
  response = None
442
478
  try:
443
479
  if self.allow_batching():
444
- stream = self.get_stream_from_args(*args)
445
- assert self._scheduler_ref is not None
446
- if stream:
447
- assert self._scheduler_ref is not None
448
- queue: Queue[Any] = Queue()
449
- ret = self._queue_consumer(queue)
450
- await self._scheduler_ref.add_request(
451
- prompt, queue, *args, **kwargs
452
- )
453
- gen = self._to_json_async_gen(ret)
454
- self._current_generator = weakref.ref(gen)
455
- return gen
456
- else:
457
- from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
458
-
459
- assert self._loop is not None
460
- future = ConcurrentFuture()
461
- await self._scheduler_ref.add_request(
462
- prompt, future, *args, **kwargs
463
- )
464
- fut = asyncio.wrap_future(future, loop=self._loop)
465
- result = await fut
466
- if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
467
- raise RuntimeError(
468
- f"This request has been cancelled by another `abort_request` request."
469
- )
470
- return await asyncio.to_thread(json_dumps, result)
480
+ return await self.handle_batching_request(
481
+ prompt, "chat", *args, **kwargs
482
+ )
471
483
  else:
472
484
  if hasattr(self._model, "chat"):
473
485
  response = await self._call_wrapper(
@@ -15,6 +15,7 @@
15
15
  import asyncio
16
16
  import functools
17
17
  import logging
18
+ import uuid
18
19
  from collections import deque
19
20
  from enum import Enum
20
21
  from typing import List, Optional, Set
@@ -50,9 +51,9 @@ class InferenceRequest:
50
51
  self._new_tokens = []
51
52
  # kv_cache used in decode phase
52
53
  self._kv_cache = None
53
- # use passed args from `chat` interface
54
+ # use passed args from upstream interface
54
55
  self._inference_args = args
55
- # use passed kwargs from `chat` interface, basically not used for now
56
+ # use passed kwargs from upstream interface, basically not used for now
56
57
  self._inference_kwargs = kwargs
57
58
  # should this request be stopped
58
59
  self._stopped = False
@@ -63,6 +64,8 @@ class InferenceRequest:
63
64
  self._aborted = False
64
65
  # sanitized generate config
65
66
  self._sanitized_generate_config = None
67
+ # Chunk id for results. In stream mode, all the chunk ids should be same.
68
+ self._stream_chunk_id = str(uuid.uuid4())
66
69
  # Use in stream mode
67
70
  self.last_output_length = 0
68
71
  # inference results,
@@ -81,19 +84,26 @@ class InferenceRequest:
81
84
  self._check_args()
82
85
 
83
86
  def _check_args(self):
84
- assert len(self._inference_args) == 3
85
- # system prompt
86
- assert self._inference_args[0] is None or isinstance(
87
- self._inference_args[0], str
88
- )
89
- # chat history
90
- assert self._inference_args[1] is None or isinstance(
91
- self._inference_args[1], list
92
- )
93
- # generate config
94
- assert self._inference_args[2] is None or isinstance(
95
- self._inference_args[2], dict
96
- )
87
+ # chat
88
+ if len(self._inference_args) == 3:
89
+ # system prompt
90
+ assert self._inference_args[0] is None or isinstance(
91
+ self._inference_args[0], str
92
+ )
93
+ # chat history
94
+ assert self._inference_args[1] is None or isinstance(
95
+ self._inference_args[1], list
96
+ )
97
+ # generate config
98
+ assert self._inference_args[2] is None or isinstance(
99
+ self._inference_args[2], dict
100
+ )
101
+ else: # generate
102
+ assert len(self._inference_args) == 1
103
+ # generate config
104
+ assert self._inference_args[0] is None or isinstance(
105
+ self._inference_args[0], dict
106
+ )
97
107
 
98
108
  @property
99
109
  def prompt(self):
@@ -148,7 +158,11 @@ class InferenceRequest:
148
158
 
149
159
  @property
150
160
  def generate_config(self):
151
- return self._inference_args[2]
161
+ return (
162
+ self._inference_args[2]
163
+ if len(self._inference_args) == 3
164
+ else self._inference_args[0]
165
+ )
152
166
 
153
167
  @property
154
168
  def sanitized_generate_config(self):
@@ -174,6 +188,10 @@ class InferenceRequest:
174
188
  def finish_reason(self, value: Optional[str]):
175
189
  self._finish_reason = value
176
190
 
191
+ @property
192
+ def chunk_id(self):
193
+ return self._stream_chunk_id
194
+
177
195
  @property
178
196
  def stream(self) -> bool:
179
197
  return (
@@ -982,32 +982,31 @@ class SupervisorActor(xo.StatelessActor):
982
982
  )
983
983
 
984
984
  @log_async(logger=logger)
985
- async def list_cached_models(self) -> List[Dict[str, Any]]:
985
+ async def list_cached_models(
986
+ self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
987
+ ) -> List[Dict[str, Any]]:
988
+ target_ip_worker_ref = (
989
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
990
+ )
991
+ if (
992
+ worker_ip is not None
993
+ and not self.is_local_deployment()
994
+ and target_ip_worker_ref is None
995
+ ):
996
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
997
+
998
+ # search assigned worker and return
999
+ if target_ip_worker_ref:
1000
+ cached_models = await target_ip_worker_ref.list_cached_models(model_name)
1001
+ cached_models = sorted(cached_models, key=lambda x: x["model_name"])
1002
+ return cached_models
1003
+
1004
+ # search all worker
986
1005
  cached_models = []
987
1006
  for worker in self._worker_address_to_worker.values():
988
- ret = await worker.list_cached_models()
989
- for model_version in ret:
990
- model_name = model_version.get("model_name", None)
991
- model_format = model_version.get("model_format", None)
992
- model_size_in_billions = model_version.get(
993
- "model_size_in_billions", None
994
- )
995
- quantizations = model_version.get("quantization", None)
996
- actor_ip_address = model_version.get("actor_ip_address", None)
997
- path = model_version.get("path", None)
998
- real_path = model_version.get("real_path", None)
999
-
1000
- cache_entry = {
1001
- "model_name": model_name,
1002
- "model_format": model_format,
1003
- "model_size_in_billions": model_size_in_billions,
1004
- "quantizations": quantizations,
1005
- "path": path,
1006
- "Actor IP Address": actor_ip_address,
1007
- "real_path": real_path,
1008
- }
1009
-
1010
- cached_models.append(cache_entry)
1007
+ res = await worker.list_cached_models(model_name)
1008
+ cached_models.extend(res)
1009
+ cached_models = sorted(cached_models, key=lambda x: x["model_name"])
1011
1010
  return cached_models
1012
1011
 
1013
1012
  @log_async(logger=logger)
@@ -1083,6 +1082,56 @@ class SupervisorActor(xo.StatelessActor):
1083
1082
  worker_status.update_time = time.time()
1084
1083
  worker_status.status = status
1085
1084
 
1085
+ async def list_deletable_models(
1086
+ self, model_version: str, worker_ip: Optional[str] = None
1087
+ ) -> List[str]:
1088
+ target_ip_worker_ref = (
1089
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
1090
+ )
1091
+ if (
1092
+ worker_ip is not None
1093
+ and not self.is_local_deployment()
1094
+ and target_ip_worker_ref is None
1095
+ ):
1096
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
1097
+
1098
+ ret = []
1099
+ if target_ip_worker_ref:
1100
+ ret = await target_ip_worker_ref.list_deletable_models(
1101
+ model_version=model_version,
1102
+ )
1103
+ return ret
1104
+
1105
+ for worker in self._worker_address_to_worker.values():
1106
+ path = await worker.list_deletable_models(model_version=model_version)
1107
+ ret.extend(path)
1108
+ return ret
1109
+
1110
+ async def confirm_and_remove_model(
1111
+ self, model_version: str, worker_ip: Optional[str] = None
1112
+ ) -> bool:
1113
+ target_ip_worker_ref = (
1114
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
1115
+ )
1116
+ if (
1117
+ worker_ip is not None
1118
+ and not self.is_local_deployment()
1119
+ and target_ip_worker_ref is None
1120
+ ):
1121
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
1122
+
1123
+ if target_ip_worker_ref:
1124
+ ret = await target_ip_worker_ref.confirm_and_remove_model(
1125
+ model_version=model_version,
1126
+ )
1127
+ return ret
1128
+ ret = True
1129
+ for worker in self._worker_address_to_worker.values():
1130
+ ret = ret and await worker.confirm_and_remove_model(
1131
+ model_version=model_version,
1132
+ )
1133
+ return ret
1134
+
1086
1135
  @staticmethod
1087
1136
  def record_metrics(name, op, kwargs):
1088
1137
  record_metrics(name, op, kwargs)