xinference 0.7.5__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (120) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/__init__.py +13 -0
  3. xinference/api/oauth2/common.py +14 -0
  4. xinference/api/oauth2/core.py +93 -0
  5. xinference/api/oauth2/types.py +36 -0
  6. xinference/api/oauth2/utils.py +44 -0
  7. xinference/api/restful_api.py +216 -27
  8. xinference/client/oscar/actor_client.py +18 -18
  9. xinference/client/restful/restful_client.py +96 -33
  10. xinference/conftest.py +63 -1
  11. xinference/constants.py +1 -0
  12. xinference/core/chat_interface.py +143 -3
  13. xinference/core/metrics.py +83 -0
  14. xinference/core/model.py +244 -181
  15. xinference/core/status_guard.py +86 -0
  16. xinference/core/supervisor.py +57 -7
  17. xinference/core/worker.py +134 -13
  18. xinference/deploy/cmdline.py +142 -16
  19. xinference/deploy/local.py +39 -7
  20. xinference/deploy/supervisor.py +2 -0
  21. xinference/deploy/worker.py +33 -5
  22. xinference/fields.py +4 -1
  23. xinference/model/core.py +8 -1
  24. xinference/model/embedding/core.py +3 -2
  25. xinference/model/embedding/model_spec_modelscope.json +60 -18
  26. xinference/model/image/stable_diffusion/core.py +4 -3
  27. xinference/model/llm/__init__.py +7 -0
  28. xinference/model/llm/ggml/llamacpp.py +3 -2
  29. xinference/model/llm/llm_family.json +87 -3
  30. xinference/model/llm/llm_family.py +15 -5
  31. xinference/model/llm/llm_family_modelscope.json +92 -3
  32. xinference/model/llm/pytorch/chatglm.py +70 -28
  33. xinference/model/llm/pytorch/core.py +11 -30
  34. xinference/model/llm/pytorch/internlm2.py +155 -0
  35. xinference/model/llm/pytorch/utils.py +0 -153
  36. xinference/model/llm/utils.py +37 -8
  37. xinference/model/llm/vllm/core.py +15 -3
  38. xinference/model/multimodal/__init__.py +15 -8
  39. xinference/model/multimodal/core.py +8 -1
  40. xinference/model/multimodal/model_spec.json +9 -0
  41. xinference/model/multimodal/model_spec_modelscope.json +45 -0
  42. xinference/model/multimodal/qwen_vl.py +5 -9
  43. xinference/model/utils.py +7 -2
  44. xinference/types.py +2 -0
  45. xinference/web/ui/build/asset-manifest.json +3 -3
  46. xinference/web/ui/build/index.html +1 -1
  47. xinference/web/ui/build/static/js/main.b83095c2.js +3 -0
  48. xinference/web/ui/build/static/js/{main.236e72e7.js.LICENSE.txt → main.b83095c2.js.LICENSE.txt} +7 -0
  49. xinference/web/ui/build/static/js/main.b83095c2.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/0a853b2fa1902551e262a2f1a4b7894341f27b3dd9587f2ef7aaea195af89518.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/193e7ba39e70d4bb2895a5cb317f6f293a5fd02e7e324c02a1eba2f83216419c.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/22858de5265f2d279fca9f2f54dfb147e4b2704200dfb5d2ad3ec9769417328f.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/27696db5fcd4fcf0e7974cadf1e4a2ab89690474045c3188eafd586323ad13bb.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/27bdbe25deab8cf08f7fab8f05f8f26cf84a98809527a37986a4ab73a57ba96a.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/2bee7b8bd3d52976a45d6068e1333df88b943e0e679403c809e45382e3818037.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/30670751f55508ef3b861e13dd71b9e5a10d2561373357a12fc3831a0b77fd93.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/3605cd3a96ff2a3b443c70a101575482279ad26847924cab0684d165ba0d2492.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/3789ef437d3ecbf945bb9cea39093d1f16ebbfa32dbe6daf35abcfb6d48de6f1.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/4d933e35e0fe79867d3aa6c46db28804804efddf5490347cb6c2c2879762a157.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/4d96f071168af43965e0fab2ded658fa0a15b8d9ca03789a5ef9c5c16a4e3cee.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/4fd24800544873512b540544ae54601240a5bfefd9105ff647855c64f8ad828f.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/5c408307c982f07f9c09c85c98212d1b1c22548a9194c69548750a3016b91b88.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/663adbcb60b942e9cf094c8d9fabe57517f5e5e6e722d28b4948a40b7445a3b8.json +1 -0
  68. xinference/web/ui/node_modules/.cache/babel-loader/666bb2e1b250dc731311a7e4880886177885dfa768508d2ed63e02630cc78725.json +1 -0
  69. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/8b246d79cd3f6fc78f11777e6a6acca6a2c5d4ecce7f2dd4dcf9a48126440d3c.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/95c8cc049fadd23085d8623e1d43d70b614a4e52217676f186a417dca894aa09.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/a8070ce4b780b4a044218536e158a9e7192a6c80ff593fdc126fee43f46296b5.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/b4e4fccaf8f2489a29081f0bf3b191656bd452fb3c8b5e3c6d92d94f680964d5.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/b53eb7c7967f6577bd3e678293c44204fb03ffa7fdc1dd59d3099015c68f6f7f.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/bd04667474fd9cac2983b03725c218908a6cc0ee9128a5953cd00d26d4877f60.json +1 -0
  79. xinference/web/ui/node_modules/.cache/babel-loader/c230a727b8f68f0e62616a75e14a3d33026dc4164f2e325a9a8072d733850edb.json +1 -0
  80. xinference/web/ui/node_modules/.cache/babel-loader/d06af85a84e5c5a29d3acf2dbb5b30c0cf75c8aec4ab5f975e6096f944ee4324.json +1 -0
  81. xinference/web/ui/node_modules/.cache/babel-loader/d44a6eb6106e09082b691a315c9f6ce17fcfe25beb7547810e0d271ce3301cd2.json +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/d5e150bff31715977d8f537c970f06d4fe3de9909d7e8342244a83a9f6447121.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/de36e5c08fd524e341d664883dda6cb1745acc852a4f1b011a35a0b4615f72fa.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/f23ab356a8603d4a2aaa74388c2f381675c207d37c4d1c832df922e9655c9a6b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/f7c23b0922f4087b9e2e3e46f15c946b772daa46c28c3a12426212ecaf481deb.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/f95a8bd358eeb55fa2f49f1224cc2f4f36006359856744ff09ae4bb295f59ec1.json +1 -0
  88. xinference/web/ui/node_modules/.cache/babel-loader/fe5db70859503a54cbe71f9637e5a314cda88b1f0eecb733b6e6f837697db1ef.json +1 -0
  89. xinference/web/ui/node_modules/.package-lock.json +36 -0
  90. xinference/web/ui/node_modules/@types/cookie/package.json +30 -0
  91. xinference/web/ui/node_modules/@types/hoist-non-react-statics/package.json +33 -0
  92. xinference/web/ui/node_modules/react-cookie/package.json +55 -0
  93. xinference/web/ui/node_modules/universal-cookie/package.json +48 -0
  94. xinference/web/ui/package-lock.json +37 -0
  95. xinference/web/ui/package.json +3 -2
  96. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/METADATA +17 -6
  97. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/RECORD +101 -66
  98. xinference/web/ui/build/static/js/main.236e72e7.js +0 -3
  99. xinference/web/ui/build/static/js/main.236e72e7.js.map +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/0cccfbe5d963b8e31eb679f9d9677392839cedd04aa2956ac6b33cf19599d597.json +0 -1
  101. xinference/web/ui/node_modules/.cache/babel-loader/0f3b6cc71b7c83bdc85aa4835927aeb86af2ce0d2ac241917ecfbf90f75c6d27.json +0 -1
  102. xinference/web/ui/node_modules/.cache/babel-loader/2f651cf60b1bde50c0601c7110f77dd44819fb6e2501ff748a631724d91445d4.json +0 -1
  103. xinference/web/ui/node_modules/.cache/babel-loader/42bb623f337ad08ed076484185726e072ca52bb88e373d72c7b052db4c273342.json +0 -1
  104. xinference/web/ui/node_modules/.cache/babel-loader/57af83639c604bd3362d0f03f7505e81c6f67ff77bee7c6bb31f6e5523eba185.json +0 -1
  105. xinference/web/ui/node_modules/.cache/babel-loader/667753ce39ce1d4bcbf9a5f1a103d653be1d19d42f4e1fbaceb9b507679a52c7.json +0 -1
  106. xinference/web/ui/node_modules/.cache/babel-loader/66ed1bd4c06748c1b176a625c25c856997edc787856c73162f82f2b465c5d956.json +0 -1
  107. xinference/web/ui/node_modules/.cache/babel-loader/78f2521da2e2a98b075a2666cb782c7e2c019cd3c72199eecd5901c82d8655df.json +0 -1
  108. xinference/web/ui/node_modules/.cache/babel-loader/8d2b0b3c6988d1894694dcbbe708ef91cfe62d62dac317031f09915ced637953.json +0 -1
  109. xinference/web/ui/node_modules/.cache/babel-loader/9427ae7f1e94ae8dcd2333fb361e381f4054fde07394fe5448658e3417368476.json +0 -1
  110. xinference/web/ui/node_modules/.cache/babel-loader/bcee2b4e76b07620f9087989eb86d43c645ba3c7a74132cf926260af1164af0e.json +0 -1
  111. xinference/web/ui/node_modules/.cache/babel-loader/cc2ddd02ccc1dad1a2737ac247c79e6f6ed2c7836c6b68e511e3048f666b64af.json +0 -1
  112. xinference/web/ui/node_modules/.cache/babel-loader/d2e8e6665a7efc832b43907dadf4e3c896a59eaf8129f9a520882466c8f2e489.json +0 -1
  113. xinference/web/ui/node_modules/.cache/babel-loader/d8a42e9df7157de9f28eecefdf178fd113bf2280d28471b6e32a8a45276042df.json +0 -1
  114. xinference/web/ui/node_modules/.cache/babel-loader/e26750d9556e9741912333349e4da454c53dbfddbfc6002ab49518dcf02af745.json +0 -1
  115. xinference/web/ui/node_modules/.cache/babel-loader/ef42ec014d7bc373b874b2a1ff0dcd785490f125e913698bc049b0bd778e4d66.json +0 -1
  116. xinference/web/ui/node_modules/.cache/babel-loader/fe3eb4d76c79ca98833f686d642224eeeb94cc83ad14300d281623796d087f0a.json +0 -1
  117. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/LICENSE +0 -0
  118. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/WHEEL +0 -0
  119. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/entry_points.txt +0 -0
  120. {xinference-0.7.5.dist-info → xinference-0.8.1.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,12 @@
14
14
 
15
15
  import asyncio
16
16
  import re
17
- from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
17
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
18
18
 
19
19
  import orjson
20
20
  import xoscar as xo
21
21
 
22
- from ...core.model import IteratorWrapper, ModelActor
22
+ from ...core.model import ModelActor
23
23
  from ...core.supervisor import SupervisorActor
24
24
  from ...isolation import Isolation
25
25
  from ..restful.restful_client import Client
@@ -97,18 +97,18 @@ class ModelHandle:
97
97
  self._isolation = isolation
98
98
 
99
99
 
100
- class ClientIteratorWrapper(IteratorWrapper):
100
+ class ClientIteratorWrapper(AsyncIterator):
101
+ def __init__(self, iterator_wrapper):
102
+ self._iw = iterator_wrapper
103
+
104
+ def __aiter__(self):
105
+ return self
106
+
101
107
  async def __anext__(self):
102
- r = await super().__anext__()
108
+ r = await self._iw.__anext__()
103
109
  text = r.decode("utf-8")
104
110
  return orjson.loads(SSEEvent.parse(text).data)
105
111
 
106
- @classmethod
107
- def wrap(cls, iterator_wrapper):
108
- c = cls.__new__(cls)
109
- c.__dict__.update(iterator_wrapper.__dict__)
110
- return c
111
-
112
112
 
113
113
  class EmbeddingModelHandle(ModelHandle):
114
114
  def create_embedding(self, input: Union[str, List[str]]) -> bytes:
@@ -171,14 +171,14 @@ class RerankModelHandle(ModelHandle):
171
171
  return results
172
172
 
173
173
 
174
- class GenerateModelHandle(EmbeddingModelHandle):
174
+ class GenerateModelHandle(ModelHandle):
175
175
  def generate(
176
176
  self,
177
177
  prompt: str,
178
178
  generate_config: Optional[
179
179
  Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
180
180
  ] = None,
181
- ) -> Union["Completion", Iterator["CompletionChunk"]]:
181
+ ) -> Union["Completion", AsyncIterator["CompletionChunk"]]:
182
182
  """
183
183
  Creates a completion for the provided prompt and parameters.
184
184
 
@@ -204,7 +204,7 @@ class GenerateModelHandle(EmbeddingModelHandle):
204
204
  r = self._isolation.call(coro)
205
205
  if isinstance(r, bytes):
206
206
  return orjson.loads(r)
207
- return ClientIteratorWrapper.wrap(r)
207
+ return ClientIteratorWrapper(r)
208
208
 
209
209
 
210
210
  class ChatModelHandle(GenerateModelHandle):
@@ -216,7 +216,7 @@ class ChatModelHandle(GenerateModelHandle):
216
216
  generate_config: Optional[
217
217
  Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
218
218
  ] = None,
219
- ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
219
+ ) -> Union["ChatCompletion", AsyncIterator["ChatCompletionChunk"]]:
220
220
  """
