xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (95) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +132 -0
  3. xinference/api/restful_api.py +282 -78
  4. xinference/client/handlers.py +3 -0
  5. xinference/client/restful/restful_client.py +108 -75
  6. xinference/constants.py +14 -4
  7. xinference/core/cache_tracker.py +102 -0
  8. xinference/core/chat_interface.py +10 -4
  9. xinference/core/event.py +56 -0
  10. xinference/core/model.py +44 -0
  11. xinference/core/resource.py +19 -12
  12. xinference/core/status_guard.py +4 -0
  13. xinference/core/supervisor.py +278 -87
  14. xinference/core/utils.py +68 -3
  15. xinference/core/worker.py +98 -8
  16. xinference/deploy/cmdline.py +6 -3
  17. xinference/deploy/local.py +2 -2
  18. xinference/deploy/supervisor.py +2 -2
  19. xinference/model/audio/__init__.py +27 -0
  20. xinference/model/audio/core.py +161 -0
  21. xinference/model/audio/model_spec.json +79 -0
  22. xinference/model/audio/utils.py +18 -0
  23. xinference/model/audio/whisper.py +132 -0
  24. xinference/model/core.py +18 -13
  25. xinference/model/embedding/__init__.py +27 -2
  26. xinference/model/embedding/core.py +43 -3
  27. xinference/model/embedding/model_spec.json +24 -0
  28. xinference/model/embedding/model_spec_modelscope.json +24 -0
  29. xinference/model/embedding/utils.py +18 -0
  30. xinference/model/image/__init__.py +12 -1
  31. xinference/model/image/core.py +63 -9
  32. xinference/model/image/utils.py +26 -0
  33. xinference/model/llm/__init__.py +20 -1
  34. xinference/model/llm/core.py +43 -2
  35. xinference/model/llm/ggml/chatglm.py +15 -6
  36. xinference/model/llm/llm_family.json +197 -6
  37. xinference/model/llm/llm_family.py +9 -7
  38. xinference/model/llm/llm_family_modelscope.json +189 -4
  39. xinference/model/llm/pytorch/chatglm.py +3 -3
  40. xinference/model/llm/pytorch/core.py +4 -2
  41. xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
  42. xinference/model/llm/pytorch/utils.py +21 -9
  43. xinference/model/llm/pytorch/yi_vl.py +246 -0
  44. xinference/model/llm/utils.py +57 -4
  45. xinference/model/llm/vllm/core.py +5 -4
  46. xinference/model/rerank/__init__.py +25 -2
  47. xinference/model/rerank/core.py +51 -9
  48. xinference/model/rerank/model_spec.json +6 -0
  49. xinference/model/rerank/model_spec_modelscope.json +7 -0
  50. xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
  51. xinference/model/utils.py +5 -3
  52. xinference/thirdparty/__init__.py +0 -0
  53. xinference/thirdparty/llava/__init__.py +1 -0
  54. xinference/thirdparty/llava/conversation.py +205 -0
  55. xinference/thirdparty/llava/mm_utils.py +122 -0
  56. xinference/thirdparty/llava/model/__init__.py +1 -0
  57. xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
  58. xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
  59. xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
  60. xinference/thirdparty/llava/model/constants.py +6 -0
  61. xinference/thirdparty/llava/model/llava_arch.py +385 -0
  62. xinference/thirdparty/llava/model/llava_llama.py +163 -0
  63. xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
  64. xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
  65. xinference/types.py +1 -1
  66. xinference/web/ui/build/asset-manifest.json +3 -3
  67. xinference/web/ui/build/index.html +1 -1
  68. xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
  69. xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
  75. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
  76. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
  77. xinference/api/oauth2/core.py +0 -93
  78. xinference/model/multimodal/__init__.py +0 -52
  79. xinference/model/multimodal/core.py +0 -467
  80. xinference/model/multimodal/model_spec.json +0 -43
  81. xinference/model/multimodal/model_spec_modelscope.json +0 -45
  82. xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
  83. xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
  92. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
  93. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
  94. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
  95. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
@@ -20,12 +20,11 @@ import multiprocessing
20
20
  import os
21
21
  import pprint
22
22
  import sys
23
+ import time
23
24
  import warnings
24
- from datetime import timedelta
25
25
  from typing import Any, List, Optional, Union
26
26
 
27
27
  import gradio as gr
28
- import pydantic
29
28
  import xoscar as xo
30
29
  from aioprometheus import REGISTRY, MetricsMiddleware
31
30
  from aioprometheus.asgi.starlette import metrics
