xinference 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (76) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +79 -2
  3. xinference/client/restful/restful_client.py +64 -2
  4. xinference/core/media_interface.py +123 -0
  5. xinference/core/model.py +31 -0
  6. xinference/core/supervisor.py +8 -17
  7. xinference/core/worker.py +5 -17
  8. xinference/deploy/cmdline.py +6 -2
  9. xinference/model/audio/chattts.py +24 -39
  10. xinference/model/audio/cosyvoice.py +18 -30
  11. xinference/model/audio/funasr.py +42 -0
  12. xinference/model/audio/model_spec.json +18 -0
  13. xinference/model/audio/model_spec_modelscope.json +19 -1
  14. xinference/model/audio/utils.py +75 -0
  15. xinference/model/core.py +1 -0
  16. xinference/model/embedding/__init__.py +74 -18
  17. xinference/model/embedding/core.py +98 -597
  18. xinference/model/embedding/embed_family.py +133 -0
  19. xinference/model/embedding/flag/__init__.py +13 -0
  20. xinference/model/embedding/flag/core.py +282 -0
  21. xinference/model/embedding/model_spec.json +24 -0
  22. xinference/model/embedding/model_spec_modelscope.json +24 -0
  23. xinference/model/embedding/sentence_transformers/__init__.py +13 -0
  24. xinference/model/embedding/sentence_transformers/core.py +399 -0
  25. xinference/model/embedding/vllm/__init__.py +0 -0
  26. xinference/model/embedding/vllm/core.py +95 -0
  27. xinference/model/image/model_spec.json +20 -2
  28. xinference/model/image/model_spec_modelscope.json +21 -2
  29. xinference/model/image/stable_diffusion/core.py +144 -53
  30. xinference/model/llm/llama_cpp/memory.py +4 -2
  31. xinference/model/llm/llm_family.json +57 -0
  32. xinference/model/llm/llm_family_modelscope.json +61 -0
  33. xinference/model/llm/sglang/core.py +4 -0
  34. xinference/model/llm/utils.py +11 -0
  35. xinference/model/llm/vllm/core.py +3 -0
  36. xinference/model/rerank/core.py +86 -4
  37. xinference/model/rerank/model_spec.json +24 -0
  38. xinference/model/rerank/model_spec_modelscope.json +24 -0
  39. xinference/model/rerank/utils.py +4 -3
  40. xinference/model/utils.py +38 -1
  41. xinference/model/video/diffusers.py +65 -3
  42. xinference/model/video/model_spec.json +31 -4
  43. xinference/model/video/model_spec_modelscope.json +32 -4
  44. xinference/web/ui/build/asset-manifest.json +6 -6
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.013f296b.css +2 -0
  47. xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
  49. xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
  56. xinference/web/ui/src/locales/en.json +18 -7
  57. xinference/web/ui/src/locales/ja.json +224 -0
  58. xinference/web/ui/src/locales/ko.json +224 -0
  59. xinference/web/ui/src/locales/zh.json +18 -7
  60. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/METADATA +9 -8
  61. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/RECORD +66 -57
  62. xinference/web/ui/build/static/css/main.337afe76.css +0 -2
  63. xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
  64. xinference/web/ui/build/static/js/main.ddf9eaee.js +0 -3
  65. xinference/web/ui/build/static/js/main.ddf9eaee.js.map +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
  71. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
  72. /xinference/web/ui/build/static/js/{main.ddf9eaee.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
  73. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/WHEEL +0 -0
  74. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
  75. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
  76. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-05-30T19:36:43+0800",
11
+ "date": "2025-06-13T18:51:07+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "72cc5e39040bdc49981b240c2b59af998554a75f",
15
- "version": "1.6.1"
14
+ "full-revisionid": "a362dba7334ef08c758bbc4a3d4904fe53cefe78",
15
+ "version": "1.7.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -387,6 +387,7 @@ class RESTfulAPI(CancelMixin):
387
387
  self._router.add_api_route(
388
388
  "/v1/cluster/auth", self.is_cluster_authenticated, methods=["GET"]
389
389
  )
390
+ # just for compatibility, LLM only
390
391
  self._router.add_api_route(
391
392
  "/v1/engines/{model_name}",
392
393
  self.query_engines_by_model_name,
@@ -397,6 +398,17 @@ class RESTfulAPI(CancelMixin):
397
398
  else None
398
399
  ),
399
400
  )
