xinference 0.12.3__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (101) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +56 -8
  3. xinference/client/restful/restful_client.py +49 -4
  4. xinference/core/model.py +36 -4
  5. xinference/core/scheduler.py +2 -0
  6. xinference/core/supervisor.py +132 -15
  7. xinference/core/worker.py +239 -53
  8. xinference/deploy/cmdline.py +5 -0
  9. xinference/deploy/utils.py +33 -2
  10. xinference/model/audio/chattts.py +6 -6
  11. xinference/model/audio/core.py +23 -15
  12. xinference/model/core.py +12 -3
  13. xinference/model/embedding/core.py +25 -16
  14. xinference/model/flexible/__init__.py +40 -0
  15. xinference/model/flexible/core.py +228 -0
  16. xinference/model/flexible/launchers/__init__.py +15 -0
  17. xinference/model/flexible/launchers/transformers_launcher.py +63 -0
  18. xinference/model/flexible/utils.py +33 -0
  19. xinference/model/image/core.py +18 -14
  20. xinference/model/image/custom.py +1 -1
  21. xinference/model/llm/__init__.py +5 -2
  22. xinference/model/llm/core.py +3 -2
  23. xinference/model/llm/ggml/llamacpp.py +1 -10
  24. xinference/model/llm/llm_family.json +292 -36
  25. xinference/model/llm/llm_family.py +102 -53
  26. xinference/model/llm/llm_family_modelscope.json +247 -27
  27. xinference/model/llm/mlx/__init__.py +13 -0
  28. xinference/model/llm/mlx/core.py +408 -0
  29. xinference/model/llm/pytorch/chatglm.py +2 -9
  30. xinference/model/llm/pytorch/cogvlm2.py +206 -21
  31. xinference/model/llm/pytorch/core.py +213 -120
  32. xinference/model/llm/pytorch/glm4v.py +171 -15
  33. xinference/model/llm/pytorch/qwen_vl.py +168 -7
  34. xinference/model/llm/pytorch/utils.py +53 -62
  35. xinference/model/llm/utils.py +28 -7
  36. xinference/model/rerank/core.py +29 -25
  37. xinference/thirdparty/deepseek_vl/serve/__init__.py +13 -0
  38. xinference/thirdparty/deepseek_vl/serve/app_deepseek.py +510 -0
  39. xinference/thirdparty/deepseek_vl/serve/app_modules/__init__.py +13 -0
  40. xinference/thirdparty/deepseek_vl/serve/app_modules/gradio_utils.py +94 -0
  41. xinference/thirdparty/deepseek_vl/serve/app_modules/overwrites.py +81 -0
  42. xinference/thirdparty/deepseek_vl/serve/app_modules/presets.py +96 -0
  43. xinference/thirdparty/deepseek_vl/serve/app_modules/utils.py +229 -0
  44. xinference/thirdparty/deepseek_vl/serve/inference.py +170 -0
  45. xinference/types.py +0 -1
  46. xinference/web/ui/build/asset-manifest.json +3 -3
  47. xinference/web/ui/build/index.html +1 -1
  48. xinference/web/ui/build/static/js/main.95c1d652.js +3 -0
  49. xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
  65. {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/METADATA +10 -11
  66. {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/RECORD +71 -69
  67. xinference/model/llm/ggml/chatglm.py +0 -457
  68. xinference/thirdparty/ChatTTS/__init__.py +0 -1
  69. xinference/thirdparty/ChatTTS/core.py +0 -200
  70. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  71. xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
  72. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  73. xinference/thirdparty/ChatTTS/infer/api.py +0 -125
  74. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  75. xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
  76. xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
  77. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  78. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
  79. xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
  80. xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
  81. xinference/web/ui/build/static/js/main.77dd47c3.js +0 -3
  82. xinference/web/ui/build/static/js/main.77dd47c3.js.map +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/0cd591866aa345566e0b63fb51ff2043e163a770af6fdc2f3bad395d046353e2.json +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/37c1476717199863bbba1530e3513a9368f8f73001b75b4a85c2075956308027.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/46edc1fe657dfedb2e673148332bb442c6eb98f09f2592c389209e376510afa5.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/82db357f3fd5b32215d747ee593f69ff06c95ad6cde37f71a96c8290aaab64c0.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/bc6da27195ec4607bb472bf61f97c928ad4966fa64e4c2247661bedb7400abba.json +0 -1
  93. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +0 -1
  94. xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +0 -1
  95. xinference/web/ui/node_modules/.cache/babel-loader/f118f99c22b713c678c1209c4e1dd43fe86e3f6e801a4c0c35d3bbf41fd05fe6.json +0 -1
  96. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +0 -1
  97. /xinference/web/ui/build/static/js/{main.77dd47c3.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
  98. {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/LICENSE +0 -0
  99. {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/WHEEL +0 -0
  100. {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/entry_points.txt +0 -0
  101. {xinference-0.12.3.dist-info → xinference-0.13.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-06-28T15:25:07+0800",
11
+ "date": "2024-07-12T17:56:13+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "3d9c261a7d5c4941091d1711cb732ce17b34e7f1",
15
- "version": "0.12.3"
14
+ "full-revisionid": "5e3f254d48383f37d849dd16db564ad9449e5163",
15
+ "version": "0.13.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -133,6 +133,7 @@ class SpeechRequest(BaseModel):
133
133
 
134
134
  class RegisterModelRequest(BaseModel):
135
135
  model: str
136
+ worker_ip: Optional[str]
136
137
  persist: bool
137
138
 
138
139
 
@@ -501,6 +502,16 @@ class RESTfulAPI:
501
502
  else None
502
503
  ),
503
504
  )
505
+ self._router.add_api_route(
506
+ "/v1/flexible/infers",
507
+ self.create_flexible_infer,
508
+ methods=["POST"],
509
+ dependencies=(
510
+ [Security(self._auth_service, scopes=["models:read"])]
511
+ if self.is_authenticated()
512
+ else None
513
+ ),
514
+ )
504
515
 
505
516
  # for custom models
506
517
  self._router.add_api_route(
@@ -772,6 +783,7 @@ class RESTfulAPI:
772
783
  peft_model_config = payload.get("peft_model_config", None)
773
784
  worker_ip = payload.get("worker_ip", None)
774
785
  gpu_idx = payload.get("gpu_idx", None)
786
+ download_hub = payload.get("download_hub", None)
775
787
 
776
788
  exclude_keys = {
777
789
  "model_uid",
@@ -787,6 +799,7 @@ class RESTfulAPI:
787
799
  "peft_model_config",
788
800
  "worker_ip",
789
801
  "gpu_idx",
802
+ "download_hub",
790
803
  }
791
804
 
792
805
  kwargs = {
@@ -834,9 +847,9 @@ class RESTfulAPI:
834
847
  peft_model_config=peft_model_config,
835
848
  worker_ip=worker_ip,
836
849
  gpu_idx=gpu_idx,
850
+ download_hub=download_hub,
837
851
  **kwargs,
838
852
  )
839
-
840
853
  except ValueError as ve:
841
854
  logger.error(str(ve), exc_info=True)
842
855
  raise HTTPException(status_code=400, detail=str(ve))
@@ -1397,6 +1410,40 @@ class RESTfulAPI:
1397
1410
  await self._report_error_event(model_uid, str(e))
1398
1411
  raise HTTPException(status_code=500, detail=str(e))
1399
1412
 
1413
+ async def create_flexible_infer(self, request: Request) -> Response:
1414
+ payload = await request.json()
1415
+
1416
+ model_uid = payload.get("model")
1417
+
1418
+ exclude = {
1419
+ "model",
1420
+ }
1421
+ kwargs = {key: value for key, value in payload.items() if key not in exclude}
1422
+
1423
+ try:
1424
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1425
+ except ValueError as ve:
1426
+ logger.error(str(ve), exc_info=True)
1427
+ await self._report_error_event(model_uid, str(ve))
1428
+ raise HTTPException(status_code=400, detail=str(ve))
1429
+ except Exception as e:
1430
+ logger.error(e, exc_info=True)
1431
+ await self._report_error_event(model_uid, str(e))
1432
+ raise HTTPException(status_code=500, detail=str(e))
1433
+
1434
+ try:
1435
+ result = await model.infer(**kwargs)
1436
+ return Response(result, media_type="application/json")
1437
+ except RuntimeError as re:
1438
+ logger.error(re, exc_info=True)
1439
+ await self._report_error_event(model_uid, str(re))
1440
+ self.handle_request_limit_error(re)
1441
+ raise HTTPException(status_code=400, detail=str(re))
1442
+ except Exception as e:
1443
+ logger.error(e, exc_info=True)
1444
+ await self._report_error_event(model_uid, str(e))
1445
+ raise HTTPException(status_code=500, detail=str(e))
1446
+
1400
1447
  async def create_chat_completion(self, request: Request) -> Response:
1401
1448
  raw_body = await request.json()
1402
1449
  body = CreateChatCompletion.parse_obj(raw_body)
@@ -1477,14 +1524,14 @@ class RESTfulAPI:
1477
1524
  await self._report_error_event(model_uid, str(e))
1478
1525
  raise HTTPException(status_code=500, detail=str(e))
1479
1526
 
1480
- from ..model.llm.utils import QWEN_TOOL_CALL_FAMILY
1527
+ from ..model.llm.utils import GLM4_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY
1481
1528
 
1482
1529
  model_family = desc.get("model_family", "")
1483
- function_call_models = [
1484
- "chatglm3",
1485
- "glm4-chat",
1486
- "gorilla-openfunctions-v1",
1487
- ] + QWEN_TOOL_CALL_FAMILY
1530
+ function_call_models = (
1531
+ ["chatglm3", "gorilla-openfunctions-v1"]
1532
+ + QWEN_TOOL_CALL_FAMILY
1533
+ + GLM4_TOOL_CALL_FAMILY
1534
+ )
1488
1535
 
1489
1536
  is_qwen = desc.get("model_format") == "ggmlv3" and "qwen-chat" == model_family
1490
1537
 
@@ -1593,11 +1640,12 @@ class RESTfulAPI:
1593
1640
  async def register_model(self, model_type: str, request: Request) -> JSONResponse:
1594
1641
  body = RegisterModelRequest.parse_obj(await request.json())
1595
1642
  model = body.model
1643
+ worker_ip = body.worker_ip
1596
1644
  persist = body.persist
1597
1645
 
1598
1646
  try:
1599
1647
  await (await self._get_supervisor_ref()).register_model(
1600
- model_type, model, persist
1648
+ model_type, model, persist, worker_ip
1601
1649
  )
1602
1650
  except ValueError as re:
1603
1651
  logger.error(re, exc_info=True)
@@ -182,8 +182,6 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
182
182
  f"Failed to rerank documents, detail: {response.json()['detail']}"
183
183
  )
184
184
  response_data = response.json()
185
- for r in response_data["results"]:
186
- r["document"] = documents[r["index"]]
187
185
  return response_data
188
186
 
189
187
 
@@ -732,6 +730,41 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
732
730
  return response.content
733
731
 
734
732
 
733
+ class RESTfulFlexibleModelHandle(RESTfulModelHandle):
734
+ def infer(
735
+ self,
736
+ **kwargs,
737
+ ):
738
+ """
739
+ Call flexible model.
740
+
741
+ Parameters
742
+ ----------
743
+
744
+ kwargs: dict
745
+ The inference arguments.
746
+
747
+
748
+ Returns
749
+ -------
750
+ bytes
751
+ The inference result.
752
+ """
753
+ url = f"{self._base_url}/v1/flexible/infers"
754
+ params = {
755
+ "model": self._model_uid,
756
+ }
757
+ params.update(kwargs)
758
+
759
+ response = requests.post(url, json=params, headers=self.auth_headers)
760
+ if response.status_code != 200:
761
+ raise RuntimeError(
762
+ f"Failed to predict, detail: {_get_error_string(response)}"
763
+ )
764
+
765
+ return response.content
766
+
767
+
735
768
  class Client:
736
769
  def __init__(self, base_url, api_key: Optional[str] = None):
737
770
  self.base_url = base_url
@@ -1011,6 +1044,10 @@ class Client:
1011
1044
  return RESTfulAudioModelHandle(
1012
1045
  model_uid, self.base_url, auth_headers=self._headers
1013
1046
  )
1047
+ elif desc["model_type"] == "flexible":
1048
+ return RESTfulFlexibleModelHandle(
1049
+ model_uid, self.base_url, auth_headers=self._headers
1050
+ )
1014
1051
  else:
1015
1052
  raise ValueError(f"Unknown model type:{desc['model_type']}")
1016
1053
 
@@ -1064,7 +1101,13 @@ class Client:
1064
1101
  )
1065
1102
  return response.json()
1066
1103
 
1067
- def register_model(self, model_type: str, model: str, persist: bool):
1104
+ def register_model(
1105
+ self,
1106
+ model_type: str,
1107
+ model: str,
1108
+ persist: bool,
1109
+ worker_ip: Optional[str] = None,
1110
+ ):
1068
1111
  """
1069
1112
  Register a custom model.
1070
1113
 
@@ -1074,6 +1117,8 @@ class Client:
1074
1117
  The type of model.
1075
1118
  model: str
1076
1119
  The model definition. (refer to: https://inference.readthedocs.io/en/latest/models/custom.html)
1120
+ worker_ip: Optional[str]
1121
+ The IP address of the worker on which the model is running.
1077
1122
  persist: bool
1078
1123
 
1079
1124
 
@@ -1083,7 +1128,7 @@ class Client:
1083
1128
  Report failure to register the custom model. Provide details of failure through error message.
1084
1129
  """
1085
1130
  url = f"{self.base_url}/v1/model_registrations/{model_type}"
1086
- request_body = {"model": model, "persist": persist}
1131
+ request_body = {"model": model, "worker_ip": worker_ip, "persist": persist}
1087
1132
  response = requests.post(url, json=request_body, headers=self._headers)
1088
1133
  if response.status_code != 200:
1089
1134
  raise RuntimeError(
xinference/core/model.py CHANGED
@@ -65,6 +65,9 @@ except ImportError:
65
65
  OutOfMemoryError = _OutOfMemoryError
66
66
 
67
67
 
68
+ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = ["qwen-vl-chat", "cogvlm2", "glm-4v"]
69
+
70
+
68
71
  def request_limit(fn):
69
72
  """
70
73
  Used by ModelActor.
@@ -268,11 +271,25 @@ class ModelActor(xo.StatelessActor):
268
271
 
269
272
  model_ability = self._model_description.get("model_ability", [])
270
273
 
271
- return (
272
- XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
273
- and isinstance(self._model, PytorchModel)
274
- and "vision" not in model_ability
274
+ condition = XINFERENCE_TRANSFORMERS_ENABLE_BATCHING and isinstance(
275
+ self._model, PytorchModel
275
276
  )
277
+ if condition and "vision" in model_ability:
278
+ if (
279
+ self._model.model_family.model_name
280
+ in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
281
+ or self._model.model_family.model_family
282
+ in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
283
+ ):
284
+ return True
285
+ else:
286
+ logger.warning(
287
+ f"Currently for multimodal models, "
288
+ f"xinference only supports {', '.join(XINFERENCE_BATCHING_ALLOWED_VISION_MODELS)} for batching. "
289
+ f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
290
+ )
291
+ return False
292
+ return condition
276
293
 
277
294
  async def load(self):
278
295
  self._model.load()
@@ -680,6 +697,21 @@ class ModelActor(xo.StatelessActor):
680
697
  f"Model {self._model.model_spec} is not for creating image."
681
698
  )
682
699
 
700
+ @log_async(logger=logger)
701
+ @request_limit
702
+ async def infer(
703
+ self,
704
+ **kwargs,
705
+ ):
706
+ if hasattr(self._model, "infer"):
707
+ return await self._call_wrapper(
708
+ self._model.infer,
709
+ **kwargs,
710
+ )
711
+ raise AttributeError(
712
+ f"Model {self._model.model_spec} is not for flexible infer."
713
+ )
714
+
683
715
  async def record_metrics(self, name, op, kwargs):
684
716
  worker_ref = await self._get_worker_ref()
685
717
  await worker_ref.record_metrics(name, op, kwargs)
@@ -82,6 +82,8 @@ class InferenceRequest:
82
82
  # Record error message when this request has error.
83
83
  # Must set stopped=True when this field is set.
84
84
  self.error_msg: Optional[str] = None
85
+ # For compatibility. Record some extra parameters for some special cases.
86
+ self.extra_kwargs = {}
85
87
 
86
88
  # check the integrity of args passed upstream
87
89
  self._check_args()
@@ -20,7 +20,17 @@ import time
20
20
  import typing
21
21
  from dataclasses import dataclass
22
22
  from logging import getLogger
23
- from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
23
+ from typing import (
24
+ TYPE_CHECKING,
25
+ Any,
26
+ Dict,
27
+ Iterator,
28
+ List,
29
+ Literal,
30
+ Optional,
31
+ Tuple,
32
+ Union,
33
+ )
24
34
 
25
35
  import xoscar as xo
26
36
 
@@ -50,6 +60,7 @@ from .utils import (
50
60
  if TYPE_CHECKING:
51
61
  from ..model.audio import AudioModelFamilyV1
52
62
  from ..model.embedding import EmbeddingModelSpec
63
+ from ..model.flexible import FlexibleModelSpec
53
64
  from ..model.image import ImageModelFamilyV1
54
65
  from ..model.llm import LLMFamilyV1
55
66
  from ..model.rerank import RerankModelSpec
@@ -153,6 +164,13 @@ class SupervisorActor(xo.StatelessActor):
153
164
  register_embedding,
154
165
  unregister_embedding,
155
166
  )
167
+ from ..model.flexible import (
168
+ FlexibleModelSpec,
169
+ generate_flexible_model_description,
170
+ get_flexible_model_descriptions,
171
+ register_flexible_model,
172
+ unregister_flexible_model,
173
+ )
156
174
  from ..model.image import (
157
175
  CustomImageModelFamilyV1,
158
176
  generate_image_description,
@@ -206,6 +224,12 @@ class SupervisorActor(xo.StatelessActor):
206
224
  unregister_audio,
207
225
  generate_audio_description,
208
226
  ),
227
+ "flexible": (
228
+ FlexibleModelSpec,
229
+ register_flexible_model,
230
+ unregister_flexible_model,
231
+ generate_flexible_model_description,
232
+ ),
209
233
  }
210
234
 
211
235
  # record model version
@@ -215,6 +239,7 @@ class SupervisorActor(xo.StatelessActor):
215
239
  model_version_infos.update(get_rerank_model_descriptions())
216
240
  model_version_infos.update(get_image_model_descriptions())
217
241
  model_version_infos.update(get_audio_model_descriptions())
242
+ model_version_infos.update(get_flexible_model_descriptions())
218
243
  await self._cache_tracker_ref.record_model_version(
219
244
  model_version_infos, self.address
220
245
  )
@@ -459,6 +484,27 @@ class SupervisorActor(xo.StatelessActor):
459
484
  res["model_instance_count"] = instance_cnt
460
485
  return res
461
486
 
487
+ async def _to_flexible_model_reg(
488
+ self, model_spec: "FlexibleModelSpec", is_builtin: bool
489
+ ) -> Dict[str, Any]:
490
+ instance_cnt = await self.get_instance_count(model_spec.model_name)
491
+ version_cnt = await self.get_model_version_count(model_spec.model_name)
492
+
493
+ if self.is_local_deployment():
494
+ res = {
495
+ **model_spec.dict(),
496
+ "cache_status": True,
497
+ "is_builtin": is_builtin,
498
+ }
499
+ else:
500
+ res = {
501
+ **model_spec.dict(),
502
+ "is_builtin": is_builtin,
503
+ }
504
+ res["model_version_count"] = version_cnt
505
+ res["model_instance_count"] = instance_cnt
506
+ return res
507
+
462
508
  @log_async(logger=logger)
463
509
  async def list_model_registrations(
464
510
  self, model_type: str, detailed: bool = False
@@ -467,10 +513,15 @@ class SupervisorActor(xo.StatelessActor):
467
513
  assert isinstance(item["model_name"], str)
468
514
  return item.get("model_name").lower()
469
515
 
516
+ ret = []
517
+ if not self.is_local_deployment():
518
+ workers = list(self._worker_address_to_worker.values())
519
+ for worker in workers:
520
+ ret.extend(await worker.list_model_registrations(model_type, detailed))
521
+
470
522
  if model_type == "LLM":
471
523
  from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
472
524
 
473
- ret = []
474
525
  for family in BUILTIN_LLM_FAMILIES:
475
526
  if detailed:
476
527
  ret.append(await self._to_llm_reg(family, True))
@@ -489,7 +540,6 @@ class SupervisorActor(xo.StatelessActor):
489
540
  from ..model.embedding import BUILTIN_EMBEDDING_MODELS
490
541
  from ..model.embedding.custom import get_user_defined_embeddings
491
542
 
492
- ret = []
493
543
  for model_name, family in BUILTIN_EMBEDDING_MODELS.items():
494
544
  if detailed:
495
545
  ret.append(
@@ -514,7 +564,6 @@ class SupervisorActor(xo.StatelessActor):
514
564
  from ..model.image import BUILTIN_IMAGE_MODELS
515
565
  from ..model.image.custom import get_user_defined_images
516
566
 
517
- ret = []
518
567
  for model_name, family in BUILTIN_IMAGE_MODELS.items():
519
568
  if detailed:
520
569
  ret.append(await self._to_image_model_reg(family, is_builtin=True))
@@ -537,7 +586,6 @@ class SupervisorActor(xo.StatelessActor):
537
586
  from ..model.audio import BUILTIN_AUDIO_MODELS
538
587
  from ..model.audio.custom import get_user_defined_audios
539
588
 
540
- ret = []
541
589
  for model_name, family in BUILTIN_AUDIO_MODELS.items():
542
590
  if detailed:
543
591
  ret.append(await self._to_audio_model_reg(family, is_builtin=True))
@@ -560,7 +608,6 @@ class SupervisorActor(xo.StatelessActor):
560
608
  from ..model.rerank import BUILTIN_RERANK_MODELS
561
609
  from ..model.rerank.custom import get_user_defined_reranks
562
610
 
563
- ret = []
564
611
  for model_name, family in BUILTIN_RERANK_MODELS.items():
565
612
  if detailed:
566
613
  ret.append(await self._to_rerank_model_reg(family, is_builtin=True))
@@ -577,13 +624,38 @@ class SupervisorActor(xo.StatelessActor):
577
624
  {"model_name": model_spec.model_name, "is_builtin": False}
578
625
  )
579
626
 
627
+ ret.sort(key=sort_helper)
628
+ return ret
629
+ elif model_type == "flexible":
630
+ from ..model.flexible import get_flexible_models
631
+
632
+ ret = []
633
+
634
+ for model_spec in get_flexible_models():
635
+ if detailed:
636
+ ret.append(
637
+ await self._to_flexible_model_reg(model_spec, is_builtin=False)
638
+ )
639
+ else:
640
+ ret.append(
641
+ {"model_name": model_spec.model_name, "is_builtin": False}
642
+ )
643
+
580
644
  ret.sort(key=sort_helper)
581
645
  return ret
582
646
  else:
583
647
  raise ValueError(f"Unsupported model type: {model_type}")
584
648
 
585
649
  @log_sync(logger=logger)
586
- def get_model_registration(self, model_type: str, model_name: str) -> Any:
650
+ async def get_model_registration(self, model_type: str, model_name: str) -> Any:
651
+ # search in worker first
652
+ if not self.is_local_deployment():
653
+ workers = list(self._worker_address_to_worker.values())
654
+ for worker in workers:
655
+ f = await worker.get_model_registration(model_type, model_name)
656
+ if f is not None:
657
+ return f
658
+
587
659
  if model_type == "LLM":
588
660
  from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
589
661
 
@@ -626,6 +698,13 @@ class SupervisorActor(xo.StatelessActor):
626
698
  if f.model_name == model_name:
627
699
  return f
628
700
  raise ValueError(f"Model {model_name} not found")
701
+ elif model_type == "flexible":
702
+ from ..model.flexible import get_flexible_models
703
+
704
+ for f in get_flexible_models():
705
+ if f.model_name == model_name:
706
+ return f
707
+ raise ValueError(f"Model {model_name} not found")
629
708
  else:
630
709
  raise ValueError(f"Unsupported model type: {model_type}")
631
710
 
@@ -635,6 +714,13 @@ class SupervisorActor(xo.StatelessActor):
635
714
 
636
715
  from ..model.llm.llm_family import LLM_ENGINES
637
716
 
717
+ # search in worker first
718
+ workers = list(self._worker_address_to_worker.values())
719
+ for worker in workers:
720
+ res = await worker.query_engines_by_model_name(model_name)
721
+ if res is not None:
722
+ return res
723
+
638
724
  if model_name not in LLM_ENGINES:
639
725
  raise ValueError(f"Model {model_name} not found")
640
726
 
@@ -648,7 +734,13 @@ class SupervisorActor(xo.StatelessActor):
648
734
  return engine_params
649
735
 
650
736
  @log_async(logger=logger)
651
- async def register_model(self, model_type: str, model: str, persist: bool):
737
+ async def register_model(
738
+ self,
739
+ model_type: str,
740
+ model: str,
741
+ persist: bool,
742
+ worker_ip: Optional[str] = None,
743
+ ):
652
744
  if model_type in self._custom_register_type_to_cls:
653
745
  (
654
746
  model_spec_cls,
@@ -657,10 +749,21 @@ class SupervisorActor(xo.StatelessActor):
657
749
  generate_fn,
658
750
  ) = self._custom_register_type_to_cls[model_type]
659
751
 
660
- if not self.is_local_deployment():
661
- workers = list(self._worker_address_to_worker.values())
662
- for worker in workers:
663
- await worker.register_model(model_type, model, persist)
752
+ target_ip_worker_ref = (
753
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
754
+ )
755
+ if (
756
+ worker_ip is not None
757
+ and not self.is_local_deployment()
758
+ and target_ip_worker_ref is None
759
+ ):
760
+ raise ValueError(
761
+ f"Worker ip address {worker_ip} is not in the cluster."
762
+ )
763
+
764
+ if target_ip_worker_ref:
765
+ await target_ip_worker_ref.register_model(model_type, model, persist)
766
+ return
664
767
 
665
768
  model_spec = model_spec_cls.parse_raw(model)
666
769
  try:
@@ -668,6 +771,8 @@ class SupervisorActor(xo.StatelessActor):
668
771
  await self._cache_tracker_ref.record_model_version(
669
772
  generate_fn(model_spec), self.address
670
773
  )
774
+ except ValueError as e:
775
+ raise e
671
776
  except Exception as e:
672
777
  unregister_fn(model_spec.model_name, raise_error=False)
673
778
  raise e
@@ -678,13 +783,14 @@ class SupervisorActor(xo.StatelessActor):
678
783
  async def unregister_model(self, model_type: str, model_name: str):
679
784
  if model_type in self._custom_register_type_to_cls:
680
785
  _, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type]
681
- unregister_fn(model_name)
682
- await self._cache_tracker_ref.unregister_model_version(model_name)
786
+ unregister_fn(model_name, False)
683
787
 
684
788
  if not self.is_local_deployment():
685
789
  workers = list(self._worker_address_to_worker.values())
686
790
  for worker in workers:
687
- await worker.unregister_model(model_name)
791
+ await worker.unregister_model(model_type, model_name)
792
+
793
+ await self._cache_tracker_ref.unregister_model_version(model_name)
688
794
  else:
689
795
  raise ValueError(f"Unsupported model type: {model_type}")
690
796
 
@@ -752,8 +858,17 @@ class SupervisorActor(xo.StatelessActor):
752
858
  peft_model_config: Optional[PeftModelConfig] = None,
753
859
  worker_ip: Optional[str] = None,
754
860
  gpu_idx: Optional[Union[int, List[int]]] = None,
861
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
755
862
  **kwargs,
756
863
  ) -> str:
864
+ # search in worker first
865
+ if not self.is_local_deployment():
866
+ workers = list(self._worker_address_to_worker.values())
867
+ for worker in workers:
868
+ res = await worker.get_model_registration(model_type, model_name)
869
+ if res is not None:
870
+ worker_ip = worker.address.split(":")[0]
871
+
757
872
  target_ip_worker_ref = (
758
873
  self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
759
874
  )
@@ -806,6 +921,7 @@ class SupervisorActor(xo.StatelessActor):
806
921
  )
807
922
  replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
808
923
  nonlocal model_type
924
+
809
925
  worker_ref = (
810
926
  target_ip_worker_ref
811
927
  if target_ip_worker_ref is not None
@@ -825,6 +941,7 @@ class SupervisorActor(xo.StatelessActor):
825
941
  request_limits=request_limits,
826
942
  peft_model_config=peft_model_config,
827
943
  gpu_idx=replica_gpu_idx,
944
+ download_hub=download_hub,
828
945
  **kwargs,
829
946
  )
830
947
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref