xinference 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (70) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +123 -3
  3. xinference/client/restful/restful_client.py +131 -2
  4. xinference/core/model.py +93 -24
  5. xinference/core/supervisor.py +132 -15
  6. xinference/core/worker.py +165 -8
  7. xinference/deploy/cmdline.py +5 -0
  8. xinference/model/audio/chattts.py +46 -14
  9. xinference/model/audio/core.py +23 -15
  10. xinference/model/core.py +12 -3
  11. xinference/model/embedding/core.py +25 -16
  12. xinference/model/flexible/__init__.py +40 -0
  13. xinference/model/flexible/core.py +228 -0
  14. xinference/model/flexible/launchers/__init__.py +15 -0
  15. xinference/model/flexible/launchers/transformers_launcher.py +63 -0
  16. xinference/model/flexible/utils.py +33 -0
  17. xinference/model/image/core.py +21 -14
  18. xinference/model/image/custom.py +1 -1
  19. xinference/model/image/model_spec.json +14 -0
  20. xinference/model/image/stable_diffusion/core.py +43 -6
  21. xinference/model/llm/__init__.py +0 -2
  22. xinference/model/llm/core.py +3 -2
  23. xinference/model/llm/ggml/llamacpp.py +1 -10
  24. xinference/model/llm/llm_family.json +292 -36
  25. xinference/model/llm/llm_family.py +97 -52
  26. xinference/model/llm/llm_family_modelscope.json +220 -27
  27. xinference/model/llm/pytorch/core.py +0 -80
  28. xinference/model/llm/sglang/core.py +7 -2
  29. xinference/model/llm/utils.py +4 -2
  30. xinference/model/llm/vllm/core.py +3 -0
  31. xinference/model/rerank/core.py +24 -25
  32. xinference/types.py +0 -1
  33. xinference/web/ui/build/asset-manifest.json +3 -3
  34. xinference/web/ui/build/index.html +1 -1
  35. xinference/web/ui/build/static/js/{main.0fb6f3ab.js → main.95c1d652.js} +3 -3
  36. xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
  43. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/METADATA +9 -11
  44. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/RECORD +49 -58
  45. xinference/model/llm/ggml/chatglm.py +0 -457
  46. xinference/thirdparty/ChatTTS/__init__.py +0 -1
  47. xinference/thirdparty/ChatTTS/core.py +0 -200
  48. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  49. xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
  50. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  51. xinference/thirdparty/ChatTTS/infer/api.py +0 -125
  52. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  53. xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
  54. xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
  55. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  56. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
  57. xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
  58. xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
  59. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +0 -1
  66. /xinference/web/ui/build/static/js/{main.0fb6f3ab.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
  67. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/LICENSE +0 -0
  68. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/WHEEL +0 -0
  69. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/entry_points.txt +0 -0
  70. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-05T18:19:09+0800",
11
+ "date": "2024-07-19T19:15:54+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "007408c55272bc343821dd152df780de5dc9c037",
15
- "version": "0.13.0"
14
+ "full-revisionid": "880929cbbc73e5206ca069591b03d9d16dd858bf",
15
+ "version": "0.13.2"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -129,10 +129,12 @@ class SpeechRequest(BaseModel):
129
129
  voice: Optional[str]
130
130
  response_format: Optional[str] = "mp3"
131
131
  speed: Optional[float] = 1.0
132
+ stream: Optional[bool] = False
132
133
 
133
134
 
134
135
  class RegisterModelRequest(BaseModel):
135
136
  model: str
137
+ worker_ip: Optional[str]
136
138
  persist: bool
137
139
 
138
140
 
@@ -490,6 +492,17 @@ class RESTfulAPI:
490
492
  else None
491
493
  ),
492
494
  )
495
+ self._router.add_api_route(
496
+ "/v1/images/inpainting",
497
+ self.create_inpainting,
498
+ methods=["POST"],
499
+ response_model=ImageList,
500
+ dependencies=(
501
+ [Security(self._auth_service, scopes=["models:read"])]
502
+ if self.is_authenticated()
503
+ else None
504
+ ),
505
+ )
493
506
  self._router.add_api_route(
494
507
  "/v1/chat/completions",
495
508
  self.create_chat_completion,
@@ -501,6 +514,16 @@ class RESTfulAPI:
501
514
  else None
502
515
  ),
