xinference 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-05-31T17:12:13+0800",
11
+ "date": "2024-06-07T15:04:33+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "69c09cd068a530cd2fdcac07e4e81f03d48f04f9",
15
- "version": "0.11.3"
14
+ "full-revisionid": "55c5636f2b6022842d1827eae373c8e5f162a1a3",
15
+ "version": "0.12.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -122,6 +122,14 @@ class TextToImageRequest(BaseModel):
122
122
  user: Optional[str] = None
123
123
 
124
124
 
125
+ class SpeechRequest(BaseModel):
126
+ model: str
127
+ input: str
128
+ voice: Optional[str]
129
+ response_format: Optional[str] = "mp3"
130
+ speed: Optional[float] = 1.0
131
+
132
+
125
133
  class RegisterModelRequest(BaseModel):
126
134
  model: str
127
135
  persist: bool
@@ -337,6 +345,16 @@ class RESTfulAPI:
337
345
  else None
338
346
  ),
339
347
  )
348
+ self._router.add_api_route(
349
+ "/v1/models/{model_uid}/requests/{request_id}/abort",
350
+ self.abort_request,
351
+ methods=["POST"],
352
+ dependencies=(
353
+ [Security(self._auth_service, scopes=["models:read"])]
354
+ if self.is_authenticated()
355
+ else None
356
+ ),
357
+ )
340
358
  self._router.add_api_route(
341
359
  "/v1/models/instance",
342
360
  self.launch_model_by_version,
@@ -418,6 +436,16 @@ class RESTfulAPI:
418
436
  else None
419
437
  ),
420
438
  )
