xinference 0.8.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (35) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +35 -1
  3. xinference/client/oscar/actor_client.py +2 -2
  4. xinference/client/restful/restful_client.py +2 -2
  5. xinference/conftest.py +5 -1
  6. xinference/core/metrics.py +83 -0
  7. xinference/core/model.py +148 -8
  8. xinference/core/status_guard.py +86 -0
  9. xinference/core/supervisor.py +57 -7
  10. xinference/core/worker.py +132 -13
  11. xinference/deploy/cmdline.py +57 -4
  12. xinference/deploy/local.py +32 -6
  13. xinference/deploy/worker.py +33 -5
  14. xinference/fields.py +4 -1
  15. xinference/model/llm/__init__.py +7 -0
  16. xinference/model/llm/ggml/llamacpp.py +3 -2
  17. xinference/model/llm/llm_family.json +70 -3
  18. xinference/model/llm/llm_family.py +11 -1
  19. xinference/model/llm/llm_family_modelscope.json +72 -3
  20. xinference/model/llm/pytorch/chatglm.py +70 -28
  21. xinference/model/llm/pytorch/core.py +11 -30
  22. xinference/model/llm/pytorch/internlm2.py +155 -0
  23. xinference/model/llm/pytorch/utils.py +0 -153
  24. xinference/model/llm/utils.py +37 -8
  25. xinference/model/llm/vllm/core.py +15 -3
  26. xinference/model/multimodal/__init__.py +15 -8
  27. xinference/model/multimodal/model_spec_modelscope.json +45 -0
  28. xinference/model/utils.py +7 -2
  29. xinference/types.py +2 -0
  30. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/METADATA +2 -1
  31. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/RECORD +35 -31
  32. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/LICENSE +0 -0
  33. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/WHEEL +0 -0
  34. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/entry_points.txt +0 -0
  35. {xinference-0.8.0.dist-info → xinference-0.8.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-01-11T21:54:21+0800",
11
+ "date": "2024-01-19T17:14:28+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e4c892c5ea9b459ac60c70ce82db73a0b0a9adfa",
15
- "version": "0.8.0"
14
+ "full-revisionid": "fb3985e95fbb3e6cb51a321d6d6a9a10661128fe",
15
+ "version": "0.8.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -27,6 +27,8 @@ from typing import Any, List, Optional, Union
27
27
  import gradio as gr
28
28
  import pydantic
29
29
  import xoscar as xo
30
+ from aioprometheus import REGISTRY, MetricsMiddleware
31
+ from aioprometheus.asgi.starlette import metrics
30
32
  from fastapi import (
31
33
  APIRouter,
32
34
  FastAPI,
@@ -252,6 +254,15 @@ class RESTfulAPI:
252
254
  self._router.add_api_route(
253
255
  "/v1/cluster/auth", self.is_cluster_authenticated, methods=["GET"]
254
256
  )
257
+ # running instances
258
+ self._router.add_api_route(
259
+ "/v1/models/instances",
260
+ self.get_instance_info,
261
+ methods=["GET"],
262
+ dependencies=[Security(verify_token, scopes=["models:list"])]
263
+ if self.is_authenticated()
264
+ else None,
265
+ )
255
266
  self._router.add_api_route(
256
267
  "/v1/models",
257
268
  self.list_models,
@@ -380,7 +391,13 @@ class RESTfulAPI:
380
391
  else None,
381
392
  )
382
393
 
394
+ # Clear the global Registry for the MetricsMiddleware, or
395
+ # the MetricsMiddleware will register duplicated metrics if the port
396
+ # conflict (This serve method run more than once).
397
+ REGISTRY.clear()
398
+ self._app.add_middleware(MetricsMiddleware)
383
399
  self._app.include_router(self._router)
400
+ self._app.add_route("/metrics", metrics)
384
401
 
385
402
  # Check all the routes returns Response.
386
403
  # This is to avoid `jsonable_encoder` performance issue:
@@ -546,7 +563,9 @@ class RESTfulAPI:
546
563
 
547
564
  return JSONResponse(content={"model_uid": model_uid})
548
565
 
549
- async def launch_model(self, request: Request) -> JSONResponse:
566
+ async def launch_model(
567
+ self, request: Request, wait_ready: bool = Query(True)
568
+ ) -> JSONResponse:
550
569
  payload = await request.json()
551
570
  model_uid = payload.get("model_uid")
552
571
  model_name = payload.get("model_name")
@@ -591,6 +610,7 @@ class RESTfulAPI:
591
610
  replica=replica,
592
611
  n_gpu=n_gpu,
593
612
  request_limits=request_limits,
613
+ wait_ready=wait_ready,
594
614
  **kwargs,
595
615
  )
596
616
 
@@ -606,6 +626,20 @@ class RESTfulAPI:
606
626
 
607
627
  return JSONResponse(content={"model_uid": model_uid})
608
628
 
629
+ async def get_instance_info(
630
+ self,
631
+ model_name: Optional[str] = Query(None),
632
+ model_uid: Optional[str] = Query(None),
633
+ ) -> JSONResponse:
634
+ try:
635
+ infos = await (await self._get_supervisor_ref()).get_instance_info(
636
+ model_name, model_uid
637
+ )
638
+ except Exception as e:
639
+ logger.error(str(e), exc_info=True)
640
+ raise HTTPException(status_code=500, detail=str(e))
641
+ return JSONResponse(content=infos)
642
+
609
643
  async def build_gradio_interface(
610
644
  self, model_uid: str, body: BuildGradioInterfaceRequest, request: Request
611
645
  ) -> JSONResponse:
@@ -171,7 +171,7 @@ class RerankModelHandle(ModelHandle):
171
171
  return results
172
172
 
173
173
 
174
- class GenerateModelHandle(EmbeddingModelHandle):
174
+ class GenerateModelHandle(ModelHandle):
175
175
  def generate(
176
176
  self,
177
177
  prompt: str,
@@ -255,7 +255,7 @@ class ChatModelHandle(GenerateModelHandle):
255
255
  return ClientIteratorWrapper(r)
256
256
 
257
257
 
258
- class ChatglmCppChatModelHandle(EmbeddingModelHandle):
258
+ class ChatglmCppChatModelHandle(ModelHandle):
259
259
  def chat(
260
260
  self,
261
261
  prompt: str,
@@ -257,7 +257,7 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
257
257
  return response_data
258
258
 
259
259
 
260
- class RESTfulGenerateModelHandle(RESTfulEmbeddingModelHandle):
260
+ class RESTfulGenerateModelHandle(RESTfulModelHandle):
261
261
  def generate(
262
262
  self,
263
263
  prompt: str,
@@ -486,7 +486,7 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
486
486
  return response_data
487
487
 
488
488
 
489
- class RESTfulChatglmCppChatModelHandle(RESTfulEmbeddingModelHandle):
489
+ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
490
490
  def chat(
491
491
  self,
492
492
  prompt: str,
xinference/conftest.py CHANGED
@@ -144,7 +144,11 @@ async def _start_test_cluster(
144
144
  SupervisorActor, address=address, uid=SupervisorActor.uid()
145
145
  )
146
146
  await start_worker_components(
147
- address=address, supervisor_address=address, main_pool=pool
147
+ address=address,
148
+ supervisor_address=address,
149
+ main_pool=pool,
150
+ metrics_exporter_host=None,
151
+ metrics_exporter_port=None,
148
152
  )
149
153
  await pool.join()
150
154
  except asyncio.CancelledError:
@@ -0,0 +1,83 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+
17
+ import uvicorn
18
+ from aioprometheus import Counter, Gauge
19
+ from aioprometheus.asgi.starlette import metrics
20
+ from fastapi import FastAPI
21
+ from fastapi.responses import RedirectResponse
22
+
23
+ DEFAULT_METRICS_SERVER_LOG_LEVEL = "warning"
24
+
25
+
26
+ generate_throughput = Gauge(
27
+ "xinference:generate_tokens_per_s", "Generate throughput in tokens/s."
28
+ )
29
+ # Latency
30
+ time_to_first_token = Gauge(
31
+ "xinference:time_to_first_token_ms", "First token latency in ms."
32
+ )
33
+ # Tokens counter
34
+ input_tokens_total_counter = Counter(
35
+ "xinference:input_tokens_total_counter", "Total number of input tokens."
36
+ )
37
+ output_tokens_total_counter = Counter(
38
+ "xinference:output_tokens_total_counter", "Total number of output tokens."
39
+ )
40
+
41
+
42
+ def record_metrics(name, op, kwargs):
43
+ collector = globals().get(name)
44
+ getattr(collector, op)(**kwargs)
45
+
46
+
47
+ def launch_metrics_export_server(q, host=None, port=None):
48
+ app = FastAPI()
49
+ app.add_route("/metrics", metrics)
50
+
51
+ @app.get("/")
52
+ async def root():
53
+ response = RedirectResponse(url="/metrics")
54
+ return response
55
+
56
+ async def main():
57
+ if host is not None and port is not None:
58
+ config = uvicorn.Config(
59
+ app, host=host, port=port, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
60
+ )
61
+ elif host is not None:
62
+ config = uvicorn.Config(
63
+ app, host=host, port=0, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
64
+ )
65
+ elif port is not None:
66
+ config = uvicorn.Config(
67
+ app, port=port, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL
68
+ )
69
+ else:
70
+ config = uvicorn.Config(app, log_level=DEFAULT_METRICS_SERVER_LOG_LEVEL)
71
+
72
+ server = uvicorn.Server(config)
73
+ task = asyncio.create_task(server.serve())
74
+
75
+ while not server.started and not task.done():
76
+ await asyncio.sleep(0.1)
77
+
78
+ for server in server.servers:
79
+ for socket in server.sockets:
80
+ q.put(socket.getsockname())
81
+ await task
82
+
83
+ asyncio.run(main())
xinference/core/model.py CHANGED
@@ -17,6 +17,7 @@ import functools
17
17
  import inspect
18
18
  import json
19
19
  import os
20
+ import time
20
21
  import types
21
22
  import weakref
22
23
  from typing import (
@@ -34,7 +35,9 @@ import sse_starlette.sse
34
35
  import xoscar as xo
35
36
 
36
37
  if TYPE_CHECKING:
38
+ from .worker import WorkerActor
37
39
  from ..model.llm.core import LLM
40
+ from ..model.core import ModelDescription
38
41
  import PIL
39
42
 
40
43
  import logging
@@ -140,13 +143,23 @@ class ModelActor(xo.StatelessActor):
140
143
  gc.collect()
141
144
  torch.cuda.empty_cache()
142
145
 
143
- def __init__(self, model: "LLM", request_limits: Optional[int] = None):
146
+ def __init__(
147
+ self,
148
+ worker_address: str,
149
+ model: "LLM",
150
+ model_description: Optional["ModelDescription"] = None,
151
+ request_limits: Optional[int] = None,
152
+ ):
144
153
  super().__init__()
145
154
  from ..model.llm.pytorch.core import PytorchModel
146
155
  from ..model.llm.pytorch.spec_model import SpeculativeModel
147
156
  from ..model.llm.vllm.core import VLLMModel
148
157
 
158
+ self._worker_address = worker_address
149
159
  self._model = model
160
+ self._model_description = (
161
+ model_description.to_dict() if model_description else {}
162
+ )
150
163
  self._request_limits = request_limits
151
164
 
152
165
  self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
@@ -156,7 +169,65 @@ class ModelActor(xo.StatelessActor):
156
169
  if isinstance(self._model, (PytorchModel, SpeculativeModel, VLLMModel))
157
170
  else asyncio.locks.Lock()
158
171
  )
172
+ self._worker_ref = None
159
173
  self._serve_count = 0
174
+ self._metrics_labels = {
175
+ "type": self._model_description.get("model_type", "unknown"),
176
+ "model": self.model_uid(),
177
+ "node": self._worker_address,
178
+ "format": self._model_description.get("model_format", "unknown"),
179
+ "quantization": self._model_description.get("quantization", "none"),
180
+ }
181
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
182
+
183
+ async def __post_create__(self):
184
+ self._loop = asyncio.get_running_loop()
185
+
186
+ async def _record_completion_metrics(
187
+ self, duration, completion_tokens, prompt_tokens
188
+ ):
189
+ coros = []
190
+ if completion_tokens > 0:
191
+ coros.append(
192
+ self.record_metrics(
193
+ "output_tokens_total_counter",
194
+ "add",
195
+ {
196
+ "labels": self._metrics_labels,
197
+ "value": completion_tokens,
198
+ },
199
+ )
200
+ )
201
+ if prompt_tokens > 0:
202
+ coros.append(
203
+ self.record_metrics(
204
+ "input_tokens_total_counter",
205
+ "add",
206
+ {"labels": self._metrics_labels, "value": prompt_tokens},
207
+ )
208
+ )
209
+ if completion_tokens > 0:
210
+ generate_throughput = completion_tokens / duration
211
+ coros.append(
212
+ self.record_metrics(
213
+ "generate_throughput",
214
+ "set",
215
+ {
216
+ "labels": self._metrics_labels,
217
+ "value": generate_throughput,
218
+ },
219
+ )
220
+ )
221
+ await asyncio.gather(*coros)
222
+
223
+ async def _get_worker_ref(self) -> xo.ActorRefType["WorkerActor"]:
224
+ from .worker import WorkerActor
225
+
226
+ if self._worker_ref is None:
227
+ self._worker_ref = await xo.actor_ref(
228
+ address=self._worker_address, uid=WorkerActor.uid()
229
+ )
230
+ return self._worker_ref
160
231
 
161
232
  def is_vllm_backend(self) -> bool:
162
233
  from ..model.llm.vllm.core import VLLMModel
@@ -178,8 +249,14 @@ class ModelActor(xo.StatelessActor):
178
249
  )
179
250
 
180
251
  def _to_json_generator(self, gen: types.GeneratorType):
252
+ start_time = time.time()
253
+ time_to_first_token = None
254
+ final_usage = None
181
255
  try:
182
256
  for v in gen:
257
+ if time_to_first_token is None:
258
+ time_to_first_token = (time.time() - start_time) * 1000
259
+ final_usage = v.pop("usage", None)
183
260
  v = dict(data=json.dumps(v))
184
261
  yield sse_starlette.sse.ensure_bytes(v, None)
185
262
  except OutOfMemoryError:
@@ -187,10 +264,31 @@ class ModelActor(xo.StatelessActor):
187
264
  "Model actor is out of memory, model id: %s", self.model_uid()
188
265
  )
189
266
  os._exit(1)
267
+ finally:
268
+ if self._loop is not None and time_to_first_token is not None:
269
+ coro = self.record_metrics(
270
+ "time_to_first_token",
271
+ "set",
272
+ {"labels": self._metrics_labels, "value": time_to_first_token},
273
+ )
274
+ asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
275
+ if self._loop is not None and final_usage is not None:
276
+ coro = self._record_completion_metrics(
277
+ time.time() - start_time,
278
+ completion_tokens=final_usage["completion_tokens"],
279
+ prompt_tokens=final_usage["prompt_tokens"],
280
+ )
281
+ asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
190
282
 
191
283
  async def _to_json_async_gen(self, gen: types.AsyncGeneratorType):
284
+ start_time = time.time()
285
+ time_to_first_token = None
286
+ final_usage = None
192
287
  try:
193
288
  async for v in gen:
289
+ if time_to_first_token is None:
290
+ time_to_first_token = (time.time() - start_time) * 1000
291
+ final_usage = v.pop("usage", None)
194
292
  v = await asyncio.to_thread(json.dumps, v)
195
293
  v = dict(data=v) # noqa: F821
196
294
  yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
@@ -199,6 +297,25 @@ class ModelActor(xo.StatelessActor):
199
297
  "Model actor is out of memory, model id: %s", self.model_uid()
200
298
  )
201
299
  os._exit(1)
300
+ finally:
301
+ coros = []
302
+ if time_to_first_token is not None:
303
+ coros.append(
304
+ self.record_metrics(
305
+ "time_to_first_token",
306
+ "set",
307
+ {"labels": self._metrics_labels, "value": time_to_first_token},
308
+ )
309
+ )
310
+ if final_usage is not None:
311
+ coros.append(
312
+ self._record_completion_metrics(
313
+ time.time() - start_time,
314
+ completion_tokens=final_usage["completion_tokens"],
315
+ prompt_tokens=final_usage["prompt_tokens"],
316
+ )
317
+ )
318
+ await asyncio.gather(*coros)
202
319
 
203
320
  @oom_check
204
321
  async def _call_wrapper(self, fn: Callable, *args, **kwargs):
@@ -245,13 +362,32 @@ class ModelActor(xo.StatelessActor):
245
362
  @request_limit
246
363
  @xo.generator
247
364
  async def chat(self, prompt: str, *args, **kwargs):
248
- if hasattr(self._model, "chat"):
249
- return await self._call_wrapper(self._model.chat, prompt, *args, **kwargs)
250
- if hasattr(self._model, "async_chat"):
251
- return await self._call_wrapper(
252
- self._model.async_chat, prompt, *args, **kwargs
253
- )
254
- raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
365
+ start_time = time.time()
366
+ response = None
367
+ try:
368
+ if hasattr(self._model, "chat"):
369
+ response = await self._call_wrapper(
370
+ self._model.chat, prompt, *args, **kwargs
371
+ )
372
+ return response
373
+ if hasattr(self._model, "async_chat"):
374
+ response = await self._call_wrapper(
375
+ self._model.async_chat, prompt, *args, **kwargs
376
+ )
377
+ return response
378
+ raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
379
+ finally:
380
+ # For the non stream result.
381
+ if response is not None and isinstance(response, dict):
382
+ usage = response["usage"]
383
+ # Some backends may not have a valid usage, we just skip them.
384
+ completion_tokens = usage["completion_tokens"]
385
+ prompt_tokens = usage["prompt_tokens"]
386
+ await self._record_completion_metrics(
387
+ time.time() - start_time,
388
+ completion_tokens,
389
+ prompt_tokens,
390
+ )
255
391
 
256
392
  @log_async(logger=logger)
257
393
  @request_limit
@@ -341,3 +477,7 @@ class ModelActor(xo.StatelessActor):
341
477
  raise AttributeError(
342
478
  f"Model {self._model.model_spec} is not for creating image."
343
479
  )
480
+
481
+ async def record_metrics(self, name, op, kwargs):
482
+ worker_ref = await self._get_worker_ref()
483
+ await worker_ref.record_metrics(name, op, kwargs)
@@ -0,0 +1,86 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from enum import Enum
15
+ from logging import getLogger
16
+ from typing import Dict, List, Optional
17
+
18
+ import xoscar as xo
19
+ from pydantic import BaseModel
20
+
21
+ logger = getLogger(__name__)
22
+
23
+
24
+ class LaunchStatus(Enum):
25
+ CREATING = 1
26
+ UPDATING = 2
27
+ TERMINATING = 3
28
+ TERMINATED = 4
29
+ READY = 5
30
+ ERROR = 6
31
+
32
+
33
+ class InstanceInfo(BaseModel):
34
+ model_name: str
35
+ model_uid: str
36
+ model_ability: List[str]
37
+ replica: int
38
+ status: str
39
+ instance_created_ts: int
40
+
41
+ def update(self, **kwargs):
42
+ for field, value in kwargs.items():
43
+ setattr(self, field, value)
44
+
45
+
46
+ class StatusGuardActor(xo.StatelessActor):
47
+ def __init__(self):
48
+ super().__init__()
49
+ self._model_uid_to_info: Dict[str, InstanceInfo] = {}
50
+
51
+ @classmethod
52
+ def uid(cls) -> str:
53
+ return "status_guard"
54
+
55
+ @staticmethod
56
+ def _drop_terminated_info(instance_infos: List[InstanceInfo]) -> List[InstanceInfo]:
57
+ return [
58
+ info
59
+ for info in instance_infos
60
+ if info.status != LaunchStatus.TERMINATED.name
61
+ ]
62
+
63
+ def set_instance_info(self, model_uid: str, info: InstanceInfo):
64
+ self._model_uid_to_info[model_uid] = info
65
+
66
+ def get_instance_info(
67
+ self, model_name: Optional[str] = None, model_uid: Optional[str] = None
68
+ ) -> List[InstanceInfo]:
69
+ if model_uid is not None:
70
+ return (
71
+ self._drop_terminated_info([self._model_uid_to_info[model_uid]])
72
+ if model_uid in self._model_uid_to_info
73
+ else []
74
+ )
75
+ all_infos: List[InstanceInfo] = list(self._model_uid_to_info.values())
76
+ filtered_infos: List[InstanceInfo] = list(
77
+ filter(lambda info: info.model_name == model_name, all_infos)
78
+ )
79
+ return (
80
+ self._drop_terminated_info(filtered_infos)
81
+ if model_name is not None
82
+ else self._drop_terminated_info(all_infos)
83
+ )
84
+
85
+ def update_instance_info(self, model_uid: str, info: Dict):
86
+ self._model_uid_to_info[model_uid].update(**info)
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Un
22
22
  import xoscar as xo
23
23
 
24
24
  from ..core import ModelActor
25
+ from ..core.status_guard import InstanceInfo, LaunchStatus
26
+ from .metrics import record_metrics
25
27
  from .resource import ResourceStatus
26
28
  from .utils import (
27
29
  build_replica_model_uid,
@@ -46,6 +48,12 @@ logger = getLogger(__name__)
46
48
 
47
49
 
48
50
  DEFAULT_NODE_TIMEOUT = 60
51
+ ASYNC_LAUNCH_TASKS = {} # type: ignore
52
+
53
+
54
+ def callback_for_async_launch(model_uid: str):
55
+ ASYNC_LAUNCH_TASKS.pop(model_uid, None)
56
+ logger.debug(f"Model uid: {model_uid} async launch completes.")
49
57
 
50
58
 
51
59
  @dataclass
@@ -81,6 +89,13 @@ class SupervisorActor(xo.StatelessActor):
81
89
  # comment this line to avoid worker lost
82
90
  # self._check_dead_nodes_task = asyncio.create_task(self._check_dead_nodes())
83
91
  logger.info(f"Xinference supervisor {self.address} started")
92
+ from .status_guard import StatusGuardActor
93
+
94
+ self._status_guard_ref: xo.ActorRefType[
95
+ "StatusGuardActor"
96
+ ] = await xo.create_actor(
97
+ StatusGuardActor, address=self.address, uid=StatusGuardActor.uid()
98
+ )
84
99
 
85
100
  from ..model.embedding import (
86
101
  CustomEmbeddingModelSpec,
@@ -119,11 +134,13 @@ class SupervisorActor(xo.StatelessActor):
119
134
  from ..model.llm.llm_family import (
120
135
  BUILTIN_LLM_MODEL_CHAT_FAMILIES,
121
136
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
137
+ BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
122
138
  )
123
139
 
124
140
  return {
125
141
  "chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
126
142
  "generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
143
+ "tool_call": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES),
127
144
  }
128
145
 
129
146
  async def get_devices_count(self) -> int:
@@ -511,6 +528,7 @@ class SupervisorActor(xo.StatelessActor):
511
528
  replica: int = 1,
512
529
  n_gpu: Optional[Union[int, str]] = "auto",
513
530
  request_limits: Optional[int] = None,
531
+ wait_ready: bool = True,
514
532
  **kwargs,
515
533
  ) -> str:
516
534
  if model_uid is None:
@@ -552,6 +570,18 @@ class SupervisorActor(xo.StatelessActor):
552
570
  )
553
571
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
554
572
 
573
+ async def _launch_model():
574
+ try:
575
+ for rep_model_uid in iter_replica_model_uid(model_uid, replica):
576
+ await _launch_one_model(rep_model_uid)
577
+ except Exception:
578
+ # terminate_model will remove the replica info.
579
+ await self.terminate_model(model_uid, suppress_exception=True)
580
+ await self._status_guard_ref.update_instance_info(
581
+ model_uid, {"status": LaunchStatus.ERROR.name}
582
+ )
583
+ raise
584
+
555
585
  if not is_valid_model_uid(model_uid):
556
586
  raise ValueError(
557
587
  "The model UID is invalid. Please specify the model UID by 0 < length <= 100."
@@ -568,15 +598,31 @@ class SupervisorActor(xo.StatelessActor):
568
598
  self._model_uid_to_replica_info[model_uid] = ReplicaInfo(
569
599
  replica=replica, scheduler=itertools.cycle(range(replica))
570
600
  )
571
- try:
572
- for rep_model_uid in iter_replica_model_uid(model_uid, replica):
573
- await _launch_one_model(rep_model_uid)
574
- except Exception:
575
- # terminate_model will remove the replica info.
576
- await self.terminate_model(model_uid, suppress_exception=True)
577
- raise
601
+ instance_info = InstanceInfo(
602
+ model_name=model_name,
603
+ model_uid=model_uid,
604
+ model_ability=[],
605
+ replica=replica,
606
+ status=LaunchStatus.CREATING.name,
607
+ instance_created_ts=int(time.time()),
608
+ )
609
+ await self._status_guard_ref.set_instance_info(model_uid, instance_info)
610
+ if wait_ready:
611
+ await _launch_model()
612
+ else:
613
+ task = asyncio.create_task(_launch_model())
614
+ ASYNC_LAUNCH_TASKS[model_uid] = task
615
+ task.add_done_callback(lambda _: callback_for_async_launch(model_uid))
578
616
  return model_uid
579
617
 
618
+ async def get_instance_info(
619
+ self, model_name: Optional[str], model_uid: Optional[str]
620
+ ) -> List[Dict]:
621
+ infos = await self._status_guard_ref.get_instance_info(
622
+ model_name=model_name, model_uid=model_uid
623
+ )
624
+ return [info.dict() for info in sorted(infos, key=lambda info: info.model_uid)]
625
+
580
626
  async def _check_dead_nodes(self):
581
627
  while True:
582
628
  dead_nodes = []
@@ -705,3 +751,7 @@ class SupervisorActor(xo.StatelessActor):
705
751
  self._worker_status[worker_address] = WorkerStatus(
706
752
  update_time=time.time(), status=status
707
753
  )
754
+
755
+ @staticmethod
756
+ def record_metrics(name, op, kwargs):
757
+ record_metrics(name, op, kwargs)