@@ -40,7 +39,6 @@ from fastapi import (
40
39
  Response,
41
40
  Security,
42
41
  UploadFile,
43
- status,
44
42
  )
45
43
  from fastapi.middleware.cors import CORSMiddleware
46
44
  from fastapi.responses import JSONResponse
@@ -54,6 +52,7 @@ from uvicorn import Config, Server
54
52
  from xoscar.utils import get_next_port
55
53
 
56
54
  from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT
55
+ from ..core.event import Event, EventCollectorActor, EventType
57
56
  from ..core.supervisor import SupervisorActor
58
57
  from ..core.utils import json_dumps
59
58
  from ..types import (
@@ -63,10 +62,10 @@ from ..types import (
63
62
  CreateChatCompletion,
64
63
  CreateCompletion,
65
64
  ImageList,
65
+ max_tokens_field,
66
66
  )
67
- from .oauth2.core import get_user, verify_token
68
- from .oauth2.types import AuthStartupConfig, LoginUserForm, User
69
- from .oauth2.utils import create_access_token, get_password_hash, verify_password
67
+ from .oauth2.auth_service import AuthService
68
+ from .oauth2.types import LoginUserForm
70
69
 
71
70
  logger = logging.getLogger(__name__)
72
71
 
@@ -135,15 +134,6 @@ class BuildGradioInterfaceRequest(BaseModel):
135
134
  model_lang: List[str]
136
135
 
137
136
 
138
- def authenticate_user(db_users: List[User], username: str, password: str):
139
- user = get_user(db_users, username)
140
- if not user:
141
- return False
142
- if not verify_password(password, user.password):
143
- return False
144
- return user
145
-
146
-
147
137
  class RESTfulAPI:
148
138
  def __init__(
149
139
  self,
@@ -157,25 +147,13 @@ class RESTfulAPI:
157
147
  self._host = host
158
148
  self._port = port
159
149
  self._supervisor_ref = None
160
- self._auth_config: AuthStartupConfig = self.init_auth_config(auth_config_file)
150
+ self._event_collector_ref = None
151
+ self._auth_service = AuthService(auth_config_file)
161
152
  self._router = APIRouter()
162
153
  self._app = FastAPI()
163
154
 
164
- @staticmethod
165
- def init_auth_config(auth_config_file: Optional[str]):
166
- from .oauth2 import common
167
-
168
- if auth_config_file:
169
- config: AuthStartupConfig = pydantic.parse_file_as(
170
- path=auth_config_file, type_=AuthStartupConfig
171
- )
172
- for user in config.user_config:
173
- user.password = get_password_hash(user.password)
174
- common.XINFERENCE_OAUTH2_CONFIG = config # type: ignore
175
- return config
176
-
177
155
  def is_authenticated(self):
178
- return False if self._auth_config is None else True
156
+ return False if self._auth_service.config is None else True
179
157
 
180
158
  @staticmethod
181
159
  def handle_request_limit_error(e: Exception):
@@ -189,29 +167,34 @@ class RESTfulAPI:
189
167
  )
190
168
  return self._supervisor_ref
191
169
 
192
- async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse:
193
- user = authenticate_user(
194
- self._auth_config.user_config, form_data.username, form_data.password
195
- )
196
- if not user:
197
- raise HTTPException(
198
- status_code=status.HTTP_401_UNAUTHORIZED,
199
- detail="Incorrect username or password",
200
- headers={"WWW-Authenticate": "Bearer"},
170
+ async def _get_event_collector_ref(self) -> xo.ActorRefType[EventCollectorActor]:
171
+ if self._event_collector_ref is None:
172
+ self._event_collector_ref = await xo.actor_ref(
173
+ address=self._supervisor_address, uid=EventCollectorActor.uid()
201
174
  )
202
- assert user is not None and isinstance(user, User)
203
- access_token_expires = timedelta(
204
- minutes=self._auth_config.auth_config.token_expire_in_minutes
205
- )
206
- access_token = create_access_token(
207
- data={"sub": user.username, "scopes": user.permissions},
208
- secret_key=self._auth_config.auth_config.secret_key,
209
- algorithm=self._auth_config.auth_config.algorithm,
210
- expires_delta=access_token_expires,
211
- )
212
- return JSONResponse(
213
- content={"access_token": access_token, "token_type": "bearer"}
175
+ return self._event_collector_ref
176
+
177
+ async def _report_error_event(self, model_uid: str, content: str):
178
+ try:
179
+ event_collector_ref = await self._get_event_collector_ref()
180
+ await event_collector_ref.report_event(
181
+ model_uid,
182
+ Event(
183
+ event_type=EventType.ERROR,
184
+ event_ts=int(time.time()),
185
+ event_content=content,
186
+ ),
187
+ )
188
+ except Exception:
189
+ logger.exception(
190
+ "Report error event failed, model: %s, content: %s", model_uid, content
191
+ )
192
+
193
+ async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse:
194
+ result = self._auth_service.generate_token_for_user(
195
+ form_data.username, form_data.password
214
196
  )
197
+ return JSONResponse(content=result)
215
198
 
216
199
  async def is_cluster_authenticated(self) -> JSONResponse:
217
200
  return JSONResponse(content={"auth": self.is_authenticated()})
@@ -234,6 +217,9 @@ class RESTfulAPI:
234
217
  self._router.add_api_route(
235
218
  "/v1/models/families", self._get_builtin_families, methods=["GET"]
236
219
  )
220
+ self._router.add_api_route(
221
+ "/v1/cluster/info", self.get_cluster_device_info, methods=["GET"]
222
+ )
237
223
  self._router.add_api_route(
238
224
  "/v1/cluster/devices", self._get_devices_count, methods=["GET"]
239
225
  )
@@ -244,7 +230,7 @@ class RESTfulAPI:
244
230
  "/v1/ui/{model_uid}",
245
231
  self.build_gradio_interface,
246
232
  methods=["POST"],
247
- dependencies=[Security(verify_token, scopes=["models:read"])]
233
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
248
234
  if self.is_authenticated()
249
235
  else None,
250
236
  )
@@ -259,7 +245,15 @@ class RESTfulAPI:
259
245
  "/v1/models/instances",
260
246
  self.get_instance_info,
261
247
  methods=["GET"],
262
- dependencies=[Security(verify_token, scopes=["models:list"])]
248
+ dependencies=[Security(self._auth_service, scopes=["models:list"])]
249
+ if self.is_authenticated()
250
+ else None,
251
+ )
252
+ self._router.add_api_route(
253
+ "/v1/models/{model_type}/{model_name}/versions",
254
+ self.get_model_versions,
255
+ methods=["GET"],
256
+ dependencies=[Security(self._auth_service, scopes=["models:list"])]
263
257
  if self.is_authenticated()
264
258
  else None,
265
259
  )
@@ -267,7 +261,7 @@ class RESTfulAPI:
267
261
  "/v1/models",
268
262
  self.list_models,
269
263
  methods=["GET"],
270
- dependencies=[Security(verify_token, scopes=["models:list"])]
264
+ dependencies=[Security(self._auth_service, scopes=["models:list"])]
271
265
  if self.is_authenticated()
272
266
  else None,
273
267
  )
