xinference 0.9.4__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (103) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +47 -18
  3. xinference/api/oauth2/types.py +1 -0
  4. xinference/api/restful_api.py +34 -7
  5. xinference/client/oscar/actor_client.py +4 -3
  6. xinference/client/restful/restful_client.py +20 -4
  7. xinference/conftest.py +13 -2
  8. xinference/core/supervisor.py +48 -1
  9. xinference/core/worker.py +139 -20
  10. xinference/deploy/cmdline.py +119 -20
  11. xinference/model/embedding/core.py +1 -2
  12. xinference/model/llm/__init__.py +4 -6
  13. xinference/model/llm/ggml/llamacpp.py +2 -10
  14. xinference/model/llm/llm_family.json +877 -13
  15. xinference/model/llm/llm_family.py +15 -0
  16. xinference/model/llm/llm_family_modelscope.json +571 -0
  17. xinference/model/llm/pytorch/chatglm.py +2 -0
  18. xinference/model/llm/pytorch/core.py +22 -26
  19. xinference/model/llm/pytorch/deepseek_vl.py +232 -0
  20. xinference/model/llm/pytorch/internlm2.py +2 -0
  21. xinference/model/llm/pytorch/omnilmm.py +153 -0
  22. xinference/model/llm/pytorch/qwen_vl.py +2 -0
  23. xinference/model/llm/pytorch/yi_vl.py +4 -2
  24. xinference/model/llm/utils.py +53 -5
  25. xinference/model/llm/vllm/core.py +54 -6
  26. xinference/model/rerank/core.py +3 -0
  27. xinference/thirdparty/deepseek_vl/__init__.py +31 -0
  28. xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
  29. xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
  30. xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
  31. xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
  32. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
  33. xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
  34. xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
  35. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
  36. xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
  37. xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
  38. xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
  39. xinference/thirdparty/omnilmm/__init__.py +0 -0
  40. xinference/thirdparty/omnilmm/chat.py +216 -0
  41. xinference/thirdparty/omnilmm/constants.py +4 -0
  42. xinference/thirdparty/omnilmm/conversation.py +332 -0
  43. xinference/thirdparty/omnilmm/model/__init__.py +1 -0
  44. xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
  45. xinference/thirdparty/omnilmm/model/resampler.py +166 -0
  46. xinference/thirdparty/omnilmm/model/utils.py +563 -0
  47. xinference/thirdparty/omnilmm/train/__init__.py +13 -0
  48. xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
  49. xinference/thirdparty/omnilmm/utils.py +134 -0
  50. xinference/types.py +15 -19
  51. xinference/web/ui/build/asset-manifest.json +3 -3
  52. xinference/web/ui/build/index.html +1 -1
  53. xinference/web/ui/build/static/js/main.76ef2b17.js +3 -0
  54. xinference/web/ui/build/static/js/main.76ef2b17.js.map +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/35d0e4a317e5582cbb79d901302e9d706520ac53f8a734c2fd8bfde6eb5a4f02.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/b9cbcb6d77ba21b22c6950b6fb5b305d23c19cf747f99f7d48b6b046f8f7b1b0.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/d076fd56cf3b15ed2433e3744b98c6b4e4410a19903d1db4de5bba0e1a1b3347.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/daad8131d91134f6d7aef895a0c9c32e1cb928277cb5aa66c01028126d215be0.json +1 -0
  68. xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +1 -0
  69. xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/f16aec63602a77bd561d0e67fa00b76469ac54b8033754bba114ec5eb3257964.json +1 -0
  73. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/METADATA +25 -12
  74. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/RECORD +79 -58
  75. xinference/model/llm/ggml/ctransformers.py +0 -281
  76. xinference/model/llm/ggml/ctransformers_util.py +0 -161
  77. xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
  78. xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/0bd70b1ecf307e2681318e864f4692305b6350c8683863007f4caf2f9ac33b6e.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/0db651c046ef908f45cde73af0dbea0a797d3e35bb57f4a0863b481502103a64.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/18e5d5422e2464abf4a3e6d38164570e2e426e0a921e9a2628bbae81b18da353.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/3d93bd9a74a1ab0cec85af40f9baa5f6a8e7384b9e18c409b95a81a7b45bb7e2.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/3e055de705e397e1d413d7f429589b1a98dd78ef378b97f0cdb462c5f2487d5e.json +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/4fd24800544873512b540544ae54601240a5bfefd9105ff647855c64f8ad828f.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/60c4b98d8ea7479fb0c94cfd19c8128f17bd7e27a1e73e6dd9adf6e9d88d18eb.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/7e094845f611802b024b57439cbf911038169d06cdf6c34a72a7277f35aa71a4.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/95c8cc049fadd23085d8623e1d43d70b614a4e52217676f186a417dca894aa09.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/98b7ef307f436affe13d75a4f265b27e828ccc2b10ffae6513abe2681bc11971.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/a8070ce4b780b4a044218536e158a9e7192a6c80ff593fdc126fee43f46296b5.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/b400cfc9db57fa6c70cd2bad055b73c5079fde0ed37974009d898083f6af8cd8.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/bd04667474fd9cac2983b03725c218908a6cc0ee9128a5953cd00d26d4877f60.json +0 -1
  93. xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
  94. xinference/web/ui/node_modules/.cache/babel-loader/c230a727b8f68f0e62616a75e14a3d33026dc4164f2e325a9a8072d733850edb.json +0 -1
  95. xinference/web/ui/node_modules/.cache/babel-loader/d44a6eb6106e09082b691a315c9f6ce17fcfe25beb7547810e0d271ce3301cd2.json +0 -1
  96. xinference/web/ui/node_modules/.cache/babel-loader/e1d9b2ae4e1248658704bc6bfc5d6160dcd1a9e771ea4ae8c1fed0aaddeedd29.json +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/fe5db70859503a54cbe71f9637e5a314cda88b1f0eecb733b6e6f837697db1ef.json +0 -1
  99. /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.76ef2b17.js.LICENSE.txt} +0 -0
  100. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/LICENSE +0 -0
  101. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/WHEEL +0 -0
  102. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/entry_points.txt +0 -0
  103. {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-03-21T14:58:01+0800",
11
+ "date": "2024-04-11T15:35:46+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "2c9465ade7f358d57d4bc087277882d896a8de15",
15
- "version": "0.9.4"
14
+ "full-revisionid": "e3a947ebddfc53b5e8ec723c1f632c2b895edef1",
15
+ "version": "0.10.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -11,8 +11,9 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import re
14
15
  from datetime import timedelta
15
- from typing import List, Optional
16
+ from typing import List, Optional, Tuple
16
17
 
17
18
  from fastapi import Depends, HTTPException, status
18
19
  from fastapi.security import OAuth2PasswordBearer, SecurityScopes
@@ -40,13 +41,30 @@ class AuthService:
40
41
  def config(self):
41
42
  return self._config
42
43
 
44
+ @staticmethod
45
+ def is_legal_api_key(key: str) -> bool:
46
+ pattern = re.compile("^sk-[a-zA-Z0-9]{13}$")
47
+ return re.match(pattern, key) is not None
48
+
43
49
  def init_auth_config(self):
44
50
  if self._auth_config_file:
45
51
  config: AuthStartupConfig = parse_file_as(
46
52
  path=self._auth_config_file, type_=AuthStartupConfig
47
53
  )
54
+ all_api_keys = set()
48
55
  for user in config.user_config:
49
56
  user.password = get_password_hash(user.password)
57
+ for api_key in user.api_keys:
58
+ if not self.is_legal_api_key(api_key):
59
+ raise ValueError(
60
+ "Api-Key should be a string started with 'sk-' with a total length of 16"
61
+ )
62
+ if api_key in all_api_keys:
63
+ raise ValueError(
64
+ "Duplicate api-keys exists, please check your configuration"
65
+ )
66
+ else:
67
+ all_api_keys.add(api_key)
50
68
  return config
51
69
 
52
70
  def __call__(
@@ -67,28 +85,30 @@ class AuthService:
67
85
  headers={"WWW-Authenticate": authenticate_value},
68
86
  )
69
87
 
70
- try:
71
- assert self._config is not None
72
- payload = jwt.decode(
73
- token,
74
- self._config.auth_config.secret_key,
75
- algorithms=[self._config.auth_config.algorithm],
76
- options={"verify_exp": False}, # TODO: supports token expiration
77
- )
78
- username: str = payload.get("sub")
79
- if username is None:
88
+ if self.is_legal_api_key(token):
89
+ user, token_scopes = self.get_user_and_scopes_with_api_key(token)
90
+ else:
91
+ try:
92
+ assert self._config is not None
93
+ payload = jwt.decode(
94
+ token,
95
+ self._config.auth_config.secret_key,
96
+ algorithms=[self._config.auth_config.algorithm],
97
+ options={"verify_exp": False}, # TODO: supports token expiration
98
+ )
99
+ username: str = payload.get("sub")
100
+ if username is None:
101
+ raise credentials_exception
102
+ token_scopes = payload.get("scopes", [])
103
+ user = self.get_user(username)
104
+ except (JWTError, ValidationError):
80
105
  raise credentials_exception
81
- token_scopes = payload.get("scopes", [])
82
- token_data = TokenData(scopes=token_scopes, username=username)
83
- except (JWTError, ValidationError):
84
- raise credentials_exception
85
- user = self.get_user(token_data.username)
86
106
  if user is None:
87
107
  raise credentials_exception
88
- if "admin" in token_data.scopes:
108
+ if "admin" in token_scopes:
89
109
  return user
90
110
  for scope in security_scopes.scopes:
91
- if scope not in token_data.scopes:
111
+ if scope not in token_scopes:
92
112
  raise HTTPException(
93
113
  status_code=status.HTTP_403_FORBIDDEN,
94
114
  detail="Not enough permissions",
@@ -102,6 +122,15 @@ class AuthService:
102
122
  return user
103
123
  return None
104
124
 
125
+ def get_user_and_scopes_with_api_key(
126
+ self, api_key: str
127
+ ) -> Tuple[Optional[User], List]:
128
+ for user in self._config.user_config:
129
+ for key in user.api_keys:
130
+ if api_key == key:
131
+ return user, user.permissions
132
+ return None, []
133
+
105
134
  def authenticate_user(self, username: str, password: str):
106
135
  user = self.get_user(username)
107
136
  if not user:
@@ -23,6 +23,7 @@ class LoginUserForm(BaseModel):
23
23
 
24
24
  class User(LoginUserForm):
25
25
  permissions: List[str]
26
+ api_keys: List[str]
26
27
 
27
28
 
28
29
  class AuthConfig(BaseModel):
@@ -89,7 +89,9 @@ class CreateCompletionRequest(CreateCompletion):
89
89
 
90
90
  class CreateEmbeddingRequest(BaseModel):
91
91
  model: str
92
- input: Union[str, List[str]] = Field(description="The input to embed.")
92
+ input: Union[str, List[str], List[int], List[List[int]]] = Field(
93
+ description="The input to embed."
94
+ )
93
95
  user: Optional[str] = None
94
96
 
95
97
  class Config:
@@ -693,6 +695,8 @@ class RESTfulAPI:
693
695
  peft_model_path = payload.get("peft_model_path", None)
694
696
  image_lora_load_kwargs = payload.get("image_lora_load_kwargs", None)
695
697
  image_lora_fuse_kwargs = payload.get("image_lora_fuse_kwargs", None)
698
+ worker_ip = payload.get("worker_ip", None)
699
+ gpu_idx = payload.get("gpu_idx", None)
696
700
 
697
701
  exclude_keys = {
698
702
  "model_uid",
@@ -707,6 +711,8 @@ class RESTfulAPI:
707
711
  "peft_model_path",
708
712
  "image_lora_load_kwargs",
709
713
  "image_lora_fuse_kwargs",
714
+ "worker_ip",
715
+ "gpu_idx",
710
716
  }
711
717
 
712
718
  kwargs = {
@@ -734,6 +740,8 @@ class RESTfulAPI:
734
740
  peft_model_path=peft_model_path,
735
741
  image_lora_load_kwargs=image_lora_load_kwargs,
736
742
  image_lora_fuse_kwargs=image_lora_fuse_kwargs,
743
+ worker_ip=worker_ip,
744
+ gpu_idx=gpu_idx,
737
745
  **kwargs,
738
746
  )
739
747
 
@@ -999,8 +1007,16 @@ class RESTfulAPI:
999
1007
  raise HTTPException(status_code=500, detail=str(e))
1000
1008
 
1001
1009
  async def create_embedding(self, request: Request) -> Response:
1002
- body = CreateEmbeddingRequest.parse_obj(await request.json())
1010
+ payload = await request.json()
1011
+ body = CreateEmbeddingRequest.parse_obj(payload)
1003
1012
  model_uid = body.model
1013
+ exclude = {
1014
+ "model",
1015
+ "input",
1016
+ "user",
1017
+ "encoding_format",
1018
+ }
1019
+ kwargs = {key: value for key, value in payload.items() if key not in exclude}
1004
1020
 
1005
1021
  try:
1006
1022
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1014,7 +1030,7 @@ class RESTfulAPI:
1014
1030
  raise HTTPException(status_code=500, detail=str(e))
1015
1031
 
1016
1032
  try:
1017
- embedding = await model.create_embedding(body.input)
1033
+ embedding = await model.create_embedding(body.input, **kwargs)
1018
1034
  return Response(embedding, media_type="application/json")
1019
1035
  except RuntimeError as re:
1020
1036
  logger.error(re, exc_info=True)
@@ -1027,8 +1043,15 @@ class RESTfulAPI:
1027
1043
  raise HTTPException(status_code=500, detail=str(e))
1028
1044
 
1029
1045
  async def rerank(self, request: Request) -> Response:
1030
- body = RerankRequest.parse_obj(await request.json())
1046
+ payload = await request.json()
1047
+ body = RerankRequest.parse_obj(payload)
1031
1048
  model_uid = body.model
1049
+ kwargs = {
1050
+ key: value
1051
+ for key, value in payload.items()
1052
+ if key not in RerankRequest.__annotations__.keys()
1053
+ }
1054
+
1032
1055
  try:
1033
1056
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
1034
1057
  except ValueError as ve:
@@ -1047,6 +1070,7 @@ class RESTfulAPI:
1047
1070
  top_n=body.top_n,
1048
1071
  max_chunks_per_doc=body.max_chunks_per_doc,
1049
1072
  return_documents=body.return_documents,
1073
+ **kwargs,
1050
1074
  )
1051
1075
  return Response(scores, media_type="application/json")
1052
1076
  except RuntimeError as re:
@@ -1337,9 +1361,12 @@ class RESTfulAPI:
1337
1361
  detail=f"Only {function_call_models} support tool messages",
1338
1362
  )