221
221
  Given a list of messages comprising a conversation, the model will return a response.
222
222
 
@@ -252,16 +252,16 @@ class ChatModelHandle(GenerateModelHandle):
252
252
  r = self._isolation.call(coro)
253
253
  if isinstance(r, bytes):
254
254
  return orjson.loads(r)
255
- return ClientIteratorWrapper.wrap(r)
255
+ return ClientIteratorWrapper(r)
256
256
 
257
257
 
258
- class ChatglmCppChatModelHandle(EmbeddingModelHandle):
258
+ class ChatglmCppChatModelHandle(ModelHandle):
259
259
  def chat(
260
260
  self,
261
261
  prompt: str,
262
262
  chat_history: Optional[List["ChatCompletionMessage"]] = None,
263
263
  generate_config: Optional["ChatglmCppGenerateConfig"] = None,
264
- ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
264
+ ) -> Union["ChatCompletion", AsyncIterator["ChatCompletionChunk"]]:
265
265
  """
266
266
  Given a list of messages comprising a conversation, the ChatGLM model will return a response.
267
267
 
@@ -287,7 +287,7 @@ class ChatglmCppChatModelHandle(EmbeddingModelHandle):
287
287
  r = self._isolation.call(coro)
288
288
  if isinstance(r, bytes):
289
289
  return orjson.loads(r)
290
- return ClientIteratorWrapper.wrap(r)
290
+ return ClientIteratorWrapper(r)
291
291
 
292
292
 
293
293
  class ImageModelHandle(ModelHandle):
@@ -53,9 +53,10 @@ class RESTfulModelHandle:
53
53
  programmatically.
54
54
  """