@@ -276,7 +270,23 @@ class RESTfulAPI:
276
270
  "/v1/models/{model_uid}",
277
271
  self.describe_model,
278
272
  methods=["GET"],
279
- dependencies=[Security(verify_token, scopes=["models:list"])]
273
+ dependencies=[Security(self._auth_service, scopes=["models:list"])]
274
+ if self.is_authenticated()
275
+ else None,
276
+ )
277
+ self._router.add_api_route(
278
+ "/v1/models/{model_uid}/events",
279
+ self.get_model_events,
280
+ methods=["GET"],
281
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
282
+ if self.is_authenticated()
283
+ else None,
284
+ )
285
+ self._router.add_api_route(
286
+ "/v1/models/instance",
287
+ self.launch_model_by_version,
288
+ methods=["POST"],
289
+ dependencies=[Security(self._auth_service, scopes=["models:start"])]
280
290
  if self.is_authenticated()
281
291
  else None,
282
292
  )
@@ -284,7 +294,7 @@ class RESTfulAPI:
284
294
  "/v1/models",
285
295
  self.launch_model,
286
296
  methods=["POST"],
287
- dependencies=[Security(verify_token, scopes=["models:start"])]
297
+ dependencies=[Security(self._auth_service, scopes=["models:start"])]
288
298
  if self.is_authenticated()
289
299
  else None,
290
300
  )
@@ -292,7 +302,7 @@ class RESTfulAPI:
292
302
  "/experimental/speculative_llms",
293
303
  self.launch_speculative_llm,
294
304
  methods=["POST"],
295
- dependencies=[Security(verify_token, scopes=["models:start"])]
305
+ dependencies=[Security(self._auth_service, scopes=["models:start"])]
296
306
  if self.is_authenticated()
297
307
  else None,
298
308
  )
@@ -300,7 +310,7 @@ class RESTfulAPI:
300
310
  "/v1/models/{model_uid}",
301
311
  self.terminate_model,
302
312
  methods=["DELETE"],