1339
1363
  if body.tools and body.stream:
1340
- raise HTTPException(
1341
- status_code=400, detail="Tool calls does not support stream"
1342
- )
1364
+ is_vllm = await model.is_vllm_backend()
1365
+ if not is_vllm or model_family not in ["qwen-chat", "qwen1.5-chat"]:
1366
+ raise HTTPException(
1367
+ status_code=400,
1368
+ detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
1369
+ )
1343
1370
 
1344
1371
  if body.stream:
1345
1372
 
@@ -111,7 +111,7 @@ class ClientIteratorWrapper(AsyncIterator):
111
111
 
112
112
 
113
113
  class EmbeddingModelHandle(ModelHandle):
114
- def create_embedding(self, input: Union[str, List[str]]) -> bytes:
114
+ def create_embedding(self, input: Union[str, List[str]], **kwargs) -> bytes:
115
115
  """
116
116
  Creates an embedding vector representing the input text.
117
117
 
@@ -128,7 +128,7 @@ class EmbeddingModelHandle(ModelHandle):
128
128
  machine learning models and algorithms.
129
129
  """
130
130
 
131
- coro = self._model_ref.create_embedding(input)
131
+ coro = self._model_ref.create_embedding(input, **kwargs)
132
132
  return orjson.loads(self._isolation.call(coro))