401
+ # engines for all model types
402
+ self._router.add_api_route(
403
+ "/v1/engines/{model_type}/{model_name}",
404
+ self.query_engines_by_model_name,
405
+ methods=["GET"],
406
+ dependencies=(
407
+ [Security(self._auth_service, scopes=["models:list"])]
408
+ if self.is_authenticated()
409
+ else None
410
+ ),
411
+ )
400
412
  # running instances
401
413
  self._router.add_api_route(
402
414
  "/v1/models/instances",
@@ -708,6 +720,17 @@ class RESTfulAPI(CancelMixin):
708
720
  else None
709
721
  ),
710
722
  )
723
+ self._router.add_api_route(
724
+ "/v1/video/generations/flf",
725
+ self.create_videos_from_first_last_frame,
726
+ methods=["POST"],
727
+ response_model=VideoList,
728
+ dependencies=(
729
+ [Security(self._auth_service, scopes=["models:read"])]
730
+ if self.is_authenticated()
731
+ else None
732
+ ),
733
+ )
711
734
  self._router.add_api_route(
712
735
  "/v1/chat/completions",
713
736
  self.create_chat_completion,
@@ -2084,6 +2107,57 @@ class RESTfulAPI(CancelMixin):
2084
2107
  self.handle_request_limit_error(e)
2085
2108
  raise HTTPException(status_code=500, detail=str(e))
2086
2109
 
2110
+ async def create_videos_from_first_last_frame(
2111
+ self,
2112
+ model: str = Form(...),
2113
+ first_frame: UploadFile = File(media_type="application/octet-stream"),
2114
+ last_frame: UploadFile = File(media_type="application/octet-stream"),
2115
+ prompt: Optional[Union[str, List[str]]] = Form(None),
2116
+ negative_prompt: Optional[Union[str, List[str]]] = Form(None),
2117
+ n: Optional[int] = Form(1),
2118
+ kwargs: Optional[str] = Form(None),
2119
+ ) -> Response:
2120
+ model_uid = model
2121
+ try:
2122
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
2123
+ except ValueError as ve:
2124
+ logger.error(str(ve), exc_info=True)
2125
+ await self._report_error_event(model_uid, str(ve))
2126
+ raise HTTPException(status_code=400, detail=str(ve))
2127
+ except Exception as e:
2128
+ logger.error(e, exc_info=True)
2129
+ await self._report_error_event(model_uid, str(e))
2130
+ raise HTTPException(status_code=500, detail=str(e))
2131
+
2132
+ request_id = None
2133
+ try:
2134
+ if kwargs is not None:
2135
+ parsed_kwargs = json.loads(kwargs)
2136
+ else:
2137
+ parsed_kwargs = {}
2138
+ request_id = parsed_kwargs.get("request_id")
2139
+ self._add_running_task(request_id)
2140
+ video_list = await model_ref.flf_to_video(
2141
+ first_frame=Image.open(first_frame.file),
2142
+ last_frame=Image.open(last_frame.file),
2143
+ prompt=prompt,
2144
+ negative_prompt=negative_prompt,
2145
+ n=n,
2146
+ **parsed_kwargs,
2147
+ )
2148
+ return Response(content=video_list, media_type="application/json")
2149
+ except asyncio.CancelledError:
2150
+ err_str = f"The request has been cancelled: {request_id}"
2151
+ logger.error(err_str)
2152
+ await self._report_error_event(model_uid, err_str)
2153
+ raise HTTPException(status_code=409, detail=err_str)
2154
+ except Exception as e:
2155
+ e = await self._get_model_last_error(model_ref.uid, e)
2156
+ logger.error(e, exc_info=True)
2157
+ await self._report_error_event(model_uid, str(e))
2158
+ self.handle_request_limit_error(e)
2159
+ raise HTTPException(status_code=500, detail=str(e))
2160
+
2087
2161
  async def create_chat_completion(self, request: Request) -> Response:
2088
2162
  raw_body = await request.json()
2089
2163
  body = CreateChatCompletion.parse_obj(raw_body)
@@ -2234,11 +2308,14 @@ class RESTfulAPI(CancelMixin):
2234
2308
  self.handle_request_limit_error(e)
2235
2309
  raise HTTPException(status_code=500, detail=str(e))
2236
2310
 
2237
- async def query_engines_by_model_name(self, model_name: str) -> JSONResponse:
2311
+ async def query_engines_by_model_name(
2312
+ self, request: Request, model_name: str, model_type: Optional[str] = None
2313
+ ) -> JSONResponse:
2238
2314
  try:
2315
+ model_type = model_type or request.path_params.get("model_type", "LLM")
2239
2316
  content = await (
2240
2317
  await self._get_supervisor_ref()
2241
- ).query_engines_by_model_name(model_name)
2318
+ ).query_engines_by_model_name(model_name, model_type=model_type)
2242
2319
  return JSONResponse(content=content)
