xinference 0.7.5__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (120) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/__init__.py +13 -0
  3. xinference/api/oauth2/common.py +14 -0
  4. xinference/api/oauth2/core.py +93 -0
  5. xinference/api/oauth2/types.py +36 -0
  6. xinference/api/oauth2/utils.py +44 -0
  7. xinference/api/restful_api.py +216 -27
  8. xinference/client/oscar/actor_client.py +18 -18
  9. xinference/client/restful/restful_client.py +96 -33
  10. xinference/conftest.py +63 -1
  11. xinference/constants.py +1 -0
  12. xinference/core/chat_interface.py +143 -3
  13. xinference/core/metrics.py +83 -0
  14. xinference/core/model.py +244 -181
  15. xinference/core/status_guard.py +86 -0
  16. xinference/core/supervisor.py +57 -7
  17. xinference/core/worker.py +134 -13
  18. xinference/deploy/cmdline.py +142 -16
  19. xinference/deploy/local.py +39 -7
  20. xinference/deploy/supervisor.py +2 -0
  21. xinference/deploy/worker.py +33 -5
  22. xinference/fields.py +4 -1
  23. xinference/model/core.py +8 -1
  24. xinference/model/embedding/core.py +3 -2
  25. xinference/model/embedding/model_spec_modelscope.json +60 -18
  26. xinference/model/image/stable_diffusion/core.py +4 -3
  27. xinference/model/llm/__init__.py +7 -0
  28. xinference/model/llm/ggml/llamacpp.py +3 -2
  29. xinference/model/llm/llm_family.json +87 -3
  30. xinference/model/llm/llm_family.py +15 -5
  31. xinference/model/llm/llm_family_modelscope.json +92 -3
  32. xinference/model/llm/pytorch/chatglm.py +70 -28
  33. xinference/model/llm/pytorch/core.py +11 -30
  34. xinference/model/llm/pytorch/internlm2.py +155 -0
  35. xinference/model/llm/pytorch/utils.py +0 -153
  36. xinference/model/llm/utils.py +37 -8
  37. xinference/model/llm/vllm/core.py +15 -3
  38. xinference/model/multimodal/__init__.py +15 -8
  39. xinference/model/multimodal/core.py +8 -1
  40. xinference/model/multimodal/model_spec.json +9 -0
  41. xinference/model/multimodal/model_spec_modelscope.json +45 -0
  42. xinference/model/multimodal/qwen_vl.py +5 -9
  43. xinference/model/utils.py +7 -2
  44. xinference/types.py +2 -0
  45. xinference/web/ui/build/asset-manifest.json +3 -3
  46. xinference/web/ui/build/index.html +1 -1
  47. xinference/web/ui/build/static/js/main.b83095c2.js +3 -0
  48. xinference/web/ui/build/static/js/{main.236e72e7.js.LICENSE.txt → main.b83095c2.js.LICENSE.txt} +7 -0
  49. xinference/web/ui/build/static/js/main.b83095c2.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/0a853b2fa1902551e262a2f1a4b7894341f27b3dd9587f2ef7aaea195af89518.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/193e7ba39e70d4bb2895a5cb317f6f293a5fd02e7e324c02a1eba2f83216419c.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/22858de5265f2d279fca9f2f54dfb147e4b2704200dfb5d2ad3ec9769417328f.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/27696db5fcd4fcf0e7974cadf1e4a2ab89690474045c3188eafd586323ad13bb.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/27bdbe25deab8cf08f7fab8f05f8f26cf84a98809527a37986a4ab73a57ba96a.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/2bee7b8bd3d52976a45d6068e1333df88b943e0e679403c809e45382e3818037.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/30670751f55508ef3b861e13dd71b9e5a10d2561373357a12fc3831a0b77fd93.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/3605cd3a96ff2a3b443c70a101575482279ad26847924cab0684d165ba0d2492.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/3789ef437d3ecbf945bb9cea39093d1f16ebbfa32dbe6daf35abcfb6d48de6f1.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/4d933e35e0fe79867d3aa6c46db28804804efddf5490347cb6c2c2879762a157.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/4d96f071168af43965e0fab2ded658fa0a15b8d9ca03789a5ef9c5c16a4e3cee.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/4fd24800544873512b540544ae54601240a5bfefd9105ff647855c64f8ad828f.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/5c408307c982f07f9c09c85c98212d1b1c22548a9194c69548750a3016b91b88.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/663adbcb60b942e9cf094c8d9fabe57517f5e5e6e722d28b4948a40b7445a3b8.json +1 -0
  68. xinference/web/ui/node_modules/.cache/babel-loader/666bb2e1b250dc731311a7e4880886177885dfa768508d2ed63e02630cc78725.json +1 -0
  69. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/8b246d79cd3f6fc78f11777e6a6acca6a2c5d4ecce7f2dd4dcf9a48126440d3c.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/95c8cc049fadd23085d8623e1d43d70b614a4e52217676f186a417dca894aa09.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/a8070ce4b780b4a044218536e158a9e7192a6c80ff593fdc126fee43f46296b5.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/b4e4fccaf8f2489a29081f0bf3b191656bd452fb3c8b5e3c6d92d94f680964d5.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/b53eb7c7967f6577bd3e678293c44204fb03ffa7fdc1dd59d3099015c68f6f7f.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/bd04667474fd9cac2983b03725c218908a6cc0ee9128a5953cd00d26d4877f60.json +1 -0
  79. xinference/web/ui/node_modules/.cache/babel-loader/c230a727b8f68f0e62616a75e14a3d33026dc4164f2e325a9a8072d733850edb.json +1 -0
  80. xinference/web/ui/node_modules/.cache/babel-loader/d06af85a84e5c5a29d3acf2dbb5b30c0cf75c8aec4ab5f975e6096f944ee4324.json +1 -0
  81. xinference/web/ui/node_modules/.cache/babel-loader/d44a6eb6106e09082b691a315c9f6ce17fcfe25beb7547810e0d271ce3301cd2.json +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/d5e150bff31715977d8f537c970f06d4fe3de9909d7e8342244a83a9f6447121.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/de36e5c08fd524e341d664883dda6cb1745acc852a4f1b011a35a0b4615f72fa.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/f23ab356a8603d4a2aaa74388c2f381675c207d37c4d1c832df922e9655c9a6b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/f7c23b0922f4087b9e2e3e46f15c946b772daa46c28c3a12426212ecaf481deb.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/f95a8bd358eeb55fa2f49f1224cc2f4f36006359856744ff09ae4bb295f59ec1.json +1 -0
  88. xinference/web/ui/node_modules/.cache/babel-loader/fe5db70859503a54cbe71f9637e5a314cda88b1f0eecb733b6e6f837697db1ef.json +1 -0
  89. xinference/web/ui/node_modules/.package-lock.json +36 -0
  90. xinference/web/ui/node_modules/@types/cookie/package.json +30 -0
  91. xinference/web/ui/node_modules/@types/hoist-non-react-statics/package.json +33 -0
  92. xinference/web/ui/node_modules/react-cookie/package.json +55 -0
  93. xinference/web/ui/node_modules/universal-cookie/package.json +48 -0
  94. xinference/web/ui/package-lock.json +37 -0
  95. xinference/web/ui/package.json +3 -2
  96. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/METADATA +17 -6
  97. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/RECORD +101 -66
  98. xinference/web/ui/build/static/js/main.236e72e7.js +0 -3
  99. xinference/web/ui/build/static/js/main.236e72e7.js.map +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/0cccfbe5d963b8e31eb679f9d9677392839cedd04aa2956ac6b33cf19599d597.json +0 -1
  101. xinference/web/ui/node_modules/.cache/babel-loader/0f3b6cc71b7c83bdc85aa4835927aeb86af2ce0d2ac241917ecfbf90f75c6d27.json +0 -1
  102. xinference/web/ui/node_modules/.cache/babel-loader/2f651cf60b1bde50c0601c7110f77dd44819fb6e2501ff748a631724d91445d4.json +0 -1
  103. xinference/web/ui/node_modules/.cache/babel-loader/42bb623f337ad08ed076484185726e072ca52bb88e373d72c7b052db4c273342.json +0 -1
  104. xinference/web/ui/node_modules/.cache/babel-loader/57af83639c604bd3362d0f03f7505e81c6f67ff77bee7c6bb31f6e5523eba185.json +0 -1
  105. xinference/web/ui/node_modules/.cache/babel-loader/667753ce39ce1d4bcbf9a5f1a103d653be1d19d42f4e1fbaceb9b507679a52c7.json +0 -1
  106. xinference/web/ui/node_modules/.cache/babel-loader/66ed1bd4c06748c1b176a625c25c856997edc787856c73162f82f2b465c5d956.json +0 -1
  107. xinference/web/ui/node_modules/.cache/babel-loader/78f2521da2e2a98b075a2666cb782c7e2c019cd3c72199eecd5901c82d8655df.json +0 -1
  108. xinference/web/ui/node_modules/.cache/babel-loader/8d2b0b3c6988d1894694dcbbe708ef91cfe62d62dac317031f09915ced637953.json +0 -1
  109. xinference/web/ui/node_modules/.cache/babel-loader/9427ae7f1e94ae8dcd2333fb361e381f4054fde07394fe5448658e3417368476.json +0 -1
  110. xinference/web/ui/node_modules/.cache/babel-loader/bcee2b4e76b07620f9087989eb86d43c645ba3c7a74132cf926260af1164af0e.json +0 -1
  111. xinference/web/ui/node_modules/.cache/babel-loader/cc2ddd02ccc1dad1a2737ac247c79e6f6ed2c7836c6b68e511e3048f666b64af.json +0 -1
  112. xinference/web/ui/node_modules/.cache/babel-loader/d2e8e6665a7efc832b43907dadf4e3c896a59eaf8129f9a520882466c8f2e489.json +0 -1
  113. xinference/web/ui/node_modules/.cache/babel-loader/d8a42e9df7157de9f28eecefdf178fd113bf2280d28471b6e32a8a45276042df.json +0 -1
  114. xinference/web/ui/node_modules/.cache/babel-loader/e26750d9556e9741912333349e4da454c53dbfddbfc6002ab49518dcf02af745.json +0 -1
  115. xinference/web/ui/node_modules/.cache/babel-loader/ef42ec014d7bc373b874b2a1ff0dcd785490f125e913698bc049b0bd778e4d66.json +0 -1
  116. xinference/web/ui/node_modules/.cache/babel-loader/fe3eb4d76c79ca98833f686d642224eeeb94cc83ad14300d281623796d087f0a.json +0 -1
  117. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/LICENSE +0 -0
  118. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/WHEEL +0 -0
  119. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/entry_points.txt +0 -0
  120. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-01-05T15:29:43+0800",