303
- dependencies=[Security(verify_token, scopes=["models:stop"])]
313
+ dependencies=[Security(self._auth_service, scopes=["models:stop"])]
304
314
  if self.is_authenticated()
305
315
  else None,
306
316
  )
@@ -309,7 +319,7 @@ class RESTfulAPI:
309
319
  self.create_completion,
310
320
  methods=["POST"],
311
321
  response_model=Completion,
312
- dependencies=[Security(verify_token, scopes=["models:read"])]
322
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
313
323
  if self.is_authenticated()
314
324
  else None,
315
325
  )
@@ -317,7 +327,7 @@ class RESTfulAPI:
317
327
  "/v1/embeddings",
318
328
  self.create_embedding,
319
329
  methods=["POST"],
320
- dependencies=[Security(verify_token, scopes=["models:read"])]
330
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
321
331
  if self.is_authenticated()
322
332
  else None,
323
333
  )
@@ -325,7 +335,23 @@ class RESTfulAPI:
325
335
  "/v1/rerank",
326
336
  self.rerank,
327
337
  methods=["POST"],
328
- dependencies=[Security(verify_token, scopes=["models:read"])]
338
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
339
+ if self.is_authenticated()
340
+ else None,
341
+ )
342
+ self._router.add_api_route(
343
+ "/v1/audio/transcriptions",
344
+ self.create_transcriptions,
345
+ methods=["POST"],
346
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
347
+ if self.is_authenticated()
348
+ else None,
349
+ )
350
+ self._router.add_api_route(
351
+ "/v1/audio/translations",
352
+ self.create_translations,
353
+ methods=["POST"],
354
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
329
355
  if self.is_authenticated()
330
356
  else None,
331
357
  )
@@ -334,7 +360,7 @@ class RESTfulAPI:
334
360
  self.create_images,
335
361
  methods=["POST"],
336
362
  response_model=ImageList,
337
- dependencies=[Security(verify_token, scopes=["models:read"])]
363
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
338
364
  if self.is_authenticated()
339
365
  else None,
340
366
  )
@@ -343,7 +369,7 @@ class RESTfulAPI:
343
369
  self.create_variations,
344
370
  methods=["POST"],
345
371
  response_model=ImageList,
346
- dependencies=[Security(verify_token, scopes=["models:read"])]
372
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
347
373
  if self.is_authenticated()
348
374
  else None,
349
375
  )
@@ -352,7 +378,7 @@ class RESTfulAPI:
352
378
  self.create_chat_completion,
353
379
  methods=["POST"],
354
380
  response_model=ChatCompletion,
355
- dependencies=[Security(verify_token, scopes=["models:read"])]
381
+ dependencies=[Security(self._auth_service, scopes=["models:read"])]
356
382
  if self.is_authenticated()
357
383
  else None,
358
384
  )
@@ -362,7 +388,7 @@ class RESTfulAPI:
362
388
  "/v1/model_registrations/{model_type}",
363
389
  self.register_model,
364
390
  methods=["POST"],
365
- dependencies=[Security(verify_token, scopes=["models:register"])]
391
+ dependencies=[Security(self._auth_service, scopes=["models:register"])]
366
392
  if self.is_authenticated()
367
393
  else None,
368
394
  )
@@ -370,7 +396,7 @@ class RESTfulAPI:
370
396
  "/v1/model_registrations/{model_type}/{model_name}",
371
397
  self.unregister_model,
372
398
  methods=["DELETE"],
373
- dependencies=[Security(verify_token, scopes=["models:unregister"])]
399
+ dependencies=[Security(self._auth_service, scopes=["models:unregister"])]
374
400
  if self.is_authenticated()
375
401
  else None,
376
402
  )
@@ -378,7 +404,7 @@ class RESTfulAPI:
378
404
  "/v1/model_registrations/{model_type}",
379
405
  self.list_model_registrations,
380
406
  methods=["GET"],
381
- dependencies=[Security(verify_token, scopes=["models:list"])]
407
+ dependencies=[Security(self._auth_service, scopes=["models:list"])]
382
408
  if self.is_authenticated()
383
409
  else None,
384
410
  )
@@ -386,7 +412,7 @@ class RESTfulAPI:
386
412
  "/v1/model_registrations/{model_type}/{model_name}",
387
413
  self.get_model_registrations,
388
414
  methods=["GET"],
389
- dependencies=[Security(verify_token, scopes=["models:list"])]
415
+ dependencies=[Security(self._auth_service, scopes=["models:list"])]
390
416
  if self.is_authenticated()
391
417
  else None,
392
418
  )
@@ -640,6 +666,44 @@ class RESTfulAPI:
640
666
  raise HTTPException(status_code=500, detail=str(e))
