xinference 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +35 -1
- xinference/client/oscar/actor_client.py +2 -2
- xinference/client/restful/restful_client.py +2 -2
- xinference/conftest.py +5 -1
- xinference/core/metrics.py +83 -0
- xinference/core/model.py +148 -8
- xinference/core/status_guard.py +86 -0
- xinference/core/supervisor.py +57 -7
- xinference/core/worker.py +132 -13
- xinference/deploy/cmdline.py +57 -4
- xinference/deploy/local.py +32 -6
- xinference/deploy/worker.py +33 -5
- xinference/fields.py +4 -1
- xinference/model/llm/__init__.py +7 -0
- xinference/model/llm/ggml/llamacpp.py +3 -2
- xinference/model/llm/llm_family.json +70 -3
- xinference/model/llm/llm_family.py +11 -1
- xinference/model/llm/llm_family_modelscope.json +72 -3
- xinference/model/llm/pytorch/chatglm.py +70 -28
- xinference/model/llm/pytorch/core.py +11 -30
- xinference/model/llm/pytorch/internlm2.py +155 -0
- xinference/model/llm/pytorch/utils.py +0 -153
- xinference/model/llm/utils.py +37 -8
- xinference/model/llm/vllm/core.py +15 -3
- xinference/model/multimodal/__init__.py +15 -8
- xinference/model/multimodal/model_spec_modelscope.json +45 -0
- xinference/model/utils.py +7 -2
- xinference/types.py +2 -0
- {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/METADATA +2 -1
- {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/RECORD +35 -31
- {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/LICENSE +0 -0
- {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/WHEEL +0 -0
- {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-01-
|
|
11
|
+
"date": "2024-01-19T17:14:28+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.8.
|
|
14
|
+
"full-revisionid": "fb3985e95fbb3e6cb51a321d6d6a9a10661128fe",
|
|
15
|
+
"version": "0.8.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -27,6 +27,8 @@ from typing import Any, List, Optional, Union
|
|
|
27
27
|
import gradio as gr
|
|
28
28
|
import pydantic
|
|
29
29
|
import xoscar as xo
|
|
30
|
+
from aioprometheus import REGISTRY, MetricsMiddleware
|
|
31
|
+
from aioprometheus.asgi.starlette import metrics
|
|
30
32
|
from fastapi import (
|
|
31
33
|
APIRouter,
|
|
32
34
|
FastAPI,
|
|
@@ -252,6 +254,15 @@ class RESTfulAPI:
|
|
|
252
254
|
self._router.add_api_route(
|
|
253
255
|
"/v1/cluster/auth", self.is_cluster_authenticated, methods=["GET"]
|
|
254
256
|
)
|
|
257
|
+
# running instances
|
|
258
|
+
self._router.add_api_route(
|
|
259
|
+
"/v1/models/instances",
|
|
260
|
+
self.get_instance_info,
|
|
261
|
+
methods=["GET"],
|
|
262
|
+
dependencies=[Security(verify_token, scopes=["models:list"])]
|
|
263
|
+
if self.is_authenticated()
|
|
264
|
+
else None,
|
|
265
|
+
)
|
|
255
266
|
self._router.add_api_route(
|
|
256
267
|
"/v1/models",
|
|
257
268
|
self.list_models,
|
|
@@ -380,7 +391,13 @@ class RESTfulAPI:
|
|
|
380
391
|
else None,
|
|
381
392
|
)
|
|
382
393
|
|
|
394
|
+
# Clear the global Registry for the MetricsMiddleware, or
|
|
395
|
+
# the MetricsMiddleware will register duplicated metrics if the port
|
|
396
|
+
# conflict (This serve method run more than once).
|
|
397
|
+
REGISTRY.clear()
|
|
398
|
+
self._app.add_middleware(MetricsMiddleware)
|
|
383
399
|
self._app.include_router(self._router)
|
|
400
|
+
self._app.add_route("/metrics", metrics)
|
|
384
401
|
|
|
385
402
|
# Check all the routes returns Response.
|
|
386
403
|
# This is to avoid `jsonable_encoder` performance issue:
|
|
@@ -546,7 +563,9 @@ class RESTfulAPI:
|
|
|
546
563
|
|
|
547
564
|
return JSONResponse(content={"model_uid": model_uid})
|
|
548
565
|
|
|
549
|
-
async def launch_model(
|
|
566
|
+
async def launch_model(
|
|
567
|
+
self, request: Request, wait_ready: bool = Query(True)
|
|
568
|
+
) -> JSONResponse:
|
|
550
569
|
payload = await request.json()
|
|
551
570
|
model_uid = payload.get("model_uid")
|
|
552
571
|
model_name = payload.get("model_name")
|
|
@@ -591,6 +610,7 @@ class RESTfulAPI:
|
|
|
591
610
|
replica=replica,
|
|
592
611
|
n_gpu=n_gpu,
|
|
593
612
|
request_limits=request_limits,
|
|
613
|
+
wait_ready=wait_ready,
|
|
594
614
|
**kwargs,
|
|
595
615
|
)
|
|
596
616
|
|
|
@@ -606,6 +626,20 @@ class RESTfulAPI:
|
|
|
606
626
|
|
|
607
627
|
return JSONResponse(content={"model_uid": model_uid})
|
|
608
628
|
|
|
629
|
+
async def get_instance_info(
|
|
630
|
+
self,
|
|
631
|
+
model_name: Optional[str] = Query(None),
|
|
632
|
+
model_uid: Optional[str] = Query(None),
|
|
633
|
+
) -> JSONResponse:
|
|
634
|
+
try:
|
|
635
|
+
infos = await (await self._get_supervisor_ref()).get_instance_info(
|
|
636
|
+
model_name, model_uid
|
|
637
|
+
)
|
|
638
|
+
except Exception as e:
|
|
639
|
+
logger.error(str(e), exc_info=True)
|
|
640
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
641
|
+
return JSONResponse(content=infos)
|
|
642
|
+
|
|
609
643
|
async def build_gradio_interface(
|
|
610
644
|
self, model_uid: str, body: BuildGradioInterfaceRequest, request: Request
|
|
611
645
|
) -> JSONResponse:
|
|
@@ -171,7 +171,7 @@ class RerankModelHandle(ModelHandle):
|
|
|
171
171
|
return results
|
|
172
172
|
|
|
173
173
|
|
|
174
|
-
class GenerateModelHandle(
|
|
174
|
+
class GenerateModelHandle(ModelHandle):
|
|
175
175
|
def generate(
|
|
176
176
|
self,
|
|
177
177
|
prompt: str,
|
|
@@ -255,7 +255,7 @@ class ChatModelHandle(GenerateModelHandle):
|
|
|
255
255
|
return ClientIteratorWrapper(r)
|
|
256
256
|
|
|
257
257
|
|
|
258
|
-
class ChatglmCppChatModelHandle(
|
|
258
|
+
class ChatglmCppChatModelHandle(ModelHandle):
|
|
259
259
|
def chat(
|
|
260
260
|
self,
|
|
261
261
|
prompt: str,
|
|
@@ -257,7 +257,7 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
|
|
|
257
257
|
return response_data
|
|
258
258
|
|
|
259
259
|
|
|
260
|
-
class RESTfulGenerateModelHandle(
|
|
260
|
+
class RESTfulGenerateModelHandle(RESTfulModelHandle):
|
|
261
261
|
def generate(
|
|
262
262
|
self,
|
|
263
263
|
prompt: str,
|
|
@@ -486,7 +486,7 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
|
|
|
486
486
|
return response_data
|
|
487
487
|
|
|
488
488
|
|
|
489
|
-
class RESTfulChatglmCppChatModelHandle(
|
|
489
|
+
class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
|
|
490
490
|
def chat(
|
|
491
491
|
self,
|
|
492
492
|
prompt: str,
|
xinference/conftest.py
CHANGED
|
@@ -144,7 +144,11 @@ async def _start_test_cluster(
|
|
|
144
144
|
SupervisorActor, address=address, uid=SupervisorActor.uid()
|
|
145
145
|
)
|
|
146
146
|
await start_worker_components(
|
|
147
|
-
address=address,
|
|
147
|
+
address=address,
|
|
148
|
+
supervisor_address=address,
|
|
149
|
+
main_pool=pool,
|
|
150
|
+
metrics_exporter_host=None,
|
|
151
|
+
metrics_exporter_port=None,
|
|
148
152
|
)
|
|
149
153
|
await pool.join()
|
|
150
154
|
except asyncio.CancelledError:
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
|
|
17
|
+
import uvicorn
|
|
18
|
+
from aioprometheus import Counter, Gauge
|
|
19
|
+
from aioprometheus.asgi.starlette import metrics
|
|
20
|
+
from fastapi import FastAPI
|
|
21
|
+
from fastapi.responses import RedirectResponse
|
|
22
|
+
|
|
23
|
+
DEFAULT_METRICS_SERVER_LOG_LEVEL = "warning"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
generate_throughput = Gauge(
|
|
27
|
+
"xinference:generate_tokens_per_s", "Generate throughput in tokens/s."
|
|
28
|
+
)
|
|
29
|
+
# Latency
|
|
30
|
+
time_to_first_token = Gauge(
|
|
31
|
+
"xinference:time_to_first_token_ms", "First token latency in ms."
|
|
32
|
+
)
|
|
33
|
+
# Tokens counter
|
|
34
|
+
input_tokens_total_counter = Counter(
|
|
35
|
+
"xinference:input_tokens_total_counter", "Total number of input tokens."
|
|
36
|
+
)
|
|
37
|
+
output_tokens_total_counter = Counter(
|
|
38
|
+
"xinference:output_tokens_total_counter", "Total number of output tokens."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def record_metrics(name, op, kwargs):
|
|
43
|
+
collector = globals().get(name)
|
|
44
|
+
getattr(collector, op)(**kwargs)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def launch_metrics_export_server(q, host=None, port=None):
|
|
48
|
+
app = FastAPI()
|
|
49
|
+
app.add_route("/metrics", metrics)
|
|
50
|
+
|
|
51
|
+
@app.get("/")
|
|
52
|
+
async def root():
|
|
53
|
+
response = RedirectResponse(url="/metrics")
|
|
54
|
+
return response
|
|
55
|
+
|
|
56
|
+
async def main():
|
|
57
|
+
if host is not None and port is not None:
|
|
58
|
+
config = uvicorn.Config(
|
|
59
|
+
app, host=host, port=port, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
|
|
60
|
+
)
|
|
61
|
+
elif host is not None:
|
|
62
|
+
config = uvicorn.Config(
|
|
63
|
+
app, host=host, port=0, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
|
|
64
|
+
)
|
|
65
|
+
elif port is not None:
|
|
66
|
+
config = uvicorn.Config(
|
|
67
|
+
app, port=port, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
config = uvicorn.Config(app, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL)
|
|
71
|
+
|
|
72
|
+
server = uvicorn.Server(config)
|
|
73
|
+
task = asyncio.create_task(server.serve())
|
|
74
|
+
|
|
75
|
+
while not server.started and not task.done():
|
|
76
|
+
await asyncio.sleep(0.1)
|
|
77
|
+
|
|
78
|
+
for server in server.servers:
|
|
79
|
+
for socket in server.sockets:
|
|
80
|
+
q.put(socket.getsockname())
|
|
81
|
+
await task
|
|
82
|
+
|
|
83
|
+
asyncio.run(main())
|
xinference/core/model.py
CHANGED
|
@@ -17,6 +17,7 @@ import functools
|
|
|
17
17
|
import inspect
|
|
18
18
|
import json
|
|
19
19
|
import os
|
|
20
|
+
import time
|
|
20
21
|
import types
|
|
21
22
|
import weakref
|
|
22
23
|
from typing import (
|
|
@@ -34,7 +35,9 @@ import sse_starlette.sse
|
|
|
34
35
|
import xoscar as xo
|
|
35
36
|
|
|
36
37
|
if TYPE_CHECKING:
|
|
38
|
+
from .worker import WorkerActor
|
|
37
39
|
from ..model.llm.core import LLM
|
|
40
|
+
from ..model.core import ModelDescription
|
|
38
41
|
import PIL
|
|
39
42
|
|
|
40
43
|
import logging
|
|
@@ -140,13 +143,23 @@ class ModelActor(xo.StatelessActor):
|
|
|
140
143
|
gc.collect()
|
|
141
144
|
torch.cuda.empty_cache()
|
|
142
145
|
|
|
143
|
-
def __init__(
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
worker_address: str,
|
|
149
|
+
model: "LLM",
|
|
150
|
+
model_description: Optional["ModelDescription"] = None,
|
|
151
|
+
request_limits: Optional[int] = None,
|
|
152
|
+
):
|
|
144
153
|
super().__init__()
|
|
145
154
|
from ..model.llm.pytorch.core import PytorchModel
|
|
146
155
|
from ..model.llm.pytorch.spec_model import SpeculativeModel
|
|
147
156
|
from ..model.llm.vllm.core import VLLMModel
|
|
148
157
|
|
|
158
|
+
self._worker_address = worker_address
|
|
149
159
|
self._model = model
|
|
160
|
+
self._model_description = (
|
|
161
|
+
model_description.to_dict() if model_description else {}
|
|
162
|
+
)
|
|
150
163
|
self._request_limits = request_limits
|
|
151
164
|
|
|
152
165
|
self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
|
|
@@ -156,7 +169,65 @@ class ModelActor(xo.StatelessActor):
|
|
|
156
169
|
if isinstance(self._model, (PytorchModel, SpeculativeModel, VLLMModel))
|
|
157
170
|
else asyncio.locks.Lock()
|
|
158
171
|
)
|
|
172
|
+
self._worker_ref = None
|
|
159
173
|
self._serve_count = 0
|
|
174
|
+
self._metrics_labels = {
|
|
175
|
+
"type": self._model_description.get("model_type", "unknown"),
|
|
176
|
+
"model": self.model_uid(),
|
|
177
|
+
"node": self._worker_address,
|
|
178
|
+
"format": self._model_description.get("model_format", "unknown"),
|
|
179
|
+
"quantization": self._model_description.get("quantization", "none"),
|
|
180
|
+
}
|
|
181
|
+
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
182
|
+
|
|
183
|
+
async def __post_create__(self):
|
|
184
|
+
self._loop = asyncio.get_running_loop()
|
|
185
|
+
|
|
186
|
+
async def _record_completion_metrics(
|
|
187
|
+
self, duration, completion_tokens, prompt_tokens
|
|
188
|
+
):
|
|
189
|
+
coros = []
|
|
190
|
+
if completion_tokens > 0:
|
|
191
|
+
coros.append(
|
|
192
|
+
self.record_metrics(
|
|
193
|
+
"output_tokens_total_counter",
|
|
194
|
+
"add",
|
|
195
|
+
{
|
|
196
|
+
"labels": self._metrics_labels,
|
|
197
|
+
"value": completion_tokens,
|
|
198
|
+
},
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
if prompt_tokens > 0:
|
|
202
|
+
coros.append(
|
|
203
|
+
self.record_metrics(
|
|
204
|
+
"input_tokens_total_counter",
|
|
205
|
+
"add",
|
|
206
|
+
{"labels": self._metrics_labels, "value": prompt_tokens},
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
if completion_tokens > 0:
|
|
210
|
+
generate_throughput = completion_tokens / duration
|
|
211
|
+
coros.append(
|
|
212
|
+
self.record_metrics(
|
|
213
|
+
"generate_throughput",
|
|
214
|
+
"set",
|
|
215
|
+
{
|
|
216
|
+
"labels": self._metrics_labels,
|
|
217
|
+
"value": generate_throughput,
|
|
218
|
+
},
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
await asyncio.gather(*coros)
|
|
222
|
+
|
|
223
|
+
async def _get_worker_ref(self) -> xo.ActorRefType["WorkerActor"]:
|
|
224
|
+
from .worker import WorkerActor
|
|
225
|
+
|
|
226
|
+
if self._worker_ref is None:
|
|
227
|
+
self._worker_ref = await xo.actor_ref(
|
|
228
|
+
address=self._worker_address, uid=WorkerActor.uid()
|
|
229
|
+
)
|
|
230
|
+
return self._worker_ref
|
|
160
231
|
|
|
161
232
|
def is_vllm_backend(self) -> bool:
|
|
162
233
|
from ..model.llm.vllm.core import VLLMModel
|
|
@@ -178,8 +249,14 @@ class ModelActor(xo.StatelessActor):
|
|
|
178
249
|
)
|
|
179
250
|
|
|
180
251
|
def _to_json_generator(self, gen: types.GeneratorType):
|
|
252
|
+
start_time = time.time()
|
|
253
|
+
time_to_first_token = None
|
|
254
|
+
final_usage = None
|
|
181
255
|
try:
|
|
182
256
|
for v in gen:
|
|
257
|
+
if time_to_first_token is None:
|
|
258
|
+
time_to_first_token = (time.time() - start_time) * 1000
|
|
259
|
+
final_usage = v.pop("usage", None)
|
|
183
260
|
v = dict(data=json.dumps(v))
|
|
184
261
|
yield sse_starlette.sse.ensure_bytes(v, None)
|
|
185
262
|
except OutOfMemoryError:
|
|
@@ -187,10 +264,31 @@ class ModelActor(xo.StatelessActor):
|
|
|
187
264
|
"Model actor is out of memory, model id: %s", self.model_uid()
|
|
188
265
|
)
|
|
189
266
|
os._exit(1)
|
|
267
|
+
finally:
|
|
268
|
+
if self._loop is not None and time_to_first_token is not None:
|
|
269
|
+
coro = self.record_metrics(
|
|
270
|
+
"time_to_first_token",
|
|
271
|
+
"set",
|
|
272
|
+
{"labels": self._metrics_labels, "value": time_to_first_token},
|
|
273
|
+
)
|
|
274
|
+
asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
|
|
275
|
+
if self._loop is not None and final_usage is not None:
|
|
276
|
+
coro = self._record_completion_metrics(
|
|
277
|
+
time.time() - start_time,
|
|
278
|
+
completion_tokens=final_usage["completion_tokens"],
|
|
279
|
+
prompt_tokens=final_usage["prompt_tokens"],
|
|
280
|
+
)
|
|
281
|
+
asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
|
|
190
282
|
|
|
191
283
|
async def _to_json_async_gen(self, gen: types.AsyncGeneratorType):
|
|
284
|
+
start_time = time.time()
|
|
285
|
+
time_to_first_token = None
|
|
286
|
+
final_usage = None
|
|
192
287
|
try:
|
|
193
288
|
async for v in gen:
|
|
289
|
+
if time_to_first_token is None:
|
|
290
|
+
time_to_first_token = (time.time() - start_time) * 1000
|
|
291
|
+
final_usage = v.pop("usage", None)
|
|
194
292
|
v = await asyncio.to_thread(json.dumps, v)
|
|
195
293
|
v = dict(data=v) # noqa: F821
|
|
196
294
|
yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
|
|
@@ -199,6 +297,25 @@ class ModelActor(xo.StatelessActor):
|
|
|
199
297
|
"Model actor is out of memory, model id: %s", self.model_uid()
|
|
200
298
|
)
|
|
201
299
|
os._exit(1)
|
|
300
|
+
finally:
|
|
301
|
+
coros = []
|
|
302
|
+
if time_to_first_token is not None:
|
|
303
|
+
coros.append(
|
|
304
|
+
self.record_metrics(
|
|
305
|
+
"time_to_first_token",
|
|
306
|
+
"set",
|
|
307
|
+
{"labels": self._metrics_labels, "value": time_to_first_token},
|
|
308
|
+
)
|
|
309
|
+
)
|
|
310
|
+
if final_usage is not None:
|
|
311
|
+
coros.append(
|
|
312
|
+
self._record_completion_metrics(
|
|
313
|
+
time.time() - start_time,
|
|
314
|
+
completion_tokens=final_usage["completion_tokens"],
|
|
315
|
+
prompt_tokens=final_usage["prompt_tokens"],
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
await asyncio.gather(*coros)
|
|
202
319
|
|
|
203
320
|
@oom_check
|
|
204
321
|
async def _call_wrapper(self, fn: Callable, *args, **kwargs):
|
|
@@ -245,13 +362,32 @@ class ModelActor(xo.StatelessActor):
|
|
|
245
362
|
@request_limit
|
|
246
363
|
@xo.generator
|
|
247
364
|
async def chat(self, prompt: str, *args, **kwargs):
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
365
|
+
start_time = time.time()
|
|
366
|
+
response = None
|
|
367
|
+
try:
|
|
368
|
+
if hasattr(self._model, "chat"):
|
|
369
|
+
response = await self._call_wrapper(
|
|
370
|
+
self._model.chat, prompt, *args, **kwargs
|
|
371
|
+
)
|
|
372
|
+
return response
|
|
373
|
+
if hasattr(self._model, "async_chat"):
|
|
374
|
+
response = await self._call_wrapper(
|
|
375
|
+
self._model.async_chat, prompt, *args, **kwargs
|
|
376
|
+
)
|
|
377
|
+
return response
|
|
378
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
|
|
379
|
+
finally:
|
|
380
|
+
# For the non stream result.
|
|
381
|
+
if response is not None and isinstance(response, dict):
|
|
382
|
+
usage = response["usage"]
|
|
383
|
+
# Some backends may not have a valid usage, we just skip them.
|
|
384
|
+
completion_tokens = usage["completion_tokens"]
|
|
385
|
+
prompt_tokens = usage["prompt_tokens"]
|
|
386
|
+
await self._record_completion_metrics(
|
|
387
|
+
time.time() - start_time,
|
|
388
|
+
completion_tokens,
|
|
389
|
+
prompt_tokens,
|
|
390
|
+
)
|
|
255
391
|
|
|
256
392
|
@log_async(logger=logger)
|
|
257
393
|
@request_limit
|
|
@@ -341,3 +477,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
341
477
|
raise AttributeError(
|
|
342
478
|
f"Model {self._model.model_spec} is not for creating image."
|
|
343
479
|
)
|
|
480
|
+
|
|
481
|
+
async def record_metrics(self, name, op, kwargs):
|
|
482
|
+
worker_ref = await self._get_worker_ref()
|
|
483
|
+
await worker_ref.record_metrics(name, op, kwargs)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from logging import getLogger
|
|
16
|
+
from typing import Dict, List, Optional
|
|
17
|
+
|
|
18
|
+
import xoscar as xo
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
logger = getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LaunchStatus(Enum):
|
|
25
|
+
CREATING = 1
|
|
26
|
+
UPDATING = 2
|
|
27
|
+
TERMINATING = 3
|
|
28
|
+
TERMINATED = 4
|
|
29
|
+
READY = 5
|
|
30
|
+
ERROR = 6
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class InstanceInfo(BaseModel):
|
|
34
|
+
model_name: str
|
|
35
|
+
model_uid: str
|
|
36
|
+
model_ability: List[str]
|
|
37
|
+
replica: int
|
|
38
|
+
status: str
|
|
39
|
+
instance_created_ts: int
|
|
40
|
+
|
|
41
|
+
def update(self, **kwargs):
|
|
42
|
+
for field, value in kwargs.items():
|
|
43
|
+
setattr(self, field, value)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class StatusGuardActor(xo.StatelessActor):
|
|
47
|
+
def __init__(self):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self._model_uid_to_info: Dict[str, InstanceInfo] = {}
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def uid(cls) -> str:
|
|
53
|
+
return "status_guard"
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def _drop_terminated_info(instance_infos: List[InstanceInfo]) -> List[InstanceInfo]:
|
|
57
|
+
return [
|
|
58
|
+
info
|
|
59
|
+
for info in instance_infos
|
|
60
|
+
if info.status != LaunchStatus.TERMINATED.name
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
def set_instance_info(self, model_uid: str, info: InstanceInfo):
|
|
64
|
+
self._model_uid_to_info[model_uid] = info
|
|
65
|
+
|
|
66
|
+
def get_instance_info(
|
|
67
|
+
self, model_name: Optional[str] = None, model_uid: Optional[str] = None
|
|
68
|
+
) -> List[InstanceInfo]:
|
|
69
|
+
if model_uid is not None:
|
|
70
|
+
return (
|
|
71
|
+
self._drop_terminated_info([self._model_uid_to_info[model_uid]])
|
|
72
|
+
if model_uid in self._model_uid_to_info
|
|
73
|
+
else []
|
|
74
|
+
)
|
|
75
|
+
all_infos: List[InstanceInfo] = list(self._model_uid_to_info.values())
|
|
76
|
+
filtered_infos: List[InstanceInfo] = list(
|
|
77
|
+
filter(lambda info: info.model_name == model_name, all_infos)
|
|
78
|
+
)
|
|
79
|
+
return (
|
|
80
|
+
self._drop_terminated_info(filtered_infos)
|
|
81
|
+
if model_name is not None
|
|
82
|
+
else self._drop_terminated_info(all_infos)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def update_instance_info(self, model_uid: str, info: Dict):
|
|
86
|
+
self._model_uid_to_info[model_uid].update(**info)
|
xinference/core/supervisor.py
CHANGED
|
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Un
|
|
|
22
22
|
import xoscar as xo
|
|
23
23
|
|
|
24
24
|
from ..core import ModelActor
|
|
25
|
+
from ..core.status_guard import InstanceInfo, LaunchStatus
|
|
26
|
+
from .metrics import record_metrics
|
|
25
27
|
from .resource import ResourceStatus
|
|
26
28
|
from .utils import (
|
|
27
29
|
build_replica_model_uid,
|
|
@@ -46,6 +48,12 @@ logger = getLogger(__name__)
|
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
DEFAULT_NODE_TIMEOUT = 60
|
|
51
|
+
ASYNC_LAUNCH_TASKS = {} # type: ignore
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def callback_for_async_launch(model_uid: str):
|
|
55
|
+
ASYNC_LAUNCH_TASKS.pop(model_uid, None)
|
|
56
|
+
logger.debug(f"Model uid: {model_uid} async launch completes.")
|
|
49
57
|
|
|
50
58
|
|
|
51
59
|
@dataclass
|
|
@@ -81,6 +89,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
81
89
|
# comment this line to avoid worker lost
|
|
82
90
|
# self._check_dead_nodes_task = asyncio.create_task(self._check_dead_nodes())
|
|
83
91
|
logger.info(f"Xinference supervisor {self.address} started")
|
|
92
|
+
from .status_guard import StatusGuardActor
|
|
93
|
+
|
|
94
|
+
self._status_guard_ref: xo.ActorRefType[
|
|
95
|
+
"StatusGuardActor"
|
|
96
|
+
] = await xo.create_actor(
|
|
97
|
+
StatusGuardActor, address=self.address, uid=StatusGuardActor.uid()
|
|
98
|
+
)
|
|
84
99
|
|
|
85
100
|
from ..model.embedding import (
|
|
86
101
|
CustomEmbeddingModelSpec,
|
|
@@ -119,11 +134,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
119
134
|
from ..model.llm.llm_family import (
|
|
120
135
|
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
|
|
121
136
|
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
|
|
137
|
+
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
|
|
122
138
|
)
|
|
123
139
|
|
|
124
140
|
return {
|
|
125
141
|
"chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
|
|
126
142
|
"generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
|
|
143
|
+
"tool_call": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES),
|
|
127
144
|
}
|
|
128
145
|
|
|
129
146
|
async def get_devices_count(self) -> int:
|
|
@@ -511,6 +528,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
511
528
|
replica: int = 1,
|
|
512
529
|
n_gpu: Optional[Union[int, str]] = "auto",
|
|
513
530
|
request_limits: Optional[int] = None,
|
|
531
|
+
wait_ready: bool = True,
|
|
514
532
|
**kwargs,
|
|
515
533
|
) -> str:
|
|
516
534
|
if model_uid is None:
|
|
@@ -552,6 +570,18 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
552
570
|
)
|
|
553
571
|
self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
|
|
554
572
|
|
|
573
|
+
async def _launch_model():
|
|
574
|
+
try:
|
|
575
|
+
for rep_model_uid in iter_replica_model_uid(model_uid, replica):
|
|
576
|
+
await _launch_one_model(rep_model_uid)
|
|
577
|
+
except Exception:
|
|
578
|
+
# terminate_model will remove the replica info.
|
|
579
|
+
await self.terminate_model(model_uid, suppress_exception=True)
|
|
580
|
+
await self._status_guard_ref.update_instance_info(
|
|
581
|
+
model_uid, {"status": LaunchStatus.ERROR.name}
|
|
582
|
+
)
|
|
583
|
+
raise
|
|
584
|
+
|
|
555
585
|
if not is_valid_model_uid(model_uid):
|
|
556
586
|
raise ValueError(
|
|
557
587
|
"The model UID is invalid. Please specify the model UID by 0 < length <= 100."
|
|
@@ -568,15 +598,31 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
568
598
|
self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
|
|
569
599
|
replica=replica, scheduler=itertools.cycle(range(replica))
|
|
570
600
|
)
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
601
|
+
instance_info = InstanceInfo(
|
|
602
|
+
model_name=model_name,
|
|
603
|
+
model_uid=model_uid,
|
|
604
|
+
model_ability=[],
|
|
605
|
+
replica=replica,
|
|
606
|
+
status=LaunchStatus.CREATING.name,
|
|
607
|
+
instance_created_ts=int(time.time()),
|
|
608
|
+
)
|
|
609
|
+
await self._status_guard_ref.set_instance_info(model_uid, instance_info)
|
|
610
|
+
if wait_ready:
|
|
611
|
+
await _launch_model()
|
|
612
|
+
else:
|
|
613
|
+
task = asyncio.create_task(_launch_model())
|
|
614
|
+
ASYNC_LAUNCH_TASKS[model_uid] = task
|
|
615
|
+
task.add_done_callback(lambda _: callback_for_async_launch(model_uid))
|
|
578
616
|
return model_uid
|
|
579
617
|
|
|
618
|
+
async def get_instance_info(
|
|
619
|
+
self, model_name: Optional[str], model_uid: Optional[str]
|
|
620
|
+
) -> List[Dict]:
|
|
621
|
+
infos = await self._status_guard_ref.get_instance_info(
|
|
622
|
+
model_name=model_name, model_uid=model_uid
|
|
623
|
+
)
|
|
624
|
+
return [info.dict() for info in sorted(infos, key=lambda info: info.model_uid)]
|
|
625
|
+
|
|
580
626
|
async def _check_dead_nodes(self):
|
|
581
627
|
while True:
|
|
582
628
|
dead_nodes = []
|
|
@@ -705,3 +751,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
705
751
|
self._worker_status[worker_address] = WorkerStatus(
|
|
706
752
|
update_time=time.time(), status=status
|
|
707
753
|
)
|
|
754
|
+
|
|
755
|
+
@staticmethod
|
|
756
|
+
def record_metrics(name, op, kwargs):
|
|
757
|
+
record_metrics(name, op, kwargs)
|