133
133
 
134
134
 
@@ -140,6 +140,7 @@ class RerankModelHandle(ModelHandle):
140
140
  top_n: Optional[int],
141
141
  max_chunks_per_doc: Optional[int],
142
142
  return_documents: Optional[bool],
143
+ **kwargs,
143
144
  ):
144
145
  """
145
146
  Returns an ordered list of documents ordered by their relevance to the provided query.
@@ -163,7 +164,7 @@ class RerankModelHandle(ModelHandle):
163
164
 
164
165
  """
165
166
  coro = self._model_ref.rerank(
166
- documents, query, top_n, max_chunks_per_doc, return_documents
167
+ documents, query, top_n, max_chunks_per_doc, return_documents, **kwargs
167
168
  )
168
169
  results = orjson.loads(self._isolation.call(coro))
169
170
  for r in results["results"]:
@@ -80,7 +80,7 @@ class RESTfulModelHandle:
80
80
 
81
81
 
82
82
  class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
83
- def create_embedding(self, input: Union[str, List[str]]) -> "Embedding":
83
+ def create_embedding(self, input: Union[str, List[str]], **kwargs) -> "Embedding":
84
84
  """
85
85
  Create an Embedding from user input via RESTful APIs.
86
86
 
@@ -102,7 +102,11 @@ class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
102
102
 
103
103
  """