641
667
  return JSONResponse(content=infos)
642
668
 
669
+ async def launch_model_by_version(
670
+ self, request: Request, wait_ready: bool = Query(True)
671
+ ) -> JSONResponse:
672
+ payload = await request.json()
673
+ model_uid = payload.get("model_uid")
674
+ model_type = payload.get("model_type")
675
+ model_version = payload.get("model_version")
676
+ replica = payload.get("replica", 1)
677
+ n_gpu = payload.get("n_gpu", "auto")
678
+
679
+ try:
680
+ model_uid = await (
681
+ await self._get_supervisor_ref()
682
+ ).launch_model_by_version(
683
+ model_uid=model_uid,
684
+ model_type=model_type,
685
+ model_version=model_version,
686
+ replica=replica,
687
+ n_gpu=n_gpu,
688
+ wait_ready=wait_ready,
689
+ )
690
+ except Exception as e:
691
+ logger.error(str(e), exc_info=True)
692
+ raise HTTPException(status_code=500, detail=str(e))
693
+ return JSONResponse(content={"model_uid": model_uid})
694
+
695
+ async def get_model_versions(
696
+ self, model_type: str, model_name: str
697
+ ) -> JSONResponse:
698
+ try:
699
+ content = await (await self._get_supervisor_ref()).get_model_versions(
700
+ model_type, model_name
701
+ )
702
+ return JSONResponse(content=content)
703
+ except Exception as e:
704
+ logger.error(e, exc_info=True)
705
+ raise HTTPException(status_code=500, detail=str(e))
706
+
643
707
  async def build_gradio_interface(
644
708
  self, model_uid: str, body: BuildGradioInterfaceRequest, request: Request
645
709
  ) -> JSONResponse:
@@ -649,7 +713,7 @@ class RESTfulAPI:
649
713
  but calling API in async function does not return