2243
2320
  except ValueError as re:
2244
2321
  logger.error(re, exc_info=True)
@@ -510,6 +510,59 @@ class RESTfulVideoModelHandle(RESTfulModelHandle):
510
510
  response_data = response.json()
511
511
  return response_data
512
512
 
513
+ def flf_to_video(
514
+ self,
515
+ first_frame: Union[str, bytes],
516
+ last_frame: Union[str, bytes],
517
+ prompt: str,
518
+ negative_prompt: Optional[str] = None,
519
+ n: int = 1,
520
+ **kwargs,
521
+ ) -> "VideoList":
522
+ """
523
+ Creates a video by the first frame, last frame and text.
524
+
525
+ Parameters
526
+ ----------
527
+ first_frame: `Union[str, bytes]`
528
+ The first frame to condition the generation on.
529
+ last_frame: `Union[str, bytes]`
530
+ The last frame to condition the generation on.
531
+ prompt: `str` or `List[str]`
532
+ The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
533
+ negative_prompt (`str` or `List[str]`, *optional*):
534
+ The prompt or prompts not to guide the image generation.
535
+ n: `int`, defaults to 1
536
+ The number of videos to generate per prompt. Must be between 1 and 10.
537
+ Returns
538
+ -------
539
+ VideoList
540
+ A list of video objects.
541
+ """
542
+ url = f"{self._base_url}/v1/video/generations/flf"
543
+ params = {
544
+ "model": self._model_uid,
545
+ "prompt": prompt,
546
+ "negative_prompt": negative_prompt,
547
+ "n": n,
548
+ "kwargs": json.dumps(kwargs),
549
+ }
550
+ files: List[Any] = []
551
+ for key, value in params.items():
552
+ files.append((key, (None, value)))
553
+ files.append(
554
+ ("first_frame", ("image", first_frame, "application/octet-stream"))
555
+ )
556
+ files.append(("last_frame", ("image", last_frame, "application/octet-stream")))
557
+ response = requests.post(url, files=files, headers=self.auth_headers)
558
+ if response.status_code != 200:
559
+ raise RuntimeError(
560
+ f"Failed to create the video from image, detail: {_get_error_string(response)}"
561
+ )
562
+
563
+ response_data = response.json()
564
+ return response_data
565
+
513
566
 
514
567
  class RESTfulGenerateModelHandle(RESTfulModelHandle):
515
568
  def generate(
@@ -637,6 +690,7 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
637
690
  response_format: Optional[str] = "json",
638
691
  temperature: Optional[float] = 0,
639
692
  timestamp_granularities: Optional[List[str]] = None,
693
+ **kwargs,
640
694
  ):
641
695
  """
642
696
  Transcribes audio into the input language.
@@ -678,6 +732,7 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
678
732
  "response_format": response_format,
679
733
  "temperature": temperature,
680
734
  "timestamp_granularities[]": timestamp_granularities,
735
+ "kwargs": json.dumps(kwargs),
681
736
  }
682
737
  files: List[Any] = []
683
738
  files.append(("file", ("file", audio, "application/octet-stream")))
@@ -1502,7 +1557,9 @@ class Client:
1502
1557
  response_data = response.json()
1503
1558
  return response_data
1504
1559
 
1505
- def query_engine_by_model_name(self, model_name: str):
1560
+ def query_engine_by_model_name(
1561
+ self, model_name: str, model_type: Optional[str] = "LLM"
1562
+ ):
1506
1563
  """