439
+ self._router.add_api_route(
440
+ "/v1/audio/speech",
441
+ self.create_speech,
442
+ methods=["POST"],
443
+ dependencies=(
444
+ [Security(self._auth_service, scopes=["models:read"])]
445
+ if self.is_authenticated()
446
+ else None
447
+ ),
448
+ )
421
449
  self._router.add_api_route(
422
450
  "/v1/images/generations",
423
451
  self.create_images,
@@ -1179,6 +1207,38 @@ class RESTfulAPI:
1179
1207
  await self._report_error_event(model_uid, str(e))
1180
1208
  raise HTTPException(status_code=500, detail=str(e))
1181
1209
 
1210
+ async def create_speech(self, request: Request) -> Response:
1211
+ body = SpeechRequest.parse_obj(await request.json())
1212
+ model_uid = body.model
1213
+ try:
1214
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1215
+ except ValueError as ve:
1216
+ logger.error(str(ve), exc_info=True)
1217
+ await self._report_error_event(model_uid, str(ve))
1218
+ raise HTTPException(status_code=400, detail=str(ve))
1219
+ except Exception as e:
1220
+ logger.error(e, exc_info=True)
1221
+ await self._report_error_event(model_uid, str(e))
1222
+ raise HTTPException(status_code=500, detail=str(e))
1223
+
1224
+ try:
1225
+ out = await model.speech(
1226
+ input=body.input,
1227
+ voice=body.voice,
1228
+ response_format=body.response_format,
1229
+ speed=body.speed,
1230
+ )
1231
+ return Response(media_type="application/octet-stream", content=out)
1232
+ except RuntimeError as re:
1233
+ logger.error(re, exc_info=True)
1234
+ await self._report_error_event(model_uid, str(re))
1235
+ self.handle_request_limit_error(re)
1236
+ raise HTTPException(status_code=400, detail=str(re))
1237
+ except Exception as e:
1238
+ logger.error(e, exc_info=True)
1239
+ await self._report_error_event(model_uid, str(e))
1240
+ raise HTTPException(status_code=500, detail=str(e))
1241
+
1182
1242
  async def create_images(self, request: Request) -> Response:
1183
1243
  body = TextToImageRequest.parse_obj(await request.json())
1184
1244
  model_uid = body.model
@@ -1518,6 +1578,15 @@ class RESTfulAPI:
1518
1578
  logger.error(e, exc_info=True)
1519
1579
  raise HTTPException(status_code=500, detail=str(e))
1520
1580
 
1581
+ async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
1582
+ try:
1583
+ supervisor_ref = await self._get_supervisor_ref()
1584
+ res = await supervisor_ref.abort_request(model_uid, request_id)
1585
+ return JSONResponse(content=res)
1586
+ except Exception as e:
1587
+ logger.error(e, exc_info=True)
1588
+ raise HTTPException(status_code=500, detail=str(e))
1589
+
1521
1590
  async def list_vllm_supported_model_families(self) -> JSONResponse:
1522
1591
  try:
1523
1592
  from ..model.llm.vllm.core import (
@@ -684,6 +684,49 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
684
684
  response_data = response.json()
685
685
  return response_data
686
686
 
687
+ def speech(
688
+ self,
689
+ input: str,
690
+ voice: str = "",
691
+ response_format: str = "mp3",
692
+ speed: float = 1.0,
693
+ ):
694
+ """
695
+ Generates audio from the input text.
696
+
697
+ Parameters
698
+ ----------
699
+
700
+ input: str
701
+ The text to generate audio for. The maximum length is 4096 characters.
702
+ voice: str
703
+ The voice to use when generating the audio.
704
+ response_format: str
705
+ The format to audio in.
706
+ speed: str
707
+ The speed of the generated audio.
708
+
709
+ Returns
710
+ -------
711
+ bytes
712
+ The generated audio binary.
713
+ """
714
+ url = f"{self._base_url}/v1/audio/speech"
715
+ params = {
716
+ "model": self._model_uid,
717
+ "input": input,
718
+ "voice": voice,
719
+ "response_format": response_format,
720
+ "speed": speed,
721
+ }
722
+ response = requests.post(url, json=params, headers=self.auth_headers)
723
+ if response.status_code != 200:
724
+ raise RuntimeError(
725
+ f"Failed to speech the text, detail: {_get_error_string(response)}"
726
+ )
727
+
728
+ return response.content
729
+
687
730
 
688
731
  class Client:
689
732
  def __init__(self, base_url, api_key: Optional[str] = None):
@@ -1181,3 +1224,30 @@ class Client:
1181
1224
 
1182
1225
  response_data = response.json()
1183
1226
  return response_data
1227
+
1228
+ def abort_request(self, model_uid: str, request_id: str):
1229
+ """
1230
+ Abort a request.
1231
+ Abort a submitted request. If the request is finished or not found, this method will be a no-op.
1232
+ Currently, this interface is only supported when batching is enabled for models on transformers backend.
1233
+
1234
+ Parameters
1235
+ ----------
1236
+ model_uid: str
1237
+ Model uid.
1238
+ request_id: str
1239
+ Request id.
1240
+ Returns
1241
+ -------
1242
+ Dict
1243
+ Return empty dict.
1244
+ """
1245
+ url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
1246
+ response = requests.post(url, headers=self._headers)
1247
+ if response.status_code != 200:
1248
+ raise RuntimeError(
1249
+ f"Failed to abort request, detail: {_get_error_string(response)}"
1250
+ )
1251
+
1252
+ response_data = response.json()
1253
+ return response_data
xinference/constants.py CHANGED
@@ -27,6 +27,7 @@ XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
27
27
  XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
28
28
  XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
29
29
  XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
30
+ XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
30
31
 
31
32
 
32
33
  def get_xinference_home() -> str:
@@ -70,3 +71,6 @@ XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG,
70
71
  XINFERENCE_DISABLE_METRICS = bool(
71
72
  int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
72
73
  )
74
+ XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
75
+ int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
76
+ )
xinference/core/model.py CHANGED
@@ -20,9 +20,14 @@ import os
20
20
  import time
21
21
  import types
22
22
  import weakref
23
+ from asyncio.queues import Queue
24
+ from asyncio.tasks import wait_for
25
+ from concurrent.futures import Future as ConcurrentFuture
23
26
  from typing import (
24
27
  TYPE_CHECKING,
28
+ Any,
25
29
  AsyncGenerator,
30
+ AsyncIterator,
26
31
  Callable,
27
32
  Dict,
28
33
  Generator,
@@ -35,6 +40,8 @@ from typing import (
35
40
  import sse_starlette.sse
36
41
  import xoscar as xo
37
42
 
43
+ from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
44
+
38
45
  if TYPE_CHECKING:
39
46
  from .worker import WorkerActor
40
47
  from ..model.llm.core import LLM
@@ -125,6 +132,16 @@ class ModelActor(xo.StatelessActor):
125
132
  from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
126
133
  from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
127
134
 
135
+ if self.allow_batching():
136
+ try:
137
+ assert self._scheduler_ref is not None
138
+ await xo.destroy_actor(self._scheduler_ref)
139
+ del self._scheduler_ref
140
+ except Exception as e:
141
+ logger.debug(
142
+ f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
143
+ )
144
+
128
145
  if (
129
146
  isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
130
147
  and self._model.model_spec.model_format == "pytorch"
@@ -181,9 +198,20 @@ class ModelActor(xo.StatelessActor):
181
198
  }
182
199
  self._loop: Optional[asyncio.AbstractEventLoop] = None
183
200
 
201
+ self._scheduler_ref = None
202
+
184
203
  async def __post_create__(self):
185
204
  self._loop = asyncio.get_running_loop()
186
205
 
206
+ if self.allow_batching():
207
+ from .scheduler import SchedulerActor
208
+
209
+ self._scheduler_ref = await xo.create_actor(
210
+ SchedulerActor,
211
+ address=self.address,
212
+ uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
213
+ )
214
+
187
215
  async def _record_completion_metrics(
188
216
  self, duration, completion_tokens, prompt_tokens
189
217
  ):
@@ -235,8 +263,22 @@ class ModelActor(xo.StatelessActor):
235
263
 
236
264
  return isinstance(self._model, VLLMModel)
237
265
 
238
- def load(self):
266
+ def allow_batching(self) -> bool:
267
+ from ..model.llm.pytorch.core import PytorchChatModel
268
+
269
+ return (
270
+ XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
271
+ and isinstance(self._model, PytorchChatModel)
272
+ and self._model.__class__.__name__ == PytorchChatModel.__name__
273
+ )
274
+
275
+ async def load(self):
239
276
  self._model.load()
277
+ if self.allow_batching():
278
+ await self._scheduler_ref.set_model(self._model)
279
+ logger.debug(
280
+ f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
281
+ )
240
282
 
241
283
  def model_uid(self):
242
284
  return (
@@ -343,6 +385,8 @@ class ModelActor(xo.StatelessActor):
343
385
  gen = self._to_json_async_gen(ret)
344
386
  self._current_generator = weakref.ref(gen)
345
387
  return gen
388
+ if isinstance(ret, bytes):
389
+ return ret
346
390
  return await asyncio.to_thread(json_dumps, ret)
347
391
 
348
392
  @log_async(logger=logger)
@@ -359,6 +403,36 @@ class ModelActor(xo.StatelessActor):
359
403
  )
360
404
  raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
361
405
 
406
+ async def _queue_consumer(
407
+ self, queue: Queue, timeout: Optional[float] = None
408
+ ) -> AsyncIterator[Any]:
409
+ from .scheduler import (
410
+ XINFERENCE_STREAMING_ABORT_FLAG,
411
+ XINFERENCE_STREAMING_DONE_FLAG,
412
+ XINFERENCE_STREAMING_ERROR_FLAG,
413
+ )
414
+
415
+ while True:
416
+ # TODO: timeout setting
417
+ res = await wait_for(queue.get(), timeout)
418
+ if res == XINFERENCE_STREAMING_DONE_FLAG:
419
+ break
420
+ elif res == XINFERENCE_STREAMING_ABORT_FLAG:
421
+ raise RuntimeError(
422
+ f"This request has been cancelled by another `abort_request` request."
423
+ )
424
+ elif isinstance(res, str) and res.startswith(
425
+ XINFERENCE_STREAMING_ERROR_FLAG
426
+ ):
427
+ raise RuntimeError(res[len(XINFERENCE_STREAMING_ERROR_FLAG) :])
428
+ else:
429
+ yield res
430
+
431
+ @staticmethod
432
+ def get_stream_from_args(*args) -> bool:
433
+ assert args[2] is None or isinstance(args[2], dict)
434
+ return False if args[2] is None else args[2].get("stream", False)
435
+
362
436
  @log_async(logger=logger)
363
437
  @request_limit
364
438
  @xo.generator
@@ -366,17 +440,46 @@ class ModelActor(xo.StatelessActor):
366
440
  start_time = time.time()
367
441
  response = None
368
442
  try:
369
- if hasattr(self._model, "chat"):
370
- response = await self._call_wrapper(
371
- self._model.chat, prompt, *args, **kwargs
372
- )
373
- return response
374
- if hasattr(self._model, "async_chat"):
375
- response = await self._call_wrapper(
376
- self._model.async_chat, prompt, *args, **kwargs
377
- )
378
- return response
379
- raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
443
+ if self.allow_batching():
444
+ stream = self.get_stream_from_args(*args)
445
+ assert self._scheduler_ref is not None
446
+ if stream:
447
+ assert self._scheduler_ref is not None
448
+ queue: Queue[Any] = Queue()
449
+ ret = self._queue_consumer(queue)
450
+ await self._scheduler_ref.add_request(
451
+ prompt, queue, *args, **kwargs
452
+ )
453
+ gen = self._to_json_async_gen(ret)
454
+ self._current_generator = weakref.ref(gen)
455
+ return gen
456
+ else:
457
+ from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
458
+
459
+ assert self._loop is not None
460
+ future = ConcurrentFuture()
461
+ await self._scheduler_ref.add_request(
462
+ prompt, future, *args, **kwargs
463
+ )
464
+ fut = asyncio.wrap_future(future, loop=self._loop)
465
+ result = await fut
466
+ if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
467
+ raise RuntimeError(
468
+ f"This request has been cancelled by another `abort_request` request."
469
+ )
470
+ return await asyncio.to_thread(json_dumps, result)
471
+ else:
472
+ if hasattr(self._model, "chat"):
473
+ response = await self._call_wrapper(
474
+ self._model.chat, prompt, *args, **kwargs
475
+ )
476
+ return response
477
+ if hasattr(self._model, "async_chat"):
478
+ response = await self._call_wrapper(
479
+ self._model.async_chat, prompt, *args, **kwargs
480
+ )
481
+ return response
482
+ raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
380
483
  finally:
381
484
  # For the non stream result.
382
485
  record = None
@@ -395,6 +498,15 @@ class ModelActor(xo.StatelessActor):
395
498
  prompt_tokens,
396
499
  )
397
500
 
501
+ async def abort_request(self, request_id: str) -> str:
502
+ from .scheduler import AbortRequestMessage
503
+
504
+ if self.allow_batching():
505
+ if self._scheduler_ref is None:
506
+ return AbortRequestMessage.NOT_FOUND.name
507
+ return await self._scheduler_ref.abort_request(request_id)
508
+ return AbortRequestMessage.NO_OP.name
509
+
398
510
  @log_async(logger=logger)
399
511
  @request_limit
400
512
  async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
@@ -482,6 +594,23 @@ class ModelActor(xo.StatelessActor):
482
594
  f"Model {self._model.model_spec} is not for creating translations."
483
595
  )
484
596
 
597
+ @log_async(logger=logger)
598
+ @request_limit
599
+ async def speech(
600
+ self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
601
+ ):
602
+ if hasattr(self._model, "speech"):
603
+ return await self._call_wrapper(
604
+ self._model.speech,
605
+ input,
606
+ voice,
607
+ response_format,
608
+ speed,
609
+ )
610
+ raise AttributeError(
611
+ f"Model {self._model.model_spec} is not for creating speech."
612
+ )
613
+
485
614
  @log_async(logger=logger)
486
615
  @request_limit
487
616
  async def text_to_image(