xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (82) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +99 -5
  4. xinference/client/restful/restful_client.py +98 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +85 -26
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/chattts.py +40 -8
  9. xinference/model/audio/core.py +5 -2
  10. xinference/model/audio/cosyvoice.py +136 -0
  11. xinference/model/audio/model_spec.json +24 -0
  12. xinference/model/audio/model_spec_modelscope.json +27 -0
  13. xinference/model/flexible/launchers/__init__.py +1 -0
  14. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  15. xinference/model/image/core.py +3 -0
  16. xinference/model/image/model_spec.json +21 -0
  17. xinference/model/image/stable_diffusion/core.py +49 -7
  18. xinference/model/llm/llm_family.json +1065 -106
  19. xinference/model/llm/llm_family.py +26 -6
  20. xinference/model/llm/llm_family_csghub.json +39 -0
  21. xinference/model/llm/llm_family_modelscope.json +460 -47
  22. xinference/model/llm/pytorch/chatglm.py +243 -5
  23. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  24. xinference/model/llm/sglang/core.py +7 -2
  25. xinference/model/llm/utils.py +78 -1
  26. xinference/model/llm/vllm/core.py +11 -0
  27. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  29. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  30. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  31. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  34. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  35. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  36. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  37. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  38. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  39. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  40. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  41. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  42. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  43. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  44. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  45. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  46. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  47. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  48. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  50. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  51. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  52. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  53. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  54. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  55. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  56. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  57. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  58. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  59. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  60. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  63. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  64. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  65. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  66. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  67. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  68. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  72. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  74. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
  75. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
  76. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  78. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  79. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  80. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  81. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  82. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
xinference/__init__.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from . import _version
17
16
 
18
17
  __version__ = _version.get_versions()["version"]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-12T17:56:13+0800",
11
+ "date": "2024-07-26T18:42:50+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "5e3f254d48383f37d849dd16db564ad9449e5163",
15
- "version": "0.13.1"
14
+ "full-revisionid": "aa51ff22dbfb5644554436270deaf57a7ebaf066",
15
+ "version": "0.13.3"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -129,6 +129,8 @@ class SpeechRequest(BaseModel):
129
129
  voice: Optional[str]
130
130
  response_format: Optional[str] = "mp3"
131
131
  speed: Optional[float] = 1.0
132
+ stream: Optional[bool] = False
133
+ kwargs: Optional[str] = None
132
134
 
133
135
 
134
136
  class RegisterModelRequest(BaseModel):
@@ -491,6 +493,17 @@ class RESTfulAPI:
491
493
  else None
492
494
  ),
493
495
  )