503
516
  )
517
+ self._router.add_api_route(
518
+ "/v1/flexible/infers",
519
+ self.create_flexible_infer,
520
+ methods=["POST"],
521
+ dependencies=(
522
+ [Security(self._auth_service, scopes=["models:read"])]
523
+ if self.is_authenticated()
524
+ else None
525
+ ),
526
+ )
504
527
 
505
528
  # for custom models
506
529
  self._router.add_api_route(
@@ -772,6 +795,7 @@ class RESTfulAPI:
772
795
  peft_model_config = payload.get("peft_model_config", None)
773
796
  worker_ip = payload.get("worker_ip", None)
774
797
  gpu_idx = payload.get("gpu_idx", None)
798
+ download_hub = payload.get("download_hub", None)
775
799
 
776
800
  exclude_keys = {
777
801
  "model_uid",
@@ -787,6 +811,7 @@ class RESTfulAPI:
787
811
  "peft_model_config",
788
812
  "worker_ip",
789
813
  "gpu_idx",
814
+ "download_hub",
790
815
  }
791
816
 
792
817
  kwargs = {
@@ -834,9 +859,9 @@ class RESTfulAPI:
834
859
  peft_model_config=peft_model_config,
835
860
  worker_ip=worker_ip,
836
861
  gpu_idx=gpu_idx,
862
+ download_hub=download_hub,
837
863
  **kwargs,
838
864
  )
839
-
840
865
  except ValueError as ve:
841
866
  logger.error(str(ve), exc_info=True)
842
867
  raise HTTPException(status_code=400, detail=str(ve))
@@ -1304,8 +1329,14 @@ class RESTfulAPI:
1304
1329
  voice=body.voice,
1305
1330
  response_format=body.response_format,
1306
1331
  speed=body.speed,
1332
+ stream=body.stream,
1307
1333
  )
1308
- return Response(media_type="application/octet-stream", content=out)
1334
+ if body.stream:
1335
+ return EventSourceResponse(
1336
+ media_type="application/octet-stream", content=out
1337
+ )
1338
+ else:
1339
+ return Response(media_type="application/octet-stream", content=out)
1309
1340
  except RuntimeError as re:
1310
1341
  logger.error(re, exc_info=True)
1311
1342
  await self._report_error_event(model_uid, str(re))
@@ -1397,6 +1428,94 @@ class RESTfulAPI:
1397
1428
  await self._report_error_event(model_uid, str(e))
1398
1429
  raise HTTPException(status_code=500, detail=str(e))
1399
1430
 
