xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +99 -5
- xinference/client/restful/restful_client.py +98 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +85 -26
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/chattts.py +40 -8
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/core.py +3 -0
- xinference/model/image/model_spec.json +21 -0
- xinference/model/image/stable_diffusion/core.py +49 -7
- xinference/model/llm/llm_family.json +1065 -106
- xinference/model/llm/llm_family.py +26 -6
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +460 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/sglang/core.py +7 -2
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +11 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
xinference/__init__.py
CHANGED
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-07-
|
|
11
|
+
"date": "2024-07-26T18:42:50+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.13.
|
|
14
|
+
"full-revisionid": "aa51ff22dbfb5644554436270deaf57a7ebaf066",
|
|
15
|
+
"version": "0.13.3"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -129,6 +129,8 @@ class SpeechRequest(BaseModel):
|
|
|
129
129
|
voice: Optional[str]
|
|
130
130
|
response_format: Optional[str] = "mp3"
|
|
131
131
|
speed: Optional[float] = 1.0
|
|
132
|
+
stream: Optional[bool] = False
|
|
133
|
+
kwargs: Optional[str] = None
|
|
132
134
|
|
|
133
135
|
|
|
134
136
|
class RegisterModelRequest(BaseModel):
|
|
@@ -491,6 +493,17 @@ class RESTfulAPI:
|
|
|
491
493
|
else None
|
|
492
494
|
),
|
|
493
495
|
)
|
|
496
|
+
self._router.add_api_route(
|
|
497
|
+
"/v1/images/inpainting",
|
|
498
|
+
self.create_inpainting,
|
|
499
|
+
methods=["POST"],
|
|
500
|
+
response_model=ImageList,
|
|
501
|
+
dependencies=(
|
|
502
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
503
|
+
if self.is_authenticated()
|
|
504
|
+
else None
|
|
505
|
+
),
|
|
506
|
+
)
|
|
494
507
|
self._router.add_api_route(
|
|
495
508
|
"/v1/chat/completions",
|
|
496
509
|
self.create_chat_completion,
|
|
@@ -1297,8 +1310,18 @@ class RESTfulAPI:
|
|
|
1297
1310
|
await self._report_error_event(model_uid, str(e))
|
|
1298
1311
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1299
1312
|
|
|
1300
|
-
async def create_speech(
|
|
1301
|
-
|
|
1313
|
+
async def create_speech(
|
|
1314
|
+
self,
|
|
1315
|
+
request: Request,
|
|
1316
|
+
prompt_speech: Optional[UploadFile] = File(
|
|
1317
|
+
None, media_type="application/octet-stream"
|
|
1318
|
+
),
|
|
1319
|
+
) -> Response:
|
|
1320
|
+
if prompt_speech:
|
|
1321
|
+
f = await request.form()
|
|
1322
|
+
else:
|
|
1323
|
+
f = await request.json()
|
|
1324
|
+
body = SpeechRequest.parse_obj(f)
|
|
1302
1325
|
model_uid = body.model
|
|
1303
1326
|
try:
|
|
1304
1327
|
model = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
@@ -1312,13 +1335,26 @@ class RESTfulAPI:
|
|
|
1312
1335
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1313
1336
|
|
|
1314
1337
|
try:
|
|
1338
|
+
if body.kwargs is not None:
|
|
1339
|
+
parsed_kwargs = json.loads(body.kwargs)
|
|
1340
|
+
else:
|
|
1341
|
+
parsed_kwargs = {}
|
|
1342
|
+
if prompt_speech is not None:
|
|
1343
|
+
parsed_kwargs["prompt_speech"] = await prompt_speech.read()
|
|
1315
1344
|
out = await model.speech(
|
|
1316
1345
|
input=body.input,
|
|
1317
1346
|
voice=body.voice,
|
|
1318
1347
|
response_format=body.response_format,
|
|
1319
1348
|
speed=body.speed,
|
|
1349
|
+
stream=body.stream,
|
|
1350
|
+
**parsed_kwargs,
|
|
1320
1351
|
)
|
|
1321
|
-
|
|
1352
|
+
if body.stream:
|
|
1353
|
+
return EventSourceResponse(
|
|
1354
|
+
media_type="application/octet-stream", content=out
|
|
1355
|
+
)
|
|
1356
|
+
else:
|
|
1357
|
+
return Response(media_type="application/octet-stream", content=out)
|
|
1322
1358
|
except RuntimeError as re:
|
|
1323
1359
|
logger.error(re, exc_info=True)
|
|
1324
1360
|
await self._report_error_event(model_uid, str(re))
|
|
@@ -1410,6 +1446,60 @@ class RESTfulAPI:
|
|
|
1410
1446
|
await self._report_error_event(model_uid, str(e))
|
|
1411
1447
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1412
1448
|
|
|
1449
|
+
async def create_inpainting(
|
|
1450
|
+
self,
|
|
1451
|
+
model: str = Form(...),
|
|
1452
|
+
image: UploadFile = File(media_type="application/octet-stream"),
|
|
1453
|
+
mask_image: UploadFile = File(media_type="application/octet-stream"),
|
|
1454
|
+
prompt: Optional[Union[str, List[str]]] = Form(None),
|
|
1455
|
+
negative_prompt: Optional[Union[str, List[str]]] = Form(None),
|
|
1456
|
+
n: Optional[int] = Form(1),
|
|
1457
|
+
response_format: Optional[str] = Form("url"),
|
|
1458
|
+
size: Optional[str] = Form(None),
|
|
1459
|
+
kwargs: Optional[str] = Form(None),
|
|
1460
|
+
) -> Response:
|
|
1461
|
+
model_uid = model
|
|
1462
|
+
try:
|
|
1463
|
+
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
|
|
1464
|
+
except ValueError as ve:
|
|
1465
|
+
logger.error(str(ve), exc_info=True)
|
|
1466
|
+
await self._report_error_event(model_uid, str(ve))
|
|
1467
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
1468
|
+
except Exception as e:
|
|
1469
|
+
logger.error(e, exc_info=True)
|
|
1470
|
+
await self._report_error_event(model_uid, str(e))
|
|
1471
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1472
|
+
|
|
1473
|
+
try:
|
|
1474
|
+
if kwargs is not None:
|
|
1475
|
+
parsed_kwargs = json.loads(kwargs)
|
|
1476
|
+
else:
|
|
1477
|
+
parsed_kwargs = {}
|
|
1478
|
+
im = Image.open(image.file)
|
|
1479
|
+
mask_im = Image.open(mask_image.file)
|
|
1480
|
+
if not size:
|
|
1481
|
+
w, h = im.size
|
|
1482
|
+
size = f"{w}*{h}"
|
|
1483
|
+
image_list = await model_ref.inpainting(
|
|
1484
|
+
image=im,
|
|
1485
|
+
mask_image=mask_im,
|
|
1486
|
+
prompt=prompt,
|
|
1487
|
+
negative_prompt=negative_prompt,
|
|
1488
|
+
n=n,
|
|
1489
|
+
size=size,
|
|
1490
|
+
response_format=response_format,
|
|
1491
|
+
**parsed_kwargs,
|
|
1492
|
+
)
|
|
1493
|
+
return Response(content=image_list, media_type="application/json")
|
|
1494
|
+
except RuntimeError as re:
|
|
1495
|
+
logger.error(re, exc_info=True)
|
|
1496
|
+
await self._report_error_event(model_uid, str(re))
|
|
1497
|
+
raise HTTPException(status_code=400, detail=str(re))
|
|
1498
|
+
except Exception as e:
|
|
1499
|
+
logger.error(e, exc_info=True)
|
|
1500
|
+
await self._report_error_event(model_uid, str(e))
|
|
1501
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1502
|
+
|
|
1413
1503
|
async def create_flexible_infer(self, request: Request) -> Response:
|
|
1414
1504
|
payload = await request.json()
|
|
1415
1505
|
|
|
@@ -1554,10 +1644,14 @@ class RESTfulAPI:
|
|
|
1554
1644
|
if body.tools and body.stream:
|
|
1555
1645
|
is_vllm = await model.is_vllm_backend()
|
|
1556
1646
|
|
|
1557
|
-
if not
|
|
1647
|
+
if not (
|
|
1648
|
+
(is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
|
|
1649
|
+
or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
|
|
1650
|
+
):
|
|
1558
1651
|
raise HTTPException(
|
|
1559
1652
|
status_code=400,
|
|
1560
|
-
detail="Streaming support for tool calls is available only when using
|
|
1653
|
+
detail="Streaming support for tool calls is available only when using "
|
|
1654
|
+
"Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
|
|
1561
1655
|
)
|
|
1562
1656
|
|
|
1563
1657
|
if body.stream:
|
|
@@ -294,6 +294,81 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
|
|
|
294
294
|
response_data = response.json()
|
|
295
295
|
return response_data
|
|
296
296
|
|
|
297
|
+
def inpainting(
|
|
298
|
+
self,
|
|
299
|
+
image: Union[str, bytes],
|
|
300
|
+
mask_image: Union[str, bytes],
|
|
301
|
+
prompt: str,
|
|
302
|
+
negative_prompt: Optional[str] = None,
|
|
303
|
+
n: int = 1,
|
|
304
|
+
size: Optional[str] = None,
|
|
305
|
+
response_format: str = "url",
|
|
306
|
+
**kwargs,
|
|
307
|
+
) -> "ImageList":
|
|
308
|
+
"""
|
|
309
|
+
Inpaint an image by the input text.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
image: `Union[str, bytes]`
|
|
314
|
+
an image batch to be inpainted (which parts of the image to
|
|
315
|
+
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
|
|
316
|
+
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
|
|
317
|
+
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
|
|
318
|
+
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
|
|
319
|
+
if passing latents directly it is not encoded again.
|
|
320
|
+
mask_image: `Union[str, bytes]`
|
|
321
|
+
representing an image batch to mask `image`. White pixels in the mask
|
|
322
|
+
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
|
323
|
+
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
|
324
|
+
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
|
325
|
+
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
|
326
|
+
1)`, or `(H, W)`.
|
|
327
|
+
prompt: `str` or `List[str]`
|
|
328
|
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
|
329
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
|
330
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
331
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
332
|
+
less than `1`).
|
|
333
|
+
n: `int`, defaults to 1
|
|
334
|
+
The number of images to generate per prompt. Must be between 1 and 10.
|
|
335
|
+
size: `str`, defaults to None
|
|
336
|
+
The width*height in pixels of the generated image.
|
|
337
|
+
response_format: `str`, defaults to `url`
|
|
338
|
+
The format in which the generated images are returned. Must be one of url or b64_json.
|
|
339
|
+
Returns
|
|
340
|
+
-------
|
|
341
|
+
ImageList
|
|
342
|
+
A list of image objects.
|
|
343
|
+
:param prompt:
|
|
344
|
+
:param image:
|
|
345
|
+
"""
|
|
346
|
+
url = f"{self._base_url}/v1/images/inpainting"
|
|
347
|
+
params = {
|
|
348
|
+
"model": self._model_uid,
|
|
349
|
+
"prompt": prompt,
|
|
350
|
+
"negative_prompt": negative_prompt,
|
|
351
|
+
"n": n,
|
|
352
|
+
"size": size,
|
|
353
|
+
"response_format": response_format,
|
|
354
|
+
"kwargs": json.dumps(kwargs),
|
|
355
|
+
}
|
|
356
|
+
files: List[Any] = []
|
|
357
|
+
for key, value in params.items():
|
|
358
|
+
files.append((key, (None, value)))
|
|
359
|
+
files.append(("image", ("image", image, "application/octet-stream")))
|
|
360
|
+
files.append(
|
|
361
|
+
("mask_image", ("mask_image", mask_image, "application/octet-stream"))
|
|
362
|
+
)
|
|
363
|
+
response = requests.post(url, files=files, headers=self.auth_headers)
|
|
364
|
+
if response.status_code != 200:
|
|
365
|
+
raise RuntimeError(
|
|
366
|
+
f"Failed to inpaint the images, detail: {_get_error_string(response)}"
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
response_data = response.json()
|
|
370
|
+
return response_data
|
|
371
|
+
|
|
297
372
|
|
|
298
373
|
class RESTfulGenerateModelHandle(RESTfulModelHandle):
|
|
299
374
|
def generate(
|
|
@@ -692,6 +767,9 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
692
767
|
voice: str = "",
|
|
693
768
|
response_format: str = "mp3",
|
|
694
769
|
speed: float = 1.0,
|
|
770
|
+
stream: bool = False,
|
|
771
|
+
prompt_speech: Optional[bytes] = None,
|
|
772
|
+
**kwargs,
|
|
695
773
|
):
|
|
696
774
|
"""
|
|
697
775
|
Generates audio from the input text.
|
|
@@ -707,6 +785,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
707
785
|
The format to audio in.
|
|
708
786
|
speed: str
|
|
709
787
|
The speed of the generated audio.
|
|
788
|
+
stream: bool
|
|
789
|
+
Use stream or not.
|
|
710
790
|
|
|
711
791
|
Returns
|
|
712
792
|
-------
|
|
@@ -720,13 +800,30 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
|
|
|
720
800
|
"voice": voice,
|
|
721
801
|
"response_format": response_format,
|
|
722
802
|
"speed": speed,
|
|
803
|
+
"stream": stream,
|
|
804
|
+
"kwargs": json.dumps(kwargs),
|
|
723
805
|
}
|
|
724
|
-
|
|
806
|
+
if prompt_speech:
|
|
807
|
+
files: List[Any] = []
|
|
808
|
+
files.append(
|
|
809
|
+
(
|
|
810
|
+
"prompt_speech",
|
|
811
|
+
("prompt_speech", prompt_speech, "application/octet-stream"),
|
|
812
|
+
)
|
|
813
|
+
)
|
|
814
|
+
response = requests.post(
|
|
815
|
+
url, data=params, files=files, headers=self.auth_headers
|
|
816
|
+
)
|
|
817
|
+
else:
|
|
818
|
+
response = requests.post(url, json=params, headers=self.auth_headers)
|
|
725
819
|
if response.status_code != 200:
|
|
726
820
|
raise RuntimeError(
|
|
727
821
|
f"Failed to speech the text, detail: {_get_error_string(response)}"
|
|
728
822
|
)
|
|
729
823
|
|
|
824
|
+
if stream:
|
|
825
|
+
return response.iter_content(chunk_size=1024)
|
|
826
|
+
|
|
730
827
|
return response.content
|
|
731
828
|
|
|
732
829
|
|
|
@@ -428,7 +428,7 @@ class GradioInterface:
|
|
|
428
428
|
}
|
|
429
429
|
|
|
430
430
|
hist.append(response_content)
|
|
431
|
-
return {
|
|
431
|
+
return { # type: ignore
|
|
432
432
|
textbox: response_content,
|
|
433
433
|
history: hist,
|
|
434
434
|
}
|
|
@@ -467,7 +467,7 @@ class GradioInterface:
|
|
|
467
467
|
}
|
|
468
468
|
|
|
469
469
|
hist.append(response_content)
|
|
470
|
-
return {
|
|
470
|
+
return { # type: ignore
|
|
471
471
|
textbox: response_content,
|
|
472
472
|
history: hist,
|
|
473
473
|
}
|
xinference/core/model.py
CHANGED
|
@@ -310,7 +310,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
310
310
|
)
|
|
311
311
|
)
|
|
312
312
|
|
|
313
|
-
def
|
|
313
|
+
def _to_generator(self, output_type: str, gen: types.GeneratorType):
|
|
314
314
|
start_time = time.time()
|
|
315
315
|
time_to_first_token = None
|
|
316
316
|
final_usage = None
|
|
@@ -318,8 +318,13 @@ class ModelActor(xo.StatelessActor):
|
|
|
318
318
|
for v in gen:
|
|
319
319
|
if time_to_first_token is None:
|
|
320
320
|
time_to_first_token = (time.time() - start_time) * 1000
|
|
321
|
-
|
|
322
|
-
|
|
321
|
+
if output_type == "json":
|
|
322
|
+
final_usage = v.get("usage", None)
|
|
323
|
+
v = dict(data=json.dumps(v, ensure_ascii=False))
|
|
324
|
+
else:
|
|
325
|
+
assert (
|
|
326
|
+
output_type == "binary"
|
|
327
|
+
), f"Unknown output type '{output_type}'"
|
|
323
328
|
yield sse_starlette.sse.ensure_bytes(v, None)
|
|
324
329
|
except OutOfMemoryError:
|
|
325
330
|
logger.exception(
|
|
@@ -342,7 +347,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
342
347
|
)
|
|
343
348
|
asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
|
|
344
349
|
|
|
345
|
-
async def
|
|
350
|
+
async def _to_async_gen(self, output_type: str, gen: types.AsyncGeneratorType):
|
|
346
351
|
start_time = time.time()
|
|
347
352
|
time_to_first_token = None
|
|
348
353
|
final_usage = None
|
|
@@ -351,8 +356,13 @@ class ModelActor(xo.StatelessActor):
|
|
|
351
356
|
if time_to_first_token is None:
|
|
352
357
|
time_to_first_token = (time.time() - start_time) * 1000
|
|
353
358
|
final_usage = v.get("usage", None)
|
|
354
|
-
|
|
355
|
-
|
|
359
|
+
if output_type == "json":
|
|
360
|
+
v = await asyncio.to_thread(json.dumps, v, ensure_ascii=False)
|
|
361
|
+
v = dict(data=v) # noqa: F821
|
|
362
|
+
else:
|
|
363
|
+
assert (
|
|
364
|
+
output_type == "binary"
|
|
365
|
+
), f"Unknown output type '{output_type}'"
|
|
356
366
|
yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
|
|
357
367
|
except OutOfMemoryError:
|
|
358
368
|
logger.exception(
|
|
@@ -379,8 +389,14 @@ class ModelActor(xo.StatelessActor):
|
|
|
379
389
|
)
|
|
380
390
|
await asyncio.gather(*coros)
|
|
381
391
|
|
|
392
|
+
async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
|
|
393
|
+
return await self._call_wrapper("json", fn, *args, **kwargs)
|
|
394
|
+
|
|
395
|
+
async def _call_wrapper_binary(self, fn: Callable, *args, **kwargs):
|
|
396
|
+
return await self._call_wrapper("binary", fn, *args, **kwargs)
|
|
397
|
+
|
|
382
398
|
@oom_check
|
|
383
|
-
async def _call_wrapper(self, fn: Callable, *args, **kwargs):
|
|
399
|
+
async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
|
|
384
400
|
if self._lock is None:
|
|
385
401
|
if inspect.iscoroutinefunction(fn):
|
|
386
402
|
ret = await fn(*args, **kwargs)
|
|
@@ -397,16 +413,18 @@ class ModelActor(xo.StatelessActor):
|
|
|
397
413
|
raise Exception("Parallel generation is not supported by ggml.")
|
|
398
414
|
|
|
399
415
|
if inspect.isgenerator(ret):
|
|
400
|
-
gen = self.
|
|
416
|
+
gen = self._to_generator(output_type, ret)
|
|
401
417
|
self._current_generator = weakref.ref(gen)
|
|
402
418
|
return gen
|
|
403
419
|
if inspect.isasyncgen(ret):
|
|
404
|
-
gen = self.
|
|
420
|
+
gen = self._to_async_gen(output_type, ret)
|
|
405
421
|
self._current_generator = weakref.ref(gen)
|
|
406
422
|
return gen
|
|
407
|
-
if
|
|
423
|
+
if output_type == "json":
|
|
424
|
+
return await asyncio.to_thread(json_dumps, ret)
|
|
425
|
+
else:
|
|
426
|
+
assert output_type == "binary", f"Unknown output type '{output_type}'"
|
|
408
427
|
return ret
|
|
409
|
-
return await asyncio.to_thread(json_dumps, ret)
|
|
410
428
|
|
|
411
429
|
@log_async(logger=logger)
|
|
412
430
|
@request_limit
|
|
@@ -419,11 +437,11 @@ class ModelActor(xo.StatelessActor):
|
|
|
419
437
|
else:
|
|
420
438
|
kwargs.pop("raw_params", None)
|
|
421
439
|
if hasattr(self._model, "generate"):
|
|
422
|
-
return await self.
|
|
440
|
+
return await self._call_wrapper_json(
|
|
423
441
|
self._model.generate, prompt, *args, **kwargs
|
|
424
442
|
)
|
|
425
443
|
if hasattr(self._model, "async_generate"):
|
|
426
|
-
return await self.
|
|
444
|
+
return await self._call_wrapper_json(
|
|
427
445
|
self._model.async_generate, prompt, *args, **kwargs
|
|
428
446
|
)
|
|
429
447
|
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
|
|
@@ -471,7 +489,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
471
489
|
queue: Queue[Any] = Queue()
|
|
472
490
|
ret = self._queue_consumer(queue)
|
|
473
491
|
await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
|
|
474
|
-
gen = self.
|
|
492
|
+
gen = self._to_async_gen("json", ret)
|
|
475
493
|
self._current_generator = weakref.ref(gen)
|
|
476
494
|
return gen
|
|
477
495
|
else:
|
|
@@ -502,12 +520,12 @@ class ModelActor(xo.StatelessActor):
|
|
|
502
520
|
else:
|
|
503
521
|
kwargs.pop("raw_params", None)
|
|
504
522
|
if hasattr(self._model, "chat"):
|
|
505
|
-
response = await self.
|
|
523
|
+
response = await self._call_wrapper_json(
|
|
506
524
|
self._model.chat, prompt, *args, **kwargs
|
|
507
525
|
)
|
|
508
526
|
return response
|
|
509
527
|
if hasattr(self._model, "async_chat"):
|
|
510
|
-
response = await self.
|
|
528
|
+
response = await self._call_wrapper_json(
|
|
511
529
|
self._model.async_chat, prompt, *args, **kwargs
|
|
512
530
|
)
|
|
513
531
|
return response
|
|
@@ -543,7 +561,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
543
561
|
@request_limit
|
|
544
562
|
async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
|
|
545
563
|
if hasattr(self._model, "create_embedding"):
|
|
546
|
-
return await self.
|
|
564
|
+
return await self._call_wrapper_json(
|
|
547
565
|
self._model.create_embedding, input, *args, **kwargs
|
|
548
566
|
)
|
|
549
567
|
|
|
@@ -565,7 +583,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
565
583
|
**kwargs,
|
|
566
584
|
):
|
|
567
585
|
if hasattr(self._model, "rerank"):
|
|
568
|
-
return await self.
|
|
586
|
+
return await self._call_wrapper_json(
|
|
569
587
|
self._model.rerank,
|
|
570
588
|
documents,
|
|
571
589
|
query,
|
|
@@ -590,7 +608,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
590
608
|
timestamp_granularities: Optional[List[str]] = None,
|
|
591
609
|
):
|
|
592
610
|
if hasattr(self._model, "transcriptions"):
|
|
593
|
-
return await self.
|
|
611
|
+
return await self._call_wrapper_json(
|
|
594
612
|
self._model.transcriptions,
|
|
595
613
|
audio,
|
|
596
614
|
language,
|
|
@@ -615,7 +633,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
615
633
|
timestamp_granularities: Optional[List[str]] = None,
|
|
616
634
|
):
|
|
617
635
|
if hasattr(self._model, "translations"):
|
|
618
|
-
return await self.
|
|
636
|
+
return await self._call_wrapper_json(
|
|
619
637
|
self._model.translations,
|
|
620
638
|
audio,
|
|
621
639
|
language,
|
|
@@ -628,18 +646,30 @@ class ModelActor(xo.StatelessActor):
|
|
|
628
646
|
f"Model {self._model.model_spec} is not for creating translations."
|
|
629
647
|
)
|
|
630
648
|
|
|
631
|
-
@log_async(
|
|
649
|
+
@log_async(
|
|
650
|
+
logger=logger,
|
|
651
|
+
args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
|
|
652
|
+
)
|
|
632
653
|
@request_limit
|
|
654
|
+
@xo.generator
|
|
633
655
|
async def speech(
|
|
634
|
-
self,
|
|
656
|
+
self,
|
|
657
|
+
input: str,
|
|
658
|
+
voice: str,
|
|
659
|
+
response_format: str = "mp3",
|
|
660
|
+
speed: float = 1.0,
|
|
661
|
+
stream: bool = False,
|
|
662
|
+
**kwargs,
|
|
635
663
|
):
|
|
636
664
|
if hasattr(self._model, "speech"):
|
|
637
|
-
return await self.
|
|
665
|
+
return await self._call_wrapper_binary(
|
|
638
666
|
self._model.speech,
|
|
639
667
|
input,
|
|
640
668
|
voice,
|
|
641
669
|
response_format,
|
|
642
670
|
speed,
|
|
671
|
+
stream,
|
|
672
|
+
**kwargs,
|
|
643
673
|
)
|
|
644
674
|
raise AttributeError(
|
|
645
675
|
f"Model {self._model.model_spec} is not for creating speech."
|
|
@@ -657,7 +687,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
657
687
|
**kwargs,
|
|
658
688
|
):
|
|
659
689
|
if hasattr(self._model, "text_to_image"):
|
|
660
|
-
return await self.
|
|
690
|
+
return await self._call_wrapper_json(
|
|
661
691
|
self._model.text_to_image,
|
|
662
692
|
prompt,
|
|
663
693
|
n,
|
|
@@ -682,7 +712,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
682
712
|
**kwargs,
|
|
683
713
|
):
|
|
684
714
|
if hasattr(self._model, "image_to_image"):
|
|
685
|
-
return await self.
|
|
715
|
+
return await self._call_wrapper_json(
|
|
686
716
|
self._model.image_to_image,
|
|
687
717
|
image,
|
|
688
718
|
prompt,
|
|
@@ -697,6 +727,35 @@ class ModelActor(xo.StatelessActor):
|
|
|
697
727
|
f"Model {self._model.model_spec} is not for creating image."
|
|
698
728
|
)
|
|
699
729
|
|
|
730
|
+
async def inpainting(
|
|
731
|
+
self,
|
|
732
|
+
image: "PIL.Image",
|
|
733
|
+
mask_image: "PIL.Image",
|
|
734
|
+
prompt: str,
|
|
735
|
+
negative_prompt: str,
|
|
736
|
+
n: int = 1,
|
|
737
|
+
size: str = "1024*1024",
|
|
738
|
+
response_format: str = "url",
|
|
739
|
+
*args,
|
|
740
|
+
**kwargs,
|
|
741
|
+
):
|
|
742
|
+
if hasattr(self._model, "inpainting"):
|
|
743
|
+
return await self._call_wrapper_json(
|
|
744
|
+
self._model.inpainting,
|
|
745
|
+
image,
|
|
746
|
+
mask_image,
|
|
747
|
+
prompt,
|
|
748
|
+
negative_prompt,
|
|
749
|
+
n,
|
|
750
|
+
size,
|
|
751
|
+
response_format,
|
|
752
|
+
*args,
|
|
753
|
+
**kwargs,
|
|
754
|
+
)
|
|
755
|
+
raise AttributeError(
|
|
756
|
+
f"Model {self._model.model_spec} is not for creating image."
|
|
757
|
+
)
|
|
758
|
+
|
|
700
759
|
@log_async(logger=logger)
|
|
701
760
|
@request_limit
|
|
702
761
|
async def infer(
|
|
@@ -704,7 +763,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
704
763
|
**kwargs,
|
|
705
764
|
):
|
|
706
765
|
if hasattr(self._model, "infer"):
|
|
707
|
-
return await self.
|
|
766
|
+
return await self._call_wrapper_json(
|
|
708
767
|
self._model.infer,
|
|
709
768
|
**kwargs,
|
|
710
769
|
)
|
xinference/core/scheduler.py
CHANGED
|
@@ -81,7 +81,7 @@ class InferenceRequest:
|
|
|
81
81
|
self.future_or_queue = future_or_queue
|
|
82
82
|
# Record error message when this request has error.
|
|
83
83
|
# Must set stopped=True when this field is set.
|
|
84
|
-
self.error_msg: Optional[str] = None
|
|
84
|
+
self.error_msg: Optional[str] = None # type: ignore
|
|
85
85
|
# For compatibility. Record some extra parameters for some special cases.
|
|
86
86
|
self.extra_kwargs = {}
|
|
87
87
|
|
|
@@ -295,11 +295,11 @@ class SchedulerActor(xo.StatelessActor):
|
|
|
295
295
|
|
|
296
296
|
def __init__(self):
|
|
297
297
|
super().__init__()
|
|
298
|
-
self._waiting_queue: deque[InferenceRequest] = deque()
|
|
299
|
-
self._running_queue: deque[InferenceRequest] = deque()
|
|
298
|
+
self._waiting_queue: deque[InferenceRequest] = deque() # type: ignore
|
|
299
|
+
self._running_queue: deque[InferenceRequest] = deque() # type: ignore
|
|
300
300
|
self._model = None
|
|
301
301
|
self._id_to_req = {}
|
|
302
|
-
self._abort_req_ids: Set[str] = set()
|
|
302
|
+
self._abort_req_ids: Set[str] = set() # type: ignore
|
|
303
303
|
self._isolation = None
|
|
304
304
|
|
|
305
305
|
async def __post_create__(self):
|
|
@@ -48,7 +48,12 @@ class ChatTTSModel:
|
|
|
48
48
|
self._model.load(source="custom", custom_path=self._model_path, compile=True)
|
|
49
49
|
|
|
50
50
|
def speech(
|
|
51
|
-
self,
|
|
51
|
+
self,
|
|
52
|
+
input: str,
|
|
53
|
+
voice: str,
|
|
54
|
+
response_format: str = "mp3",
|
|
55
|
+
speed: float = 1.0,
|
|
56
|
+
stream: bool = False,
|
|
52
57
|
):
|
|
53
58
|
import ChatTTS
|
|
54
59
|
import numpy as np
|
|
@@ -74,11 +79,38 @@ class ChatTTSModel:
|
|
|
74
79
|
)
|
|
75
80
|
|
|
76
81
|
assert self._model is not None
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
with BytesIO() as out:
|
|
81
|
-
torchaudio.save(
|
|
82
|
-
out, torch.from_numpy(wavs[0]), 24000, format=response_format
|
|
82
|
+
if stream:
|
|
83
|
+
iter = self._model.infer(
|
|
84
|
+
[input], params_infer_code=params_infer_code, stream=True
|
|
83
85
|
)
|
|
84
|
-
|
|
86
|
+
|
|
87
|
+
def _generator():
|
|
88
|
+
with BytesIO() as out:
|
|
89
|
+
writer = torchaudio.io.StreamWriter(out, format=response_format)
|
|
90
|
+
writer.add_audio_stream(sample_rate=24000, num_channels=1)
|
|
91
|
+
i = 0
|
|
92
|
+
last_pos = 0
|
|
93
|
+
with writer.open():
|
|
94
|
+
for it in iter:
|
|
95
|
+
for itt in it:
|
|
96
|
+
for chunk in itt:
|
|
97
|
+
chunk = np.array([chunk]).transpose()
|
|
98
|
+
writer.write_audio_chunk(i, torch.from_numpy(chunk))
|
|
99
|
+
new_last_pos = out.tell()
|
|
100
|
+
if new_last_pos != last_pos:
|
|
101
|
+
out.seek(last_pos)
|
|
102
|
+
encoded_bytes = out.read()
|
|
103
|
+
print(len(encoded_bytes))
|
|
104
|
+
yield encoded_bytes
|
|
105
|
+
last_pos = new_last_pos
|
|
106
|
+
|
|
107
|
+
return _generator()
|
|
108
|
+
else:
|
|
109
|
+
wavs = self._model.infer([input], params_infer_code=params_infer_code)
|
|
110
|
+
|
|
111
|
+
# Save the generated audio
|
|
112
|
+
with BytesIO() as out:
|
|
113
|
+
torchaudio.save(
|
|
114
|
+
out, torch.from_numpy(wavs[0]), 24000, format=response_format
|
|
115
|
+
)
|
|
116
|
+
return out.getvalue()
|