496
+ self._router.add_api_route(
497
+ "/v1/images/inpainting",
498
+ self.create_inpainting,
499
+ methods=["POST"],
500
+ response_model=ImageList,
501
+ dependencies=(
502
+ [Security(self._auth_service, scopes=["models:read"])]
503
+ if self.is_authenticated()
504
+ else None
505
+ ),
506
+ )
494
507
  self._router.add_api_route(
495
508
  "/v1/chat/completions",
496
509
  self.create_chat_completion,
@@ -1297,8 +1310,18 @@ class RESTfulAPI:
1297
1310
  await self._report_error_event(model_uid, str(e))
1298
1311
  raise HTTPException(status_code=500, detail=str(e))
1299
1312
 
1300
- async def create_speech(self, request: Request) -> Response:
1301
- body = SpeechRequest.parse_obj(await request.json())
1313
+ async def create_speech(
1314
+ self,
1315
+ request: Request,
1316
+ prompt_speech: Optional[UploadFile] = File(
1317
+ None, media_type="application/octet-stream"
1318
+ ),
1319
+ ) -> Response:
1320
+ if prompt_speech:
1321
+ f = await request.form()
1322
+ else:
1323
+ f = await request.json()
1324
+ body = SpeechRequest.parse_obj(f)
1302
1325
  model_uid = body.model
1303
1326
  try:
1304
1327
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1312,13 +1335,26 @@ class RESTfulAPI:
1312
1335
  raise HTTPException(status_code=500, detail=str(e))
1313
1336
 
1314
1337
  try:
1338
+ if body.kwargs is not None:
1339
+ parsed_kwargs = json.loads(body.kwargs)
1340
+ else:
1341
+ parsed_kwargs = {}
1342
+ if prompt_speech is not None:
1343
+ parsed_kwargs["prompt_speech"] = await prompt_speech.read()
1315
1344
  out = await model.speech(
1316
1345
  input=body.input,
1317
1346
  voice=body.voice,
1318
1347
  response_format=body.response_format,
1319
1348
  speed=body.speed,
1349
+ stream=body.stream,
1350
+ **parsed_kwargs,
1320
1351
  )
1321
- return Response(media_type="application/octet-stream", content=out)
1352
+ if body.stream:
1353
+ return EventSourceResponse(
1354
+ media_type="application/octet-stream", content=out
1355
+ )
1356
+ else:
1357
+ return Response(media_type="application/octet-stream", content=out)
1322
1358
  except RuntimeError as re:
1323
1359
  logger.error(re, exc_info=True)
1324
1360
  await self._report_error_event(model_uid, str(re))
@@ -1410,6 +1446,60 @@ class RESTfulAPI:
1410
1446
  await self._report_error_event(model_uid, str(e))
1411
1447
  raise HTTPException(status_code=500, detail=str(e))
1412
1448
 
1449
+ async def create_inpainting(
1450
+ self,
1451
+ model: str = Form(...),
1452
+ image: UploadFile = File(media_type="application/octet-stream"),
1453
+ mask_image: UploadFile = File(media_type="application/octet-stream"),
1454
+ prompt: Optional[Union[str, List[str]]] = Form(None),
1455
+ negative_prompt: Optional[Union[str, List[str]]] = Form(None),
1456
+ n: Optional[int] = Form(1),
1457
+ response_format: Optional[str] = Form("url"),
1458
+ size: Optional[str] = Form(None),
1459
+ kwargs: Optional[str] = Form(None),
1460
+ ) -> Response:
1461
+ model_uid = model
1462
+ try:
1463
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
1464
+ except ValueError as ve:
1465
+ logger.error(str(ve), exc_info=True)
1466
+ await self._report_error_event(model_uid, str(ve))
1467
+ raise HTTPException(status_code=400, detail=str(ve))
1468
+ except Exception as e:
1469
+ logger.error(e, exc_info=True)
1470
+ await self._report_error_event(model_uid, str(e))
1471
+ raise HTTPException(status_code=500, detail=str(e))
1472
+
1473
+ try:
1474
+ if kwargs is not None:
1475
+ parsed_kwargs = json.loads(kwargs)
1476
+ else:
1477
+ parsed_kwargs = {}
1478
+ im = Image.open(image.file)
1479
+ mask_im = Image.open(mask_image.file)
1480
+ if not size:
1481
+ w, h = im.size
1482
+ size = f"{w}*{h}"
1483
+ image_list = await model_ref.inpainting(
1484
+ image=im,
1485
+ mask_image=mask_im,
1486
+ prompt=prompt,
1487
+ negative_prompt=negative_prompt,
1488
+ n=n,
1489
+ size=size,
1490
+ response_format=response_format,
1491
+ **parsed_kwargs,
1492
+ )
1493
+ return Response(content=image_list, media_type="application/json")
1494
+ except RuntimeError as re:
1495
+ logger.error(re, exc_info=True)
1496
+ await self._report_error_event(model_uid, str(re))
1497
+ raise HTTPException(status_code=400, detail=str(re))
1498
+ except Exception as e:
1499
+ logger.error(e, exc_info=True)
1500
+ await self._report_error_event(model_uid, str(e))
1501
+ raise HTTPException(status_code=500, detail=str(e))
1502
+
1413
1503
  async def create_flexible_infer(self, request: Request) -> Response:
1414
1504
  payload = await request.json()
1415
1505
 
@@ -1554,10 +1644,14 @@ class RESTfulAPI:
1554
1644
  if body.tools and body.stream:
1555
1645
  is_vllm = await model.is_vllm_backend()
1556
1646
 
1557
- if not is_vllm or model_family not in QWEN_TOOL_CALL_FAMILY:
1647
+ if not (
1648
+ (is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
1649
+ or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
1650
+ ):
1558
1651
  raise HTTPException(
1559
1652
  status_code=400,
1560
- detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
1653
+ detail="Streaming support for tool calls is available only when using "
1654
+ "Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
1561
1655
  )
1562
1656
 
1563
1657
  if body.stream:
@@ -294,6 +294,81 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
294
294
  response_data = response.json()
295
295
  return response_data
296
296
 
297
+ def inpainting(
298
+ self,
299
+ image: Union[str, bytes],
300
+ mask_image: Union[str, bytes],
301
+ prompt: str,
302
+ negative_prompt: Optional[str] = None,
303
+ n: int = 1,
304
+ size: Optional[str] = None,
305
+ response_format: str = "url",
306
+ **kwargs,
307
+ ) -> "ImageList":
308
+ """
309
+ Inpaint an image by the input text.
310
+
311
+ Parameters
312
+ ----------
313
+ image: `Union[str, bytes]`
314
+ an image batch to be inpainted (which parts of the image to
315
+ be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
316
+ tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
317
+ expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
318
+ expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
319
+ if passing latents directly it is not encoded again.
320
+ mask_image: `Union[str, bytes]`
321
+ representing an image batch to mask `image`. White pixels in the mask
322
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
323
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
324
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
325
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
326
+ 1)`, or `(H, W)`.
327
+ prompt: `str` or `List[str]`
328
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
329
+ negative_prompt (`str` or `List[str]`, *optional*):
330
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
331
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
332
+ less than `1`).
333
+ n: `int`, defaults to 1
334
+ The number of images to generate per prompt. Must be between 1 and 10.
335
+ size: `str`, defaults to None
336
+ The width*height in pixels of the generated image.
337
+ response_format: `str`, defaults to `url`
338
+ The format in which the generated images are returned. Must be one of url or b64_json.
339
+ Returns
340
+ -------
341
+ ImageList
342
+ A list of image objects.
343
+ :param prompt:
344
+ :param image:
345
+ """
346
+ url = f"{self._base_url}/v1/images/inpainting"
347
+ params = {
348
+ "model": self._model_uid,
349
+ "prompt": prompt,
350
+ "negative_prompt": negative_prompt,
351
+ "n": n,
352
+ "size": size,
353
+ "response_format": response_format,
354
+ "kwargs": json.dumps(kwargs),
355
+ }
356
+ files: List[Any] = []
357
+ for key, value in params.items():
358
+ files.append((key, (None, value)))
359
+ files.append(("image", ("image", image, "application/octet-stream")))
360
+ files.append(
361
+ ("mask_image", ("mask_image", mask_image, "application/octet-stream"))
362
+ )
363
+ response = requests.post(url, files=files, headers=self.auth_headers)
364
+ if response.status_code != 200:
365
+ raise RuntimeError(
366
+ f"Failed to inpaint the images, detail: {_get_error_string(response)}"
367
+ )
368
+
369
+ response_data = response.json()
370
+ return response_data
371
+
297
372
 
298
373
  class RESTfulGenerateModelHandle(RESTfulModelHandle):
299
374
  def generate(
@@ -692,6 +767,9 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
692
767
  voice: str = "",
693
768
  response_format: str = "mp3",
694
769
  speed: float = 1.0,
770
+ stream: bool = False,
771
+ prompt_speech: Optional[bytes] = None,
772
+ **kwargs,
695
773
  ):
696
774
  """
697
775
  Generates audio from the input text.
@@ -707,6 +785,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
707
785
  The format to audio in.
708
786
  speed: str
709
787
  The speed of the generated audio.
788
+ stream: bool
789
+ Use stream or not.
710
790
 
711
791
  Returns
712
792
  -------
@@ -720,13 +800,30 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
720
800
  "voice": voice,
721
801
  "response_format": response_format,
722
802
  "speed": speed,
803
+ "stream": stream,
804
+ "kwargs": json.dumps(kwargs),
723
805
  }
724
- response = requests.post(url, json=params, headers=self.auth_headers)
806
+ if prompt_speech:
807
+ files: List[Any] = []
808
+ files.append(
809
+ (
810
+ "prompt_speech",
811
+ ("prompt_speech", prompt_speech, "application/octet-stream"),
812
+ )
813
+ )
814
+ response = requests.post(
815
+ url, data=params, files=files, headers=self.auth_headers
816
+ )
817
+ else:
818
+ response = requests.post(url, json=params, headers=self.auth_headers)
725
819
  if response.status_code != 200:
726
820
  raise RuntimeError(
727
821
  f"Failed to speech the text, detail: {_get_error_string(response)}"
728
822
  )
729
823
 
824
+ if stream:
825
+ return response.iter_content(chunk_size=1024)
826
+
730
827
  return response.content
731
828
 
732
829
 
@@ -428,7 +428,7 @@ class GradioInterface:
428
428
  }
429
429
 
430
430
  hist.append(response_content)
431
- return {
431
+ return { # type: ignore
432
432
  textbox: response_content,
433
433
  history: hist,
434
434
  }
@@ -467,7 +467,7 @@ class GradioInterface:
467
467
  }
468
468
 
469
469
  hist.append(response_content)
470
- return {
470
+ return { # type: ignore
471
471
  textbox: response_content,
472
472
  history: hist,
473
473
  }
xinference/core/model.py CHANGED
@@ -310,7 +310,7 @@ class ModelActor(xo.StatelessActor):
310
310
  )
311
311
  )
312
312
 
313
- def _to_json_generator(self, gen: types.GeneratorType):
313
+ def _to_generator(self, output_type: str, gen: types.GeneratorType):
314
314
  start_time = time.time()
315
315
  time_to_first_token = None
316
316
  final_usage = None
@@ -318,8 +318,13 @@ class ModelActor(xo.StatelessActor):
318
318
  for v in gen:
319
319
  if time_to_first_token is None:
320
320
  time_to_first_token = (time.time() - start_time) * 1000
321
- final_usage = v.get("usage", None)
322
- v = dict(data=json.dumps(v, ensure_ascii=False))
321
+ if output_type == "json":
322
+ final_usage = v.get("usage", None)
323
+ v = dict(data=json.dumps(v, ensure_ascii=False))
324
+ else:
325
+ assert (
326
+ output_type == "binary"
327
+ ), f"Unknown output type '{output_type}'"
323
328
  yield sse_starlette.sse.ensure_bytes(v, None)
324
329
  except OutOfMemoryError:
325
330
  logger.exception(
@@ -342,7 +347,7 @@ class ModelActor(xo.StatelessActor):
342
347
  )
343
348
  asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
344
349
 
345
- async def _to_json_async_gen(self, gen: types.AsyncGeneratorType):
350
+ async def _to_async_gen(self, output_type: str, gen: types.AsyncGeneratorType):
346
351
  start_time = time.time()
347
352
  time_to_first_token = None
348
353
  final_usage = None
@@ -351,8 +356,13 @@ class ModelActor(xo.StatelessActor):
351
356
  if time_to_first_token is None:
352
357
  time_to_first_token = (time.time() - start_time) * 1000
353
358
  final_usage = v.get("usage", None)
354
- v = await asyncio.to_thread(json.dumps, v)
355
- v = dict(data=v) # noqa: F821
359
+ if output_type == "json":
360
+ v = await asyncio.to_thread(json.dumps, v, ensure_ascii=False)
361
+ v = dict(data=v) # noqa: F821
362
+ else:
363
+ assert (
364
+ output_type == "binary"
365
+ ), f"Unknown output type '{output_type}'"
356
366
  yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
357
367
  except OutOfMemoryError:
358
368
  logger.exception(
@@ -379,8 +389,14 @@ class ModelActor(xo.StatelessActor):
379
389
  )
380
390
  await asyncio.gather(*coros)
381
391
 
392
+ async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
393
+ return await self._call_wrapper("json", fn, *args, **kwargs)
394
+
395
+ async def _call_wrapper_binary(self, fn: Callable, *args, **kwargs):
396
+ return await self._call_wrapper("binary", fn, *args, **kwargs)
397
+
382
398
  @oom_check
383
- async def _call_wrapper(self, fn: Callable, *args, **kwargs):
399
+ async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
384
400
  if self._lock is None:
385
401
  if inspect.iscoroutinefunction(fn):
386
402
  ret = await fn(*args, **kwargs)
@@ -397,16 +413,18 @@ class ModelActor(xo.StatelessActor):
397
413
  raise Exception("Parallel generation is not supported by ggml.")
398
414
 
399
415
  if inspect.isgenerator(ret):
400
- gen = self._to_json_generator(ret)
416
+ gen = self._to_generator(output_type, ret)
401
417
  self._current_generator = weakref.ref(gen)
402
418
  return gen
403
419
  if inspect.isasyncgen(ret):
404
- gen = self._to_json_async_gen(ret)
420
+ gen = self._to_async_gen(output_type, ret)
405
421
  self._current_generator = weakref.ref(gen)
406
422
  return gen
407
- if isinstance(ret, bytes):
423
+ if output_type == "json":
424
+ return await asyncio.to_thread(json_dumps, ret)
425
+ else:
426
+ assert output_type == "binary", f"Unknown output type '{output_type}'"
408
427
  return ret
409
- return await asyncio.to_thread(json_dumps, ret)
410
428
 
411
429
  @log_async(logger=logger)
412
430
  @request_limit
@@ -419,11 +437,11 @@ class ModelActor(xo.StatelessActor):
419
437
  else:
420
438
  kwargs.pop("raw_params", None)
421
439
  if hasattr(self._model, "generate"):
422
- return await self._call_wrapper(
440
+ return await self._call_wrapper_json(
423
441
  self._model.generate, prompt, *args, **kwargs
424
442
  )
425
443
  if hasattr(self._model, "async_generate"):
426
- return await self._call_wrapper(
444
+ return await self._call_wrapper_json(
427
445
  self._model.async_generate, prompt, *args, **kwargs
428
446
  )
429
447
  raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
@@ -471,7 +489,7 @@ class ModelActor(xo.StatelessActor):
471
489
  queue: Queue[Any] = Queue()
472
490
  ret = self._queue_consumer(queue)
473
491
  await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
474
- gen = self._to_json_async_gen(ret)
492
+ gen = self._to_async_gen("json", ret)
475
493
  self._current_generator = weakref.ref(gen)
476
494
  return gen
477
495
  else:
@@ -502,12 +520,12 @@ class ModelActor(xo.StatelessActor):
502
520
  else:
503
521
  kwargs.pop("raw_params", None)
504
522
  if hasattr(self._model, "chat"):
505
- response = await self._call_wrapper(
523
+ response = await self._call_wrapper_json(
506
524
  self._model.chat, prompt, *args, **kwargs
507
525
  )
508
526
  return response
509
527
  if hasattr(self._model, "async_chat"):
510
- response = await self._call_wrapper(
528
+ response = await self._call_wrapper_json(
511
529
  self._model.async_chat, prompt, *args, **kwargs
512
530
  )
513
531
  return response
@@ -543,7 +561,7 @@ class ModelActor(xo.StatelessActor):
543
561
  @request_limit
544
562
  async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
545
563
  if hasattr(self._model, "create_embedding"):
546
- return await self._call_wrapper(
564
+ return await self._call_wrapper_json(
547
565
  self._model.create_embedding, input, *args, **kwargs
548
566
  )
549
567
 
@@ -565,7 +583,7 @@ class ModelActor(xo.StatelessActor):
565
583
  **kwargs,
566
584
  ):
567
585
  if hasattr(self._model, "rerank"):
568
- return await self._call_wrapper(
586
+ return await self._call_wrapper_json(
569
587
  self._model.rerank,
570
588
  documents,
571
589
  query,
@@ -590,7 +608,7 @@ class ModelActor(xo.StatelessActor):
590
608
  timestamp_granularities: Optional[List[str]] = None,
591
609
  ):
592
610
  if hasattr(self._model, "transcriptions"):
593
- return await self._call_wrapper(
611
+ return await self._call_wrapper_json(
594
612
  self._model.transcriptions,
595
613
  audio,
596
614
  language,
@@ -615,7 +633,7 @@ class ModelActor(xo.StatelessActor):
615
633
  timestamp_granularities: Optional[List[str]] = None,
616
634
  ):
617
635
  if hasattr(self._model, "translations"):
618
- return await self._call_wrapper(
636
+ return await self._call_wrapper_json(
619
637
  self._model.translations,
620
638
  audio,
621
639
  language,
@@ -628,18 +646,30 @@ class ModelActor(xo.StatelessActor):
628
646
  f"Model {self._model.model_spec} is not for creating translations."
629
647
  )
630
648
 
631
- @log_async(logger=logger)
649
+ @log_async(
650
+ logger=logger,
651
+ args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
652
+ )
632
653
  @request_limit
654
+ @xo.generator
633
655
  async def speech(
634
- self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
656
+ self,
657
+ input: str,
658
+ voice: str,
659
+ response_format: str = "mp3",
660
+ speed: float = 1.0,
661
+ stream: bool = False,
662
+ **kwargs,
635
663
  ):
636
664
  if hasattr(self._model, "speech"):
637
- return await self._call_wrapper(
665
+ return await self._call_wrapper_binary(
638
666
  self._model.speech,
639
667
  input,
640
668
  voice,
641
669
  response_format,
642
670
  speed,
671
+ stream,
672
+ **kwargs,
643
673
  )
644
674
  raise AttributeError(
645
675
  f"Model {self._model.model_spec} is not for creating speech."
@@ -657,7 +687,7 @@ class ModelActor(xo.StatelessActor):
657
687
  **kwargs,
658
688
  ):
659
689
  if hasattr(self._model, "text_to_image"):
660
- return await self._call_wrapper(
690
+ return await self._call_wrapper_json(
661
691
  self._model.text_to_image,
662
692
  prompt,
663
693
  n,
@@ -682,7 +712,7 @@ class ModelActor(xo.StatelessActor):
682
712
  **kwargs,
683
713
  ):
684
714
  if hasattr(self._model, "image_to_image"):
685
- return await self._call_wrapper(
715
+ return await self._call_wrapper_json(
686
716
  self._model.image_to_image,
687
717
  image,
688
718
  prompt,
@@ -697,6 +727,35 @@ class ModelActor(xo.StatelessActor):
697
727
  f"Model {self._model.model_spec} is not for creating image."
698
728
  )
699
729
 
730
+ async def inpainting(
731
+ self,
732
+ image: "PIL.Image",
733
+ mask_image: "PIL.Image",
734
+ prompt: str,
735
+ negative_prompt: str,
736
+ n: int = 1,
737
+ size: str = "1024*1024",
738
+ response_format: str = "url",
739
+ *args,
740
+ **kwargs,
741
+ ):
742
+ if hasattr(self._model, "inpainting"):
743
+ return await self._call_wrapper_json(
744
+ self._model.inpainting,
745
+ image,
746
+ mask_image,
747
+ prompt,
748
+ negative_prompt,
749
+ n,
750
+ size,
751
+ response_format,
752
+ *args,
753
+ **kwargs,
754
+ )
755
+ raise AttributeError(
756
+ f"Model {self._model.model_spec} is not for creating image."
757
+ )
758
+
700
759
  @log_async(logger=logger)
701
760
  @request_limit
702
761
  async def infer(
@@ -704,7 +763,7 @@ class ModelActor(xo.StatelessActor):
704
763
  **kwargs,
705
764
  ):
706
765
  if hasattr(self._model, "infer"):
707
- return await self._call_wrapper(
766
+ return await self._call_wrapper_json(
708
767
  self._model.infer,
709
768
  **kwargs,
710
769
  )
@@ -81,7 +81,7 @@ class InferenceRequest:
81
81
  self.future_or_queue = future_or_queue
82
82
  # Record error message when this request has error.
83
83
  # Must set stopped=True when this field is set.
84
- self.error_msg: Optional[str] = None
84
+ self.error_msg: Optional[str] = None # type: ignore
85
85
  # For compatibility. Record some extra parameters for some special cases.
86
86
  self.extra_kwargs = {}
87
87
 
@@ -295,11 +295,11 @@ class SchedulerActor(xo.StatelessActor):
295
295
 
296
296
  def __init__(self):
297
297
  super().__init__()
298
- self._waiting_queue: deque[InferenceRequest] = deque()
299
- self._running_queue: deque[InferenceRequest] = deque()
298
+ self._waiting_queue: deque[InferenceRequest] = deque() # type: ignore
299
+ self._running_queue: deque[InferenceRequest] = deque() # type: ignore
300
300
  self._model = None
301
301
  self._id_to_req = {}
302
- self._abort_req_ids: Set[str] = set()
302
+ self._abort_req_ids: Set[str] = set() # type: ignore
303
303
  self._isolation = None
304
304
 
305
305
  async def __post_create__(self):
@@ -48,7 +48,12 @@ class ChatTTSModel:
48
48
  self._model.load(source="custom", custom_path=self._model_path, compile=True)
49
49
 
50
50
  def speech(
51
- self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
51
+ self,
52
+ input: str,
53
+ voice: str,
54
+ response_format: str = "mp3",
55
+ speed: float = 1.0,
56
+ stream: bool = False,
52
57
  ):
53
58
  import ChatTTS
54
59
  import numpy as np
@@ -74,11 +79,38 @@ class ChatTTSModel:
74
79
  )
75
80
 
76
81
  assert self._model is not None
77
- wavs = self._model.infer([input], params_infer_code=params_infer_code)
78
-
79
- # Save the generated audio
80
- with BytesIO() as out:
81
- torchaudio.save(
82
- out, torch.from_numpy(wavs[0]), 24000, format=response_format
82
+ if stream:
83
+ iter = self._model.infer(
84
+ [input], params_infer_code=params_infer_code, stream=True
83
85
  )
84
- return out.getvalue()
86
+
87
+ def _generator():
88
+ with BytesIO() as out:
89
+ writer = torchaudio.io.StreamWriter(out, format=response_format)
90
+ writer.add_audio_stream(sample_rate=24000, num_channels=1)
91
+ i = 0
92
+ last_pos = 0
93
+ with writer.open():
94
+ for it in iter:
95
+ for itt in it:
96
+ for chunk in itt:
97
+ chunk = np.array([chunk]).transpose()
98
+ writer.write_audio_chunk(i, torch.from_numpy(chunk))
99
+ new_last_pos = out.tell()
100
+ if new_last_pos != last_pos:
101
+ out.seek(last_pos)
102
+ encoded_bytes = out.read()
103
+ print(len(encoded_bytes))
104
+ yield encoded_bytes
105
+ last_pos = new_last_pos
106
+
107
+ return _generator()
108
+ else:
109
+ wavs = self._model.infer([input], params_infer_code=params_infer_code)
110
+
111
+ # Save the generated audio
112
+ with BytesIO() as out:
113
+ torchaudio.save(
114
+ out, torch.from_numpy(wavs[0]), 24000, format=response_format
115
+ )
116
+ return out.getvalue()