55
55
 
56
- def __init__(self, model_uid: str, base_url: str):
56
+ def __init__(self, model_uid: str, base_url: str, auth_headers: Dict):
57
57
  self._model_uid = model_uid
58
58
  self._base_url = base_url
59
+ self.auth_headers = auth_headers
59
60
 
60
61
 
61
62
  class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
@@ -82,7 +83,7 @@ class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
82
83
  """
83
84
  url = f"{self._base_url}/v1/embeddings"
84
85
  request_body = {"model": self._model_uid, "input": input}
85
- response = requests.post(url, json=request_body)
86
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
86
87
  if response.status_code != 200:
87
88
  raise RuntimeError(
88
89
  f"Failed to create the embeddings, detail: {_get_error_string(response)}"
@@ -135,7 +136,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
135
136
  "max_chunks_per_doc": max_chunks_per_doc,
136
137
  "return_documents": return_documents,
137
138
  }
138
- response = requests.post(url, json=request_body)
139
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
139
140
  if response.status_code != 200:
140
141
  raise RuntimeError(
141
142
  f"Failed to rerank documents, detail: {response.json()['detail']}"
@@ -182,7 +183,7 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
182
183
  "response_format": response_format,
183
184
  "kwargs": json.dumps(kwargs),
184
185
  }
185
- response = requests.post(url, json=request_body)
186
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
186
187
  if response.status_code != 200:
187
188
  raise RuntimeError(
188
189
  f"Failed to create the images, detail: {_get_error_string(response)}"
@@ -246,10 +247,7 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
246
247
  for key, value in params.items():
247
248
  files.append((key, (None, value)))
248
249
  files.append(("image", ("image", image, "application/octet-stream")))
249
- response = requests.post(
250
- url,
251
- files=files,
252
- )
250
+ response = requests.post(url, files=files, headers=self.auth_headers)
253
251
  if response.status_code != 200:
254
252
  raise RuntimeError(
255
253
  f"Failed to variants the images, detail: {_get_error_string(response)}"
@@ -259,7 +257,7 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
259
257
  return response_data
260
258
 
261
259
 
262
- class RESTfulGenerateModelHandle(RESTfulEmbeddingModelHandle):
260
+ class RESTfulGenerateModelHandle(RESTfulModelHandle):
263
261
  def generate(
264
262
  self,
265
263
  prompt: str,
@@ -302,7 +300,9 @@ class RESTfulGenerateModelHandle(RESTfulEmbeddingModelHandle):
302
300
 
303
301
  stream = bool(generate_config and generate_config.get("stream"))
304
302
 
305
- response = requests.post(url, json=request_body, stream=stream)
303
+ response = requests.post(
304
+ url, json=request_body, stream=stream, headers=self.auth_headers
305
+ )
306
306
  if response.status_code != 200:
307
307
  raise RuntimeError(
308
308
  f"Failed to generate completion, detail: {_get_error_string(response)}"
@@ -384,7 +384,9 @@ class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
384
384
  request_body[key] = value
385
385
 
386
386
  stream = bool(generate_config and generate_config.get("stream"))
387
- response = requests.post(url, json=request_body, stream=stream)
387
+ response = requests.post(
388
+ url, json=request_body, stream=stream, headers=self.auth_headers
389
+ )
388
390
 
389
391
  if response.status_code != 200:
390
392
  raise RuntimeError(
@@ -468,7 +470,9 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
468
470
  request_body[key] = value
469
471
 
470
472
  stream = bool(generate_config and generate_config.get("stream"))
471
- response = requests.post(url, json=request_body, stream=stream)
473
+ response = requests.post(
474
+ url, json=request_body, stream=stream, headers=self.auth_headers
475
+ )
472
476
 
473
477
  if response.status_code != 200:
474
478
  raise RuntimeError(
@@ -482,7 +486,7 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
482
486
  return response_data
483
487
 
484
488
 
485
- class RESTfulChatglmCppChatModelHandle(RESTfulEmbeddingModelHandle):
489
+ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
486
490
  def chat(
487
491
  self,
488
492
  prompt: str,
@@ -536,7 +540,9 @@ class RESTfulChatglmCppChatModelHandle(RESTfulEmbeddingModelHandle):
536
540
  request_body[key] = value
537
541
 
538
542
  stream = bool(generate_config and generate_config.get("stream"))
539
- response = requests.post(url, json=request_body, stream=stream)
543
+ response = requests.post(
544
+ url, json=request_body, stream=stream, headers=self.auth_headers
545
+ )
540
546
 
541
547
  if response.status_code != 200:
542
548
  raise RuntimeError(
@@ -589,7 +595,9 @@ class RESTfulChatglmCppGenerateModelHandle(RESTfulChatglmCppChatModelHandle):
589
595
 
590
596
  stream = bool(generate_config and generate_config.get("stream"))
591
597
 
592
- response = requests.post(url, json=request_body, stream=stream)
598
+ response = requests.post(
599
+ url, json=request_body, stream=stream, headers=self.auth_headers
600
+ )
593
601
  if response.status_code != 200:
594
602
  raise RuntimeError(
595
603
  f"Failed to generate completion, detail: {response.json()['detail']}"
@@ -605,6 +613,47 @@ class RESTfulChatglmCppGenerateModelHandle(RESTfulChatglmCppChatModelHandle):
605
613
  class Client:
606
614
  def __init__(self, base_url):
607
615
  self.base_url = base_url
616
+ self._headers = {}
617
+ self._cluster_authed = False
618
+ self._check_cluster_authenticated()
619
+
620
+ def _set_token(self, token: Optional[str]):
621
+ if not self._cluster_authed or token is None:
622
+ return
623
+ self._headers["Authorization"] = f"Bearer {token}"
624
+
625
+ def _get_token(self) -> Optional[str]:
626
+ return (
627
+ str(self._headers["Authorization"]).replace("Bearer ", "")
628
+ if "Authorization" in self._headers
629
+ else None
630
+ )
631
+
632
+ def _check_cluster_authenticated(self):
633
+ url = f"{self.base_url}/v1/cluster/auth"
634
+ response = requests.get(url)
635
+ if response.status_code != 200:
636
+ raise RuntimeError(
637
+ f"Failed to get cluster information, detail: {response.json()['detail']}"
638
+ )
639
+ response_data = response.json()
640
+ self._cluster_authed = bool(response_data["auth"])
641
+
642
+ def login(self, username: str, password: str):
643
+ if not self._cluster_authed:
644
+ return
645
+ url = f"{self.base_url}/token"
646
+
647
+ payload = {"username": username, "password": password}
648
+
649
+ response = requests.post(url, json=payload)
650
+ if response.status_code != 200:
651
+ raise RuntimeError(f"Failed to login, detail: {response.json()['detail']}")
652
+
653
+ response_data = response.json()
654
+ # Only bearer token for now
655
+ access_token = response_data["access_token"]
656
+ self._headers["Authorization"] = f"Bearer {access_token}"
608
657
 
609
658
  def list_models(self) -> Dict[str, Dict[str, Any]]:
610
659
  """