1507
1564
  Get the engine parameters with the model name registered on the server.
1508
1565
 
@@ -1510,12 +1567,17 @@ class Client:
1510
1567
  ----------
1511
1568
  model_name: str
1512
1569
  The name of the model.
1570
+ model_type: str
1571
+ Model type, LLM by default.
1513
1572
  Returns
1514
1573
  -------
1515
1574
  Dict[str, List[Dict[str, Any]]]
1516
1575
  The supported engine parameters of registered models on the server.
1517
1576
  """
1518
- url = f"{self.base_url}/v1/engines/{model_name}"
1577
+ if not model_type:
1578
+ url = f"{self.base_url}/v1/engines/{model_name}"
1579
+ else:
1580
+ url = f"{self.base_url}/v1/engines/{model_type}/{model_name}"
1519
1581
  response = requests.get(url, headers=self._headers)
1520
1582
  if response.status_code != 200:
1521
1583
  raise RuntimeError(
@@ -577,6 +577,126 @@ class MediaInterface:
577
577
 
578
578
  return image2video_ui
579
579
 
580
+ def flf2video_interface(self) -> "gr.Blocks":
581
+ def generate_video_from_flf(
582
+ first_frame: "PIL.Image.Image",
583
+ last_frame: "PIL.Image.Image",
584
+ prompt: str,
585
+ negative_prompt: str,
586
+ num_frames: int,
587
+ fps: int,
588
+ num_inference_steps: int,
589
+ guidance_scale: float,
590
+ width: int,
591
+ height: int,
592
+ progress=gr.Progress(),
593
+ ) -> List[Tuple[str, str]]:
594
+ from ..client import RESTfulClient
595
+
596
+ client = RESTfulClient(self.endpoint)
597
+ client._set_token(self.access_token)
598
+ model = client.get_model(self.model_uid)
599
+ assert hasattr(model, "flf_to_video")
600
+
601
+ request_id = str(uuid.uuid4())
602
+ response = None
603
+ exc = None
604
+
605
+ buffer_first = io.BytesIO()
606
+ buffer_last = io.BytesIO()
607
+ first_frame.save(buffer_first, format="PNG")
608
+ last_frame.save(buffer_last, format="PNG")
609
+
610
+ def run_in_thread():
611
+ nonlocal exc, response
612
+ try:
613
+ response = model.flf_to_video(
614
+ first_frame=buffer_first.getvalue(),
615
+ last_frame=buffer_last.getvalue(),
616
+ prompt=prompt,
617
+ negative_prompt=negative_prompt,
618
+ n=1,
619
+ num_frames=num_frames,
620
+ fps=fps,
621
+ num_inference_steps=num_inference_steps,
622
+ guidance_scale=guidance_scale,
623
+ width=width,
624
+ height=height,
625
+ response_format="b64_json",
626
+ request_id=request_id,
627
+ )
628
+ except Exception as e:
629
+ exc = e
630
+
631
+ t = threading.Thread(target=run_in_thread)
632
+ t.start()
633
+
634
+ while t.is_alive():
635
+ try:
636
+ cur_progress = client.get_progress(request_id)["progress"]
637
+ except Exception:
638
+ cur_progress = 0.0
639
+ progress(cur_progress, desc="Generating video from first/last frames")
640
+ time.sleep(1)
641
+
642
+ if exc:
643
+ raise exc
644
+
645
+ videos = []
646
+ for video_dict in response["data"]: # type: ignore
647
+ video_data = base64.b64decode(video_dict["b64_json"])
648
+ video_path = f"/tmp/{uuid.uuid4()}.mp4"
649
+ with open(video_path, "wb") as f:
650
+ f.write(video_data)
651
+ videos.append((video_path, "Generated Video"))
652
+
653
+ return videos
654
+
655
+ # Gradio UI
656
+ with gr.Blocks() as flf2video_ui:
657
+ with gr.Row():
658
+ first_frame = gr.Image(label="First Frame", type="pil")
659
+ last_frame = gr.Image(label="Last Frame", type="pil")
660
+
661
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter video prompt")
662
+ negative_prompt = gr.Textbox(
663
+ label="Negative Prompt", placeholder="Enter negative prompt"
664
+ )
665
+
666
+ with gr.Row():
667
+ with gr.Column():
668
+ width = gr.Number(label="Width", value=512)
669
+ num_frames = gr.Number(label="Frames", value=16)
670
+ steps = gr.Number(label="Inference Steps", value=25)
671
+ with gr.Column():
672
+ height = gr.Number(label="Height", value=512)
673
+ fps = gr.Number(label="FPS", value=8)
674
+ guidance_scale = gr.Slider(
675
+ label="Guidance Scale", minimum=1, maximum=20, value=7.5
676
+ )
677
+
678
+ generate = gr.Button("Generate")
679
+ gallery = gr.Gallery(label="Generated Videos", columns=2)
680
+
681
+ generate.click(
682
+ fn=generate_video_from_flf,
683
+ inputs=[
684
+ first_frame,
685
+ last_frame,
686
+ prompt,
687
+ negative_prompt,
688
+ num_frames,
689
+ fps,
690
+ steps,
691
+ guidance_scale,
692
+ width,
693
+ height,
694
+ ],
695
+ outputs=gallery,
696
+ )
697
+
698
+ return flf2video_ui
699
+
580
700
  def audio2text_interface(self) -> "gr.Blocks":
581
701
  def transcribe_audio(
582
702
  audio_path: str,
@@ -750,6 +870,9 @@ class MediaInterface:
750
870
  if "image2video" in self.model_ability:
751
871
  with gr.Tab("Image to Video"):
752
872
  self.image2video_interface()
873
+ if "firstlastframe2video" in self.model_ability:
874
+ with gr.Tab("FirstLastFrame to Video"):
875
+ self.flf2video_interface()
753
876
  if "audio2text" in self.model_ability:
754
877
  with gr.Tab("Audio to Text"):
755
878
  self.audio2text_interface()
xinference/core/model.py CHANGED
@@ -1289,6 +1289,37 @@ class ModelActor(xo.StatelessActor, CancelMixin):
1289
1289
  f"Model {self._model.model_spec} is not for creating video from image."
1290
1290
  )
1291
1291
 
1292
+ @request_limit
1293
+ @log_async(logger=logger)
1294
+ async def flf_to_video(
1295
+ self,
1296
+ first_frame: "PIL.Image.Image",
1297
+ last_frame: "PIL.Image.Image",
1298
+ prompt: str,
1299
+ negative_prompt: Optional[str] = None,
1300
+ n: int = 1,
1301
+ *args,
1302
+ **kwargs,
1303
+ ):
1304
+ kwargs["negative_prompt"] = negative_prompt
1305
+ progressor = kwargs["progressor"] = await self._get_progressor(
1306
+ kwargs.pop("request_id", None)
1307
+ )
1308
+ with progressor:
1309
+ if hasattr(self._model, "firstlastframe_to_video"):
1310
+ return await self._call_wrapper_json(
1311
+ self._model.firstlastframe_to_video,
1312
+ first_frame,
1313
+ last_frame,
1314
+ prompt,
1315
+ n,
1316
+ *args,
1317
+ **kwargs,
1318
+ )
1319
+ raise AttributeError(
1320
+ f"Model {self._model.model_spec} is not for creating video from first-last-frame."
1321
+ )
1322
+
1292
1323
  async def record_metrics(self, name, op, kwargs):
1293
1324
  worker_ref = await self._get_worker_ref()
1294
1325
  await worker_ref.record_metrics(name, op, kwargs)
@@ -45,6 +45,7 @@ from ..constants import (
45
45
  )
46
46
  from ..core.model import ModelActor
47
47
  from ..core.status_guard import InstanceInfo, LaunchStatus
48
+ from ..model.utils import get_engine_params_by_name
48
49
  from ..types import PeftModelConfig
49
50
  from .metrics import record_metrics
50
51
  from .resource import GPUStatus, ResourceStatus
@@ -780,29 +781,19 @@ class SupervisorActor(xo.StatelessActor):
780
781
  raise ValueError(f"Unsupported model type: {model_type}")
781
782
 
782
783
  @log_async(logger=logger)
783
- async def query_engines_by_model_name(self, model_name: str):
784
- from copy import deepcopy
785
-
786
- from ..model.llm.llm_family import LLM_ENGINES
787
-
784
+ async def query_engines_by_model_name(
785
+ self, model_name: str, model_type: Optional[str] = None
786
+ ):
788
787
  # search in worker first
789
788
  workers = list(self._worker_address_to_worker.values())
790
789
  for worker in workers:
791
- res = await worker.query_engines_by_model_name(model_name)
790
+ res = await worker.query_engines_by_model_name(
791
+ model_name, model_type=model_type
792
+ )
792
793
  if res is not None:
793
794
  return res
794
795
 
795
- if model_name not in LLM_ENGINES:
796
- raise ValueError(f"Model {model_name} not found")
797
-
798
- # filter llm_class
799
- engine_params = deepcopy(LLM_ENGINES[model_name])
800
- for engine in engine_params:
801
- params = engine_params[engine]
802
- for param in params:
803
- del param["llm_class"]
804
-
805
- return engine_params
796
+ return get_engine_params_by_name(model_type, model_name)
806
797
 
807
798
  @log_async(logger=logger)
808
799
  async def register_model(
xinference/core/worker.py CHANGED
@@ -53,7 +53,7 @@ from ..core.model import ModelActor
53
53
  from ..core.status_guard import LaunchStatus
54
54
  from ..device_utils import get_available_device_env_name, gpu_count
55
55
  from ..model.core import ModelDescription, VirtualEnvSettings, create_model_instance
56
- from ..model.utils import CancellableDownloader
56
+ from ..model.utils import CancellableDownloader, get_engine_params_by_name
57
57
  from ..types import PeftModelConfig
58
58
  from ..utils import get_pip_config_args, get_real_path
59
59
  from .cache_tracker import CacheTrackerActor
@@ -747,22 +747,10 @@ class WorkerActor(xo.StatelessActor):
747
747
  return None
748
748
 
749
749
  @log_async(logger=logger)
750
- async def query_engines_by_model_name(self, model_name: str):
751
- from copy import deepcopy
752
-
753
- from ..model.llm.llm_family import LLM_ENGINES
754
-
755
- if model_name not in LLM_ENGINES:
756
- return None
757
-
758
- # filter llm_class
759
- engine_params = deepcopy(LLM_ENGINES[model_name])
760
- for engine in engine_params:
761
- params = engine_params[engine]
762
- for param in params:
763
- del param["llm_class"]
764
-
765
- return engine_params
750
+ async def query_engines_by_model_name(
751
+ self, model_name: str, model_type: Optional[str] = None
752
+ ):
753
+ return get_engine_params_by_name(model_type, model_name)
766
754
 
767
755
  async def _get_model_ability(self, model: Any, model_type: str) -> List[str]:
768
756
  from ..model.llm.core import LLM
@@ -1315,8 +1315,12 @@ def model_chat(
1315
1315
  if "content" not in delta:
1316
1316
  continue
1317
1317
  else:
1318
- response_content += delta["content"]
1319
- print(delta["content"], end="", flush=True, file=sys.stdout)
1318
+ # The first chunk of stream output may have no content (None). Related PRs:
1319
+ # https://github.com/ggml-org/llama.cpp/pull/13634
1320
+ # https://github.com/ggml-org/llama.cpp/pull/12379
1321
+ content = delta["content"] or ""
1322
+ response_content += content
1323
+ print(content, end="", flush=True, file=sys.stdout)
1320
1324
  print("", file=sys.stdout)
1321
1325
  messages.append(dict(role="assistant", content=response_content))
1322
1326
 
@@ -71,9 +71,10 @@ class ChatTTSModel:
71
71
  import ChatTTS
72
72
  import numpy as np
73
73
  import torch
74
- import torchaudio
75
74
  import xxhash
76
75
 
76
+ from .utils import audio_stream_generator, audio_to_bytes
77
+
77
78
  rnd_spk_emb = None
78
79
 
79
80
  if len(voice) > 400:
@@ -105,44 +106,28 @@ class ChatTTSModel:
105
106
  )
106
107
 
107
108
  assert self._model is not None
109
+
110
+ output = self._model.infer(
111
+ [input], params_infer_code=params_infer_code, stream=stream
112
+ )
108
113
  if stream:
109
- iter = self._model.infer(
110
- [input], params_infer_code=params_infer_code, stream=True
111
- )
112
114
 
113
- def _generator():
114
- with BytesIO() as out:
115
- writer = torchaudio.io.StreamWriter(out, format=response_format)
116
- writer.add_audio_stream(sample_rate=24000, num_channels=1)
117
- i = 0
118
- last_pos = 0
119
- with writer.open():
120
- for it in iter:
121
- for chunk in it:
122
- chunk = np.array([chunk]).transpose()
123
- writer.write_audio_chunk(i, torch.from_numpy(chunk))
124
- new_last_pos = out.tell()
125
- if new_last_pos != last_pos:
126
- out.seek(last_pos)
127
- encoded_bytes = out.read()
128
- yield encoded_bytes
129
- last_pos = new_last_pos
130
-
131
- return _generator()
115
+ def _gen_chunk():
116
+ for it in output:
117
+ for chunk in it:
118
+ yield chunk
119
+
120
+ return audio_stream_generator(
121
+ response_format=response_format,
122
+ sample_rate=24000,
123
+ output_generator=_gen_chunk(),
124
+ output_chunk_transformer=lambda c: torch.from_numpy(
125
+ np.array([c]).transpose()
126
+ ),
127
+ )
132
128
  else:
133
- wavs = self._model.infer([input], params_infer_code=params_infer_code)
134
-
135
- # Save the generated audio
136
- with BytesIO() as out:
137
- try:
138
- torchaudio.save(
139
- out,
140
- torch.from_numpy(wavs[0]).unsqueeze(0),
141
- 24000,
142
- format=response_format,
143
- )
144
- except:
145
- torchaudio.save(
146
- out, torch.from_numpy(wavs[0]), 24000, format=response_format
147
- )
148
- return out.getvalue()
129
+ return audio_to_bytes(
130
+ response_format=response_format,
131
+ sample_rate=24000,
132
+ tensor=torch.from_numpy(output[0]).unsqueeze(0),
133
+ )
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  import io
15
15
  import logging
16
- from io import BytesIO
17
16
  from typing import TYPE_CHECKING, Optional
18
17
 
19
18
  from ..utils import set_all_random_seed
@@ -132,36 +131,25 @@ class CosyVoiceModel:
132
131
  output = self._model.inference_sft(input, voice, stream=stream)
133
132
 
134
133
  import torch
135
- import torchaudio
136
134
 
137
- def _generator_stream():
138
- with BytesIO() as out:
139
- writer = torchaudio.io.StreamWriter(out, format=response_format)
140
- writer.add_audio_stream(
141
- sample_rate=self._model.sample_rate, num_channels=1
142
- )
143
- i = 0
144
- last_pos = 0
145
- with writer.open():
146
- for chunk in output:
147
- chunk = chunk["tts_speech"]
148
- trans_chunk = torch.transpose(chunk, 0, 1)
149
- writer.write_audio_chunk(i, trans_chunk)
150
- new_last_pos = out.tell()
151
- if new_last_pos != last_pos:
152
- out.seek(last_pos)
153
- encoded_bytes = out.read()
154
- yield encoded_bytes
155
- last_pos = new_last_pos
156
-
157
- def _generator_block():
158
- chunks = [o["tts_speech"] for o in output]
159
- t = torch.cat(chunks, dim=1)
160
- with BytesIO() as out:
161
- torchaudio.save(out, t, self._model.sample_rate, format=response_format)
162
- return out.getvalue()
163
-
164
- return _generator_stream() if stream else _generator_block()
135
+ from .utils import audio_stream_generator, audio_to_bytes
136
+
137
+ return (
138
+ audio_stream_generator(
139
+ response_format=response_format,
140
+ sample_rate=self._model.sample_rate,
141
+ output_generator=output,
142
+ output_chunk_transformer=lambda c: torch.transpose(
143
+ c["tts_speech"], 0, 1
144
+ ),
145
+ )
146
+ if stream
147
+ else audio_to_bytes(
148
+ response_format=response_format,
149
+ sample_rate=self._model.sample_rate,
150
+ tensor=torch.cat([o["tts_speech"] for o in output], dim=1),
151
+ )
152
+ )
165
153
 
166
154
  def speech(
167
155
  self,