xinference 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +69 -0
- xinference/client/restful/restful_client.py +70 -0
- xinference/constants.py +4 -0
- xinference/core/model.py +141 -12
- xinference/core/scheduler.py +428 -0
- xinference/core/supervisor.py +26 -0
- xinference/isolation.py +9 -2
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +10 -3
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +507 -1
- xinference/model/llm/llm_family_modelscope.json +409 -2
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +91 -6
- xinference/model/llm/pytorch/glm4v.py +258 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/utils.py +386 -2
- xinference/model/llm/vllm/core.py +6 -0
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/types.py +3 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/RECORD +30 -24
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-06-07T15:04:33+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "55c5636f2b6022842d1827eae373c8e5f162a1a3",
|
|
15
|
+
"version": "0.12.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -122,6 +122,14 @@ class TextToImageRequest(BaseModel):
|
|
|
122
122
|
user: Optional[str] = None
|
|
123
123
|
|
|
124
124
|
|
|
125
|
+
class SpeechRequest(BaseModel):
|
|
126
|
+
model: str
|
|
127
|
+
input: str
|
|
128
|
+
voice: Optional[str]
|
|
129
|
+
response_format: Optional[str] = "mp3"
|
|
130
|
+
speed: Optional[float] = 1.0
|
|
131
|
+
|
|
132
|
+
|
|
125
133
|
class RegisterModelRequest(BaseModel):
|
|
126
134
|
model: str
|
|
127
135
|
persist: bool
|
|
@@ -337,6 +345,16 @@ class RESTfulAPI:
|
|
|
337
345
|
else None
|
|
338
346
|
),
|
|
339
347
|
)
|
|
348
|
+
self._router.add_api_route(
|
|
349
|
+
"/v1/models/{model_uid}/requests/{request_id}/abort",
|
|
350
|
+
self.abort_request,
|
|
351
|
+
methods=["POST"],
|
|
352
|
+
dependencies=(
|
|
353
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
354
|
+
if self.is_authenticated()
|
|
355
|
+
else None
|
|
356
|
+
),
|
|
357
|
+
)
|
|
340
358
|
self._router.add_api_route(
|
|
341
359
|
"/v1/models/instance",
|
|
342
360
|
self.launch_model_by_version,
|
|
@@ -418,6 +436,16 @@ class RESTfulAPI:
|
|
|
418
436
|
else None
|
|
419
437
|
),
|
|
420
438
|
)
|
|
439
|
+
self._router.add_api_route(
|
|
440
|
+
"/v1/audio/speech",
|
|
441
|
+
self.create_speech,
|
|
442
|
+
methods=["POST"],
|
|
443
|
+
dependencies=(
|
|
444
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
445
|
+
if self.is_authenticated()
|
|
446
|
+
else None
|
|
447
|
+
),
|
|
448
|
+
)
|
|
421
449
|
self._router.add_api_route(
|
|
422
450
|
"/v1/images/generations",
|
|
423
451
|
self.create_images,
|
|
@@ -1179,6 +1207,38 @@ class RESTfulAPI:
|
|
|
1179
1207
|
await self._report_error_event(model_uid, str(e))
|
|
1180
1208
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1181
1209
|
|
|
1210
|
+
async def create_speech(self, request: Request) -> Response:
|
|
1211
|
+
body = SpeechRequest.parse_obj(await request.json())
|
|
1212
|
+
model_uid = body.model
|
|
1213
|
+
try:
|
|
1214
|
+
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1215
|
+
except ValueError as ve:
|
|
1216
|
+
logger.error(str(ve), exc_info=True)
|
|
1217
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1218
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1219
|
+
except Exception as e:
|
|
1220
|
+
logger.error(e, exc_info=True)
|
|
1221
|
+
await self._report_error_event(model_uid, str(e))
|
|
1222
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1223
|
+
|
|
1224
|
+
try:
|
|
1225
|
+
out = await model.speech(
|
|
1226
|
+
input=body.input,
|
|
1227
|
+
voice=body.voice,
|
|
1228
|
+
response_format=body.response_format,
|
|
1229
|
+
speed=body.speed,
|
|
1230
|
+
)
|
|
1231
|
+
return Response(media_type="application/octet-stream", content=out)
|
|
1232
|
+
except RuntimeError as re:
|
|
1233
|
+
logger.error(re, exc_info=True)
|
|
1234
|
+
await self._report_error_event(model_uid, str(re))
|
|
1235
|
+
self.handle_request_limit_error(re)
|
|
1236
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1237
|
+
except Exception as e:
|
|
1238
|
+
logger.error(e, exc_info=True)
|
|
1239
|
+
await self._report_error_event(model_uid, str(e))
|
|
1240
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1241
|
+
|
|
1182
1242
|
async def create_images(self, request: Request) -> Response:
|
|
1183
1243
|
body = TextToImageRequest.parse_obj(await request.json())
|
|
1184
1244
|
model_uid = body.model
|
|
@@ -1518,6 +1578,15 @@ class RESTfulAPI:
|
|
|
1518
1578
|
logger.error(e, exc_info=True)
|
|
1519
1579
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1520
1580
|
|
|
1581
|
+
async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
|
|
1582
|
+
try:
|
|
1583
|
+
supervisor_ref = await self._get_supervisor_ref()
|
|
1584
|
+
res = await supervisor_ref.abort_request(model_uid, request_id)
|
|
1585
|
+
return JSONResponse(content=res)
|
|
1586
|
+
except Exception as e:
|
|
1587
|
+
logger.error(e, exc_info=True)
|
|
1588
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1589
|
+
|
|
1521
1590
|
async def list_vllm_supported_model_families(self) -> JSONResponse:
|
|
1522
1591
|
try:
|
|
1523
1592
|
from ..model.llm.vllm.core import (
|
|
@@ -684,6 +684,49 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
684
684
|
response_data = response.json()
|
|
685
685
|
return response_data
|
|
686
686
|
|
|
687
|
+
def speech(
|
|
688
|
+
self,
|
|
689
|
+
input: str,
|
|
690
|
+
voice: str = "",
|
|
691
|
+
response_format: str = "mp3",
|
|
692
|
+
speed: float = 1.0,
|
|
693
|
+
):
|
|
694
|
+
"""
|
|
695
|
+
Generates audio from the input text.
|
|
696
|
+
|
|
697
|
+
Parameters
|
|
698
|
+
----------
|
|
699
|
+
|
|
700
|
+
input: str
|
|
701
|
+
The text to generate audio for. The maximum length is 4096 characters.
|
|
702
|
+
voice: str
|
|
703
|
+
The voice to use when generating the audio.
|
|
704
|
+
response_format: str
|
|
705
|
+
The format to audio in.
|
|
706
|
+
speed: str
|
|
707
|
+
The speed of the generated audio.
|
|
708
|
+
|
|
709
|
+
Returns
|
|
710
|
+
-------
|
|
711
|
+
bytes
|
|
712
|
+
The generated audio binary.
|
|
713
|
+
"""
|
|
714
|
+
url = f"{self._base_url}/v1/audio/speech"
|
|
715
|
+
params = {
|
|
716
|
+
"model": self._model_uid,
|
|
717
|
+
"input": input,
|
|
718
|
+
"voice": voice,
|
|
719
|
+
"response_format": response_format,
|
|
720
|
+
"speed": speed,
|
|
721
|
+
}
|
|
722
|
+
response = requests.post(url, json=params, headers=self.auth_headers)
|
|
723
|
+
if response.status_code != 200:
|
|
724
|
+
raise RuntimeError(
|
|
725
|
+
f"Failed to speech the text, detail: {_get_error_string(response)}"
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
return response.content
|
|
729
|
+
|
|
687
730
|
|
|
688
731
|
class Client:
|
|
689
732
|
def __init__(self, base_url, api_key: Optional[str] = None):
|
|
@@ -1181,3 +1224,30 @@ class Client:
|
|
|
1181
1224
|
|
|
1182
1225
|
response_data = response.json()
|
|
1183
1226
|
return response_data
|
|
1227
|
+
|
|
1228
|
+
def abort_request(self, model_uid: str, request_id: str):
|
|
1229
|
+
"""
|
|
1230
|
+
Abort a request.
|
|
1231
|
+
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
1232
|
+
Currently, this interface is only supported when batching is enabled for models on transformers backend.
|
|
1233
|
+
|
|
1234
|
+
Parameters
|
|
1235
|
+
----------
|
|
1236
|
+
model_uid: str
|
|
1237
|
+
Model uid.
|
|
1238
|
+
request_id: str
|
|
1239
|
+
Request id.
|
|
1240
|
+
Returns
|
|
1241
|
+
-------
|
|
1242
|
+
Dict
|
|
1243
|
+
Return empty dict.
|
|
1244
|
+
"""
|
|
1245
|
+
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
|
|
1246
|
+
response = requests.post(url, headers=self._headers)
|
|
1247
|
+
if response.status_code != 200:
|
|
1248
|
+
raise RuntimeError(
|
|
1249
|
+
f"Failed to abort request, detail: {_get_error_string(response)}"
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
response_data = response.json()
|
|
1253
|
+
return response_data
|
xinference/constants.py
CHANGED
|
@@ -27,6 +27,7 @@ XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
|
|
|
27
27
|
XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
|
|
28
28
|
XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
|
|
29
29
|
XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
|
|
30
|
+
XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def get_xinference_home() -> str:
|
|
@@ -70,3 +71,6 @@ XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG,
|
|
|
70
71
|
XINFERENCE_DISABLE_METRICS = bool(
|
|
71
72
|
int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
|
|
72
73
|
)
|
|
74
|
+
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
|
|
75
|
+
int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
|
|
76
|
+
)
|
xinference/core/model.py
CHANGED
|
@@ -20,9 +20,14 @@ import os
|
|
|
20
20
|
import time
|
|
21
21
|
import types
|
|
22
22
|
import weakref
|
|
23
|
+
from asyncio.queues import Queue
|
|
24
|
+
from asyncio.tasks import wait_for
|
|
25
|
+
from concurrent.futures import Future as ConcurrentFuture
|
|
23
26
|
from typing import (
|
|
24
27
|
TYPE_CHECKING,
|
|
28
|
+
Any,
|
|
25
29
|
AsyncGenerator,
|
|
30
|
+
AsyncIterator,
|
|
26
31
|
Callable,
|
|
27
32
|
Dict,
|
|
28
33
|
Generator,
|
|
@@ -35,6 +40,8 @@ from typing import (
|
|
|
35
40
|
import sse_starlette.sse
|
|
36
41
|
import xoscar as xo
|
|
37
42
|
|
|
43
|
+
from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
44
|
+
|
|
38
45
|
if TYPE_CHECKING:
|
|
39
46
|
from .worker import WorkerActor
|
|
40
47
|
from ..model.llm.core import LLM
|
|
@@ -125,6 +132,16 @@ class ModelActor(xo.StatelessActor):
|
|
|
125
132
|
from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
|
|
126
133
|
from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
|
|
127
134
|
|
|
135
|
+
if self.allow_batching():
|
|
136
|
+
try:
|
|
137
|
+
assert self._scheduler_ref is not None
|
|
138
|
+
await xo.destroy_actor(self._scheduler_ref)
|
|
139
|
+
del self._scheduler_ref
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
143
|
+
)
|
|
144
|
+
|
|
128
145
|
if (
|
|
129
146
|
isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
|
|
130
147
|
and self._model.model_spec.model_format == "pytorch"
|
|
@@ -181,9 +198,20 @@ class ModelActor(xo.StatelessActor):
|
|
|
181
198
|
}
|
|
182
199
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
183
200
|
|
|
201
|
+
self._scheduler_ref = None
|
|
202
|
+
|
|
184
203
|
async def __post_create__(self):
|
|
185
204
|
self._loop = asyncio.get_running_loop()
|
|
186
205
|
|
|
206
|
+
if self.allow_batching():
|
|
207
|
+
from .scheduler import SchedulerActor
|
|
208
|
+
|
|
209
|
+
self._scheduler_ref = await xo.create_actor(
|
|
210
|
+
SchedulerActor,
|
|
211
|
+
address=self.address,
|
|
212
|
+
uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
|
|
213
|
+
)
|
|
214
|
+
|
|
187
215
|
async def _record_completion_metrics(
|
|
188
216
|
self, duration, completion_tokens, prompt_tokens
|
|
189
217
|
):
|
|
@@ -235,8 +263,22 @@ class ModelActor(xo.StatelessActor):
|
|
|
235
263
|
|
|
236
264
|
return isinstance(self._model, VLLMModel)
|
|
237
265
|
|
|
238
|
-
def
|
|
266
|
+
def allow_batching(self) -> bool:
|
|
267
|
+
from ..model.llm.pytorch.core import PytorchChatModel
|
|
268
|
+
|
|
269
|
+
return (
|
|
270
|
+
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
271
|
+
and isinstance(self._model, PytorchChatModel)
|
|
272
|
+
and self._model.__class__.__name__ == PytorchChatModel.__name__
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
async def load(self):
|
|
239
276
|
self._model.load()
|
|
277
|
+
if self.allow_batching():
|
|
278
|
+
await self._scheduler_ref.set_model(self._model)
|
|
279
|
+
logger.debug(
|
|
280
|
+
f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
|
|
281
|
+
)
|
|
240
282
|
|
|
241
283
|
def model_uid(self):
|
|
242
284
|
return (
|
|
@@ -343,6 +385,8 @@ class ModelActor(xo.StatelessActor):
|
|
|
343
385
|
gen = self._to_json_async_gen(ret)
|
|
344
386
|
self._current_generator = weakref.ref(gen)
|
|
345
387
|
return gen
|
|
388
|
+
if isinstance(ret, bytes):
|
|
389
|
+
return ret
|
|
346
390
|
return await asyncio.to_thread(json_dumps, ret)
|
|
347
391
|
|
|
348
392
|
@log_async(logger=logger)
|
|
@@ -359,6 +403,36 @@ class ModelActor(xo.StatelessActor):
|
|
|
359
403
|
)
|
|
360
404
|
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
|
|
361
405
|
|
|
406
|
+
async def _queue_consumer(
|
|
407
|
+
self, queue: Queue, timeout: Optional[float] = None
|
|
408
|
+
) -> AsyncIterator[Any]:
|
|
409
|
+
from .scheduler import (
|
|
410
|
+
XINFERENCE_STREAMING_ABORT_FLAG,
|
|
411
|
+
XINFERENCE_STREAMING_DONE_FLAG,
|
|
412
|
+
XINFERENCE_STREAMING_ERROR_FLAG,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
while True:
|
|
416
|
+
# TODO: timeout setting
|
|
417
|
+
res = await wait_for(queue.get(), timeout)
|
|
418
|
+
if res == XINFERENCE_STREAMING_DONE_FLAG:
|
|
419
|
+
break
|
|
420
|
+
elif res == XINFERENCE_STREAMING_ABORT_FLAG:
|
|
421
|
+
raise RuntimeError(
|
|
422
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
423
|
+
)
|
|
424
|
+
elif isinstance(res, str) and res.startswith(
|
|
425
|
+
XINFERENCE_STREAMING_ERROR_FLAG
|
|
426
|
+
):
|
|
427
|
+
raise RuntimeError(res[len(XINFERENCE_STREAMING_ERROR_FLAG) :])
|
|
428
|
+
else:
|
|
429
|
+
yield res
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def get_stream_from_args(*args) -> bool:
|
|
433
|
+
assert args[2] is None or isinstance(args[2], dict)
|
|
434
|
+
return False if args[2] is None else args[2].get("stream", False)
|
|
435
|
+
|
|
362
436
|
@log_async(logger=logger)
|
|
363
437
|
@request_limit
|
|
364
438
|
@xo.generator
|
|
@@ -366,17 +440,46 @@ class ModelActor(xo.StatelessActor):
|
|
|
366
440
|
start_time = time.time()
|
|
367
441
|
response = None
|
|
368
442
|
try:
|
|
369
|
-
if
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
self.
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
443
|
+
if self.allow_batching():
|
|
444
|
+
stream = self.get_stream_from_args(*args)
|
|
445
|
+
assert self._scheduler_ref is not None
|
|
446
|
+
if stream:
|
|
447
|
+
assert self._scheduler_ref is not None
|
|
448
|
+
queue: Queue[Any] = Queue()
|
|
449
|
+
ret = self._queue_consumer(queue)
|
|
450
|
+
await self._scheduler_ref.add_request(
|
|
451
|
+
prompt, queue, *args, **kwargs
|
|
452
|
+
)
|
|
453
|
+
gen = self._to_json_async_gen(ret)
|
|
454
|
+
self._current_generator = weakref.ref(gen)
|
|
455
|
+
return gen
|
|
456
|
+
else:
|
|
457
|
+
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
458
|
+
|
|
459
|
+
assert self._loop is not None
|
|
460
|
+
future = ConcurrentFuture()
|
|
461
|
+
await self._scheduler_ref.add_request(
|
|
462
|
+
prompt, future, *args, **kwargs
|
|
463
|
+
)
|
|
464
|
+
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
465
|
+
result = await fut
|
|
466
|
+
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
467
|
+
raise RuntimeError(
|
|
468
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
469
|
+
)
|
|
470
|
+
return await asyncio.to_thread(json_dumps, result)
|
|
471
|
+
else:
|
|
472
|
+
if hasattr(self._model, "chat"):
|
|
473
|
+
response = await self._call_wrapper(
|
|
474
|
+
self._model.chat, prompt, *args, **kwargs
|
|
475
|
+
)
|
|
476
|
+
return response
|
|
477
|
+
if hasattr(self._model, "async_chat"):
|
|
478
|
+
response = await self._call_wrapper(
|
|
479
|
+
self._model.async_chat, prompt, *args, **kwargs
|
|
480
|
+
)
|
|
481
|
+
return response
|
|
482
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
|
|
380
483
|
finally:
|
|
381
484
|
# For the non stream result.
|
|
382
485
|
record = None
|
|
@@ -395,6 +498,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
395
498
|
prompt_tokens,
|
|
396
499
|
)
|
|
397
500
|
|
|
501
|
+
async def abort_request(self, request_id: str) -> str:
|
|
502
|
+
from .scheduler import AbortRequestMessage
|
|
503
|
+
|
|
504
|
+
if self.allow_batching():
|
|
505
|
+
if self._scheduler_ref is None:
|
|
506
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
507
|
+
return await self._scheduler_ref.abort_request(request_id)
|
|
508
|
+
return AbortRequestMessage.NO_OP.name
|
|
509
|
+
|
|
398
510
|
@log_async(logger=logger)
|
|
399
511
|
@request_limit
|
|
400
512
|
async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
|
|
@@ -482,6 +594,23 @@ class ModelActor(xo.StatelessActor):
|
|
|
482
594
|
f"Model {self._model.model_spec} is not for creating translations."
|
|
483
595
|
)
|
|
484
596
|
|
|
597
|
+
@log_async(logger=logger)
|
|
598
|
+
@request_limit
|
|
599
|
+
async def speech(
|
|
600
|
+
self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
|
|
601
|
+
):
|
|
602
|
+
if hasattr(self._model, "speech"):
|
|
603
|
+
return await self._call_wrapper(
|
|
604
|
+
self._model.speech,
|
|
605
|
+
input,
|
|
606
|
+
voice,
|
|
607
|
+
response_format,
|
|
608
|
+
speed,
|
|
609
|
+
)
|
|
610
|
+
raise AttributeError(
|
|
611
|
+
f"Model {self._model.model_spec} is not for creating speech."
|
|
612
|
+
)
|
|
613
|
+
|
|
485
614
|
@log_async(logger=logger)
|
|
486
615
|
@request_limit
|
|
487
616
|
async def text_to_image(
|