104
104
  url = f"{self._base_url}/v1/embeddings"
105
- request_body = {"model": self._model_uid, "input": input}
105
+ request_body = {
106
+ "model": self._model_uid,
107
+ "input": input,
108
+ }
109
+ request_body.update(kwargs)
106
110
  response = requests.post(url, json=request_body, headers=self.auth_headers)
107
111
  if response.status_code != 200:
108
112
  raise RuntimeError(
@@ -121,6 +125,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
121
125
  top_n: Optional[int] = None,
122
126
  max_chunks_per_doc: Optional[int] = None,
123
127
  return_documents: Optional[bool] = None,
128
+ **kwargs,
124
129
  ):
125
130
  """
126
131
  Returns an ordered list of documents ordered by their relevance to the provided query.
@@ -156,6 +161,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
156
161
  "max_chunks_per_doc": max_chunks_per_doc,
157
162
  "return_documents": return_documents,
158
163
  }
164
+ request_body.update(kwargs)
159
165
  response = requests.post(url, json=request_body, headers=self.auth_headers)
160
166
  if response.status_code != 200:
161
167
  raise RuntimeError(
@@ -651,11 +657,13 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
651
657
 
652
658
 
653
659
  class Client:
654
- def __init__(self, base_url):
660
+ def __init__(self, base_url, api_key: Optional[str] = None):
655
661
  self.base_url = base_url
656
- self._headers = {}
662
+ self._headers: Dict[str, str] = {}
657
663
  self._cluster_authed = False
658
664
  self._check_cluster_authenticated()
665
+ if api_key is not None and self._cluster_authed:
666
+ self._headers["Authorization"] = f"Bearer {api_key}"
659
667
 
660
668
  def _set_token(self, token: Optional[str]):
661
669
  if not self._cluster_authed or token is None:
@@ -795,6 +803,8 @@ class Client:
795
803
  peft_model_path: Optional[str] = None,
796
804
  image_lora_load_kwargs: Optional[Dict] = None,
797
805
  image_lora_fuse_kwargs: Optional[Dict] = None,
806
+ worker_ip: Optional[str] = None,
807
+ gpu_idx: Optional[Union[int, List[int]]] = None,
798
808
  **kwargs,
799
809
  ) -> str:
800
810
  """
