xinference 0.11.2.post1__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (36) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +83 -8
  3. xinference/client/restful/restful_client.py +70 -0
  4. xinference/constants.py +8 -0
  5. xinference/core/__init__.py +0 -2
  6. xinference/core/cache_tracker.py +22 -1
  7. xinference/core/chat_interface.py +71 -10
  8. xinference/core/model.py +141 -12
  9. xinference/core/scheduler.py +428 -0
  10. xinference/core/supervisor.py +31 -3
  11. xinference/core/worker.py +8 -3
  12. xinference/isolation.py +9 -2
  13. xinference/model/audio/chattts.py +84 -0
  14. xinference/model/audio/core.py +10 -3
  15. xinference/model/audio/model_spec.json +20 -0
  16. xinference/model/llm/__init__.py +6 -0
  17. xinference/model/llm/llm_family.json +1063 -260
  18. xinference/model/llm/llm_family_modelscope.json +686 -13
  19. xinference/model/llm/pytorch/baichuan.py +2 -1
  20. xinference/model/llm/pytorch/chatglm.py +2 -1
  21. xinference/model/llm/pytorch/cogvlm2.py +316 -0
  22. xinference/model/llm/pytorch/core.py +92 -6
  23. xinference/model/llm/pytorch/glm4v.py +258 -0
  24. xinference/model/llm/pytorch/intern_vl.py +5 -10
  25. xinference/model/llm/pytorch/minicpmv25.py +232 -0
  26. xinference/model/llm/pytorch/utils.py +386 -2
  27. xinference/model/llm/vllm/core.py +7 -1
  28. xinference/thirdparty/ChatTTS/__init__.py +1 -0
  29. xinference/thirdparty/ChatTTS/core.py +200 -0
  30. xinference/types.py +3 -0
  31. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/METADATA +28 -11
  32. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/RECORD +36 -29
  33. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
  34. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
  35. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
  36. {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-05-24T19:39:58+0800",
11
+ "date": "2024-06-07T15:04:33+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "ac8f33439c25e6fb05eba79e7932cbbadd068174",
15
- "version": "0.11.2.post1"
14
+ "full-revisionid": "55c5636f2b6022842d1827eae373c8e5f162a1a3",
15
+ "version": "0.12.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -52,7 +52,7 @@ from xoscar.utils import get_next_port
52
52
 
53
53
  from .._compat import BaseModel, Field
54
54
  from .._version import get_versions
55
- from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT
55
+ from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS
56
56
  from ..core.event import Event, EventCollectorActor, EventType
57
57
  from ..core.supervisor import SupervisorActor
58
58
  from ..core.utils import json_dumps
@@ -122,6 +122,14 @@ class TextToImageRequest(BaseModel):
122
122
  user: Optional[str] = None
123
123
 
124
124
 
125
+ class SpeechRequest(BaseModel):
126
+ model: str
127
+ input: str
128
+ voice: Optional[str]
129
+ response_format: Optional[str] = "mp3"
130
+ speed: Optional[float] = 1.0
131
+
132
+
125
133
  class RegisterModelRequest(BaseModel):
126
134
  model: str
127
135
  persist: bool
@@ -337,6 +345,16 @@ class RESTfulAPI:
337
345
  else None
338
346
  ),
339
347
  )
348
+ self._router.add_api_route(
349
+ "/v1/models/{model_uid}/requests/{request_id}/abort",
350
+ self.abort_request,
351
+ methods=["POST"],
352
+ dependencies=(
353
+ [Security(self._auth_service, scopes=["models:read"])]
354
+ if self.is_authenticated()
355
+ else None
356
+ ),
357
+ )
340
358
  self._router.add_api_route(
341
359
  "/v1/models/instance",
342
360
  self.launch_model_by_version,
@@ -418,6 +436,16 @@ class RESTfulAPI:
418
436
  else None
419
437
  ),
420
438
  )
439
+ self._router.add_api_route(
440
+ "/v1/audio/speech",
441
+ self.create_speech,
442
+ methods=["POST"],
443
+ dependencies=(
444
+ [Security(self._auth_service, scopes=["models:read"])]
445
+ if self.is_authenticated()
446
+ else None
447
+ ),
448
+ )
421
449
  self._router.add_api_route(
422
450
  "/v1/images/generations",
423
451
  self.create_images,
@@ -504,13 +532,19 @@ class RESTfulAPI:
504
532
  ),