11
+ "date": "2024-01-19T17:14:28+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "56b28b3e4149b0a9ab6f5322401b1c3f1fc95c1a",
15
- "version": "0.7.5"
14
+ "full-revisionid": "fb3985e95fbb3e6cb51a321d6d6a9a10661128fe",
15
+ "version": "0.8.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,14 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ XINFERENCE_OAUTH2_CONFIG = None
@@ -0,0 +1,93 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from typing import List, Optional, Union
16
+
17
+ from fastapi import Depends, HTTPException, status
18
+ from fastapi.security import OAuth2PasswordBearer, SecurityScopes
19
+ from jose import JWTError, jwt
20
+ from pydantic import BaseModel, ValidationError
21
+ from typing_extensions import Annotated
22
+
23
+ from .types import AuthStartupConfig, User
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
29
+
30
+
31
+ def get_db():
32
+ from .common import XINFERENCE_OAUTH2_CONFIG
33
+
34
+ # In a real enterprise-level environment, this should be the database
35
+ yield XINFERENCE_OAUTH2_CONFIG
36
+
37
+
38
+ def get_user(db_users: List[User], username: str) -> Optional[User]:
39
+ for user in db_users:
40
+ if user.username == username:
41
+ return user
42
+ return None
43
+
44
+
45
+ class TokenData(BaseModel):
46
+ username: Union[str, None] = None
47
+ scopes: List[str] = []
48
+
49
+
50
+ def verify_token(
51
+ security_scopes: SecurityScopes,
52
+ token: Annotated[str, Depends(oauth2_scheme)],
53
+ config: Optional[AuthStartupConfig] = Depends(get_db),
54
+ ):
55
+ if security_scopes.scopes:
56
+ authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
57
+ else:
58
+ authenticate_value = "Bearer"
59
+ credentials_exception = HTTPException(
60
+ status_code=status.HTTP_401_UNAUTHORIZED,
61
+ detail="Could not validate credentials",
62
+ headers={"WWW-Authenticate": authenticate_value},
63
+ )
64
+
65
+ try:
66
+ assert config is not None
67
+ payload = jwt.decode(
68
+ token,
69
+ config.auth_config.secret_key,
70
+ algorithms=[config.auth_config.algorithm],
71
+ options={"verify_exp": False}, # TODO: supports token expiration
72
+ )
73
+ username: str = payload.get("sub")
74
+ if username is None:
75
+ raise credentials_exception
76
+ token_scopes = payload.get("scopes", [])
77
+ # TODO: check expire
78
+ token_data = TokenData(scopes=token_scopes, username=username)
79
+ except (JWTError, ValidationError):
80
+ raise credentials_exception
81
+ user = get_user(config.user_config, username=token_data.username) # type: ignore
82
+ if user is None:
83
+ raise credentials_exception
84
+ if "admin" in token_data.scopes:
85
+ return user
86
+ for scope in security_scopes.scopes:
87
+ if scope not in token_data.scopes:
88
+ raise HTTPException(
89
+ status_code=status.HTTP_403_FORBIDDEN,
90
+ detail="Not enough permissions",
91
+ headers={"WWW-Authenticate": authenticate_value},
92
+ )
93
+ return user
@@ -0,0 +1,36 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+
16
+ from pydantic import BaseModel
17
+
18
+
19
+ class LoginUserForm(BaseModel):
20
+ username: str
21
+ password: str
22
+
23
+
24
+ class User(LoginUserForm):
25
+ permissions: List[str]
26
+
27
+
28
+ class AuthConfig(BaseModel):
29
+ algorithm: str = "HS256"
30
+ secret_key: str
31
+ token_expire_in_minutes: int
32
+
33
+
34
+ class AuthStartupConfig(BaseModel):
35
+ auth_config: AuthConfig
36
+ user_config: List[User]
@@ -0,0 +1,44 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from datetime import datetime, timedelta
15
+ from typing import Union
16
+
17
+ from jose import jwt
18
+ from passlib.context import CryptContext
19
+
20
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
21
+
22
+
23
+ def create_access_token(
24
+ data: dict,
25
+ secret_key: str,
26
+ algorithm: str,
27
+ expires_delta: Union[timedelta, None] = None,
28
+ ):
29
+ to_encode = data.copy()
30
+ if expires_delta:
31
+ expire = datetime.utcnow() + expires_delta
32
+ else:
33
+ expire = datetime.utcnow() + timedelta(minutes=15)
34
+ to_encode.update({"exp": expire})
35
+ encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
36
+ return encoded_jwt
37
+
38
+
39
+ def verify_password(plain_password, hashed_password):
40
+ return pwd_context.verify(plain_password, hashed_password)
41
+
42
+
43
+ def get_password_hash(password):
44
+ return pwd_context.hash(password)
@@ -21,10 +21,14 @@ import os
21
21
  import pprint