1431
+ async def create_inpainting(
1432
+ self,
1433
+ model: str = Form(...),
1434
+ image: UploadFile = File(media_type="application/octet-stream"),
1435
+ mask_image: UploadFile = File(media_type="application/octet-stream"),
1436
+ prompt: Optional[Union[str, List[str]]] = Form(None),
1437
+ negative_prompt: Optional[Union[str, List[str]]] = Form(None),
1438
+ n: Optional[int] = Form(1),
1439
+ response_format: Optional[str] = Form("url"),
1440
+ size: Optional[str] = Form(None),
1441
+ kwargs: Optional[str] = Form(None),
1442
+ ) -> Response:
1443
+ model_uid = model
1444
+ try:
1445
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
1446
+ except ValueError as ve:
1447
+ logger.error(str(ve), exc_info=True)
1448
+ await self._report_error_event(model_uid, str(ve))
1449
+ raise HTTPException(status_code=400, detail=str(ve))
1450
+ except Exception as e:
1451
+ logger.error(e, exc_info=True)
1452
+ await self._report_error_event(model_uid, str(e))
1453
+ raise HTTPException(status_code=500, detail=str(e))
1454
+
1455
+ try:
1456
+ if kwargs is not None:
1457
+ parsed_kwargs = json.loads(kwargs)
1458
+ else:
1459
+ parsed_kwargs = {}
1460
+ im = Image.open(image.file)
1461
+ mask_im = Image.open(mask_image.file)
1462
+ if not size:
1463
+ w, h = im.size
1464
+ size = f"{w}*{h}"
1465
+ image_list = await model_ref.inpainting(
1466
+ image=im,
1467
+ mask_image=mask_im,
1468
+ prompt=prompt,
1469
+ negative_prompt=negative_prompt,
1470
+ n=n,
1471
+ size=size,
1472
+ response_format=response_format,
1473
+ **parsed_kwargs,
1474
+ )
1475
+ return Response(content=image_list, media_type="application/json")
1476
+ except RuntimeError as re:
1477
+ logger.error(re, exc_info=True)
1478
+ await self._report_error_event(model_uid, str(re))
1479
+ raise HTTPException(status_code=400, detail=str(re))
1480
+ except Exception as e:
1481
+ logger.error(e, exc_info=True)
1482
+ await self._report_error_event(model_uid, str(e))
1483
+ raise HTTPException(status_code=500, detail=str(e))
1484
+
1485
+ async def create_flexible_infer(self, request: Request) -> Response:
1486
+ payload = await request.json()
1487
+
1488
+ model_uid = payload.get("model")
1489
+
1490
+ exclude = {
1491
+ "model",
1492
+ }
1493
+ kwargs = {key: value for key, value in payload.items() if key not in exclude}
1494
+
1495
+ try:
1496
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1497
+ except ValueError as ve:
1498
+ logger.error(str(ve), exc_info=True)
1499
+ await self._report_error_event(model_uid, str(ve))
1500
+ raise HTTPException(status_code=400, detail=str(ve))
1501
+ except Exception as e:
1502
+ logger.error(e, exc_info=True)
1503
+ await self._report_error_event(model_uid, str(e))
1504
+ raise HTTPException(status_code=500, detail=str(e))
1505
+
1506
+ try:
1507
+ result = await model.infer(**kwargs)
1508
+ return Response(result, media_type="application/json")
1509
+ except RuntimeError as re:
1510
+ logger.error(re, exc_info=True)
1511
+ await self._report_error_event(model_uid, str(re))
1512
+ self.handle_request_limit_error(re)
1513
+ raise HTTPException(status_code=400, detail=str(re))
1514
+ except Exception as e:
1515
+ logger.error(e, exc_info=True)
1516
+ await self._report_error_event(model_uid, str(e))
1517
+ raise HTTPException(status_code=500, detail=str(e))
1518
+
1400
1519
  async def create_chat_completion(self, request: Request) -> Response:
1401
1520
  raw_body = await request.json()
1402
1521
  body = CreateChatCompletion.parse_obj(raw_body)
@@ -1593,11 +1712,12 @@ class RESTfulAPI:
1593
1712
  async def register_model(self, model_type: str, request: Request) -> JSONResponse:
1594
1713
  body = RegisterModelRequest.parse_obj(await request.json())
1595
1714
  model = body.model
1715
+ worker_ip = body.worker_ip
1596
1716
  persist = body.persist
1597
1717
 
1598
1718
  try:
1599
1719
  await (await self._get_supervisor_ref()).register_model(
1600
- model_type, model, persist
1720
+ model_type, model, persist, worker_ip
1601
1721
  )
1602
1722
  except ValueError as re:
1603
1723
  logger.error(re, exc_info=True)
@@ -294,6 +294,81 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
294
294
  response_data = response.json()
295
295
  return response_data
296
296
 