@@ -828,6 +838,10 @@ class Client:
828
838
  lora load parameters for image model
829
839
  image_lora_fuse_kwargs: Optional[Dict]
830
840
  lora fuse parameters for image model
841
+ worker_ip: Optional[str]
842
+ Specify the worker ip where the model is located in a distributed scenario.
843
+ gpu_idx: Optional[Union[int, List[int]]]
844
+ Specify the GPU index where the model is located.
831
845
  **kwargs:
832
846
  Any other parameters been specified.
833
847
 
@@ -853,6 +867,8 @@ class Client:
853
867
  "peft_model_path": peft_model_path,
854
868
  "image_lora_load_kwargs": image_lora_load_kwargs,
855
869
  "image_lora_fuse_kwargs": image_lora_fuse_kwargs,
870
+ "worker_ip": worker_ip,
871
+ "gpu_idx": gpu_idx,
856
872
  }
857
873
 
858
874
  for key, value in kwargs.items():
xinference/conftest.py CHANGED
@@ -261,12 +261,23 @@ def setup_with_auth():
261
261
  if not cluster_health_check(supervisor_addr, max_attempts=10, sleep_interval=3):
262
262
  raise RuntimeError("Cluster is not available after multiple attempts")
263
263
 
264
- user1 = User(username="user1", password="pass1", permissions=["admin"])
265
- user2 = User(username="user2", password="pass2", permissions=["models:list"])
264
+ user1 = User(
265
+ username="user1",
266
+ password="pass1",
267
+ permissions=["admin"],
268
+ api_keys=["sk-3sjLbdwqAhhAF", "sk-0HCRO1rauFQDL"],
269
+ )
270
+ user2 = User(
271
+ username="user2",
272
+ password="pass2",
273
+ permissions=["models:list"],
274
+ api_keys=["sk-72tkvudyGLPMi"],
275
+ )
266
276
  user3 = User(
267
277
  username="user3",
268
278
  password="pass3",
269
279
  permissions=["models:list", "models:read", "models:start"],
280
+ api_keys=["sk-m6jEzEwmCc4iQ", "sk-ZOTLIY4gt9w11"],
270
281
  )