22
22
  import sys
23
23
  import warnings
24
+ from datetime import timedelta
24
25
  from typing import Any, List, Optional, Union
25
26
 
26
27
  import gradio as gr
28
+ import pydantic
27
29
  import xoscar as xo
30
+ from aioprometheus import REGISTRY, MetricsMiddleware
31
+ from aioprometheus.asgi.starlette import metrics
28
32
  from fastapi import (
29
33
  APIRouter,
30
34
  FastAPI,
@@ -34,9 +38,12 @@ from fastapi import (
34
38
  Query,
35
39
  Request,
36
40
  Response,
41
+ Security,
37
42
  UploadFile,
43
+ status,
38
44
  )
39
45
  from fastapi.middleware.cors import CORSMiddleware
46
+ from fastapi.responses import JSONResponse
40
47
  from fastapi.staticfiles import StaticFiles
41
48
  from PIL import Image
42
49
  from pydantic import BaseModel, Field
@@ -57,11 +64,14 @@ from ..types import (
57
64
  CreateCompletion,
58
65
  ImageList,
59
66
  )
67
+ from .oauth2.core import get_user, verify_token
68
+ from .oauth2.types import AuthStartupConfig, LoginUserForm, User
69
+ from .oauth2.utils import create_access_token, get_password_hash, verify_password
60
70
 
61
71
  logger = logging.getLogger(__name__)
62
72
 
63
73
 
64
- class JSONResponse(StarletteJSONResponse):
74
+ class JSONResponse(StarletteJSONResponse): # type: ignore # noqa: F811
65
75
  def render(self, content: Any) -> bytes:
66
76
  return json_dumps(content)
67
77
 
@@ -125,16 +135,48 @@ class BuildGradioInterfaceRequest(BaseModel):
125
135
  model_lang: List[str]
126
136
 
127
137
 
138
+ def authenticate_user(db_users: List[User], username: str, password: str):
139
+ user = get_user(db_users, username)
140
+ if not user:
141
+ return False
142
+ if not verify_password(password, user.password):
143
+ return False
144
+ return user
145
+
146
+
128
147
  class RESTfulAPI:
129
- def __init__(self, supervisor_address: str, host: str, port: int):
148
+ def __init__(
149
+ self,
150
+ supervisor_address: str,
151
+ host: str,
152
+ port: int,
153
+ auth_config_file: Optional[str] = None,
154
+ ):
130
155
  super().__init__()
131
156
  self._supervisor_address = supervisor_address
132
157
  self._host = host
133
158
  self._port = port
134
159
  self._supervisor_ref = None
160
+ self._auth_config: AuthStartupConfig = self.init_auth_config(auth_config_file)
135
161
  self._router = APIRouter()
136
162
  self._app = FastAPI()
137
163
 
164
+ @staticmethod
165
+ def init_auth_config(auth_config_file: Optional[str]):
166
+ from .oauth2 import common
167
+
168
+ if auth_config_file:
169
+ config: AuthStartupConfig = pydantic.parse_file_as(
170
+ path=auth_config_file, type_=AuthStartupConfig
171
+ )
172
+ for user in config.user_config:
173
+ user.password = get_password_hash(user.password)
174
+ common.XINFERENCE_OAUTH2_CONFIG = config # type: ignore
175
+ return config
176
+
177
+ def is_authenticated(self):
178
+ return False if self._auth_config is None else True
179
+
138
180
  @staticmethod
139
181
  def handle_request_limit_error(e: Exception):
140
182
  if "Rate limit reached" in str(e):
@@ -147,6 +189,33 @@ class RESTfulAPI:
147
189
  )
