xinference 0.9.4__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +47 -18
- xinference/api/oauth2/types.py +1 -0
- xinference/api/restful_api.py +34 -7
- xinference/client/oscar/actor_client.py +4 -3
- xinference/client/restful/restful_client.py +20 -4
- xinference/conftest.py +13 -2
- xinference/core/supervisor.py +48 -1
- xinference/core/worker.py +139 -20
- xinference/deploy/cmdline.py +119 -20
- xinference/model/embedding/core.py +1 -2
- xinference/model/llm/__init__.py +4 -6
- xinference/model/llm/ggml/llamacpp.py +2 -10
- xinference/model/llm/llm_family.json +877 -13
- xinference/model/llm/llm_family.py +15 -0
- xinference/model/llm/llm_family_modelscope.json +571 -0
- xinference/model/llm/pytorch/chatglm.py +2 -0
- xinference/model/llm/pytorch/core.py +22 -26
- xinference/model/llm/pytorch/deepseek_vl.py +232 -0
- xinference/model/llm/pytorch/internlm2.py +2 -0
- xinference/model/llm/pytorch/omnilmm.py +153 -0
- xinference/model/llm/pytorch/qwen_vl.py +2 -0
- xinference/model/llm/pytorch/yi_vl.py +4 -2
- xinference/model/llm/utils.py +53 -5
- xinference/model/llm/vllm/core.py +54 -6
- xinference/model/rerank/core.py +3 -0
- xinference/thirdparty/deepseek_vl/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
- xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
- xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
- xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
- xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
- xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
- xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
- xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
- xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
- xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
- xinference/thirdparty/omnilmm/__init__.py +0 -0
- xinference/thirdparty/omnilmm/chat.py +216 -0
- xinference/thirdparty/omnilmm/constants.py +4 -0
- xinference/thirdparty/omnilmm/conversation.py +332 -0
- xinference/thirdparty/omnilmm/model/__init__.py +1 -0
- xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
- xinference/thirdparty/omnilmm/model/resampler.py +166 -0
- xinference/thirdparty/omnilmm/model/utils.py +563 -0
- xinference/thirdparty/omnilmm/train/__init__.py +13 -0
- xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
- xinference/thirdparty/omnilmm/utils.py +134 -0
- xinference/types.py +15 -19
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.76ef2b17.js +3 -0
- xinference/web/ui/build/static/js/main.76ef2b17.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/35d0e4a317e5582cbb79d901302e9d706520ac53f8a734c2fd8bfde6eb5a4f02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9cbcb6d77ba21b22c6950b6fb5b305d23c19cf747f99f7d48b6b046f8f7b1b0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d076fd56cf3b15ed2433e3744b98c6b4e4410a19903d1db4de5bba0e1a1b3347.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/daad8131d91134f6d7aef895a0c9c32e1cb928277cb5aa66c01028126d215be0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f16aec63602a77bd561d0e67fa00b76469ac54b8033754bba114ec5eb3257964.json +1 -0
- {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/METADATA +25 -12
- {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/RECORD +79 -58
- xinference/model/llm/ggml/ctransformers.py +0 -281
- xinference/model/llm/ggml/ctransformers_util.py +0 -161
- xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
- xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0bd70b1ecf307e2681318e864f4692305b6350c8683863007f4caf2f9ac33b6e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0db651c046ef908f45cde73af0dbea0a797d3e35bb57f4a0863b481502103a64.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/18e5d5422e2464abf4a3e6d38164570e2e426e0a921e9a2628bbae81b18da353.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3d93bd9a74a1ab0cec85af40f9baa5f6a8e7384b9e18c409b95a81a7b45bb7e2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e055de705e397e1d413d7f429589b1a98dd78ef378b97f0cdb462c5f2487d5e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4fd24800544873512b540544ae54601240a5bfefd9105ff647855c64f8ad828f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/60c4b98d8ea7479fb0c94cfd19c8128f17bd7e27a1e73e6dd9adf6e9d88d18eb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/7e094845f611802b024b57439cbf911038169d06cdf6c34a72a7277f35aa71a4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/95c8cc049fadd23085d8623e1d43d70b614a4e52217676f186a417dca894aa09.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/98b7ef307f436affe13d75a4f265b27e828ccc2b10ffae6513abe2681bc11971.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a8070ce4b780b4a044218536e158a9e7192a6c80ff593fdc126fee43f46296b5.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/b400cfc9db57fa6c70cd2bad055b73c5079fde0ed37974009d898083f6af8cd8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd04667474fd9cac2983b03725c218908a6cc0ee9128a5953cd00d26d4877f60.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c230a727b8f68f0e62616a75e14a3d33026dc4164f2e325a9a8072d733850edb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d44a6eb6106e09082b691a315c9f6ce17fcfe25beb7547810e0d271ce3301cd2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e1d9b2ae4e1248658704bc6bfc5d6160dcd1a9e771ea4ae8c1fed0aaddeedd29.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/fe5db70859503a54cbe71f9637e5a314cda88b1f0eecb733b6e6f837697db1ef.json +0 -1
- /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.76ef2b17.js.LICENSE.txt} +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/LICENSE +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/WHEEL +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-04-11T15:35:46+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "e3a947ebddfc53b5e8ec723c1f632c2b895edef1",
|
|
15
|
+
"version": "0.10.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -11,8 +11,9 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import re
|
|
14
15
|
from datetime import timedelta
|
|
15
|
-
from typing import List, Optional
|
|
16
|
+
from typing import List, Optional, Tuple
|
|
16
17
|
|
|
17
18
|
from fastapi import Depends, HTTPException, status
|
|
18
19
|
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
|
|
@@ -40,13 +41,30 @@ class AuthService:
|
|
|
40
41
|
def config(self):
|
|
41
42
|
return self._config
|
|
42
43
|
|
|
44
|
+
@staticmethod
|
|
45
|
+
def is_legal_api_key(key: str) -> bool:
|
|
46
|
+
pattern = re.compile("^sk-[a-zA-Z0-9]{13}$")
|
|
47
|
+
return re.match(pattern, key) is not None
|
|
48
|
+
|
|
43
49
|
def init_auth_config(self):
|
|
44
50
|
if self._auth_config_file:
|
|
45
51
|
config: AuthStartupConfig = parse_file_as(
|
|
46
52
|
path=self._auth_config_file, type_=AuthStartupConfig
|
|
47
53
|
)
|
|
54
|
+
all_api_keys = set()
|
|
48
55
|
for user in config.user_config:
|
|
49
56
|
user.password = get_password_hash(user.password)
|
|
57
|
+
for api_key in user.api_keys:
|
|
58
|
+
if not self.is_legal_api_key(api_key):
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"Api-Key should be a string started with 'sk-' with a total length of 16"
|
|
61
|
+
)
|
|
62
|
+
if api_key in all_api_keys:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"Duplicate api-keys exists, please check your configuration"
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
all_api_keys.add(api_key)
|
|
50
68
|
return config
|
|
51
69
|
|
|
52
70
|
def __call__(
|
|
@@ -67,28 +85,30 @@ class AuthService:
|
|
|
67
85
|
headers={"WWW-Authenticate": authenticate_value},
|
|
68
86
|
)
|
|
69
87
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
self._config
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
88
|
+
if self.is_legal_api_key(token):
|
|
89
|
+
user, token_scopes = self.get_user_and_scopes_with_api_key(token)
|
|
90
|
+
else:
|
|
91
|
+
try:
|
|
92
|
+
assert self._config is not None
|
|
93
|
+
payload = jwt.decode(
|
|
94
|
+
token,
|
|
95
|
+
self._config.auth_config.secret_key,
|
|
96
|
+
algorithms=[self._config.auth_config.algorithm],
|
|
97
|
+
options={"verify_exp": False}, # TODO: supports token expiration
|
|
98
|
+
)
|
|
99
|
+
username: str = payload.get("sub")
|
|
100
|
+
if username is None:
|
|
101
|
+
raise credentials_exception
|
|
102
|
+
token_scopes = payload.get("scopes", [])
|
|
103
|
+
user = self.get_user(username)
|
|
104
|
+
except (JWTError, ValidationError):
|
|
80
105
|
raise credentials_exception
|
|
81
|
-
token_scopes = payload.get("scopes", [])
|
|
82
|
-
token_data = TokenData(scopes=token_scopes, username=username)
|
|
83
|
-
except (JWTError, ValidationError):
|
|
84
|
-
raise credentials_exception
|
|
85
|
-
user = self.get_user(token_data.username)
|
|
86
106
|
if user is None:
|
|
87
107
|
raise credentials_exception
|
|
88
|
-
if "admin" in
|
|
108
|
+
if "admin" in token_scopes:
|
|
89
109
|
return user
|
|
90
110
|
for scope in security_scopes.scopes:
|
|
91
|
-
if scope not in
|
|
111
|
+
if scope not in token_scopes:
|
|
92
112
|
raise HTTPException(
|
|
93
113
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
94
114
|
detail="Not enough permissions",
|
|
@@ -102,6 +122,15 @@ class AuthService:
|
|
|
102
122
|
return user
|
|
103
123
|
return None
|
|
104
124
|
|
|
125
|
+
def get_user_and_scopes_with_api_key(
|
|
126
|
+
self, api_key: str
|
|
127
|
+
) -> Tuple[Optional[User], List]:
|
|
128
|
+
for user in self._config.user_config:
|
|
129
|
+
for key in user.api_keys:
|
|
130
|
+
if api_key == key:
|
|
131
|
+
return user, user.permissions
|
|
132
|
+
return None, []
|
|
133
|
+
|
|
105
134
|
def authenticate_user(self, username: str, password: str):
|
|
106
135
|
user = self.get_user(username)
|
|
107
136
|
if not user:
|
xinference/api/oauth2/types.py
CHANGED
xinference/api/restful_api.py
CHANGED
|
@@ -89,7 +89,9 @@ class CreateCompletionRequest(CreateCompletion):
|
|
|
89
89
|
|
|
90
90
|
class CreateEmbeddingRequest(BaseModel):
|
|
91
91
|
model: str
|
|
92
|
-
input: Union[str, List[str]] = Field(
|
|
92
|
+
input: Union[str, List[str], List[int], List[List[int]]] = Field(
|
|
93
|
+
description="The input to embed."
|
|
94
|
+
)
|
|
93
95
|
user: Optional[str] = None
|
|
94
96
|
|
|
95
97
|
class Config:
|
|
@@ -693,6 +695,8 @@ class RESTfulAPI:
|
|
|
693
695
|
peft_model_path = payload.get("peft_model_path", None)
|
|
694
696
|
image_lora_load_kwargs = payload.get("image_lora_load_kwargs", None)
|
|
695
697
|
image_lora_fuse_kwargs = payload.get("image_lora_fuse_kwargs", None)
|
|
698
|
+
worker_ip = payload.get("worker_ip", None)
|
|
699
|
+
gpu_idx = payload.get("gpu_idx", None)
|
|
696
700
|
|
|
697
701
|
exclude_keys = {
|
|
698
702
|
"model_uid",
|
|
@@ -707,6 +711,8 @@ class RESTfulAPI:
|
|
|
707
711
|
"peft_model_path",
|
|
708
712
|
"image_lora_load_kwargs",
|
|
709
713
|
"image_lora_fuse_kwargs",
|
|
714
|
+
"worker_ip",
|
|
715
|
+
"gpu_idx",
|
|
710
716
|
}
|
|
711
717
|
|
|
712
718
|
kwargs = {
|
|
@@ -734,6 +740,8 @@ class RESTfulAPI:
|
|
|
734
740
|
peft_model_path=peft_model_path,
|
|
735
741
|
image_lora_load_kwargs=image_lora_load_kwargs,
|
|
736
742
|
image_lora_fuse_kwargs=image_lora_fuse_kwargs,
|
|
743
|
+
worker_ip=worker_ip,
|
|
744
|
+
gpu_idx=gpu_idx,
|
|
737
745
|
**kwargs,
|
|
738
746
|
)
|
|
739
747
|
|
|
@@ -999,8 +1007,16 @@ class RESTfulAPI:
|
|
|
999
1007
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1000
1008
|
|
|
1001
1009
|
async def create_embedding(self, request: Request) -> Response:
|
|
1002
|
-
|
|
1010
|
+
payload = await request.json()
|
|
1011
|
+
body = CreateEmbeddingRequest.parse_obj(payload)
|
|
1003
1012
|
model_uid = body.model
|
|
1013
|
+
exclude = {
|
|
1014
|
+
"model",
|
|
1015
|
+
"input",
|
|
1016
|
+
"user",
|
|
1017
|
+
"encoding_format",
|
|
1018
|
+
}
|
|
1019
|
+
kwargs = {key: value for key, value in payload.items() if key not in exclude}
|
|
1004
1020
|
|
|
1005
1021
|
try:
|
|
1006
1022
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -1014,7 +1030,7 @@ class RESTfulAPI:
|
|
|
1014
1030
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1015
1031
|
|
|
1016
1032
|
try:
|
|
1017
|
-
embedding = await model.create_embedding(body.input)
|
|
1033
|
+
embedding = await model.create_embedding(body.input, **kwargs)
|
|
1018
1034
|
return Response(embedding, media_type="application/json")
|
|
1019
1035
|
except RuntimeError as re:
|
|
1020
1036
|
logger.error(re, exc_info=True)
|
|
@@ -1027,8 +1043,15 @@ class RESTfulAPI:
|
|
|
1027
1043
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1028
1044
|
|
|
1029
1045
|
async def rerank(self, request: Request) -> Response:
|
|
1030
|
-
|
|
1046
|
+
payload = await request.json()
|
|
1047
|
+
body = RerankRequest.parse_obj(payload)
|
|
1031
1048
|
model_uid = body.model
|
|
1049
|
+
kwargs = {
|
|
1050
|
+
key: value
|
|
1051
|
+
for key, value in payload.items()
|
|
1052
|
+
if key not in RerankRequest.__annotations__.keys()
|
|
1053
|
+
}
|
|
1054
|
+
|
|
1032
1055
|
try:
|
|
1033
1056
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1034
1057
|
except ValueError as ve:
|
|
@@ -1047,6 +1070,7 @@ class RESTfulAPI:
|
|
|
1047
1070
|
top_n=body.top_n,
|
|
1048
1071
|
max_chunks_per_doc=body.max_chunks_per_doc,
|
|
1049
1072
|
return_documents=body.return_documents,
|
|
1073
|
+
**kwargs,
|
|
1050
1074
|
)
|
|
1051
1075
|
return Response(scores, media_type="application/json")
|
|
1052
1076
|
except RuntimeError as re:
|
|
@@ -1337,9 +1361,12 @@ class RESTfulAPI:
|
|
|
1337
1361
|
detail=f"Only {function_call_models} support tool messages",
|
|
1338
1362
|
)
|
|
1339
1363
|
if body.tools and body.stream:
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1364
|
+
is_vllm = await model.is_vllm_backend()
|
|
1365
|
+
if not is_vllm or model_family not in ["qwen-chat", "qwen1.5-chat"]:
|
|
1366
|
+
raise HTTPException(
|
|
1367
|
+
status_code=400,
|
|
1368
|
+
detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
|
|
1369
|
+
)
|
|
1343
1370
|
|
|
1344
1371
|
if body.stream:
|
|
1345
1372
|
|
|
@@ -111,7 +111,7 @@ class ClientIteratorWrapper(AsyncIterator):
|
|
|
111
111
|
|
|
112
112
|
|
|
113
113
|
class EmbeddingModelHandle(ModelHandle):
|
|
114
|
-
def create_embedding(self, input: Union[str, List[str]]) -> bytes:
|
|
114
|
+
def create_embedding(self, input: Union[str, List[str]], **kwargs) -> bytes:
|
|
115
115
|
"""
|
|
116
116
|
Creates an embedding vector representing the input text.
|
|
117
117
|
|
|
@@ -128,7 +128,7 @@ class EmbeddingModelHandle(ModelHandle):
|
|
|
128
128
|
machine learning models and algorithms.
|
|
129
129
|
"""
|
|
130
130
|
|
|
131
|
-
coro = self._model_ref.create_embedding(input)
|
|
131
|
+
coro = self._model_ref.create_embedding(input, **kwargs)
|
|
132
132
|
return orjson.loads(self._isolation.call(coro))
|
|
133
133
|
|
|
134
134
|
|
|
@@ -140,6 +140,7 @@ class RerankModelHandle(ModelHandle):
|
|
|
140
140
|
top_n: Optional[int],
|
|
141
141
|
max_chunks_per_doc: Optional[int],
|
|
142
142
|
return_documents: Optional[bool],
|
|
143
|
+
**kwargs,
|
|
143
144
|
):
|
|
144
145
|
"""
|
|
145
146
|
Returns an ordered list of documents ordered by their relevance to the provided query.
|
|
@@ -163,7 +164,7 @@ class RerankModelHandle(ModelHandle):
|
|
|
163
164
|
|
|
164
165
|
"""
|
|
165
166
|
coro = self._model_ref.rerank(
|
|
166
|
-
documents, query, top_n, max_chunks_per_doc, return_documents
|
|
167
|
+
documents, query, top_n, max_chunks_per_doc, return_documents, **kwargs
|
|
167
168
|
)
|
|
168
169
|
results = orjson.loads(self._isolation.call(coro))
|
|
169
170
|
for r in results["results"]:
|
|
@@ -80,7 +80,7 @@ class RESTfulModelHandle:
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
|
|
83
|
-
def create_embedding(self, input: Union[str, List[str]]) -> "Embedding":
|
|
83
|
+
def create_embedding(self, input: Union[str, List[str]], **kwargs) -> "Embedding":
|
|
84
84
|
"""
|
|
85
85
|
Create an Embedding from user input via RESTful APIs.
|
|
86
86
|
|
|
@@ -102,7 +102,11 @@ class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
|
|
|
102
102
|
|
|
103
103
|
"""
|
|
104
104
|
url = f"{self._base_url}/v1/embeddings"
|
|
105
|
-
request_body = {
|
|
105
|
+
request_body = {
|
|
106
|
+
"model": self._model_uid,
|
|
107
|
+
"input": input,
|
|
108
|
+
}
|
|
109
|
+
request_body.update(kwargs)
|
|
106
110
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
107
111
|
if response.status_code != 200:
|
|
108
112
|
raise RuntimeError(
|
|
@@ -121,6 +125,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
121
125
|
top_n: Optional[int] = None,
|
|
122
126
|
max_chunks_per_doc: Optional[int] = None,
|
|
123
127
|
return_documents: Optional[bool] = None,
|
|
128
|
+
**kwargs,
|
|
124
129
|
):
|
|
125
130
|
"""
|
|
126
131
|
Returns an ordered list of documents ordered by their relevance to the provided query.
|
|
@@ -156,6 +161,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
|
|
|
156
161
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
157
162
|
"return_documents": return_documents,
|
|
158
163
|
}
|
|
164
|
+
request_body.update(kwargs)
|
|
159
165
|
response = requests.post(url, json=request_body, headers=self.auth_headers)
|
|
160
166
|
if response.status_code != 200:
|
|
161
167
|
raise RuntimeError(
|
|
@@ -651,11 +657,13 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
651
657
|
|
|
652
658
|
|
|
653
659
|
class Client:
|
|
654
|
-
def __init__(self, base_url):
|
|
660
|
+
def __init__(self, base_url, api_key: Optional[str] = None):
|
|
655
661
|
self.base_url = base_url
|
|
656
|
-
self._headers = {}
|
|
662
|
+
self._headers: Dict[str, str] = {}
|
|
657
663
|
self._cluster_authed = False
|
|
658
664
|
self._check_cluster_authenticated()
|
|
665
|
+
if api_key is not None and self._cluster_authed:
|
|
666
|
+
self._headers["Authorization"] = f"Bearer {api_key}"
|
|
659
667
|
|
|
660
668
|
def _set_token(self, token: Optional[str]):
|
|
661
669
|
if not self._cluster_authed or token is None:
|
|
@@ -795,6 +803,8 @@ class Client:
|
|
|
795
803
|
peft_model_path: Optional[str] = None,
|
|
796
804
|
image_lora_load_kwargs: Optional[Dict] = None,
|
|
797
805
|
image_lora_fuse_kwargs: Optional[Dict] = None,
|
|
806
|
+
worker_ip: Optional[str] = None,
|
|
807
|
+
gpu_idx: Optional[Union[int, List[int]]] = None,
|
|
798
808
|
**kwargs,
|
|
799
809
|
) -> str:
|
|
800
810
|
"""
|
|
@@ -828,6 +838,10 @@ class Client:
|
|
|
828
838
|
lora load parameters for image model
|
|
829
839
|
image_lora_fuse_kwargs: Optional[Dict]
|
|
830
840
|
lora fuse parameters for image model
|
|
841
|
+
worker_ip: Optional[str]
|
|
842
|
+
Specify the worker ip where the model is located in a distributed scenario.
|
|
843
|
+
gpu_idx: Optional[Union[int, List[int]]]
|
|
844
|
+
Specify the GPU index where the model is located.
|
|
831
845
|
**kwargs:
|
|
832
846
|
Any other parameters been specified.
|
|
833
847
|
|
|
@@ -853,6 +867,8 @@ class Client:
|
|
|
853
867
|
"peft_model_path": peft_model_path,
|
|
854
868
|
"image_lora_load_kwargs": image_lora_load_kwargs,
|
|
855
869
|
"image_lora_fuse_kwargs": image_lora_fuse_kwargs,
|
|
870
|
+
"worker_ip": worker_ip,
|
|
871
|
+
"gpu_idx": gpu_idx,
|
|
856
872
|
}
|
|
857
873
|
|
|
858
874
|
for key, value in kwargs.items():
|
xinference/conftest.py
CHANGED
|
@@ -261,12 +261,23 @@ def setup_with_auth():
|
|
|
261
261
|
if not cluster_health_check(supervisor_addr, max_attempts=10, sleep_interval=3):
|
|
262
262
|
raise RuntimeError("Cluster is not available after multiple attempts")
|
|
263
263
|
|
|
264
|
-
user1 = User(
|
|
265
|
-
|
|
264
|
+
user1 = User(
|
|
265
|
+
username="user1",
|
|
266
|
+
password="pass1",
|
|
267
|
+
permissions=["admin"],
|
|
268
|
+
api_keys=["sk-3sjLbdwqAhhAF", "sk-0HCRO1rauFQDL"],
|
|
269
|
+
)
|
|
270
|
+
user2 = User(
|
|
271
|
+
username="user2",
|
|
272
|
+
password="pass2",
|
|
273
|
+
permissions=["models:list"],
|
|
274
|
+
api_keys=["sk-72tkvudyGLPMi"],
|
|
275
|
+
)
|
|
266
276
|
user3 = User(
|
|
267
277
|
username="user3",
|
|
268
278
|
password="pass3",
|
|
269
279
|
permissions=["models:list", "models:read", "models:start"],
|
|
280
|
+
api_keys=["sk-m6jEzEwmCc4iQ", "sk-ZOTLIY4gt9w11"],
|
|
270
281
|
)
|
|
271
282
|
auth_config = AuthConfig(
|
|
272
283
|
algorithm="HS256",
|
xinference/core/supervisor.py
CHANGED
|
@@ -92,6 +92,15 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
92
92
|
def uid(cls) -> str:
|
|
93
93
|
return "supervisor"
|
|
94
94
|
|
|
95
|
+
def _get_worker_ref_by_ip(
|
|
96
|
+
self, ip: str
|
|
97
|
+
) -> Optional[xo.ActorRefType["WorkerActor"]]:
|
|
98
|
+
for addr, ref in self._worker_address_to_worker.items():
|
|
99
|
+
existing_ip = addr.split(":")[0]
|
|
100
|
+
if existing_ip == ip:
|
|
101
|
+
return ref
|
|
102
|
+
return None
|
|
103
|
+
|
|
95
104
|
async def __post_create__(self):
|
|
96
105
|
self._uptime = time.time()
|
|
97
106
|
if not XINFERENCE_DISABLE_HEALTH_CHECK:
|
|
@@ -717,8 +726,25 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
717
726
|
peft_model_path: Optional[str] = None,
|
|
718
727
|
image_lora_load_kwargs: Optional[Dict] = None,
|
|
719
728
|
image_lora_fuse_kwargs: Optional[Dict] = None,
|
|
729
|
+
worker_ip: Optional[str] = None,
|
|
730
|
+
gpu_idx: Optional[Union[int, List[int]]] = None,
|
|
720
731
|
**kwargs,
|
|
721
732
|
) -> str:
|
|
733
|
+
target_ip_worker_ref = (
|
|
734
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
735
|
+
)
|
|
736
|
+
if (
|
|
737
|
+
worker_ip is not None
|
|
738
|
+
and not self.is_local_deployment()
|
|
739
|
+
and target_ip_worker_ref is None
|
|
740
|
+
):
|
|
741
|
+
raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
|
|
742
|
+
if worker_ip is not None and self.is_local_deployment():
|
|
743
|
+
logger.warning(
|
|
744
|
+
f"You specified the worker ip: {worker_ip} in local mode, "
|
|
745
|
+
f"xinference will ignore this option."
|
|
746
|
+
)
|
|
747
|
+
|
|
722
748
|
if model_uid is None:
|
|
723
749
|
model_uid = self._gen_model_uid(model_name)
|
|
724
750
|
|
|
@@ -735,7 +761,11 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
735
761
|
)
|
|
736
762
|
|
|
737
763
|
nonlocal model_type
|
|
738
|
-
worker_ref =
|
|
764
|
+
worker_ref = (
|
|
765
|
+
target_ip_worker_ref
|
|
766
|
+
if target_ip_worker_ref is not None
|
|
767
|
+
else await self._choose_worker()
|
|
768
|
+
)
|
|
739
769
|
# LLM as default for compatibility
|
|
740
770
|
model_type = model_type or "LLM"
|
|
741
771
|
await worker_ref.launch_builtin_model(
|
|
@@ -750,6 +780,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
750
780
|
peft_model_path=peft_model_path,
|
|
751
781
|
image_lora_load_kwargs=image_lora_load_kwargs,
|
|
752
782
|
image_lora_fuse_kwargs=image_lora_fuse_kwargs,
|
|
783
|
+
gpu_idx=gpu_idx,
|
|
753
784
|
**kwargs,
|
|
754
785
|
)
|
|
755
786
|
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
|
|
@@ -839,6 +870,12 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
839
870
|
address,
|
|
840
871
|
dead_models,
|
|
841
872
|
)
|
|
873
|
+
for replica_model_uid in dead_models:
|
|
874
|
+
model_uid, _, _ = parse_replica_model_uid(replica_model_uid)
|
|
875
|
+
self._model_uid_to_replica_info.pop(model_uid, None)
|
|
876
|
+
self._replica_model_uid_to_worker.pop(
|
|
877
|
+
replica_model_uid, None
|
|
878
|
+
)
|
|
842
879
|
dead_nodes.append(address)
|
|
843
880
|
elif (
|
|
844
881
|
status.failure_remaining_count
|
|
@@ -948,6 +985,16 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
948
985
|
|
|
949
986
|
@log_async(logger=logger)
|
|
950
987
|
async def remove_worker(self, worker_address: str):
|
|
988
|
+
uids_to_remove = []
|
|
989
|
+
for model_uid in self._replica_model_uid_to_worker:
|
|
990
|
+
if self._replica_model_uid_to_worker[model_uid].address == worker_address:
|
|
991
|
+
uids_to_remove.append(model_uid)
|
|
992
|
+
|
|
993
|
+
for replica_model_uid in uids_to_remove:
|
|
994
|
+
model_uid, _, _ = parse_replica_model_uid(replica_model_uid)
|
|
995
|
+
self._model_uid_to_replica_info.pop(model_uid, None)
|
|
996
|
+
self._replica_model_uid_to_worker.pop(replica_model_uid, None)
|
|
997
|
+
|
|
951
998
|
if worker_address in self._worker_address_to_worker:
|
|
952
999
|
del self._worker_address_to_worker[worker_address]
|
|
953
1000
|
logger.debug("Worker %s has been removed successfully", worker_address)
|