505
533
  )
506
534
 
507
- # Clear the global Registry for the MetricsMiddleware, or
508
- # the MetricsMiddleware will register duplicated metrics if the port
509
- # conflict (This serve method run more than once).
510
- REGISTRY.clear()
511
- self._app.add_middleware(MetricsMiddleware)
512
- self._app.include_router(self._router)
513
- self._app.add_route("/metrics", metrics)
535
+ if XINFERENCE_DISABLE_METRICS:
536
+ logger.info(
537
+ "Supervisor metrics is disabled due to the environment XINFERENCE_DISABLE_METRICS=1"
538
+ )
539
+ self._app.include_router(self._router)
540
+ else:
541
+ # Clear the global Registry for the MetricsMiddleware, or
542
+ # the MetricsMiddleware will register duplicated metrics if the port
543
+ # conflict (This serve method run more than once).
544
+ REGISTRY.clear()
545
+ self._app.add_middleware(MetricsMiddleware)
546
+ self._app.include_router(self._router)
547
+ self._app.add_route("/metrics", metrics)
514
548
 
515
549
  # Check all the routes returns Response.
516
550
  # This is to avoid `jsonable_encoder` performance issue:
@@ -1173,6 +1207,38 @@ class RESTfulAPI:
1173
1207
  await self._report_error_event(model_uid, str(e))
1174
1208
  raise HTTPException(status_code=500, detail=str(e))
1175
1209
 
1210
+ async def create_speech(self, request: Request) -> Response:
1211
+ body = SpeechRequest.parse_obj(await request.json())
1212
+ model_uid = body.model
1213
+ try:
1214
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1215
+ except ValueError as ve:
1216
+ logger.error(str(ve), exc_info=True)
1217
+ await self._report_error_event(model_uid, str(ve))
1218
+ raise HTTPException(status_code=400, detail=str(ve))
1219
+ except Exception as e:
1220
+ logger.error(e, exc_info=True)
1221
+ await self._report_error_event(model_uid, str(e))
1222
+ raise HTTPException(status_code=500, detail=str(e))
1223
+
1224
+ try:
1225
+ out = await model.speech(
1226
+ input=body.input,
1227
+ voice=body.voice,
1228
+ response_format=body.response_format,
1229
+ speed=body.speed,
1230
+ )
1231
+ return Response(media_type="application/octet-stream", content=out)
1232
+ except RuntimeError as re:
1233
+ logger.error(re, exc_info=True)
1234
+ await self._report_error_event(model_uid, str(re))
1235
+ self.handle_request_limit_error(re)
1236
+ raise HTTPException(status_code=400, detail=str(re))
1237
+ except Exception as e:
1238
+ logger.error(e, exc_info=True)
1239
+ await self._report_error_event(model_uid, str(e))
1240
+ raise HTTPException(status_code=500, detail=str(e))
1241
+
1176
1242
  async def create_images(self, request: Request) -> Response:
1177
1243
  body = TextToImageRequest.parse_obj(await request.json())
1178
1244
  model_uid = body.model
@@ -1512,6 +1578,15 @@ class RESTfulAPI:
1512
1578
  logger.error(e, exc_info=True)
1513
1579
  raise HTTPException(status_code=500, detail=str(e))
1514
1580
 
1581
+ async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
1582
+ try:
1583
+ supervisor_ref = await self._get_supervisor_ref()
1584
+ res = await supervisor_ref.abort_request(model_uid, request_id)
1585
+ return JSONResponse(content=res)
1586
+ except Exception as e:
1587
+ logger.error(e, exc_info=True)
1588
+ raise HTTPException(status_code=500, detail=str(e))
1589
+
1515
1590
  async def list_vllm_supported_model_families(self) -> JSONResponse:
1516
1591
  try:
1517
1592
  from ..model.llm.vllm.core import (
@@ -684,6 +684,49 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
684
684
  response_data = response.json()
685
685
  return response_data
686
686
 
687
+ def speech(
688
+ self,
689
+ input: str,
690
+ voice: str = "",
691
+ response_format: str = "mp3",
692
+ speed: float = 1.0,
693
+ ):
694
+ """
695
+ Generates audio from the input text.
696
+
697
+ Parameters
698
+ ----------
699
+
700
+ input: str
701
+ The text to generate audio for. The maximum length is 4096 characters.
702
+ voice: str
703
+ The voice to use when generating the audio.
704
+ response_format: str
705
+ The format to audio in.
706
+ speed: str
707
+ The speed of the generated audio.
708
+
709
+ Returns
710
+ -------
711
+ bytes
712
+ The generated audio binary.
713
+ """
714
+ url = f"{self._base_url}/v1/audio/speech"
715
+ params = {
716
+ "model": self._model_uid,
717
+ "input": input,
718
+ "voice": voice,
719
+ "response_format": response_format,
720
+ "speed": speed,
721
+ }
722
+ response = requests.post(url, json=params, headers=self.auth_headers)
723
+ if response.status_code != 200:
724
+ raise RuntimeError(
725
+ f"Failed to speech the text, detail: {_get_error_string(response)}"
726
+ )
727
+
728
+ return response.content
729
+
687
730
 
688
731
  class Client:
689
732
  def __init__(self, base_url, api_key: Optional[str] = None):
@@ -1181,3 +1224,30 @@ class Client:
1181
1224
 
1182
1225
  response_data = response.json()
1183
1226
  return response_data
1227
+
1228
+ def abort_request(self, model_uid: str, request_id: str):
1229
+ """
1230
+ Abort a request.
1231
+ Abort a submitted request. If the request is finished or not found, this method will be a no-op.
1232
+ Currently, this interface is only supported when batching is enabled for models on transformers backend.
1233
+
1234
+ Parameters
1235
+ ----------
1236
+ model_uid: str
1237
+ Model uid.
1238
+ request_id: str
1239
+ Request id.
1240
+ Returns
1241
+ -------
1242
+ Dict
1243
+ Return empty dict.
1244
+ """
1245
+ url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
1246
+ response = requests.post(url, headers=self._headers)
1247
+ if response.status_code != 200:
1248
+ raise RuntimeError(
1249
+ f"Failed to abort request, detail: {_get_error_string(response)}"
1250
+ )
1251
+
1252
+ response_data = response.json()
1253
+ return response_data
xinference/constants.py CHANGED
@@ -26,6 +26,8 @@ XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
26
26
  XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
27
27
  XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
28
28
  XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
29
+ XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
30
+ XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
29
31
 
30
32
 
31
33
  def get_xinference_home() -> str:
@@ -66,3 +68,9 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
66
68
  )
67
69
  XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
68
70
  XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG, 0)))
71
+ XINFERENCE_DISABLE_METRICS = bool(
72
+ int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
73
+ )
74
+ XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
75
+ int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
76
+ )
@@ -11,5 +11,3 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- from .model import ModelActor
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import os
14
15
  from logging import getLogger
15
16
  from typing import Any, Dict, List, Optional
16
17
 
@@ -105,9 +106,29 @@ class CacheTrackerActor(xo.Actor):
105
106
  cached_models = []
106
107
  for model_name, model_versions in self._model_name_to_version_info.items():
107
108
  for version_info in model_versions:
108
- if version_info["cache_status"]:
109
+ cache_status = version_info.get("cache_status", None)
110
+ if cache_status == True:
109
111
  ret = version_info.copy()
110
112
  ret["model_name"] = model_name
113
+
114
+ re_dict = version_info.get("model_file_location", None)
115
+ if re_dict is not None and isinstance(re_dict, dict):
116
+ if re_dict:
117
+ actor_ip_address, path = next(iter(re_dict.items()))
118
+ else:
119
+ raise ValueError("The dictionary is empty.")
120
+ else:
121
+ raise ValueError("re_dict must be a non-empty dictionary.")
122
+
123
+ ret["actor_ip_address"] = actor_ip_address
124
+ ret["path"] = path
125
+ if os.path.isdir(path):
126
+ files = os.listdir(path)
127
+ resolved_file = os.path.realpath(os.path.join(path, files[0]))
128
+ if resolved_file:
129
+ ret["real_path"] = os.path.dirname(resolved_file)
130
+ else:
131
+ ret["real_path"] = os.path.realpath(path)
111
132
  cached_models.append(ret)
112
133
  cached_models = sorted(cached_models, key=lambda x: x["model_name"])
113
134
  return cached_models
@@ -186,8 +186,7 @@ class GradioInterface:
186
186
  def build_chat_vl_interface(
187
187
  self,
188
188
  ) -> "gr.Blocks":