297
+ def inpainting(
298
+ self,
299
+ image: Union[str, bytes],
300
+ mask_image: Union[str, bytes],
301
+ prompt: str,
302
+ negative_prompt: Optional[str] = None,
303
+ n: int = 1,
304
+ size: Optional[str] = None,
305
+ response_format: str = "url",
306
+ **kwargs,
307
+ ) -> "ImageList":
308
+ """
309
+ Inpaint an image by the input text.
310
+
311
+ Parameters
312
+ ----------
313
+ image: `Union[str, bytes]`
314
+ an image batch to be inpainted (which parts of the image to
315
+ be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
316
+ tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
317
+ expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
318
+ expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
319
+ if passing latents directly it is not encoded again.
320
+ mask_image: `Union[str, bytes]`
321
+ representing an image batch to mask `image`. White pixels in the mask
322
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
323
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
324
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
325
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
326
+ 1)`, or `(H, W)`.
327
+ prompt: `str` or `List[str]`
328
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
329
+ negative_prompt (`str` or `List[str]`, *optional*):
330
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
331
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
332
+ less than `1`).
333
+ n: `int`, defaults to 1
334
+ The number of images to generate per prompt. Must be between 1 and 10.
335
+ size: `str`, defaults to None
336
+ The width*height in pixels of the generated image.
337
+ response_format: `str`, defaults to `url`
338
+ The format in which the generated images are returned. Must be one of url or b64_json.
339
+ Returns
340
+ -------
341
+ ImageList
342
+ A list of image objects.
343
+ :param prompt:
344
+ :param image:
345
+ """
346
+ url = f"{self._base_url}/v1/images/inpainting"
347
+ params = {
348
+ "model": self._model_uid,
349
+ "prompt": prompt,
350
+ "negative_prompt": negative_prompt,
351
+ "n": n,
352
+ "size": size,
353
+ "response_format": response_format,
354
+ "kwargs": json.dumps(kwargs),
355
+ }
356
+ files: List[Any] = []
357
+ for key, value in params.items():
358
+ files.append((key, (None, value)))
359
+ files.append(("image", ("image", image, "application/octet-stream")))
360
+ files.append(
361
+ ("mask_image", ("mask_image", mask_image, "application/octet-stream"))
362
+ )
363
+ response = requests.post(url, files=files, headers=self.auth_headers)
364
+ if response.status_code != 200:
365
+ raise RuntimeError(
366
+ f"Failed to inpaint the images, detail: {_get_error_string(response)}"
367
+ )
368
+
369
+ response_data = response.json()
370
+ return response_data
371
+
297
372
 
298
373
  class RESTfulGenerateModelHandle(RESTfulModelHandle):
299
374
  def generate(
@@ -692,6 +767,7 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
692
767
  voice: str = "",
693
768
  response_format: str = "mp3",
694
769
  speed: float = 1.0,
770
+ stream: bool = False,
695
771
  ):
696
772
  """
697
773
  Generates audio from the input text.
@@ -707,6 +783,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
707
783
  The format to audio in.
708
784
  speed: str
709
785
  The speed of the generated audio.
786
+ stream: bool
787
+ Use stream or not.
710
788
 
711
789
  Returns
712
790
  -------
@@ -720,6 +798,7 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
720
798
  "voice": voice,
721
799
  "response_format": response_format,
722
800
  "speed": speed,
801
+ "stream": stream,
723
802
  }
724
803
  response = requests.post(url, json=params, headers=self.auth_headers)
725
804
  if response.status_code != 200:
@@ -727,6 +806,44 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
727
806
  f"Failed to speech the text, detail: {_get_error_string(response)}"
728
807
  )
729
808
 
809
+ if stream:
810
+ return response.iter_content(chunk_size=1024)
811
+
812
+ return response.content
813
+
814
+
815
+ class RESTfulFlexibleModelHandle(RESTfulModelHandle):
816
+ def infer(
817
+ self,
818
+ **kwargs,
819
+ ):
820
+ """
821
+ Call flexible model.
822
+
823
+ Parameters
824
+ ----------
825
+
826
+ kwargs: dict
827
+ The inference arguments.
828
+
829
+
830
+ Returns
831
+ -------
832
+ bytes
833
+ The inference result.
834
+ """
835
+ url = f"{self._base_url}/v1/flexible/infers"
836
+ params = {
837
+ "model": self._model_uid,
838
+ }
839
+ params.update(kwargs)
840
+
841
+ response = requests.post(url, json=params, headers=self.auth_headers)
842
+ if response.status_code != 200:
843
+ raise RuntimeError(
844
+ f"Failed to predict, detail: {_get_error_string(response)}"
845
+ )
846
+
730
847
  return response.content
731
848
 
732
849
 
@@ -1009,6 +1126,10 @@ class Client:
1009
1126
  return RESTfulAudioModelHandle(
1010
1127
  model_uid, self.base_url, auth_headers=self._headers
1011
1128
  )
1129
+ elif desc["model_type"] == "flexible":
1130
+ return RESTfulFlexibleModelHandle(
1131
+ model_uid, self.base_url, auth_headers=self._headers
1132
+ )
1012
1133
  else:
1013
1134
  raise ValueError(f"Unknown model type:{desc['model_type']}")
1014
1135
 
@@ -1062,7 +1183,13 @@ class Client:
1062
1183
  )
1063
1184
  return response.json()
1064
1185
 
1065
- def register_model(self, model_type: str, model: str, persist: bool):
1186
+ def register_model(
1187
+ self,
1188
+ model_type: str,
1189
+ model: str,
1190
+ persist: bool,
1191
+ worker_ip: Optional[str] = None,
1192
+ ):
1066
1193
  """