271
282
  auth_config = AuthConfig(
272
283
  algorithm="HS256",
@@ -92,6 +92,15 @@ class SupervisorActor(xo.StatelessActor):
92
92
  def uid(cls) -> str:
93
93
  return "supervisor"
94
94
 
95
+ def _get_worker_ref_by_ip(
96
+ self, ip: str
97
+ ) -> Optional[xo.ActorRefType["WorkerActor"]]:
98
+ for addr, ref in self._worker_address_to_worker.items():
99
+ existing_ip = addr.split(":")[0]
100
+ if existing_ip == ip:
101
+ return ref
102
+ return None
103
+
95
104
  async def __post_create__(self):
96
105
  self._uptime = time.time()
97
106
  if not XINFERENCE_DISABLE_HEALTH_CHECK:
@@ -717,8 +726,25 @@ class SupervisorActor(xo.StatelessActor):
717
726
  peft_model_path: Optional[str] = None,
718
727
  image_lora_load_kwargs: Optional[Dict] = None,
719
728
  image_lora_fuse_kwargs: Optional[Dict] = None,
729
+ worker_ip: Optional[str] = None,
730
+ gpu_idx: Optional[Union[int, List[int]]] = None,
720
731
  **kwargs,
721
732
  ) -> str:
733
+ target_ip_worker_ref = (
734
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
735
+ )
736
+ if (
737
+ worker_ip is not None
738
+ and not self.is_local_deployment()
739
+ and target_ip_worker_ref is None
740
+ ):
741
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
742
+ if worker_ip is not None and self.is_local_deployment():
743
+ logger.warning(
744
+ f"You specified the worker ip: {worker_ip} in local mode, "
745
+ f"xinference will ignore this option."
746
+ )
747
+
722
748
  if model_uid is None:
723
749
  model_uid = self._gen_model_uid(model_name)
724
750
 
@@ -735,7 +761,11 @@ class SupervisorActor(xo.StatelessActor):
735
761
  )
736
762
 
737
763
  nonlocal model_type
738
- worker_ref = await self._choose_worker()
764
+ worker_ref = (
765
+ target_ip_worker_ref
766
+ if target_ip_worker_ref is not None
767
+ else await self._choose_worker()
768
+ )
739
769
  # LLM as default for compatibility
740
770
  model_type = model_type or "LLM"
741
771
  await worker_ref.launch_builtin_model(
@@ -750,6 +780,7 @@ class SupervisorActor(xo.StatelessActor):
750
780
  peft_model_path=peft_model_path,
751
781
  image_lora_load_kwargs=image_lora_load_kwargs,
752
782
  image_lora_fuse_kwargs=image_lora_fuse_kwargs,
783
+ gpu_idx=gpu_idx,
753
784
  **kwargs,
754
785
  )
755
786
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
@@ -839,6 +870,12 @@ class SupervisorActor(xo.StatelessActor):
839
870
  address,
840
871
  dead_models,
841
872
  )
873
+ for replica_model_uid in dead_models:
874
+ model_uid, _, _ = parse_replica_model_uid(replica_model_uid)
875
+ self._model_uid_to_replica_info.pop(model_uid, None)
876
+ self._replica_model_uid_to_worker.pop(
877
+ replica_model_uid, None
878
+ )
842
879
  dead_nodes.append(address)
843
880
  elif (
844
881
  status.failure_remaining_count
@@ -948,6 +985,16 @@ class SupervisorActor(xo.StatelessActor):
948
985
 
949
986
  @log_async(logger=logger)
950
987
  async def remove_worker(self, worker_address: str):
988
+ uids_to_remove = []
989
+ for model_uid in self._replica_model_uid_to_worker:
990
+ if self._replica_model_uid_to_worker[model_uid].address == worker_address:
991
+ uids_to_remove.append(model_uid)
992
+
993
+ for replica_model_uid in uids_to_remove:
994
+ model_uid, _, _ = parse_replica_model_uid(replica_model_uid)
995
+ self._model_uid_to_replica_info.pop(model_uid, None)
996
+ self._replica_model_uid_to_worker.pop(replica_model_uid, None)
997
+
951
998
  if worker_address in self._worker_address_to_worker:
952
999
  del self._worker_address_to_worker[worker_address]
953
1000
  logger.debug("Worker %s has been removed successfully", worker_address)