189
- def predict(history, bot):
190
- logger.debug("Predict model: %s, history: %s", self.model_uid, history)
189
+ def predict(history, bot, max_tokens, temperature, stream):
191
190
  from ..client import RESTfulClient
192
191
 
193
192
  client = RESTfulClient(self.endpoint)
@@ -199,10 +198,46 @@ class GradioInterface:
199
198
  assert prompt["role"] == "user"
200
199
  prompt = prompt["content"]
201
200
  # multimodal chat does not support stream.
202
- response = model.chat(prompt=prompt, chat_history=history[:-1])
203
- history.append(response["choices"][0]["message"])
204
- bot[-1][1] = history[-1]["content"]
205
- return history, bot
201
+ if stream:
202
+ response_content = ""
203
+ for chunk in model.chat(
204
+ prompt=prompt,
205
+ chat_history=history[:-1],
206
+ generate_config={
207
+ "max_tokens": max_tokens,
208
+ "temperature": temperature,
209
+ "stream": stream,
210
+ },
211
+ ):
212
+ assert isinstance(chunk, dict)
213
+ delta = chunk["choices"][0]["delta"]
214
+ if "content" not in delta:
215
+ continue
216
+ else:
217
+ response_content += delta["content"]
218
+ bot[-1][1] = response_content
219
+ yield history, bot
220
+ history.append(
221
+ {
222
+ "content": response_content,
223
+ "role": "assistant",
224
+ }
225
+ )
226
+ bot[-1][1] = response_content
227
+ yield history, bot
228
+ else:
229
+ response = model.chat(
230
+ prompt=prompt,
231
+ chat_history=history[:-1],
232
+ generate_config={
233
+ "max_tokens": max_tokens,
234
+ "temperature": temperature,
235
+ "stream": stream,
236
+ },
237
+ )
238
+ history.append(response["choices"][0]["message"])
239
+ bot[-1][1] = history[-1]["content"]
240
+ yield history, bot
206
241
 
207
242
  def add_text(history, bot, text, image):
208
243
  logger.debug("Add text, text: %s, image: %s", text, image)
@@ -217,14 +252,19 @@ class GradioInterface:
217
252
  "role": "user",
218
253
  "content": [
219
254
  {"type": "text", "text": text},
220
- {"type": "image_url", "image_url": {"url": image}},
255
+ {
256
+ "type": "image_url",
257
+ "image_url": {
258
+ "url": f"data:image/png;base64,{img_b64_str}"
259
+ },
260
+ },
221
261
  ],
222
262
  }
223
263
  else:
224
264
  display_content = text
225
265
  message = {"role": "user", "content": text}
226
266
  history = history + [message]
227
- bot = bot + [(display_content, None)]
267
+ bot = bot + [[display_content, None]]
228
268
  return history, bot, "", None
229
269
 
230
270
  def clear_history():
@@ -286,6 +326,19 @@ class GradioInterface:
286
326
  )
287
327
  clear_btn = gr.Button(value="Clear")
288
328
 
329
+ with gr.Accordion("Additional Inputs", open=False):
330
+ max_tokens = gr.Slider(
331
+ minimum=1,
332
+ maximum=self.context_length,
333
+ value=512,
334
+ step=1,
335
+ label="Max Tokens",
336
+ )
337
+ temperature = gr.Slider(
338
+ minimum=0, maximum=2, value=1, step=0.01, label="Temperature"
339
+ )
340
+ stream = gr.Checkbox(label="Stream", value=False)
341
+
289
342
  textbox.change(update_button, [textbox], [submit_btn], queue=False)
290
343
 
291
344
  textbox.submit(
@@ -293,14 +346,22 @@ class GradioInterface:
293
346
  [state, chatbot, textbox, imagebox],
294
347
  [state, chatbot, textbox, imagebox],
295
348
  queue=False,
296
- ).then(predict, [state, chatbot], [state, chatbot])
349
+ ).then(
350
+ predict,
351
+ [state, chatbot, max_tokens, temperature, stream],
352
+ [state, chatbot],
353
+ )
297
354
 
298
355
  submit_btn.click(
299
356
  add_text,
300
357
  [state, chatbot, textbox, imagebox],
301
358
  [state, chatbot, textbox, imagebox],
302
359
  queue=False,
303
- ).then(predict, [state, chatbot], [state, chatbot])
360
+ ).then(
361
+ predict,
362
+ [state, chatbot, max_tokens, temperature, stream],
363
+ [state, chatbot],
364
+ )
304
365
 