148
190
  return self._supervisor_ref
149
191
 
192
+ async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse:
193
+ user = authenticate_user(
194
+ self._auth_config.user_config, form_data.username, form_data.password
195
+ )
196
+ if not user:
197
+ raise HTTPException(
198
+ status_code=status.HTTP_401_UNAUTHORIZED,
199
+ detail="Incorrect username or password",
200
+ headers={"WWW-Authenticate": "Bearer"},
201
+ )
202
+ assert user is not None and isinstance(user, User)
203
+ access_token_expires = timedelta(
204
+ minutes=self._auth_config.auth_config.token_expire_in_minutes
205
+ )
206
+ access_token = create_access_token(
207
+ data={"sub": user.username, "scopes": user.permissions},
208
+ secret_key=self._auth_config.auth_config.secret_key,
209
+ algorithm=self._auth_config.auth_config.algorithm,
210
+ expires_delta=access_token_expires,
211
+ )
212
+ return JSONResponse(
213
+ content={"access_token": access_token, "token_type": "bearer"}
214
+ )
215
+
216
+ async def is_cluster_authenticated(self) -> JSONResponse:
217
+ return JSONResponse(content={"auth": self.is_authenticated()})
218
+
150
219
  def serve(self, logging_conf: Optional[dict] = None):