@@ -619,7 +668,7 @@ class Client:
619
668
 
620
669
  url = f"{self.base_url}/v1/models"
621
670
 
622
- response = requests.get(url)
671
+ response = requests.get(url, headers=self._headers)
623
672
  if response.status_code != 200:
624
673
  raise RuntimeError(
625
674
  f"Failed to list model, detail: {_get_error_string(response)}"
@@ -664,7 +713,7 @@ class Client:
664
713
  }
665
714
 
666
715
  url = f"{self.base_url}/experimental/speculative_llms"
667
- response = requests.post(url, json=payload)
716
+ response = requests.post(url, json=payload, headers=self._headers)
668
717
  if response.status_code != 200:
669
718
  raise RuntimeError(
670
719
  f"Failed to launch model, detail: {_get_error_string(response)}"
@@ -739,7 +788,7 @@ class Client:
739
788
  for key, value in kwargs.items():
740
789
  payload[str(key)] = value
741
790
 
742
- response = requests.post(url, json=payload)
791
+ response = requests.post(url, json=payload, headers=self._headers)
743
792
  if response.status_code != 200:
744
793
  raise RuntimeError(
745
794
  f"Failed to launch model, detail: {_get_error_string(response)}"
@@ -766,7 +815,7 @@ class Client:
766
815
 
767
816
  url = f"{self.base_url}/v1/models/{model_uid}"
768
817
 
769
- response = requests.delete(url)
818
+ response = requests.delete(url, headers=self._headers)
770
819
  if response.status_code != 200:
771
820
  raise RuntimeError(
772
821
  f"Failed to terminate model, detail: {_get_error_string(response)}"
@@ -774,7 +823,7 @@ class Client:
774
823
 
775
824
  def _get_supervisor_internal_address(self):
776
825
  url = f"{self.base_url}/v1/address"
777
- response = requests.get(url)
826
+ response = requests.get(url, headers=self._headers)
778
827
  if response.status_code != 200:
779
828
  raise RuntimeError(f"Failed to get supervisor internal address")
780
829
  response_data = response.json()
@@ -806,7 +855,7 @@ class Client:
806
855
  """
807
856
 
808
857
  url = f"{self.base_url}/v1/models/{model_uid}"
809
- response = requests.get(url)
858
+ response = requests.get(url, headers=self._headers)
810
859
  if response.status_code != 200:
811
860
  raise RuntimeError(
812
861
  f"Failed to get the model description, detail: {_get_error_string(response)}"
@@ -815,21 +864,35 @@ class Client:
815
864
 
816
865
  if desc["model_type"] == "LLM":
817
866
  if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
818
- return RESTfulChatglmCppGenerateModelHandle(model_uid, self.base_url)
867
+ return RESTfulChatglmCppGenerateModelHandle(
868
+ model_uid, self.base_url, auth_headers=self._headers
869
+ )
819
870
  elif "chat" in desc["model_ability"]:
820
- return RESTfulChatModelHandle(model_uid, self.base_url)
871
+ return RESTfulChatModelHandle(
872
+ model_uid, self.base_url, auth_headers=self._headers
873
+ )
821
874
  elif "generate" in desc["model_ability"]:
822
- return RESTfulGenerateModelHandle(model_uid, self.base_url)
875
+ return RESTfulGenerateModelHandle(
876
+ model_uid, self.base_url, auth_headers=self._headers
877
+ )
823
878
  else:
824
879
  raise ValueError(f"Unrecognized model ability: {desc['model_ability']}")
825
880
  elif desc["model_type"] == "embedding":
826
- return RESTfulEmbeddingModelHandle(model_uid, self.base_url)
881
+ return RESTfulEmbeddingModelHandle(
882
+ model_uid, self.base_url, auth_headers=self._headers
883
+ )
827
884
  elif desc["model_type"] == "image":
828
- return RESTfulImageModelHandle(model_uid, self.base_url)
885
+ return RESTfulImageModelHandle(
886
+ model_uid, self.base_url, auth_headers=self._headers
887
+ )
829
888
  elif desc["model_type"] == "rerank":
830
- return RESTfulRerankModelHandle(model_uid, self.base_url)
889
+ return RESTfulRerankModelHandle(
890
+ model_uid, self.base_url, auth_headers=self._headers
891
+ )
831
892
  elif desc["model_type"] == "multimodal":
832
- return RESTfulMultimodalModelHandle(model_uid, self.base_url)
893
+ return RESTfulMultimodalModelHandle(
894
+ model_uid, self.base_url, auth_headers=self._headers
895
+ )
833
896
  else:
834
897
  raise ValueError(f"Unknown model type:{desc['model_type']}")
835
898
 
@@ -876,7 +939,7 @@ class Client:
876
939
  """
877
940
 
878
941
  url = f"{self.base_url}/v1/models/{model_uid}"
879
- response = requests.get(url)
942
+ response = requests.get(url, headers=self._headers)
880
943
  if response.status_code != 200:
881
944
  raise RuntimeError(
882
945
  f"Failed to get the model description, detail: {_get_error_string(response)}"
@@ -903,7 +966,7 @@ class Client:
903
966
  """
904
967
  url = f"{self.base_url}/v1/model_registrations/{model_type}"
905
968
  request_body = {"model": model, "persist": persist}
906
- response = requests.post(url, json=request_body)
969
+ response = requests.post(url, json=request_body, headers=self._headers)
907
970
  if response.status_code != 200:
908
971
  raise RuntimeError(
909
972
  f"Failed to register model, detail: {_get_error_string(response)}"
@@ -929,7 +992,7 @@ class Client:
929
992
  Report failure to unregister the custom model. Provide details of failure through error message.
930
993
  """
931
994
  url = f"{self.base_url}/v1/model_registrations/{model_type}/{model_name}"
932
- response = requests.delete(url)
995
+ response = requests.delete(url, headers=self._headers)
933
996
  if response.status_code != 200:
934
997
  raise RuntimeError(
935
998
  f"Failed to register model, detail: {_get_error_string(response)}"
@@ -959,7 +1022,7 @@ class Client:
959
1022
 
960
1023
  """
961
1024
  url = f"{self.base_url}/v1/model_registrations/{model_type}"
962
- response = requests.get(url)
1025
+ response = requests.get(url, headers=self._headers)
963
1026
  if response.status_code != 200:
964
1027
  raise RuntimeError(
965
1028
  f"Failed to list model registration, detail: {_get_error_string(response)}"
@@ -987,7 +1050,7 @@ class Client:
987
1050
  The collection of registered models on the server.
988
1051
  """
989
1052
  url = f"{self.base_url}/v1/model_registrations/{model_type}/{model_name}"
990
- response = requests.get(url)
1053
+ response = requests.get(url, headers=self._headers)
991
1054
  if response.status_code != 200:
992
1055
  raise RuntimeError(
993
1056
  f"Failed to list model registration, detail: {_get_error_string(response)}"
xinference/conftest.py CHANGED
@@ -13,16 +13,19 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ import json
16
17
  import logging
17
18
  import multiprocessing
18
19
  import os
19
20
  import signal
20
21
  import sys
22
+ import tempfile
21
23
  from typing import Dict, Optional
22
24
 
23
25
  import pytest
24
26
  import xoscar as xo
25
27
 
28
+ from .api.oauth2.types import AuthConfig, AuthStartupConfig, User
26
29
  from .constants import XINFERENCE_LOG_BACKUP_COUNT, XINFERENCE_LOG_MAX_BYTES
27
30
  from .core.supervisor import SupervisorActor
28
31
  from .deploy.utils import create_worker_actor_pool, get_log_file, get_timestamp_ms
@@ -141,7 +144,11 @@ async def _start_test_cluster(
141
144
  SupervisorActor, address=address, uid=SupervisorActor.uid()
142
145
  )
143
146
  await start_worker_components(
144
- address=address, supervisor_address=address, main_pool=pool
147
+ address=address,
148
+ supervisor_address=address,
149
+ main_pool=pool,
150
+ metrics_exporter_host=None,
151
+ metrics_exporter_port=None,
145
152
  )
146
153
  await pool.join()
147
154
  except asyncio.CancelledError:
@@ -233,3 +240,58 @@ def setup_with_file_logging():
233
240
 
234
241
  local_cluster_proc.terminate()
235
242
  restful_api_proc.terminate()
243
+
244
+
245
+ @pytest.fixture
246
+ def setup_with_auth():
247
+ from .api.restful_api import run_in_subprocess as run_restful_api
248
+ from .deploy.utils import health_check as cluster_health_check
249
+
250
+ logging.config.dictConfig(TEST_LOGGING_CONF) # type: ignore
251
+
252
+ supervisor_addr = f"localhost:{xo.utils.get_next_port()}"
253
+ local_cluster_proc = run_test_cluster_in_subprocess(
254
+ supervisor_addr, TEST_LOGGING_CONF
255
+ )
256
+ if not cluster_health_check(supervisor_addr, max_attempts=10, sleep_interval=3):
257
+ raise RuntimeError("Cluster is not available after multiple attempts")
258
+
259
+ user1 = User(username="user1", password="pass1", permissions=["admin"])
260
+ user2 = User(username="user2", password="pass2", permissions=["models:list"])
261
+ user3 = User(
262
+ username="user3",
263
+ password="pass3",
264
+ permissions=["models:list", "models:read", "models:start"],
265
+ )
266
+ auth_config = AuthConfig(
267
+ algorithm="HS256",
268
+ secret_key="09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7",
269
+ token_expire_in_minutes=30,
270
+ )
271
+ startup_config = AuthStartupConfig(
272
+ auth_config=auth_config, user_config=[user1, user2, user3]
273
+ )
274
+ _, auth_file = tempfile.mkstemp()
275
+ with open(auth_file, "w") as fd:
276
+ fd.write(json.dumps(startup_config.dict()))
277
+
278
+ port = xo.utils.get_next_port()
279
+ restful_api_proc = run_restful_api(
280
+ supervisor_addr,
281
+ host="localhost",
282
+ port=port,
283
+ logging_conf=TEST_LOGGING_CONF,
284
+ auth_config_file=auth_file,
285
+ )
286
+ endpoint = f"http://localhost:{port}"
287
+ if not api_health_check(endpoint, max_attempts=10, sleep_interval=5):
288
+ raise RuntimeError("Endpoint is not available after multiple attempts")
289
+
290
+ yield f"http://localhost:{port}", supervisor_addr
291
+
292
+ local_cluster_proc.terminate()
293
+ restful_api_proc.terminate()
294
+ try:
295
+ os.remove(auth_file)
296
+ except:
297
+ pass
xinference/constants.py CHANGED
@@ -39,6 +39,7 @@ XINFERENCE_CACHE_DIR = os.path.join(XINFERENCE_HOME, "cache")
39
39
  XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
40
40
  XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
41
41
  XINFERENCE_IMAGE_DIR = os.path.join(XINFERENCE_HOME, "image")
42
+ XINFERENCE_AUTH_DIR = os.path.join(XINFERENCE_HOME, "auth")
42
43
 
43
44
  XINFERENCE_DEFAULT_LOCAL_HOST = "127.0.0.1"
44
45
  XINFERENCE_DEFAULT_DISTRIBUTED_HOST = "0.0.0.0"