650
714
  """
651
715
  assert self._app is not None
652
- assert body.model_type in ["LLM", "multimodal"]
716
+ assert body.model_type == "LLM"
653
717
 
654
718
  # asyncio.Lock() behaves differently in 3.9 than 3.10+
655
719
  # A event loop is required in 3.9 but not 3.10+
@@ -731,6 +795,9 @@ class RESTfulAPI:
731
795
  }
732
796
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
733
797
 
798
+ if body.max_tokens is None:
799
+ kwargs["max_tokens"] = max_tokens_field.default
800
+
734
801
  if body.logit_bias is not None:
735
802
  raise HTTPException(status_code=501, detail="Not implemented")
736
803
 
@@ -740,10 +807,12 @@ class RESTfulAPI:
740
807
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
741
808
  except ValueError as ve:
742
809
  logger.error(str(ve), exc_info=True)
810
+ await self._report_error_event(model_uid, str(ve))
743
811
  raise HTTPException(status_code=400, detail=str(ve))
744
812
 
745
813
  except Exception as e:
746
814
  logger.error(e, exc_info=True)
815
+ await self._report_error_event(model_uid, str(e))
747
816
  raise HTTPException(status_code=500, detail=str(e))
748
817
 
749
818
  if body.stream:
@@ -759,6 +828,7 @@ class RESTfulAPI:
759
828
  yield item
760
829
  except Exception as ex:
761
830
  logger.exception("Completion stream got an error: %s", ex)
831
+ await self._report_error_event(model_uid, str(ex))
762
832
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
763
833
  yield dict(data=json.dumps({"error": str(ex)}))
764
834
 
@@ -769,6 +839,7 @@ class RESTfulAPI:
769
839
  return Response(data, media_type="application/json")
770
840
  except Exception as e:
771
841
  logger.error(e, exc_info=True)
842
+ await self._report_error_event(model_uid, str(e))
772
843
  self.handle_request_limit_error(e)
773
844
  raise HTTPException(status_code=500, detail=str(e))
774
845
 
@@ -779,9 +850,11 @@ class RESTfulAPI:
779
850
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
780
851
  except ValueError as ve:
781
852
  logger.error(str(ve), exc_info=True)
853
+ await self._report_error_event(model_uid, str(ve))
782
854
  raise HTTPException(status_code=400, detail=str(ve))
783
855
  except Exception as e:
784
856
  logger.error(e, exc_info=True)
857
+ await self._report_error_event(model_uid, str(e))
785
858
  raise HTTPException(status_code=500, detail=str(e))
786
859
 
787
860
  try:
@@ -789,10 +862,12 @@ class RESTfulAPI:
789
862
  return Response(embedding, media_type="application/json")
790
863
  except RuntimeError as re:
791
864
  logger.error(re, exc_info=True)
865
+ await self._report_error_event(model_uid, str(re))
792
866
  self.handle_request_limit_error(re)
793
867
  raise HTTPException(status_code=400, detail=str(re))
794
868
  except Exception as e:
795
869
  logger.error(e, exc_info=True)
870
+ await self._report_error_event(model_uid, str(e))
796
871
  raise HTTPException(status_code=500, detail=str(e))
797
872
 
798
873
  async def rerank(self, request: RerankRequest) -> Response:
@@ -801,9 +876,11 @@ class RESTfulAPI:
801
876
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
802
877
  except ValueError as ve:
803
878
  logger.error(str(ve), exc_info=True)
879
+ await self._report_error_event(model_uid, str(ve))
804
880
  raise HTTPException(status_code=400, detail=str(ve))
805
881
  except Exception as e:
806
882
  logger.error(e, exc_info=True)
883
+ await self._report_error_event(model_uid, str(e))
807
884
  raise HTTPException(status_code=500, detail=str(e))
808
885
 
809
886
  try:
@@ -817,10 +894,100 @@ class RESTfulAPI:
817
894
  return Response(scores, media_type="application/json")
818
895
  except RuntimeError as re:
819
896
  logger.error(re, exc_info=True)
897
+ await self._report_error_event(model_uid, str(re))
820
898
  self.handle_request_limit_error(re)
821
899
  raise HTTPException(status_code=400, detail=str(re))
822
900
  except Exception as e:
823
901
  logger.error(e, exc_info=True)
902
+ await self._report_error_event(model_uid, str(e))
903
+ raise HTTPException(status_code=500, detail=str(e))
904
+
905
+ async def create_transcriptions(
906
+ self,
907
+ model: str = Form(...),
908
+ file: UploadFile = File(media_type="application/octet-stream"),
909
+ language: Optional[str] = Form(None),
910
+ prompt: Optional[str] = Form(None),
911
+ response_format: Optional[str] = Form("json"),
912
+ temperature: Optional[float] = Form(0),
913
+ kwargs: Optional[str] = Form(None),
914
+ ) -> Response:
915
+ model_uid = model
916
+ try:
917
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
918
+ except ValueError as ve:
919
+ logger.error(str(ve), exc_info=True)
920
+ await self._report_error_event(model_uid, str(ve))
921
+ raise HTTPException(status_code=400, detail=str(ve))
922
+ except Exception as e:
923
+ logger.error(e, exc_info=True)
924
+ await self._report_error_event(model_uid, str(e))
925
+ raise HTTPException(status_code=500, detail=str(e))
926
+
927
+ try:
928
+ if kwargs is not None:
929
+ parsed_kwargs = json.loads(kwargs)
930
+ else:
931
+ parsed_kwargs = {}
932
+ transcription = await model_ref.transcriptions(
933
+ audio=await file.read(),
934
+ language=language,
935
+ prompt=prompt,
936
+ response_format=response_format,
937
+ temperature=temperature,
938
+ **parsed_kwargs,
939
+ )
940
+ return Response(content=transcription, media_type="application/json")
941
+ except RuntimeError as re:
942
+ logger.error(re, exc_info=True)
943
+ await self._report_error_event(model_uid, str(re))
944
+ raise HTTPException(status_code=400, detail=str(re))
945
+ except Exception as e:
946
+ logger.error(e, exc_info=True)
947
+ await self._report_error_event(model_uid, str(e))
948
+ raise HTTPException(status_code=500, detail=str(e))
949
+
950
+ async def create_translations(
951
+ self,
952
+ model: str = Form(...),
953
+ file: UploadFile = File(media_type="application/octet-stream"),
954
+ prompt: Optional[str] = Form(None),
955
+ response_format: Optional[str] = Form("json"),
956
+ temperature: Optional[float] = Form(0),
957
+ kwargs: Optional[str] = Form(None),
958
+ ) -> Response:
959
+ model_uid = model
960
+ try:
961
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
962
+ except ValueError as ve:
963
+ logger.error(str(ve), exc_info=True)
964
+ await self._report_error_event(model_uid, str(ve))
965
+ raise HTTPException(status_code=400, detail=str(ve))
966
+ except Exception as e:
967
+ logger.error(e, exc_info=True)
968
+ await self._report_error_event(model_uid, str(e))
969
+ raise HTTPException(status_code=500, detail=str(e))
970
+
971
+ try:
972
+ if kwargs is not None:
973
+ parsed_kwargs = json.loads(kwargs)
974
+ else:
975
+ parsed_kwargs = {}
976
+ translation = await model_ref.translations(
977
+ audio=await file.read(),
978
+ prompt=prompt,
979
+ response_format=response_format,
980
+ temperature=temperature,
981
+ **parsed_kwargs,
982
+ )
983
+ return Response(content=translation, media_type="application/json")
984
+ except RuntimeError as re:
985
+ logger.error(re, exc_info=True)
986
+ await self._report_error_event(model_uid, str(re))
987
+ raise HTTPException(status_code=400, detail=str(re))
988
+ except Exception as e:
989
+ logger.error(e, exc_info=True)
990
+ await self._report_error_event(model_uid, str(e))
824
991
  raise HTTPException(status_code=500, detail=str(e))
825
992
 
826
993
  async def create_images(self, request: TextToImageRequest) -> Response:
@@ -829,9 +996,11 @@ class RESTfulAPI:
829
996
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
830
997
  except ValueError as ve:
831
998
  logger.error(str(ve), exc_info=True)
999
+ await self._report_error_event(model_uid, str(ve))
832
1000
  raise HTTPException(status_code=400, detail=str(ve))
833
1001
  except Exception as e:
834
1002
  logger.error(e, exc_info=True)
1003
+ await self._report_error_event(model_uid, str(e))
835
1004
  raise HTTPException(status_code=500, detail=str(e))
836
1005
 
837
1006
  try:
@@ -846,10 +1015,12 @@ class RESTfulAPI:
846
1015
  return Response(content=image_list, media_type="application/json")
847
1016
  except RuntimeError as re:
848
1017
  logger.error(re, exc_info=True)
1018
+ await self._report_error_event(model_uid, str(re))
849
1019
  self.handle_request_limit_error(re)
850
1020
  raise HTTPException(status_code=400, detail=str(re))
851
1021
  except Exception as e:
852
1022
  logger.error(e, exc_info=True)
1023
+ await self._report_error_event(model_uid, str(e))
853
1024
  raise HTTPException(status_code=500, detail=str(e))
854
1025
 
855
1026
  async def create_variations(
@@ -868,14 +1039,18 @@ class RESTfulAPI:
868
1039
  model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
869
1040
  except ValueError as ve:
870
1041
  logger.error(str(ve), exc_info=True)
1042
+ await self._report_error_event(model_uid, str(ve))
871
1043
  raise HTTPException(status_code=400, detail=str(ve))
872
1044
  except Exception as e:
873
1045
  logger.error(e, exc_info=True)
1046
+ await self._report_error_event(model_uid, str(e))
874
1047
  raise HTTPException(status_code=500, detail=str(e))
875
1048
 
876
1049
  try:
877
1050
  if kwargs is not None:
878
- kwargs = json.loads(kwargs)
1051
+ parsed_kwargs = json.loads(kwargs)
1052
+ else:
1053
+ parsed_kwargs = {}
879
1054
  image_list = await model_ref.image_to_image(
880
1055
  image=Image.open(image.file),
881
1056
  prompt=prompt,
@@ -883,14 +1058,16 @@ class RESTfulAPI:
883
1058
  n=n,
884
1059
  size=size,
885
1060
  response_format=response_format,
886
- **kwargs,
1061
+ **parsed_kwargs,
887
1062
  )
888
1063
  return Response(content=image_list, media_type="application/json")
889
1064
  except RuntimeError as re:
890
1065
  logger.error(re, exc_info=True)
1066
+ await self._report_error_event(model_uid, str(re))
891
1067
  raise HTTPException(status_code=400, detail=str(re))
892
1068
  except Exception as e:
893
1069
  logger.error(e, exc_info=True)
1070
+ await self._report_error_event(model_uid, str(e))
894
1071
  raise HTTPException(status_code=500, detail=str(e))
895
1072
 
896
1073
  async def create_chat_completion(
@@ -909,6 +1086,9 @@ class RESTfulAPI:
909
1086
  }
910
1087
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
911
1088
 
1089
+ if body.max_tokens is None:
1090
+ kwargs["max_tokens"] = max_tokens_field.default
1091
+
912
1092
  if body.logit_bias is not None:
913
1093
  raise HTTPException(status_code=501, detail="Not implemented")
914
1094
 
@@ -958,31 +1138,32 @@ class RESTfulAPI:
958
1138
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
959
1139
  except ValueError as ve:
960
1140
  logger.error(str(ve), exc_info=True)
1141
+ await self._report_error_event(model_uid, str(ve))
961
1142
  raise HTTPException(status_code=400, detail=str(ve))
962
1143
  except Exception as e:
963
1144
  logger.error(e, exc_info=True)
1145
+ await self._report_error_event(model_uid, str(e))
964
1146
  raise HTTPException(status_code=500, detail=str(e))
965
1147
 
966
1148
  try:
967
1149
  desc = await (await self._get_supervisor_ref()).describe_model(model_uid)
968
1150
  except ValueError as ve:
969
1151
  logger.error(str(ve), exc_info=True)
1152
+ await self._report_error_event(model_uid, str(ve))
970
1153
  raise HTTPException(status_code=400, detail=str(ve))
971
1154
  except Exception as e:
972
1155
  logger.error(e, exc_info=True)
1156
+ await self._report_error_event(model_uid, str(e))
973
1157
  raise HTTPException(status_code=500, detail=str(e))
974
1158
 
975
1159
  model_name = desc.get("model_name", "")
976
- is_chatglm_ggml = (
977
- desc.get("model_format") == "ggmlv3" and "chatglm" in model_name
978
- )
979
1160
  function_call_models = ["chatglm3", "gorilla-openfunctions-v1", "qwen-chat"]
980
1161
 
981
1162
  is_qwen = desc.get("model_format") == "ggmlv3" and "qwen" in model_name
982
1163
 
983
- if (is_chatglm_ggml or is_qwen) and system_prompt is not None:
1164
+ if is_qwen and system_prompt is not None:
984
1165
  raise HTTPException(
985
- status_code=400, detail="ChatGLM ggml does not have system prompt"
1166
+ status_code=400, detail="Qwen ggml does not have system prompt"
986
1167
  )
987
1168
 
988
1169
  if not any(name in model_name for name in function_call_models):
@@ -1007,31 +1188,34 @@ class RESTfulAPI:
1007
1188
  iterator = None
1008
1189
  try:
1009
1190
  try:
1010
- if is_chatglm_ggml or is_qwen:
1191
+ if is_qwen:
1011
1192
  iterator = await model.chat(prompt, chat_history, kwargs)
1012
1193
  else:
1013
1194
  iterator = await model.chat(
1014
1195
  prompt, system_prompt, chat_history, kwargs
1015
1196
  )
1016
1197
  except RuntimeError as re:
1198
+ await self._report_error_event(model_uid, str(re))
1017
1199
  self.handle_request_limit_error(re)
1018
1200
  async for item in iterator:
1019
1201
  yield item
1020
1202
  except Exception as ex:
1021
1203
  logger.exception("Chat completion stream got an error: %s", ex)
1204
+ await self._report_error_event(model_uid, str(ex))
1022
1205
  # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
1023
1206
  yield dict(data=json.dumps({"error": str(ex)}))
1024
1207
 
1025
1208
  return EventSourceResponse(stream_results())
1026
1209
  else:
1027
1210
  try:
1028
- if is_chatglm_ggml or is_qwen:
1211
+ if is_qwen:
1029
1212
  data = await model.chat(prompt, chat_history, kwargs)
1030
1213
  else:
1031
1214
  data = await model.chat(prompt, system_prompt, chat_history, kwargs)
1032
1215
  return Response(content=data, media_type="application/json")
1033
1216
  except Exception as e:
1034
1217
  logger.error(e, exc_info=True)
1218
+ await self._report_error_event(model_uid, str(e))
1035
1219
  self.handle_request_limit_error(e)
1036
1220
  raise HTTPException(status_code=500, detail=str(e))
1037
1221
 
@@ -1096,6 +1280,26 @@ class RESTfulAPI:
1096
1280
  logger.error(e, exc_info=True)
1097
1281
  raise HTTPException(status_code=500, detail=str(e))
1098
1282
 
1283
+ async def get_model_events(self, model_uid: str) -> JSONResponse:
1284
+ try:
1285
+ event_collector_ref = await self._get_event_collector_ref()
1286
+ events = await event_collector_ref.get_model_events(model_uid)
1287
+ return JSONResponse(content=events)
1288
+ except ValueError as re:
1289
+ logger.error(re, exc_info=True)
1290
+ raise HTTPException(status_code=400, detail=str(re))
1291
+ except Exception as e:
1292
+ logger.error(e, exc_info=True)
1293
+ raise HTTPException(status_code=500, detail=str(e))
1294
+
1295
+ async def get_cluster_device_info(self) -> JSONResponse:
1296
+ try:
1297
+ data = await (await self._get_supervisor_ref()).get_cluster_device_info()
1298
+ return JSONResponse(content=data)
1299
+ except Exception as e:
1300
+ logger.error(e, exc_info=True)
1301
+ raise HTTPException(status_code=500, detail=str(e))
1302
+
1099
1303
 
1100
1304
  def run(
1101
1305
  supervisor_address: str,