151
220
  self._app.add_middleware(
152
221
  CORSMiddleware,
@@ -155,8 +224,10 @@ class RESTfulAPI:
155
224
  allow_methods=["*"],
156
225
  allow_headers=["*"],
157
226
  )
227
+
228
+ # internal interface
158
229
  self._router.add_api_route("/status", self.get_status, methods=["GET"])
159
- self._router.add_api_route("/v1/models", self.list_models, methods=["GET"])
230
+ # conflict with /v1/models/{model_uid} below, so register this first
160
231
  self._router.add_api_route(
161
232
  "/v1/models/prompts", self._get_builtin_prompts, methods=["GET"]
162
233
  )
@@ -166,52 +237,124 @@ class RESTfulAPI:
166
237
  self._router.add_api_route(
167
238
  "/v1/cluster/devices", self._get_devices_count, methods=["GET"]
168
239
  )
240
+ self._router.add_api_route("/v1/address", self.get_address, methods=["GET"])
241
+
242
+ # user interface
243
+ self._router.add_api_route(
244
+ "/v1/ui/{model_uid}",
245
+ self.build_gradio_interface,
246
+ methods=["POST"],
247
+ dependencies=[Security(verify_token, scopes=["models:read"])]
248
+ if self.is_authenticated()
249
+ else None,
250
+ )
251
+ self._router.add_api_route(
252
+ "/token", self.login_for_access_token, methods=["POST"]
253
+ )
254
+ self._router.add_api_route(
255
+ "/v1/cluster/auth", self.is_cluster_authenticated, methods=["GET"]
256
+ )
257
+ # running instances
169
258
  self._router.add_api_route(
170
- "/v1/models/{model_uid}", self.describe_model, methods=["GET"]
259
+ "/v1/models/instances",
260
+ self.get_instance_info,
261
+ methods=["GET"],
262
+ dependencies=[Security(verify_token, scopes=["models:list"])]
263
+ if self.is_authenticated()
264
+ else None,
265
+ )
266
+ self._router.add_api_route(
267
+ "/v1/models",
268
+ self.list_models,
269
+ methods=["GET"],
270
+ dependencies=[Security(verify_token, scopes=["models:list"])]
271
+ if self.is_authenticated()
272
+ else None,
273
+ )
274
+
275
+ self._router.add_api_route(
276
+ "/v1/models/{model_uid}",
277
+ self.describe_model,
278
+ methods=["GET"],
279
+ dependencies=[Security(verify_token, scopes=["models:list"])]
280
+ if self.is_authenticated()
281
+ else None,
282
+ )
283
+ self._router.add_api_route(
284
+ "/v1/models",
285
+ self.launch_model,
286
+ methods=["POST"],
287
+ dependencies=[Security(verify_token, scopes=["models:start"])]
288
+ if self.is_authenticated()
289
+ else None,
171
290
  )