305
366
  clear_btn.click(
306
367
  clear_history, None, [state, chatbot, textbox, imagebox], queue=False
xinference/core/model.py CHANGED
@@ -20,9 +20,14 @@ import os
20
20
  import time
21
21
  import types
22
22
  import weakref
23
+ from asyncio.queues import Queue
24
+ from asyncio.tasks import wait_for
25
+ from concurrent.futures import Future as ConcurrentFuture
23
26
  from typing import (
24
27
  TYPE_CHECKING,
28
+ Any,
25
29
  AsyncGenerator,
30
+ AsyncIterator,
26
31
  Callable,
27
32
  Dict,
28
33
  Generator,
@@ -35,6 +40,8 @@ from typing import (
35
40
  import sse_starlette.sse
36
41
  import xoscar as xo
37
42
 
43
+ from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
44
+
38
45
  if TYPE_CHECKING:
39
46
  from .worker import WorkerActor
40
47
  from ..model.llm.core import LLM
@@ -125,6 +132,16 @@ class ModelActor(xo.StatelessActor):
125
132
  from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
126
133
  from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
127
134
 
135
+ if self.allow_batching():
136
+ try:
137
+ assert self._scheduler_ref is not None
138
+ await xo.destroy_actor(self._scheduler_ref)
139
+ del self._scheduler_ref
140
+ except Exception as e:
141
+ logger.debug(
142
+ f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
143
+ )
144
+
128
145
  if (
129
146
  isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
130
147
  and self._model.model_spec.model_format == "pytorch"
@@ -181,9 +198,20 @@ class ModelActor(xo.StatelessActor):
181
198
  }
182
199
  self._loop: Optional[asyncio.AbstractEventLoop] = None
183
200
 
201
+ self._scheduler_ref = None
202
+
184
203
  async def __post_create__(self):
185
204
  self._loop = asyncio.get_running_loop()
186
205
 
206
+ if self.allow_batching():
207
+ from .scheduler import SchedulerActor
208
+
209
+ self._scheduler_ref = await xo.create_actor(
210
+ SchedulerActor,
211
+ address=self.address,
212
+ uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
213
+ )
214
+
187
215
  async def _record_completion_metrics(
188
216
  self, duration, completion_tokens, prompt_tokens
189
217
  ):
@@ -235,8 +263,22 @@ class ModelActor(xo.StatelessActor):
235
263
 
236
264
  return isinstance(self._model, VLLMModel)
237
265
 
238
- def load(self):
266
+ def allow_batching(self) -> bool:
267
+ from ..model.llm.pytorch.core import PytorchChatModel
268
+
269
+ return (
270
+ XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
271
+ and isinstance(self._model, PytorchChatModel)
272
+ and self._model.__class__.__name__ == PytorchChatModel.__name__
273
+ )
274
+
275
+ async def load(self):
239
276
  self._model.load()
277
+ if self.allow_batching():
278
+ await self._scheduler_ref.set_model(self._model)
279
+ logger.debug(
280
+ f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
281
+ )
240
282
 
241
283
  def model_uid(self):
242
284
  return (
@@ -343,6 +385,8 @@ class ModelActor(xo.StatelessActor):
343
385
  gen = self._to_json_async_gen(ret)
344
386
  self._current_generator = weakref.ref(gen)
345
387
  return gen
388
+ if isinstance(ret, bytes):
389
+ return ret
346
390
  return await asyncio.to_thread(json_dumps, ret)
347
391
 
348
392
  @log_async(logger=logger)
@@ -359,6 +403,36 @@ class ModelActor(xo.StatelessActor):
359
403
  )
360
404
  raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
361
405
 
406
+ async def _queue_consumer(
407
+ self, queue: Queue, timeout: Optional[float] = None
408
+ ) -> AsyncIterator[Any]:
409
+ from .scheduler import (
410
+ XINFERENCE_STREAMING_ABORT_FLAG,
411
+ XINFERENCE_STREAMING_DONE_FLAG,
412
+ XINFERENCE_STREAMING_ERROR_FLAG,
413
+ )
414
+
415
+ while True:
416
+ # TODO: timeout setting
417
+ res = await wait_for(queue.get(), timeout)
418
+ if res == XINFERENCE_STREAMING_DONE_FLAG:
419
+ break
420
+ elif res == XINFERENCE_STREAMING_ABORT_FLAG:
421
+ raise RuntimeError(
422
+ f"This request has been cancelled by another `abort_request` request."
423
+ )
424
+ elif isinstance(res, str) and res.startswith(
425
+ XINFERENCE_STREAMING_ERROR_FLAG
426
+ ):
427
+ raise RuntimeError(res[len(XINFERENCE_STREAMING_ERROR_FLAG) :])
428
+ else:
429
+ yield res
430
+
431
+ @staticmethod
432
+ def get_stream_from_args(*args) -> bool:
433
+ assert args[2] is None or isinstance(args[2], dict)
434
+ return False if args[2] is None else args[2].get("stream", False)
435
+
362
436
  @log_async(logger=logger)
363
437
  @request_limit
364
438
  @xo.generator
@@ -366,17 +440,46 @@ class ModelActor(xo.StatelessActor):
366
440
  start_time = time.time()
367
441
  response = None
368
442
  try:
369
- if hasattr(self._model, "chat"):
370
- response = await self._call_wrapper(
371
- self._model.chat, prompt, *args, **kwargs
372
- )
373
- return response
374
- if hasattr(self._model, "async_chat"):
375
- response = await self._call_wrapper(
376
- self._model.async_chat, prompt, *args, **kwargs
377
- )
378
- return response
379
- raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
443
+ if self.allow_batching():
444
+ stream = self.get_stream_from_args(*args)
445
+ assert self._scheduler_ref is not None
446
+ if stream:
447
+ assert self._scheduler_ref is not None
448
+ queue: Queue[Any] = Queue()
449
+ ret = self._queue_consumer(queue)
450
+ await self._scheduler_ref.add_request(
451
+ prompt, queue, *args, **kwargs
452
+ )
453
+ gen = self._to_json_async_gen(ret)
454
+ self._current_generator = weakref.ref(gen)
455
+ return gen
456
+ else:
457
+ from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
458
+
459
+ assert self._loop is not None
460
+ future = ConcurrentFuture()
461
+ await self._scheduler_ref.add_request(
462
+ prompt, future, *args, **kwargs
463
+ )
464
+ fut = asyncio.wrap_future(future, loop=self._loop)
465
+ result = await fut
466
+ if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
467
+ raise RuntimeError(
468
+ f"This request has been cancelled by another `abort_request` request."
469
+ )
470
+ return await asyncio.to_thread(json_dumps, result)
471
+ else:
472
+ if hasattr(self._model, "chat"):
473
+ response = await self._call_wrapper(
474
+ self._model.chat, prompt, *args, **kwargs
475
+ )
476
+ return response
477
+ if hasattr(self._model, "async_chat"):
478
+ response = await self._call_wrapper(
479
+ self._model.async_chat, prompt, *args, **kwargs
480
+ )
481
+ return response
482
+ raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
380
483
  finally:
381
484
  # For the non stream result.
382
485
  record = None
@@ -395,6 +498,15 @@ class ModelActor(xo.StatelessActor):
395
498
  prompt_tokens,
396
499
  )
397
500
 
501
+ async def abort_request(self, request_id: str) -> str:
502
+ from .scheduler import AbortRequestMessage
503
+
504
+ if self.allow_batching():
505
+ if self._scheduler_ref is None:
506
+ return AbortRequestMessage.NOT_FOUND.name
507
+ return await self._scheduler_ref.abort_request(request_id)
508
+ return AbortRequestMessage.NO_OP.name
509
+
398
510
  @log_async(logger=logger)
399
511
  @request_limit
400
512
  async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
@@ -482,6 +594,23 @@ class ModelActor(xo.StatelessActor):
482
594
  f"Model {self._model.model_spec} is not for creating translations."
483
595
  )
484
596
 
597
+ @log_async(logger=logger)
598
+ @request_limit
599
+ async def speech(
600
+ self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
601
+ ):
602
+ if hasattr(self._model, "speech"):
603
+ return await self._call_wrapper(
604
+ self._model.speech,
605
+ input,
606
+ voice,
607
+ response_format,
608
+ speed,
609
+ )
610
+ raise AttributeError(
611
+ f"Model {self._model.model_spec} is not for creating speech."
612
+ )
613
+
485
614
  @log_async(logger=logger)
486
615
  @request_limit
487
616
  async def text_to_image(