xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +132 -0
- xinference/api/restful_api.py +282 -78
- xinference/client/handlers.py +3 -0
- xinference/client/restful/restful_client.py +108 -75
- xinference/constants.py +14 -4
- xinference/core/cache_tracker.py +102 -0
- xinference/core/chat_interface.py +10 -4
- xinference/core/event.py +56 -0
- xinference/core/model.py +44 -0
- xinference/core/resource.py +19 -12
- xinference/core/status_guard.py +4 -0
- xinference/core/supervisor.py +278 -87
- xinference/core/utils.py +68 -3
- xinference/core/worker.py +98 -8
- xinference/deploy/cmdline.py +6 -3
- xinference/deploy/local.py +2 -2
- xinference/deploy/supervisor.py +2 -2
- xinference/model/audio/__init__.py +27 -0
- xinference/model/audio/core.py +161 -0
- xinference/model/audio/model_spec.json +79 -0
- xinference/model/audio/utils.py +18 -0
- xinference/model/audio/whisper.py +132 -0
- xinference/model/core.py +18 -13
- xinference/model/embedding/__init__.py +27 -2
- xinference/model/embedding/core.py +43 -3
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/utils.py +18 -0
- xinference/model/image/__init__.py +12 -1
- xinference/model/image/core.py +63 -9
- xinference/model/image/utils.py +26 -0
- xinference/model/llm/__init__.py +20 -1
- xinference/model/llm/core.py +43 -2
- xinference/model/llm/ggml/chatglm.py +15 -6
- xinference/model/llm/llm_family.json +197 -6
- xinference/model/llm/llm_family.py +9 -7
- xinference/model/llm/llm_family_modelscope.json +189 -4
- xinference/model/llm/pytorch/chatglm.py +3 -3
- xinference/model/llm/pytorch/core.py +4 -2
- xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
- xinference/model/llm/pytorch/utils.py +21 -9
- xinference/model/llm/pytorch/yi_vl.py +246 -0
- xinference/model/llm/utils.py +57 -4
- xinference/model/llm/vllm/core.py +5 -4
- xinference/model/rerank/__init__.py +25 -2
- xinference/model/rerank/core.py +51 -9
- xinference/model/rerank/model_spec.json +6 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -0
- xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
- xinference/model/utils.py +5 -3
- xinference/thirdparty/__init__.py +0 -0
- xinference/thirdparty/llava/__init__.py +1 -0
- xinference/thirdparty/llava/conversation.py +205 -0
- xinference/thirdparty/llava/mm_utils.py +122 -0
- xinference/thirdparty/llava/model/__init__.py +1 -0
- xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
- xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
- xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
- xinference/thirdparty/llava/model/constants.py +6 -0
- xinference/thirdparty/llava/model/llava_arch.py +385 -0
- xinference/thirdparty/llava/model/llava_llama.py +163 -0
- xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
- xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
- xinference/types.py +1 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
- xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
- xinference/api/oauth2/core.py +0 -93
- xinference/model/multimodal/__init__.py +0 -52
- xinference/model/multimodal/core.py +0 -467
- xinference/model/multimodal/model_spec.json +0 -43
- xinference/model/multimodal/model_spec_modelscope.json +0 -45
- xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
- xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
- /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
xinference/api/restful_api.py
CHANGED
|
@@ -20,12 +20,11 @@ import multiprocessing
|
|
|
20
20
|
import os
|
|
21
21
|
import pprint
|
|
22
22
|
import sys
|
|
23
|
+
import time
|
|
23
24
|
import warnings
|
|
24
|
-
from datetime import timedelta
|
|
25
25
|
from typing import Any, List, Optional, Union
|
|
26
26
|
|
|
27
27
|
import gradio as gr
|
|
28
|
-
import pydantic
|
|
29
28
|
import xoscar as xo
|
|
30
29
|
from aioprometheus import REGISTRY, MetricsMiddleware
|
|
31
30
|
from aioprometheus.asgi.starlette import metrics
|
|
@@ -40,7 +39,6 @@ from fastapi import (
|
|
|
40
39
|
Response,
|
|
41
40
|
Security,
|
|
42
41
|
UploadFile,
|
|
43
|
-
status,
|
|
44
42
|
)
|
|
45
43
|
from fastapi.middleware.cors import CORSMiddleware
|
|
46
44
|
from fastapi.responses import JSONResponse
|
|
@@ -54,6 +52,7 @@ from uvicorn import Config, Server
|
|
|
54
52
|
from xoscar.utils import get_next_port
|
|
55
53
|
|
|
56
54
|
from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT
|
|
55
|
+
from ..core.event import Event, EventCollectorActor, EventType
|
|
57
56
|
from ..core.supervisor import SupervisorActor
|
|
58
57
|
from ..core.utils import json_dumps
|
|
59
58
|
from ..types import (
|
|
@@ -63,10 +62,10 @@ from ..types import (
|
|
|
63
62
|
CreateChatCompletion,
|
|
64
63
|
CreateCompletion,
|
|
65
64
|
ImageList,
|
|
65
|
+
max_tokens_field,
|
|
66
66
|
)
|
|
67
|
-
from .oauth2.
|
|
68
|
-
from .oauth2.types import
|
|
69
|
-
from .oauth2.utils import create_access_token, get_password_hash, verify_password
|
|
67
|
+
from .oauth2.auth_service import AuthService
|
|
68
|
+
from .oauth2.types import LoginUserForm
|
|
70
69
|
|
|
71
70
|
logger = logging.getLogger(__name__)
|
|
72
71
|
|
|
@@ -135,15 +134,6 @@ class BuildGradioInterfaceRequest(BaseModel):
|
|
|
135
134
|
model_lang: List[str]
|
|
136
135
|
|
|
137
136
|
|
|
138
|
-
def authenticate_user(db_users: List[User], username: str, password: str):
|
|
139
|
-
user = get_user(db_users, username)
|
|
140
|
-
if not user:
|
|
141
|
-
return False
|
|
142
|
-
if not verify_password(password, user.password):
|
|
143
|
-
return False
|
|
144
|
-
return user
|
|
145
|
-
|
|
146
|
-
|
|
147
137
|
class RESTfulAPI:
|
|
148
138
|
def __init__(
|
|
149
139
|
self,
|
|
@@ -157,25 +147,13 @@ class RESTfulAPI:
|
|
|
157
147
|
self._host = host
|
|
158
148
|
self._port = port
|
|
159
149
|
self._supervisor_ref = None
|
|
160
|
-
self.
|
|
150
|
+
self._event_collector_ref = None
|
|
151
|
+
self._auth_service = AuthService(auth_config_file)
|
|
161
152
|
self._router = APIRouter()
|
|
162
153
|
self._app = FastAPI()
|
|
163
154
|
|
|
164
|
-
@staticmethod
|
|
165
|
-
def init_auth_config(auth_config_file: Optional[str]):
|
|
166
|
-
from .oauth2 import common
|
|
167
|
-
|
|
168
|
-
if auth_config_file:
|
|
169
|
-
config: AuthStartupConfig = pydantic.parse_file_as(
|
|
170
|
-
path=auth_config_file, type_=AuthStartupConfig
|
|
171
|
-
)
|
|
172
|
-
for user in config.user_config:
|
|
173
|
-
user.password = get_password_hash(user.password)
|
|
174
|
-
common.XINFERENCE_OAUTH2_CONFIG = config # type: ignore
|
|
175
|
-
return config
|
|
176
|
-
|
|
177
155
|
def is_authenticated(self):
|
|
178
|
-
return False if self.
|
|
156
|
+
return False if self._auth_service.config is None else True
|
|
179
157
|
|
|
180
158
|
@staticmethod
|
|
181
159
|
def handle_request_limit_error(e: Exception):
|
|
@@ -189,29 +167,34 @@ class RESTfulAPI:
|
|
|
189
167
|
)
|
|
190
168
|
return self._supervisor_ref
|
|
191
169
|
|
|
192
|
-
async def
|
|
193
|
-
|
|
194
|
-
self.
|
|
195
|
-
|
|
196
|
-
if not user:
|
|
197
|
-
raise HTTPException(
|
|
198
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
199
|
-
detail="Incorrect username or password",
|
|
200
|
-
headers={"WWW-Authenticate": "Bearer"},
|
|
170
|
+
async def _get_event_collector_ref(self) -> xo.ActorRefType[EventCollectorActor]:
|
|
171
|
+
if self._event_collector_ref is None:
|
|
172
|
+
self._event_collector_ref = await xo.actor_ref(
|
|
173
|
+
address=self._supervisor_address, uid=EventCollectorActor.uid()
|
|
201
174
|
)
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
175
|
+
return self._event_collector_ref
|
|
176
|
+
|
|
177
|
+
async def _report_error_event(self, model_uid: str, content: str):
|
|
178
|
+
try:
|
|
179
|
+
event_collector_ref = await self._get_event_collector_ref()
|
|
180
|
+
await event_collector_ref.report_event(
|
|
181
|
+
model_uid,
|
|
182
|
+
Event(
|
|
183
|
+
event_type=EventType.ERROR,
|
|
184
|
+
event_ts=int(time.time()),
|
|
185
|
+
event_content=content,
|
|
186
|
+
),
|
|
187
|
+
)
|
|
188
|
+
except Exception:
|
|
189
|
+
logger.exception(
|
|
190
|
+
"Report error event failed, model: %s, content: %s", model_uid, content
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse:
|
|
194
|
+
result = self._auth_service.generate_token_for_user(
|
|
195
|
+
form_data.username, form_data.password
|
|
214
196
|
)
|
|
197
|
+
return JSONResponse(content=result)
|
|
215
198
|
|
|
216
199
|
async def is_cluster_authenticated(self) -> JSONResponse:
|
|
217
200
|
return JSONResponse(content={"auth": self.is_authenticated()})
|
|
@@ -234,6 +217,9 @@ class RESTfulAPI:
|
|
|
234
217
|
self._router.add_api_route(
|
|
235
218
|
"/v1/models/families", self._get_builtin_families, methods=["GET"]
|
|
236
219
|
)
|
|
220
|
+
self._router.add_api_route(
|
|
221
|
+
"/v1/cluster/info", self.get_cluster_device_info, methods=["GET"]
|
|
222
|
+
)
|
|
237
223
|
self._router.add_api_route(
|
|
238
224
|
"/v1/cluster/devices", self._get_devices_count, methods=["GET"]
|
|
239
225
|
)
|
|
@@ -244,7 +230,7 @@ class RESTfulAPI:
|
|
|
244
230
|
"/v1/ui/{model_uid}",
|
|
245
231
|
self.build_gradio_interface,
|
|
246
232
|
methods=["POST"],
|
|
247
|
-
dependencies=[Security(
|
|
233
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
248
234
|
if self.is_authenticated()
|
|
249
235
|
else None,
|
|
250
236
|
)
|
|
@@ -259,7 +245,15 @@ class RESTfulAPI:
|
|
|
259
245
|
"/v1/models/instances",
|
|
260
246
|
self.get_instance_info,
|
|
261
247
|
methods=["GET"],
|
|
262
|
-
dependencies=[Security(
|
|
248
|
+
dependencies=[Security(self._auth_service, scopes=["models:list"])]
|
|
249
|
+
if self.is_authenticated()
|
|
250
|
+
else None,
|
|
251
|
+
)
|
|
252
|
+
self._router.add_api_route(
|
|
253
|
+
"/v1/models/{model_type}/{model_name}/versions",
|
|
254
|
+
self.get_model_versions,
|
|
255
|
+
methods=["GET"],
|
|
256
|
+
dependencies=[Security(self._auth_service, scopes=["models:list"])]
|
|
263
257
|
if self.is_authenticated()
|
|
264
258
|
else None,
|
|
265
259
|
)
|
|
@@ -267,7 +261,7 @@ class RESTfulAPI:
|
|
|
267
261
|
"/v1/models",
|
|
268
262
|
self.list_models,
|
|
269
263
|
methods=["GET"],
|
|
270
|
-
dependencies=[Security(
|
|
264
|
+
dependencies=[Security(self._auth_service, scopes=["models:list"])]
|
|
271
265
|
if self.is_authenticated()
|
|
272
266
|
else None,
|
|
273
267
|
)
|
|
@@ -276,7 +270,23 @@ class RESTfulAPI:
|
|
|
276
270
|
"/v1/models/{model_uid}",
|
|
277
271
|
self.describe_model,
|
|
278
272
|
methods=["GET"],
|
|
279
|
-
dependencies=[Security(
|
|
273
|
+
dependencies=[Security(self._auth_service, scopes=["models:list"])]
|
|
274
|
+
if self.is_authenticated()
|
|
275
|
+
else None,
|
|
276
|
+
)
|
|
277
|
+
self._router.add_api_route(
|
|
278
|
+
"/v1/models/{model_uid}/events",
|
|
279
|
+
self.get_model_events,
|
|
280
|
+
methods=["GET"],
|
|
281
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
282
|
+
if self.is_authenticated()
|
|
283
|
+
else None,
|
|
284
|
+
)
|
|
285
|
+
self._router.add_api_route(
|
|
286
|
+
"/v1/models/instance",
|
|
287
|
+
self.launch_model_by_version,
|
|
288
|
+
methods=["POST"],
|
|
289
|
+
dependencies=[Security(self._auth_service, scopes=["models:start"])]
|
|
280
290
|
if self.is_authenticated()
|
|
281
291
|
else None,
|
|
282
292
|
)
|
|
@@ -284,7 +294,7 @@ class RESTfulAPI:
|
|
|
284
294
|
"/v1/models",
|
|
285
295
|
self.launch_model,
|
|
286
296
|
methods=["POST"],
|
|
287
|
-
dependencies=[Security(
|
|
297
|
+
dependencies=[Security(self._auth_service, scopes=["models:start"])]
|
|
288
298
|
if self.is_authenticated()
|
|
289
299
|
else None,
|
|
290
300
|
)
|
|
@@ -292,7 +302,7 @@ class RESTfulAPI:
|
|
|
292
302
|
"/experimental/speculative_llms",
|
|
293
303
|
self.launch_speculative_llm,
|
|
294
304
|
methods=["POST"],
|
|
295
|
-
dependencies=[Security(
|
|
305
|
+
dependencies=[Security(self._auth_service, scopes=["models:start"])]
|
|
296
306
|
if self.is_authenticated()
|
|
297
307
|
else None,
|
|
298
308
|
)
|
|
@@ -300,7 +310,7 @@ class RESTfulAPI:
|
|
|
300
310
|
"/v1/models/{model_uid}",
|
|
301
311
|
self.terminate_model,
|
|
302
312
|
methods=["DELETE"],
|
|
303
|
-
dependencies=[Security(
|
|
313
|
+
dependencies=[Security(self._auth_service, scopes=["models:stop"])]
|
|
304
314
|
if self.is_authenticated()
|
|
305
315
|
else None,
|
|
306
316
|
)
|
|
@@ -309,7 +319,7 @@ class RESTfulAPI:
|
|
|
309
319
|
self.create_completion,
|
|
310
320
|
methods=["POST"],
|
|
311
321
|
response_model=Completion,
|
|
312
|
-
dependencies=[Security(
|
|
322
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
313
323
|
if self.is_authenticated()
|
|
314
324
|
else None,
|
|
315
325
|
)
|
|
@@ -317,7 +327,7 @@ class RESTfulAPI:
|
|
|
317
327
|
"/v1/embeddings",
|
|
318
328
|
self.create_embedding,
|
|
319
329
|
methods=["POST"],
|
|
320
|
-
dependencies=[Security(
|
|
330
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
321
331
|
if self.is_authenticated()
|
|
322
332
|
else None,
|
|
323
333
|
)
|
|
@@ -325,7 +335,23 @@ class RESTfulAPI:
|
|
|
325
335
|
"/v1/rerank",
|
|
326
336
|
self.rerank,
|
|
327
337
|
methods=["POST"],
|
|
328
|
-
dependencies=[Security(
|
|
338
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
339
|
+
if self.is_authenticated()
|
|
340
|
+
else None,
|
|
341
|
+
)
|
|
342
|
+
self._router.add_api_route(
|
|
343
|
+
"/v1/audio/transcriptions",
|
|
344
|
+
self.create_transcriptions,
|
|
345
|
+
methods=["POST"],
|
|
346
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
347
|
+
if self.is_authenticated()
|
|
348
|
+
else None,
|
|
349
|
+
)
|
|
350
|
+
self._router.add_api_route(
|
|
351
|
+
"/v1/audio/translations",
|
|
352
|
+
self.create_translations,
|
|
353
|
+
methods=["POST"],
|
|
354
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
329
355
|
if self.is_authenticated()
|
|
330
356
|
else None,
|
|
331
357
|
)
|
|
@@ -334,7 +360,7 @@ class RESTfulAPI:
|
|
|
334
360
|
self.create_images,
|
|
335
361
|
methods=["POST"],
|
|
336
362
|
response_model=ImageList,
|
|
337
|
-
dependencies=[Security(
|
|
363
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
338
364
|
if self.is_authenticated()
|
|
339
365
|
else None,
|
|
340
366
|
)
|
|
@@ -343,7 +369,7 @@ class RESTfulAPI:
|
|
|
343
369
|
self.create_variations,
|
|
344
370
|
methods=["POST"],
|
|
345
371
|
response_model=ImageList,
|
|
346
|
-
dependencies=[Security(
|
|
372
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
347
373
|
if self.is_authenticated()
|
|
348
374
|
else None,
|
|
349
375
|
)
|
|
@@ -352,7 +378,7 @@ class RESTfulAPI:
|
|
|
352
378
|
self.create_chat_completion,
|
|
353
379
|
methods=["POST"],
|
|
354
380
|
response_model=ChatCompletion,
|
|
355
|
-
dependencies=[Security(
|
|
381
|
+
dependencies=[Security(self._auth_service, scopes=["models:read"])]
|
|
356
382
|
if self.is_authenticated()
|
|
357
383
|
else None,
|
|
358
384
|
)
|
|
@@ -362,7 +388,7 @@ class RESTfulAPI:
|
|
|
362
388
|
"/v1/model_registrations/{model_type}",
|
|
363
389
|
self.register_model,
|
|
364
390
|
methods=["POST"],
|
|
365
|
-
dependencies=[Security(
|
|
391
|
+
dependencies=[Security(self._auth_service, scopes=["models:register"])]
|
|
366
392
|
if self.is_authenticated()
|
|
367
393
|
else None,
|
|
368
394
|
)
|
|
@@ -370,7 +396,7 @@ class RESTfulAPI:
|
|
|
370
396
|
"/v1/model_registrations/{model_type}/{model_name}",
|
|
371
397
|
self.unregister_model,
|
|
372
398
|
methods=["DELETE"],
|
|
373
|
-
dependencies=[Security(
|
|
399
|
+
dependencies=[Security(self._auth_service, scopes=["models:unregister"])]
|
|
374
400
|
if self.is_authenticated()
|
|
375
401
|
else None,
|
|
376
402
|
)
|
|
@@ -378,7 +404,7 @@ class RESTfulAPI:
|
|
|
378
404
|
"/v1/model_registrations/{model_type}",
|
|
379
405
|
self.list_model_registrations,
|
|
380
406
|
methods=["GET"],
|
|
381
|
-
dependencies=[Security(
|
|
407
|
+
dependencies=[Security(self._auth_service, scopes=["models:list"])]
|
|
382
408
|
if self.is_authenticated()
|
|
383
409
|
else None,
|
|
384
410
|
)
|
|
@@ -386,7 +412,7 @@ class RESTfulAPI:
|
|
|
386
412
|
"/v1/model_registrations/{model_type}/{model_name}",
|
|
387
413
|
self.get_model_registrations,
|
|
388
414
|
methods=["GET"],
|
|
389
|
-
dependencies=[Security(
|
|
415
|
+
dependencies=[Security(self._auth_service, scopes=["models:list"])]
|
|
390
416
|
if self.is_authenticated()
|
|
391
417
|
else None,
|
|
392
418
|
)
|
|
@@ -640,6 +666,44 @@ class RESTfulAPI:
|
|
|
640
666
|
raise HTTPException(status_code=500, detail=str(e))
|
|
641
667
|
return JSONResponse(content=infos)
|
|
642
668
|
|
|
669
|
+
async def launch_model_by_version(
|
|
670
|
+
self, request: Request, wait_ready: bool = Query(True)
|
|
671
|
+
) -> JSONResponse:
|
|
672
|
+
payload = await request.json()
|
|
673
|
+
model_uid = payload.get("model_uid")
|
|
674
|
+
model_type = payload.get("model_type")
|
|
675
|
+
model_version = payload.get("model_version")
|
|
676
|
+
replica = payload.get("replica", 1)
|
|
677
|
+
n_gpu = payload.get("n_gpu", "auto")
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
model_uid = await (
|
|
681
|
+
await self._get_supervisor_ref()
|
|
682
|
+
).launch_model_by_version(
|
|
683
|
+
model_uid=model_uid,
|
|
684
|
+
model_type=model_type,
|
|
685
|
+
model_version=model_version,
|
|
686
|
+
replica=replica,
|
|
687
|
+
n_gpu=n_gpu,
|
|
688
|
+
wait_ready=wait_ready,
|
|
689
|
+
)
|
|
690
|
+
except Exception as e:
|
|
691
|
+
logger.error(str(e), exc_info=True)
|
|
692
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
693
|
+
return JSONResponse(content={"model_uid": model_uid})
|
|
694
|
+
|
|
695
|
+
async def get_model_versions(
|
|
696
|
+
self, model_type: str, model_name: str
|
|
697
|
+
) -> JSONResponse:
|
|
698
|
+
try:
|
|
699
|
+
content = await (await self._get_supervisor_ref()).get_model_versions(
|
|
700
|
+
model_type, model_name
|
|
701
|
+
)
|
|
702
|
+
return JSONResponse(content=content)
|
|
703
|
+
except Exception as e:
|
|
704
|
+
logger.error(e, exc_info=True)
|
|
705
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
706
|
+
|
|
643
707
|
async def build_gradio_interface(
|
|
644
708
|
self, model_uid: str, body: BuildGradioInterfaceRequest, request: Request
|
|
645
709
|
) -> JSONResponse:
|
|
@@ -649,7 +713,7 @@ class RESTfulAPI:
|
|
|
649
713
|
but calling API in async function does not return
|
|
650
714
|
"""
|
|
651
715
|
assert self._app is not None
|
|
652
|
-
assert body.model_type
|
|
716
|
+
assert body.model_type == "LLM"
|
|
653
717
|
|
|
654
718
|
# asyncio.Lock() behaves differently in 3.9 than 3.10+
|
|
655
719
|
# A event loop is required in 3.9 but not 3.10+
|
|
@@ -731,6 +795,9 @@ class RESTfulAPI:
|
|
|
731
795
|
}
|
|
732
796
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
733
797
|
|
|
798
|
+
if body.max_tokens is None:
|
|
799
|
+
kwargs["max_tokens"] = max_tokens_field.default
|
|
800
|
+
|
|
734
801
|
if body.logit_bias is not None:
|
|
735
802
|
raise HTTPException(status_code=501, detail="Not implemented")
|
|
736
803
|
|
|
@@ -740,10 +807,12 @@ class RESTfulAPI:
|
|
|
740
807
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
741
808
|
except ValueError as ve:
|
|
742
809
|
logger.error(str(ve), exc_info=True)
|
|
810
|
+
await self._report_error_event(model_uid, str(ve))
|
|
743
811
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
744
812
|
|
|
745
813
|
except Exception as e:
|
|
746
814
|
logger.error(e, exc_info=True)
|
|
815
|
+
await self._report_error_event(model_uid, str(e))
|
|
747
816
|
raise HTTPException(status_code=500, detail=str(e))
|
|
748
817
|
|
|
749
818
|
if body.stream:
|
|
@@ -759,6 +828,7 @@ class RESTfulAPI:
|
|
|
759
828
|
yield item
|
|
760
829
|
except Exception as ex:
|
|
761
830
|
logger.exception("Completion stream got an error: %s", ex)
|
|
831
|
+
await self._report_error_event(model_uid, str(ex))
|
|
762
832
|
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
|
|
763
833
|
yield dict(data=json.dumps({"error": str(ex)}))
|
|
764
834
|
|
|
@@ -769,6 +839,7 @@ class RESTfulAPI:
|
|
|
769
839
|
return Response(data, media_type="application/json")
|
|
770
840
|
except Exception as e:
|
|
771
841
|
logger.error(e, exc_info=True)
|
|
842
|
+
await self._report_error_event(model_uid, str(e))
|
|
772
843
|
self.handle_request_limit_error(e)
|
|
773
844
|
raise HTTPException(status_code=500, detail=str(e))
|
|
774
845
|
|
|
@@ -779,9 +850,11 @@ class RESTfulAPI:
|
|
|
779
850
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
780
851
|
except ValueError as ve:
|
|
781
852
|
logger.error(str(ve), exc_info=True)
|
|
853
|
+
await self._report_error_event(model_uid, str(ve))
|
|
782
854
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
783
855
|
except Exception as e:
|
|
784
856
|
logger.error(e, exc_info=True)
|
|
857
|
+
await self._report_error_event(model_uid, str(e))
|
|
785
858
|
raise HTTPException(status_code=500, detail=str(e))
|
|
786
859
|
|
|
787
860
|
try:
|
|
@@ -789,10 +862,12 @@ class RESTfulAPI:
|
|
|
789
862
|
return Response(embedding, media_type="application/json")
|
|
790
863
|
except RuntimeError as re:
|
|
791
864
|
logger.error(re, exc_info=True)
|
|
865
|
+
await self._report_error_event(model_uid, str(re))
|
|
792
866
|
self.handle_request_limit_error(re)
|
|
793
867
|
raise HTTPException(status_code=400, detail=str(re))
|
|
794
868
|
except Exception as e:
|
|
795
869
|
logger.error(e, exc_info=True)
|
|
870
|
+
await self._report_error_event(model_uid, str(e))
|
|
796
871
|
raise HTTPException(status_code=500, detail=str(e))
|
|
797
872
|
|
|
798
873
|
async def rerank(self, request: RerankRequest) -> Response:
|
|
@@ -801,9 +876,11 @@ class RESTfulAPI:
|
|
|
801
876
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
802
877
|
except ValueError as ve:
|
|
803
878
|
logger.error(str(ve), exc_info=True)
|
|
879
|
+
await self._report_error_event(model_uid, str(ve))
|
|
804
880
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
805
881
|
except Exception as e:
|
|
806
882
|
logger.error(e, exc_info=True)
|
|
883
|
+
await self._report_error_event(model_uid, str(e))
|
|
807
884
|
raise HTTPException(status_code=500, detail=str(e))
|
|
808
885
|
|
|
809
886
|
try:
|
|
@@ -817,10 +894,100 @@ class RESTfulAPI:
|
|
|
817
894
|
return Response(scores, media_type="application/json")
|
|
818
895
|
except RuntimeError as re:
|
|
819
896
|
logger.error(re, exc_info=True)
|
|
897
|
+
await self._report_error_event(model_uid, str(re))
|
|
820
898
|
self.handle_request_limit_error(re)
|
|
821
899
|
raise HTTPException(status_code=400, detail=str(re))
|
|
822
900
|
except Exception as e:
|
|
823
901
|
logger.error(e, exc_info=True)
|
|
902
|
+
await self._report_error_event(model_uid, str(e))
|
|
903
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
904
|
+
|
|
905
|
+
async def create_transcriptions(
|
|
906
|
+
self,
|
|
907
|
+
model: str = Form(...),
|
|
908
|
+
file: UploadFile = File(media_type="application/octet-stream"),
|
|
909
|
+
language: Optional[str] = Form(None),
|
|
910
|
+
prompt: Optional[str] = Form(None),
|
|
911
|
+
response_format: Optional[str] = Form("json"),
|
|
912
|
+
temperature: Optional[float] = Form(0),
|
|
913
|
+
kwargs: Optional[str] = Form(None),
|
|
914
|
+
) -> Response:
|
|
915
|
+
model_uid = model
|
|
916
|
+
try:
|
|
917
|
+
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
918
|
+
except ValueError as ve:
|
|
919
|
+
logger.error(str(ve), exc_info=True)
|
|
920
|
+
await self._report_error_event(model_uid, str(ve))
|
|
921
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
922
|
+
except Exception as e:
|
|
923
|
+
logger.error(e, exc_info=True)
|
|
924
|
+
await self._report_error_event(model_uid, str(e))
|
|
925
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
926
|
+
|
|
927
|
+
try:
|
|
928
|
+
if kwargs is not None:
|
|
929
|
+
parsed_kwargs = json.loads(kwargs)
|
|
930
|
+
else:
|
|
931
|
+
parsed_kwargs = {}
|
|
932
|
+
transcription = await model_ref.transcriptions(
|
|
933
|
+
audio=await file.read(),
|
|
934
|
+
language=language,
|
|
935
|
+
prompt=prompt,
|
|
936
|
+
response_format=response_format,
|
|
937
|
+
temperature=temperature,
|
|
938
|
+
**parsed_kwargs,
|
|
939
|
+
)
|
|
940
|
+
return Response(content=transcription, media_type="application/json")
|
|
941
|
+
except RuntimeError as re:
|
|
942
|
+
logger.error(re, exc_info=True)
|
|
943
|
+
await self._report_error_event(model_uid, str(re))
|
|
944
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
945
|
+
except Exception as e:
|
|
946
|
+
logger.error(e, exc_info=True)
|
|
947
|
+
await self._report_error_event(model_uid, str(e))
|
|
948
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
949
|
+
|
|
950
|
+
async def create_translations(
|
|
951
|
+
self,
|
|
952
|
+
model: str = Form(...),
|
|
953
|
+
file: UploadFile = File(media_type="application/octet-stream"),
|
|
954
|
+
prompt: Optional[str] = Form(None),
|
|
955
|
+
response_format: Optional[str] = Form("json"),
|
|
956
|
+
temperature: Optional[float] = Form(0),
|
|
957
|
+
kwargs: Optional[str] = Form(None),
|
|
958
|
+
) -> Response:
|
|
959
|
+
model_uid = model
|
|
960
|
+
try:
|
|
961
|
+
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
962
|
+
except ValueError as ve:
|
|
963
|
+
logger.error(str(ve), exc_info=True)
|
|
964
|
+
await self._report_error_event(model_uid, str(ve))
|
|
965
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
966
|
+
except Exception as e:
|
|
967
|
+
logger.error(e, exc_info=True)
|
|
968
|
+
await self._report_error_event(model_uid, str(e))
|
|
969
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
970
|
+
|
|
971
|
+
try:
|
|
972
|
+
if kwargs is not None:
|
|
973
|
+
parsed_kwargs = json.loads(kwargs)
|
|
974
|
+
else:
|
|
975
|
+
parsed_kwargs = {}
|
|
976
|
+
translation = await model_ref.translations(
|
|
977
|
+
audio=await file.read(),
|
|
978
|
+
prompt=prompt,
|
|
979
|
+
response_format=response_format,
|
|
980
|
+
temperature=temperature,
|
|
981
|
+
**parsed_kwargs,
|
|
982
|
+
)
|
|
983
|
+
return Response(content=translation, media_type="application/json")
|
|
984
|
+
except RuntimeError as re:
|
|
985
|
+
logger.error(re, exc_info=True)
|
|
986
|
+
await self._report_error_event(model_uid, str(re))
|
|
987
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
988
|
+
except Exception as e:
|
|
989
|
+
logger.error(e, exc_info=True)
|
|
990
|
+
await self._report_error_event(model_uid, str(e))
|
|
824
991
|
raise HTTPException(status_code=500, detail=str(e))
|
|
825
992
|
|
|
826
993
|
async def create_images(self, request: TextToImageRequest) -> Response:
|
|
@@ -829,9 +996,11 @@ class RESTfulAPI:
|
|
|
829
996
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
830
997
|
except ValueError as ve:
|
|
831
998
|
logger.error(str(ve), exc_info=True)
|
|
999
|
+
await self._report_error_event(model_uid, str(ve))
|
|
832
1000
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
833
1001
|
except Exception as e:
|
|
834
1002
|
logger.error(e, exc_info=True)
|
|
1003
|
+
await self._report_error_event(model_uid, str(e))
|
|
835
1004
|
raise HTTPException(status_code=500, detail=str(e))
|
|
836
1005
|
|
|
837
1006
|
try:
|
|
@@ -846,10 +1015,12 @@ class RESTfulAPI:
|
|
|
846
1015
|
return Response(content=image_list, media_type="application/json")
|
|
847
1016
|
except RuntimeError as re:
|
|
848
1017
|
logger.error(re, exc_info=True)
|
|
1018
|
+
await self._report_error_event(model_uid, str(re))
|
|
849
1019
|
self.handle_request_limit_error(re)
|
|
850
1020
|
raise HTTPException(status_code=400, detail=str(re))
|
|
851
1021
|
except Exception as e:
|
|
852
1022
|
logger.error(e, exc_info=True)
|
|
1023
|
+
await self._report_error_event(model_uid, str(e))
|
|
853
1024
|
raise HTTPException(status_code=500, detail=str(e))
|
|
854
1025
|
|
|
855
1026
|
async def create_variations(
|
|
@@ -868,14 +1039,18 @@ class RESTfulAPI:
|
|
|
868
1039
|
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
869
1040
|
except ValueError as ve:
|
|
870
1041
|
logger.error(str(ve), exc_info=True)
|
|
1042
|
+
await self._report_error_event(model_uid, str(ve))
|
|
871
1043
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
872
1044
|
except Exception as e:
|
|
873
1045
|
logger.error(e, exc_info=True)
|
|
1046
|
+
await self._report_error_event(model_uid, str(e))
|
|
874
1047
|
raise HTTPException(status_code=500, detail=str(e))
|
|
875
1048
|
|
|
876
1049
|
try:
|
|
877
1050
|
if kwargs is not None:
|
|
878
|
-
|
|
1051
|
+
parsed_kwargs = json.loads(kwargs)
|
|
1052
|
+
else:
|
|
1053
|
+
parsed_kwargs = {}
|
|
879
1054
|
image_list = await model_ref.image_to_image(
|
|
880
1055
|
image=Image.open(image.file),
|
|
881
1056
|
prompt=prompt,
|
|
@@ -883,14 +1058,16 @@ class RESTfulAPI:
|
|
|
883
1058
|
n=n,
|
|
884
1059
|
size=size,
|
|
885
1060
|
response_format=response_format,
|
|
886
|
-
**
|
|
1061
|
+
**parsed_kwargs,
|
|
887
1062
|
)
|
|
888
1063
|
return Response(content=image_list, media_type="application/json")
|
|
889
1064
|
except RuntimeError as re:
|
|
890
1065
|
logger.error(re, exc_info=True)
|
|
1066
|
+
await self._report_error_event(model_uid, str(re))
|
|
891
1067
|
raise HTTPException(status_code=400, detail=str(re))
|
|
892
1068
|
except Exception as e:
|
|
893
1069
|
logger.error(e, exc_info=True)
|
|
1070
|
+
await self._report_error_event(model_uid, str(e))
|
|
894
1071
|
raise HTTPException(status_code=500, detail=str(e))
|
|
895
1072
|
|
|
896
1073
|
async def create_chat_completion(
|
|
@@ -909,6 +1086,9 @@ class RESTfulAPI:
|
|
|
909
1086
|
}
|
|
910
1087
|
kwargs = body.dict(exclude_unset=True, exclude=exclude)
|
|
911
1088
|
|
|
1089
|
+
if body.max_tokens is None:
|
|
1090
|
+
kwargs["max_tokens"] = max_tokens_field.default
|
|
1091
|
+
|
|
912
1092
|
if body.logit_bias is not None:
|
|
913
1093
|
raise HTTPException(status_code=501, detail="Not implemented")
|
|
914
1094
|
|
|
@@ -958,31 +1138,32 @@ class RESTfulAPI:
|
|
|
958
1138
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
959
1139
|
except ValueError as ve:
|
|
960
1140
|
logger.error(str(ve), exc_info=True)
|
|
1141
|
+
await self._report_error_event(model_uid, str(ve))
|
|
961
1142
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
962
1143
|
except Exception as e:
|
|
963
1144
|
logger.error(e, exc_info=True)
|
|
1145
|
+
await self._report_error_event(model_uid, str(e))
|
|
964
1146
|
raise HTTPException(status_code=500, detail=str(e))
|
|
965
1147
|
|
|
966
1148
|
try:
|
|
967
1149
|
desc = await (await self._get_supervisor_ref()).describe_model(model_uid)
|
|
968
1150
|
except ValueError as ve:
|
|
969
1151
|
logger.error(str(ve), exc_info=True)
|
|
1152
|
+
await self._report_error_event(model_uid, str(ve))
|
|
970
1153
|
raise HTTPException(status_code=400, detail=str(ve))
|
|
971
1154
|
except Exception as e:
|
|
972
1155
|
logger.error(e, exc_info=True)
|
|
1156
|
+
await self._report_error_event(model_uid, str(e))
|
|
973
1157
|
raise HTTPException(status_code=500, detail=str(e))
|
|
974
1158
|
|
|
975
1159
|
model_name = desc.get("model_name", "")
|
|
976
|
-
is_chatglm_ggml = (
|
|
977
|
-
desc.get("model_format") == "ggmlv3" and "chatglm" in model_name
|
|
978
|
-
)
|
|
979
1160
|
function_call_models = ["chatglm3", "gorilla-openfunctions-v1", "qwen-chat"]
|
|
980
1161
|
|
|
981
1162
|
is_qwen = desc.get("model_format") == "ggmlv3" and "qwen" in model_name
|
|
982
1163
|
|
|
983
|
-
if
|
|
1164
|
+
if is_qwen and system_prompt is not None:
|
|
984
1165
|
raise HTTPException(
|
|
985
|
-
status_code=400, detail="
|
|
1166
|
+
status_code=400, detail="Qwen ggml does not have system prompt"
|
|
986
1167
|
)
|
|
987
1168
|
|
|
988
1169
|
if not any(name in model_name for name in function_call_models):
|
|
@@ -1007,31 +1188,34 @@ class RESTfulAPI:
|
|
|
1007
1188
|
iterator = None
|
|
1008
1189
|
try:
|
|
1009
1190
|
try:
|
|
1010
|
-
if
|
|
1191
|
+
if is_qwen:
|
|
1011
1192
|
iterator = await model.chat(prompt, chat_history, kwargs)
|
|
1012
1193
|
else:
|
|
1013
1194
|
iterator = await model.chat(
|
|
1014
1195
|
prompt, system_prompt, chat_history, kwargs
|
|
1015
1196
|
)
|
|
1016
1197
|
except RuntimeError as re:
|
|
1198
|
+
await self._report_error_event(model_uid, str(re))
|
|
1017
1199
|
self.handle_request_limit_error(re)
|
|
1018
1200
|
async for item in iterator:
|
|
1019
1201
|
yield item
|
|
1020
1202
|
except Exception as ex:
|
|
1021
1203
|
logger.exception("Chat completion stream got an error: %s", ex)
|
|
1204
|
+
await self._report_error_event(model_uid, str(ex))
|
|
1022
1205
|
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
|
|
1023
1206
|
yield dict(data=json.dumps({"error": str(ex)}))
|
|
1024
1207
|
|
|
1025
1208
|
return EventSourceResponse(stream_results())
|
|
1026
1209
|
else:
|
|
1027
1210
|
try:
|
|
1028
|
-
if
|
|
1211
|
+
if is_qwen:
|
|
1029
1212
|
data = await model.chat(prompt, chat_history, kwargs)
|
|
1030
1213
|
else:
|
|
1031
1214
|
data = await model.chat(prompt, system_prompt, chat_history, kwargs)
|
|
1032
1215
|
return Response(content=data, media_type="application/json")
|
|
1033
1216
|
except Exception as e:
|
|
1034
1217
|
logger.error(e, exc_info=True)
|
|
1218
|
+
await self._report_error_event(model_uid, str(e))
|
|
1035
1219
|
self.handle_request_limit_error(e)
|
|
1036
1220
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1037
1221
|
|
|
@@ -1096,6 +1280,26 @@ class RESTfulAPI:
|
|
|
1096
1280
|
logger.error(e, exc_info=True)
|
|
1097
1281
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1098
1282
|
|
|
1283
|
+
async def get_model_events(self, model_uid: str) -> JSONResponse:
|
|
1284
|
+
try:
|
|
1285
|
+
event_collector_ref = await self._get_event_collector_ref()
|
|
1286
|
+
events = await event_collector_ref.get_model_events(model_uid)
|
|
1287
|
+
return JSONResponse(content=events)
|
|
1288
|
+
except ValueError as re:
|
|
1289
|
+
logger.error(re, exc_info=True)
|
|
1290
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1291
|
+
except Exception as e:
|
|
1292
|
+
logger.error(e, exc_info=True)
|
|
1293
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1294
|
+
|
|
1295
|
+
async def get_cluster_device_info(self) -> JSONResponse:
|
|
1296
|
+
try:
|
|
1297
|
+
data = await (await self._get_supervisor_ref()).get_cluster_device_info()
|
|
1298
|
+
return JSONResponse(content=data)
|
|
1299
|
+
except Exception as e:
|
|
1300
|
+
logger.error(e, exc_info=True)
|
|
1301
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1302
|
+
|
|
1099
1303
|
|
|
1100
1304
|
def run(
|
|
1101
1305
|
supervisor_address: str,
|