172
- self._router.add_api_route("/v1/models", self.launch_model, methods=["POST"])
173
291
  self._router.add_api_route(
174
292
  "/experimental/speculative_llms",
175
293
  self.launch_speculative_llm,
176
294
  methods=["POST"],
295
+ dependencies=[Security(verify_token, scopes=["models:start"])]
296
+ if self.is_authenticated()
297
+ else None,
177
298
  )
178
299
  self._router.add_api_route(
179
- "/v1/models/{model_uid}", self.terminate_model, methods=["DELETE"]
300
+ "/v1/models/{model_uid}",
301
+ self.terminate_model,
302
+ methods=["DELETE"],
303
+ dependencies=[Security(verify_token, scopes=["models:stop"])]
304
+ if self.is_authenticated()
305
+ else None,
180
306
  )
181
- self._router.add_api_route("/v1/address", self.get_address, methods=["GET"])
182
307
  self._router.add_api_route(
183
308
  "/v1/completions",
184
309
  self.create_completion,
185
310
  methods=["POST"],
186
311
  response_model=Completion,
312
+ dependencies=[Security(verify_token, scopes=["models:read"])]
313
+ if self.is_authenticated()
314
+ else None,
187
315
  )
188
316
  self._router.add_api_route(
189
317
  "/v1/embeddings",
190
318
  self.create_embedding,
191
319
  methods=["POST"],
320
+ dependencies=[Security(verify_token, scopes=["models:read"])]
321
+ if self.is_authenticated()
322
+ else None,
192
323
  )
193
324
  self._router.add_api_route(
194
325
  "/v1/rerank",
195
326
  self.rerank,
196
327
  methods=["POST"],
328
+ dependencies=[Security(verify_token, scopes=["models:read"])]
329
+ if self.is_authenticated()
330
+ else None,
197
331
  )
198
332
  self._router.add_api_route(
199
333
  "/v1/images/generations",
200
334
  self.create_images,
201
335
  methods=["POST"],
202
336
  response_model=ImageList,
337
+ dependencies=[Security(verify_token, scopes=["models:read"])]
338
+ if self.is_authenticated()
339
+ else None,
203
340
  )
204
341
  self._router.add_api_route(
205
342
  "/v1/images/variations",
206
343
  self.create_variations,
207
344
  methods=["POST"],
208
345
  response_model=ImageList,
346
+ dependencies=[Security(verify_token, scopes=["models:read"])]
347
+ if self.is_authenticated()
348
+ else None,
209
349
  )
210
350
  self._router.add_api_route(
211
351
  "/v1/chat/completions",
212
352
  self.create_chat_completion,
213
353
  methods=["POST"],
214
354
  response_model=ChatCompletion,
355
+ dependencies=[Security(verify_token, scopes=["models:read"])]
356
+ if self.is_authenticated()
357
+ else None,
215
358
  )
216
359
 
217
360
  # for custom models
@@ -219,28 +362,42 @@ class RESTfulAPI:
219
362
  "/v1/model_registrations/{model_type}",
220
363
  self.register_model,
221
364
  methods=["POST"],