1067
1194
  Register a custom model.
1068
1195
 
@@ -1072,6 +1199,8 @@ class Client:
1072
1199
  The type of model.
1073
1200
  model: str
1074
1201
  The model definition. (refer to: https://inference.readthedocs.io/en/latest/models/custom.html)
1202
+ worker_ip: Optional[str]
1203
+ The IP address of the worker on which the model is running.
1075
1204
  persist: bool
1076
1205
 
1077
1206
 
@@ -1081,7 +1210,7 @@ class Client:
1081
1210
  Report failure to register the custom model. Provide details of failure through error message.
1082
1211
  """
1083
1212
  url = f"{self.base_url}/v1/model_registrations/{model_type}"
1084
- request_body = {"model": model, "persist": persist}
1213
+ request_body = {"model": model, "worker_ip": worker_ip, "persist": persist}
1085
1214
  response = requests.post(url, json=request_body, headers=self._headers)
1086
1215
  if response.status_code != 200:
1087
1216
  raise RuntimeError(
xinference/core/model.py CHANGED
@@ -310,7 +310,7 @@ class ModelActor(xo.StatelessActor):
310
310
  )
311
311
  )
312
312
 
313
- def _to_json_generator(self, gen: types.GeneratorType):
313
+ def _to_generator(self, output_type: str, gen: types.GeneratorType):
314
314
  start_time = time.time()
315
315
  time_to_first_token = None
316
316
  final_usage = None
@@ -318,8 +318,13 @@ class ModelActor(xo.StatelessActor):
318
318
  for v in gen:
319
319
  if time_to_first_token is None:
320
320
  time_to_first_token = (time.time() - start_time) * 1000
321
- final_usage = v.get("usage", None)
322
- v = dict(data=json.dumps(v, ensure_ascii=False))
321
+ if output_type == "json":
322
+ final_usage = v.get("usage", None)
323
+ v = dict(data=json.dumps(v, ensure_ascii=False))
324
+ else:
325
+ assert (
326
+ output_type == "binary"
327
+ ), f"Unknown output type '{output_type}'"
323
328
  yield sse_starlette.sse.ensure_bytes(v, None)
324
329
  except OutOfMemoryError:
325
330
  logger.exception(
@@ -342,7 +347,7 @@ class ModelActor(xo.StatelessActor):
342
347
  )
343
348
  asyncio.run_coroutine_threadsafe(coro, loop=self._loop)
344
349
 
345
- async def _to_json_async_gen(self, gen: types.AsyncGeneratorType):
350
+ async def _to_async_gen(self, output_type: str, gen: types.AsyncGeneratorType):
346
351
  start_time = time.time()
347
352
  time_to_first_token = None
348
353
  final_usage = None
@@ -351,8 +356,13 @@ class ModelActor(xo.StatelessActor):
351
356
  if time_to_first_token is None:
352
357
  time_to_first_token = (time.time() - start_time) * 1000
353
358
  final_usage = v.get("usage", None)
354
- v = await asyncio.to_thread(json.dumps, v)
355
- v = dict(data=v) # noqa: F821
359
+ if output_type == "json":
360
+ v = await asyncio.to_thread(json.dumps, v, ensure_ascii=False)
361
+ v = dict(data=v) # noqa: F821
362
+ else:
363
+ assert (
364
+ output_type == "binary"
365
+ ), f"Unknown output type '{output_type}'"
356
366
  yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
357
367
  except OutOfMemoryError:
358
368
  logger.exception(
@@ -379,8 +389,14 @@ class ModelActor(xo.StatelessActor):
379
389
  )
380
390
  await asyncio.gather(*coros)
381
391
 
392
+ async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
393
+ return await self._call_wrapper("json", fn, *args, **kwargs)
394
+
395
+ async def _call_wrapper_binary(self, fn: Callable, *args, **kwargs):
396
+ return await self._call_wrapper("binary", fn, *args, **kwargs)
397
+
382
398
  @oom_check
383
- async def _call_wrapper(self, fn: Callable, *args, **kwargs):
399
+ async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
384
400
  if self._lock is None:
385
401
  if inspect.iscoroutinefunction(fn):
386
402
  ret = await fn(*args, **kwargs)
@@ -397,16 +413,18 @@ class ModelActor(xo.StatelessActor):
397
413
  raise Exception("Parallel generation is not supported by ggml.")
398
414
 
399
415
  if inspect.isgenerator(ret):
400
- gen = self._to_json_generator(ret)
416
+ gen = self._to_generator(output_type, ret)
401
417
  self._current_generator = weakref.ref(gen)
402
418
  return gen
403
419
  if inspect.isasyncgen(ret):
404
- gen = self._to_json_async_gen(ret)
420
+ gen = self._to_async_gen(output_type, ret)
405
421
  self._current_generator = weakref.ref(gen)
406
422
  return gen
407
- if isinstance(ret, bytes):
423
+ if output_type == "json":
424
+ return await asyncio.to_thread(json_dumps, ret)
425
+ else:
426
+ assert output_type == "binary", f"Unknown output type '{output_type}'"
408
427
  return ret
409
- return await asyncio.to_thread(json_dumps, ret)
410
428
 
411
429
  @log_async(logger=logger)
412
430
  @request_limit
@@ -419,11 +437,11 @@ class ModelActor(xo.StatelessActor):
419
437
  else:
420
438
  kwargs.pop("raw_params", None)
421
439
  if hasattr(self._model, "generate"):
422
- return await self._call_wrapper(
440
+ return await self._call_wrapper_json(
423
441
  self._model.generate, prompt, *args, **kwargs
424
442
  )
425
443
  if hasattr(self._model, "async_generate"):
426
- return await self._call_wrapper(
444
+ return await self._call_wrapper_json(
427
445
  self._model.async_generate, prompt, *args, **kwargs
428
446
  )
429
447
  raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
@@ -471,7 +489,7 @@ class ModelActor(xo.StatelessActor):
471
489
  queue: Queue[Any] = Queue()
472
490
  ret = self._queue_consumer(queue)
473
491
  await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
474
- gen = self._to_json_async_gen(ret)
492
+ gen = self._to_async_gen("json", ret)
475
493
  self._current_generator = weakref.ref(gen)
476
494
  return gen
477
495
  else:
@@ -502,12 +520,12 @@ class ModelActor(xo.StatelessActor):
502
520
  else:
503
521
  kwargs.pop("raw_params", None)
504
522
  if hasattr(self._model, "chat"):
505
- response = await self._call_wrapper(
523
+ response = await self._call_wrapper_json(
506
524
  self._model.chat, prompt, *args, **kwargs
507
525
  )
508
526
  return response
509
527
  if hasattr(self._model, "async_chat"):
510
- response = await self._call_wrapper(
528
+ response = await self._call_wrapper_json(
511
529
  self._model.async_chat, prompt, *args, **kwargs
512
530
  )
513
531
  return response
@@ -543,7 +561,7 @@ class ModelActor(xo.StatelessActor):
543
561
  @request_limit
544
562
  async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
545
563
  if hasattr(self._model, "create_embedding"):
546
- return await self._call_wrapper(
564
+ return await self._call_wrapper_json(
547
565
  self._model.create_embedding, input, *args, **kwargs
548
566
  )
549
567
 
@@ -565,7 +583,7 @@ class ModelActor(xo.StatelessActor):
565
583
  **kwargs,
566
584
  ):
567
585
  if hasattr(self._model, "rerank"):
568
- return await self._call_wrapper(
586
+ return await self._call_wrapper_json(
569
587
  self._model.rerank,
570
588
  documents,
571
589
  query,
@@ -590,7 +608,7 @@ class ModelActor(xo.StatelessActor):
590
608
  timestamp_granularities: Optional[List[str]] = None,
591
609
  ):
592
610
  if hasattr(self._model, "transcriptions"):
593
- return await self._call_wrapper(
611
+ return await self._call_wrapper_json(
594
612
  self._model.transcriptions,
595
613
  audio,
596
614
  language,
@@ -615,7 +633,7 @@ class ModelActor(xo.StatelessActor):
615
633
  timestamp_granularities: Optional[List[str]] = None,
616
634
  ):
617
635
  if hasattr(self._model, "translations"):
618
- return await self._call_wrapper(
636
+ return await self._call_wrapper_json(
619
637
  self._model.translations,
620
638
  audio,
621
639
  language,
@@ -630,16 +648,23 @@ class ModelActor(xo.StatelessActor):
630
648
 
631
649
  @log_async(logger=logger)
632
650
  @request_limit
651
+ @xo.generator
633
652
  async def speech(
634
- self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
653
+ self,
654
+ input: str,
655
+ voice: str,
656
+ response_format: str = "mp3",
657
+ speed: float = 1.0,
658
+ stream: bool = False,
635
659
  ):
636
660
  if hasattr(self._model, "speech"):
637
- return await self._call_wrapper(
661
+ return await self._call_wrapper_binary(
638
662
  self._model.speech,
639
663
  input,
640
664
  voice,
641
665
  response_format,
642
666
  speed,
667
+ stream,
643
668
  )
644
669
  raise AttributeError(
645
670
  f"Model {self._model.model_spec} is not for creating speech."
@@ -657,7 +682,7 @@ class ModelActor(xo.StatelessActor):
657
682
  **kwargs,
658
683
  ):
659
684
  if hasattr(self._model, "text_to_image"):
660
- return await self._call_wrapper(
685
+ return await self._call_wrapper_json(
661
686
  self._model.text_to_image,
662
687
  prompt,
663
688
  n,
@@ -682,7 +707,7 @@ class ModelActor(xo.StatelessActor):
682
707
  **kwargs,
683
708
  ):
684
709
  if hasattr(self._model, "image_to_image"):
685
- return await self._call_wrapper(
710
+ return await self._call_wrapper_json(
686
711
  self._model.image_to_image,
687
712
  image,
688
713
  prompt,
@@ -697,6 +722,50 @@ class ModelActor(xo.StatelessActor):
697
722
  f"Model {self._model.model_spec} is not for creating image."
698
723
  )
699
724
 
725
+ async def inpainting(
726
+ self,
727
+ image: "PIL.Image",
728
+ mask_image: "PIL.Image",
729
+ prompt: str,
730
+ negative_prompt: str,
731
+ n: int = 1,
732
+ size: str = "1024*1024",
733
+ response_format: str = "url",
734
+ *args,
735
+ **kwargs,
736
+ ):
737
+ if hasattr(self._model, "inpainting"):
738
+ return await self._call_wrapper(
739
+ self._model.inpainting,
740
+ image,
741
+ mask_image,
742
+ prompt,
743
+ negative_prompt,
744
+ n,
745
+ size,
746
+ response_format,
747
+ *args,
748
+ **kwargs,
749
+ )
750
+ raise AttributeError(
751
+ f"Model {self._model.model_spec} is not for creating image."
752
+ )
753
+
754
+ @log_async(logger=logger)
755
+ @request_limit
756
+ async def infer(
757
+ self,
758
+ **kwargs,
759
+ ):
760
+ if hasattr(self._model, "infer"):
761
+ return await self._call_wrapper(
762
+ self._model.infer,
763
+ **kwargs,
764
+ )
765
+ raise AttributeError(
766
+ f"Model {self._model.model_spec} is not for flexible infer."
767
+ )
768
+
700
769
  async def record_metrics(self, name, op, kwargs):
701
770
  worker_ref = await self._get_worker_ref()
702
771
  await worker_ref.record_metrics(name, op, kwargs)