365
+ dependencies=[Security(verify_token, scopes=["models:register"])]
366
+ if self.is_authenticated()
367
+ else None,
222
368
  )
223
369
  self._router.add_api_route(
224
370
  "/v1/model_registrations/{model_type}/{model_name}",
225
371
  self.unregister_model,
226
372
  methods=["DELETE"],
373
+ dependencies=[Security(verify_token, scopes=["models:unregister"])]
374
+ if self.is_authenticated()
375
+ else None,
227
376
  )
228
377
  self._router.add_api_route(
229
378
  "/v1/model_registrations/{model_type}",
230
379
  self.list_model_registrations,
231
380
  methods=["GET"],
381
+ dependencies=[Security(verify_token, scopes=["models:list"])]
382
+ if self.is_authenticated()
383
+ else None,
232
384
  )
233
385
  self._router.add_api_route(
234
386
  "/v1/model_registrations/{model_type}/{model_name}",
235
387
  self.get_model_registrations,
236
388
  methods=["GET"],
389
+ dependencies=[Security(verify_token, scopes=["models:list"])]
390
+ if self.is_authenticated()
391
+ else None,
237
392
  )
238
393
 
239
- self._router.add_api_route(
240
- "/v1/ui/{model_uid}", self.build_gradio_interface, methods=["POST"]
241
- )
242
-
394
+ # Clear the global Registry for the MetricsMiddleware, or
395
+ # the MetricsMiddleware will register duplicated metrics if the port
396
+ # conflict (This serve method run more than once).
397
+ REGISTRY.clear()
398
+ self._app.add_middleware(MetricsMiddleware)
243
399
  self._app.include_router(self._router)
400
+ self._app.add_route("/metrics", metrics)
244
401
 
245
402
  # Check all the routes returns Response.
246
403
  # This is to avoid `jsonable_encoder` performance issue:
@@ -406,7 +563,9 @@ class RESTfulAPI:
406
563
 
407
564
  return JSONResponse(content={"model_uid": model_uid})
408
565
 
409
- async def launch_model(self, request: Request) -> JSONResponse:
566
+ async def launch_model(
567
+ self, request: Request, wait_ready: bool = Query(True)
568
+ ) -> JSONResponse:
410
569
  payload = await request.json()
411
570
  model_uid = payload.get("model_uid")
412
571
  model_name = payload.get("model_name")
@@ -451,6 +610,7 @@ class RESTfulAPI:
451
610
  replica=replica,
452
611
  n_gpu=n_gpu,
453
612
  request_limits=request_limits,
613
+ wait_ready=wait_ready,
454
614
  **kwargs,
455
615
  )
456
616
 
@@ -466,8 +626,22 @@ class RESTfulAPI:
466
626
 
467
627
  return JSONResponse(content={"model_uid": model_uid})
468
628
 
629
+ async def get_instance_info(
630
+ self,
631
+ model_name: Optional[str] = Query(None),
632
+ model_uid: Optional[str] = Query(None),
633
+ ) -> JSONResponse:
634
+ try:
635
+ infos = await (await self._get_supervisor_ref()).get_instance_info(
636
+ model_name, model_uid
637
+ )
638
+ except Exception as e:
639
+ logger.error(str(e), exc_info=True)
640
+ raise HTTPException(status_code=500, detail=str(e))
641
+ return JSONResponse(content=infos)
642
+
469
643
  async def build_gradio_interface(
470
- self, model_uid: str, body: BuildGradioInterfaceRequest
644
+ self, model_uid: str, body: BuildGradioInterfaceRequest, request: Request
471
645
  ) -> JSONResponse:
472
646
  """
473
647
  Separate build_interface with launch_model
@@ -475,7 +649,7 @@ class RESTfulAPI:
475
649
  but calling API in async function does not return
476
650
  """
477
651
  assert self._app is not None
478
- assert body.model_type == "LLM"
652
+ assert body.model_type in ["LLM", "multimodal"]
479
653
 
480
654
  # asyncio.Lock() behaves differently in 3.9 than 3.10+
481
655
  # A event loop is required in 3.9 but not 3.10+
@@ -489,21 +663,24 @@ class RESTfulAPI:
489
663
  )
490
664
  asyncio.set_event_loop(asyncio.new_event_loop())
491
665
 
492
- from ..core.chat_interface import LLMInterface
666
+ from ..core.chat_interface import GradioInterface
493
667
 
494
668
  try:
669
+ access_token = request.headers.get("Authorization")
495
670
  internal_host = "localhost" if self._host == "0.0.0.0" else self._host
496
- interface = LLMInterface(
671
+ interface = GradioInterface(
497
672
  endpoint=f"http://{internal_host}:{self._port}",
498
673
  model_uid=model_uid,
499
674
  model_name=body.model_name,
500
675
  model_size_in_billions=body.model_size_in_billions,
676
+ model_type=body.model_type,
501
677
  model_format=body.model_format,
502
678
  quantization=body.quantization,
503
679
  context_length=body.context_length,
504
680
  model_ability=body.model_ability,
505
681
  model_description=body.model_description,
506
682
  model_lang=body.model_lang,
683
+ access_token=access_token,
507
684
  ).build()
508
685
  gr.mount_gradio_app(self._app, interface, f"/{model_uid}")
509
686
  except ValueError as ve:
@@ -581,8 +758,6 @@ class RESTfulAPI:
581
758
  async for item in iterator:
582
759
  yield item
583
760
  except Exception as ex:
584
- if iterator is not None:
585
- await iterator.destroy()
586
761
  logger.exception("Completion stream got an error: %s", ex)
587
762
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
588
763
  yield dict(data=json.dumps({"error": str(ex)}))
@@ -660,8 +835,7 @@ class RESTfulAPI:
660
835
  raise HTTPException(status_code=500, detail=str(e))
661
836
 
662
837
  try:
663
- if request.kwargs:
664
- kwargs = json.loads(request.kwargs)
838
+ kwargs = json.loads(request.kwargs) if request.kwargs else {}
665
839
  image_list = await model.text_to_image(
666
840
  prompt=request.prompt,
667
841
  n=request.n,
@@ -844,8 +1018,6 @@ class RESTfulAPI:
844
1018
  async for item in iterator:
845
1019
  yield item
846
1020
  except Exception as ex:
847
- if iterator is not None:
848
- await iterator.destroy()
849
1021
  logger.exception("Chat completion stream got an error: %s", ex)
850
1022
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
851
1023
  yield dict(data=json.dumps({"error": str(ex)}))
@@ -926,11 +1098,20 @@ class RESTfulAPI:
926
1098
 
927
1099
 
928
1100
  def run(
929
- supervisor_address: str, host: str, port: int, logging_conf: Optional[dict] = None
1101
+ supervisor_address: str,
1102
+ host: str,
1103
+ port: int,
1104
+ logging_conf: Optional[dict] = None,
1105
+ auth_config_file: Optional[str] = None,
930
1106
  ):
931
1107
  logger.info(f"Starting Xinference at endpoint: http://{host}:{port}")
932
1108
  try:
933
- api = RESTfulAPI(supervisor_address=supervisor_address, host=host, port=port)
1109
+ api = RESTfulAPI(
1110
+ supervisor_address=supervisor_address,
1111
+ host=host,
1112
+ port=port,
1113
+ auth_config_file=auth_config_file,
1114
+ )
934
1115
  api.serve(logging_conf=logging_conf)
935
1116
  except SystemExit:
936
1117
  logger.warning("Failed to create socket with port %d", port)
@@ -941,7 +1122,10 @@ def run(
941
1122
  logger.info(f"Found available port: {port}")
942
1123
  logger.info(f"Starting Xinference at endpoint: http://{host}:{port}")
943
1124
  api = RESTfulAPI(
944
- supervisor_address=supervisor_address, host=host, port=port
1125
+ supervisor_address=supervisor_address,
1126
+ host=host,
1127
+ port=port,
1128
+ auth_config_file=auth_config_file,
945
1129
  )
946
1130
  api.serve(logging_conf=logging_conf)
947
1131
  else:
@@ -949,10 +1133,15 @@ def run(
949
1133
 
950
1134
 
951
1135
  def run_in_subprocess(
952
- supervisor_address: str, host: str, port: int, logging_conf: Optional[dict] = None
1136
+ supervisor_address: str,
1137
+ host: str,
1138
+ port: int,
1139
+ logging_conf: Optional[dict] = None,
1140
+ auth_config_file: Optional[str] = None,
953
1141
  ) -> multiprocessing.Process:
954
1142
  p = multiprocessing.Process(
955
- target=run, args=(supervisor_address, host, port, logging_conf)
1143
+ target=run,
1144
+ args=(supervisor_address, host, port, logging_conf, auth_config_file),
956
1145
  )
957
1146
  p.daemon = True
958
1147
  p.start()