xinference 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (103) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +30 -5
  4. xinference/client/restful/restful_client.py +18 -3
  5. xinference/constants.py +0 -4
  6. xinference/core/chat_interface.py +2 -2
  7. xinference/core/image_interface.py +6 -3
  8. xinference/core/model.py +9 -4
  9. xinference/core/scheduler.py +4 -4
  10. xinference/core/supervisor.py +2 -0
  11. xinference/core/worker.py +7 -0
  12. xinference/deploy/utils.py +6 -0
  13. xinference/model/audio/core.py +9 -4
  14. xinference/model/audio/cosyvoice.py +136 -0
  15. xinference/model/audio/model_spec.json +24 -0
  16. xinference/model/audio/model_spec_modelscope.json +27 -0
  17. xinference/model/core.py +25 -4
  18. xinference/model/embedding/core.py +88 -13
  19. xinference/model/embedding/model_spec.json +8 -0
  20. xinference/model/embedding/model_spec_modelscope.json +8 -0
  21. xinference/model/flexible/core.py +8 -2
  22. xinference/model/flexible/launchers/__init__.py +1 -0
  23. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  24. xinference/model/image/core.py +8 -5
  25. xinference/model/image/model_spec.json +36 -5
  26. xinference/model/image/model_spec_modelscope.json +21 -3
  27. xinference/model/image/stable_diffusion/core.py +36 -28
  28. xinference/model/llm/core.py +6 -4
  29. xinference/model/llm/ggml/llamacpp.py +7 -5
  30. xinference/model/llm/llm_family.json +802 -82
  31. xinference/model/llm/llm_family.py +6 -6
  32. xinference/model/llm/llm_family_csghub.json +39 -0
  33. xinference/model/llm/llm_family_modelscope.json +295 -47
  34. xinference/model/llm/mlx/core.py +7 -0
  35. xinference/model/llm/pytorch/chatglm.py +246 -5
  36. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  37. xinference/model/llm/pytorch/deepseek_vl.py +2 -1
  38. xinference/model/llm/pytorch/falcon.py +2 -1
  39. xinference/model/llm/pytorch/llama_2.py +4 -2
  40. xinference/model/llm/pytorch/omnilmm.py +2 -1
  41. xinference/model/llm/pytorch/qwen_vl.py +2 -1
  42. xinference/model/llm/pytorch/vicuna.py +2 -1
  43. xinference/model/llm/pytorch/yi_vl.py +2 -1
  44. xinference/model/llm/sglang/core.py +12 -6
  45. xinference/model/llm/utils.py +78 -1
  46. xinference/model/llm/vllm/core.py +9 -5
  47. xinference/model/rerank/core.py +4 -3
  48. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  50. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  51. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  52. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  53. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  54. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  55. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  56. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  57. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  58. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  59. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  61. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  63. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  64. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  65. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  66. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  67. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  68. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  69. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  70. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  71. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  72. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  73. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  74. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  77. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  78. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  79. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  80. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  81. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  82. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  83. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  84. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  85. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  86. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  87. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  88. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  89. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  90. xinference/web/ui/build/asset-manifest.json +3 -3
  91. xinference/web/ui/build/index.html +1 -1
  92. xinference/web/ui/build/static/js/{main.95c1d652.js → main.af906659.js} +3 -3
  93. xinference/web/ui/build/static/js/main.af906659.js.map +1 -0
  94. xinference/web/ui/node_modules/.cache/babel-loader/2cd5e4279ad7e13a1f41d486e9fca7756295bfad5bd77d90992f4ac3e10b496d.json +1 -0
  95. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/METADATA +39 -11
  96. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/RECORD +101 -57
  97. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  99. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.af906659.js.LICENSE.txt} +0 -0
  100. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/LICENSE +0 -0
  101. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/WHEEL +0 -0
  102. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/entry_points.txt +0 -0
  103. {xinference-0.13.2.dist-info → xinference-0.13.4.dist-info}/top_level.txt +0 -0
xinference/__init__.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from . import _version
17
16
 
18
17
  __version__ = _version.get_versions()["version"]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-19T19:15:54+0800",
11
+ "date": "2024-08-02T16:08:07+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "880929cbbc73e5206ca069591b03d9d16dd858bf",
15
- "version": "0.13.2"
14
+ "full-revisionid": "dd85cfe015c9cd2d8110c79213640aa0e21f3a6a",
15
+ "version": "0.13.4"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -130,6 +130,7 @@ class SpeechRequest(BaseModel):
130
130
  response_format: Optional[str] = "mp3"
131
131
  speed: Optional[float] = 1.0
132
132
  stream: Optional[bool] = False
133
+ kwargs: Optional[str] = None
133
134
 
134
135
 
135
136
  class RegisterModelRequest(BaseModel):
@@ -796,6 +797,7 @@ class RESTfulAPI:
796
797
  worker_ip = payload.get("worker_ip", None)
797
798
  gpu_idx = payload.get("gpu_idx", None)
798
799
  download_hub = payload.get("download_hub", None)
800
+ model_path = payload.get("model_path", None)
799
801
 
800
802
  exclude_keys = {
801
803
  "model_uid",
@@ -812,6 +814,7 @@ class RESTfulAPI:
812
814
  "worker_ip",
813
815
  "gpu_idx",
814
816
  "download_hub",
817
+ "model_path",
815
818
  }
816
819
 
817
820
  kwargs = {
@@ -860,6 +863,7 @@ class RESTfulAPI:
860
863
  worker_ip=worker_ip,
861
864
  gpu_idx=gpu_idx,
862
865
  download_hub=download_hub,
866
+ model_path=model_path,
863
867
  **kwargs,
864
868
  )
865
869
  except ValueError as ve:
@@ -1309,8 +1313,18 @@ class RESTfulAPI:
1309
1313
  await self._report_error_event(model_uid, str(e))
1310
1314
  raise HTTPException(status_code=500, detail=str(e))
1311
1315
 
1312
- async def create_speech(self, request: Request) -> Response:
1313
- body = SpeechRequest.parse_obj(await request.json())
1316
+ async def create_speech(
1317
+ self,
1318
+ request: Request,
1319
+ prompt_speech: Optional[UploadFile] = File(
1320
+ None, media_type="application/octet-stream"
1321
+ ),
1322
+ ) -> Response:
1323
+ if prompt_speech:
1324
+ f = await request.form()
1325
+ else:
1326
+ f = await request.json()
1327
+ body = SpeechRequest.parse_obj(f)
1314
1328
  model_uid = body.model
1315
1329
  try:
1316
1330
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1324,12 +1338,19 @@ class RESTfulAPI:
1324
1338
  raise HTTPException(status_code=500, detail=str(e))
1325
1339
 
1326
1340
  try:
1341
+ if body.kwargs is not None:
1342
+ parsed_kwargs = json.loads(body.kwargs)
1343
+ else:
1344
+ parsed_kwargs = {}
1345
+ if prompt_speech is not None:
1346
+ parsed_kwargs["prompt_speech"] = await prompt_speech.read()
1327
1347
  out = await model.speech(
1328
1348
  input=body.input,
1329
1349
  voice=body.voice,
1330
1350
  response_format=body.response_format,
1331
1351
  speed=body.speed,
1332
1352
  stream=body.stream,
1353
+ **parsed_kwargs,
1333
1354
  )
1334
1355
  if body.stream:
1335
1356
  return EventSourceResponse(
@@ -1389,7 +1410,7 @@ class RESTfulAPI:
1389
1410
  negative_prompt: Optional[Union[str, List[str]]] = Form(None),
1390
1411
  n: Optional[int] = Form(1),
1391
1412
  response_format: Optional[str] = Form("url"),
1392
- size: Optional[str] = Form("1024*1024"),
1413
+ size: Optional[str] = Form(None),
1393
1414
  kwargs: Optional[str] = Form(None),
1394
1415
  ) -> Response:
1395
1416
  model_uid = model
@@ -1626,10 +1647,14 @@ class RESTfulAPI:
1626
1647
  if body.tools and body.stream:
1627
1648
  is_vllm = await model.is_vllm_backend()
1628
1649
 
1629
- if not is_vllm or model_family not in QWEN_TOOL_CALL_FAMILY:
1650
+ if not (
1651
+ (is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
1652
+ or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
1653
+ ):
1630
1654
  raise HTTPException(
1631
1655
  status_code=400,
1632
- detail="Streaming support for tool calls is available only when using vLLM backend and Qwen models.",
1656
+ detail="Streaming support for tool calls is available only when using "
1657
+ "Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
1633
1658
  )
1634
1659
 
1635
1660
  if body.stream:
@@ -234,9 +234,9 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
234
234
  self,
235
235
  image: Union[str, bytes],
236
236
  prompt: str,
237
- negative_prompt: str,
237
+ negative_prompt: Optional[str] = None,
238
238
  n: int = 1,
239
- size: str = "1024*1024",
239
+ size: Optional[str] = None,
240
240
  response_format: str = "url",
241
241
  **kwargs,
242
242
  ) -> "ImageList":
@@ -768,6 +768,8 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
768
768
  response_format: str = "mp3",
769
769
  speed: float = 1.0,
770
770
  stream: bool = False,
771
+ prompt_speech: Optional[bytes] = None,
772
+ **kwargs,
771
773
  ):
772
774
  """
773
775
  Generates audio from the input text.
@@ -799,8 +801,21 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
799
801
  "response_format": response_format,
800
802
  "speed": speed,
801
803
  "stream": stream,
804
+ "kwargs": json.dumps(kwargs),
802
805
  }
803
- response = requests.post(url, json=params, headers=self.auth_headers)
806
+ if prompt_speech:
807
+ files: List[Any] = []
808
+ files.append(
809
+ (
810
+ "prompt_speech",
811
+ ("prompt_speech", prompt_speech, "application/octet-stream"),
812
+ )
813
+ )
814
+ response = requests.post(
815
+ url, data=params, files=files, headers=self.auth_headers
816
+ )
817
+ else:
818
+ response = requests.post(url, json=params, headers=self.auth_headers)
804
819
  if response.status_code != 200:
805
820
  raise RuntimeError(
806
821
  f"Failed to speech the text, detail: {_get_error_string(response)}"
xinference/constants.py CHANGED
@@ -26,8 +26,6 @@ XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
26
26
  XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
27
27
  XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
28
28
  XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
29
- XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
30
- XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
31
29
  XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
32
30
  XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
33
31
 
@@ -72,8 +70,6 @@ XINFERENCE_HEALTH_CHECK_TIMEOUT = int(
72
70
  XINFERENCE_DISABLE_HEALTH_CHECK = bool(
73
71
  int(os.environ.get(XINFERENCE_ENV_DISABLE_HEALTH_CHECK, 0))
74
72
  )
75
- XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
76
- XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG, 0)))
77
73
  XINFERENCE_DISABLE_METRICS = bool(
78
74
  int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
79
75
  )
@@ -428,7 +428,7 @@ class GradioInterface:
428
428
  }
429
429
 
430
430
  hist.append(response_content)
431
- return {
431
+ return { # type: ignore
432
432
  textbox: response_content,
433
433
  history: hist,
434
434
  }
@@ -467,7 +467,7 @@ class GradioInterface:
467
467
  }
468
468
 
469
469
  hist.append(response_content)
470
- return {
470
+ return { # type: ignore
471
471
  textbox: response_content,
472
472
  history: hist,
473
473
  }
@@ -153,7 +153,10 @@ class ImageInterface:
153
153
  model = client.get_model(self.model_uid)
154
154
  assert isinstance(model, RESTfulImageModelHandle)
155
155
 
156
- size = f"{int(size_width)}*{int(size_height)}"
156
+ if size_width > 0 and size_height > 0:
157
+ size = f"{int(size_width)}*{int(size_height)}"
158
+ else:
159
+ size = None
157
160
 
158
161
  bio = io.BytesIO()
159
162
  image.save(bio, format="png")
@@ -195,8 +198,8 @@ class ImageInterface:
195
198
 
196
199
  with gr.Row():
197
200
  n = gr.Number(label="Number of image", value=1)
198
- size_width = gr.Number(label="Width", value=512)
199
- size_height = gr.Number(label="Height", value=512)
201
+ size_width = gr.Number(label="Width", value=-1)
202
+ size_height = gr.Number(label="Height", value=-1)
200
203
 
201
204
  with gr.Row():
202
205
  with gr.Column(scale=1):
xinference/core/model.py CHANGED
@@ -646,7 +646,10 @@ class ModelActor(xo.StatelessActor):
646
646
  f"Model {self._model.model_spec} is not for creating translations."
647
647
  )
648
648
 
649
- @log_async(logger=logger)
649
+ @log_async(
650
+ logger=logger,
651
+ args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
652
+ )
650
653
  @request_limit
651
654
  @xo.generator
652
655
  async def speech(
@@ -656,6 +659,7 @@ class ModelActor(xo.StatelessActor):
656
659
  response_format: str = "mp3",
657
660
  speed: float = 1.0,
658
661
  stream: bool = False,
662
+ **kwargs,
659
663
  ):
660
664
  if hasattr(self._model, "speech"):
661
665
  return await self._call_wrapper_binary(
@@ -665,6 +669,7 @@ class ModelActor(xo.StatelessActor):
665
669
  response_format,
666
670
  speed,
667
671
  stream,
672
+ **kwargs,
668
673
  )
669
674
  raise AttributeError(
670
675
  f"Model {self._model.model_spec} is not for creating speech."
@@ -701,7 +706,7 @@ class ModelActor(xo.StatelessActor):
701
706
  prompt: str,
702
707
  negative_prompt: str,
703
708
  n: int = 1,
704
- size: str = "1024*1024",
709
+ size: Optional[str] = None,
705
710
  response_format: str = "url",
706
711
  *args,
707
712
  **kwargs,
@@ -735,7 +740,7 @@ class ModelActor(xo.StatelessActor):
735
740
  **kwargs,
736
741
  ):
737
742
  if hasattr(self._model, "inpainting"):
738
- return await self._call_wrapper(
743
+ return await self._call_wrapper_json(
739
744
  self._model.inpainting,
740
745
  image,
741
746
  mask_image,
@@ -758,7 +763,7 @@ class ModelActor(xo.StatelessActor):
758
763
  **kwargs,
759
764
  ):
760
765
  if hasattr(self._model, "infer"):
761
- return await self._call_wrapper(
766
+ return await self._call_wrapper_json(
762
767
  self._model.infer,
763
768
  **kwargs,
764
769
  )
@@ -81,7 +81,7 @@ class InferenceRequest:
81
81
  self.future_or_queue = future_or_queue
82
82
  # Record error message when this request has error.
83
83
  # Must set stopped=True when this field is set.
84
- self.error_msg: Optional[str] = None
84
+ self.error_msg: Optional[str] = None # type: ignore
85
85
  # For compatibility. Record some extra parameters for some special cases.
86
86
  self.extra_kwargs = {}
87
87
 
@@ -295,11 +295,11 @@ class SchedulerActor(xo.StatelessActor):
295
295
 
296
296
  def __init__(self):
297
297
  super().__init__()
298
- self._waiting_queue: deque[InferenceRequest] = deque()
299
- self._running_queue: deque[InferenceRequest] = deque()
298
+ self._waiting_queue: deque[InferenceRequest] = deque() # type: ignore
299
+ self._running_queue: deque[InferenceRequest] = deque() # type: ignore
300
300
  self._model = None
301
301
  self._id_to_req = {}
302
- self._abort_req_ids: Set[str] = set()
302
+ self._abort_req_ids: Set[str] = set() # type: ignore
303
303
  self._isolation = None
304
304
 
305
305
  async def __post_create__(self):
@@ -859,6 +859,7 @@ class SupervisorActor(xo.StatelessActor):
859
859
  worker_ip: Optional[str] = None,
860
860
  gpu_idx: Optional[Union[int, List[int]]] = None,
861
861
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
862
+ model_path: Optional[str] = None,
862
863
  **kwargs,
863
864
  ) -> str:
864
865
  # search in worker first
@@ -942,6 +943,7 @@ class SupervisorActor(xo.StatelessActor):
942
943
  peft_model_config=peft_model_config,
943
944
  gpu_idx=replica_gpu_idx,
944
945
  download_hub=download_hub,
946
+ model_path=model_path,
945
947
  **kwargs,
946
948
  )
947
949
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
xinference/core/worker.py CHANGED
@@ -743,6 +743,7 @@ class WorkerActor(xo.StatelessActor):
743
743
  request_limits: Optional[int] = None,
744
744
  gpu_idx: Optional[Union[int, List[int]]] = None,
745
745
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
746
+ model_path: Optional[str] = None,
746
747
  **kwargs,
747
748
  ):
748
749
  # !!! Note that The following code must be placed at the very beginning of this function,
@@ -799,6 +800,11 @@ class WorkerActor(xo.StatelessActor):
799
800
  raise ValueError(
800
801
  f"PEFT adaptors can only be applied to pytorch-like models"
801
802
  )
803
+ if model_path is not None:
804
+ if not os.path.exists(model_path):
805
+ raise ValueError(
806
+ f"Invalid input. `model_path`: {model_path} File or directory does not exist."
807
+ )
802
808
 
803
809
  assert model_uid not in self._model_uid_to_model
804
810
  self._check_model_is_valid(model_name, model_format)
@@ -826,6 +832,7 @@ class WorkerActor(xo.StatelessActor):
826
832
  quantization,
827
833
  peft_model_config,
828
834
  download_hub,
835
+ model_path,
829
836
  **kwargs,
830
837
  )
831
838
  await self.update_cache_status(model_name, model_description)
@@ -27,6 +27,9 @@ if TYPE_CHECKING:
27
27
 
28
28
  logger = logging.getLogger(__name__)
29
29
 
30
+ # mainly for k8s
31
+ XINFERENCE_POD_NAME_ENV_KEY = "XINFERENCE_POD_NAME"
32
+
30
33
 
31
34
  class LoggerNameFilter(logging.Filter):
32
35
  def filter(self, record):
@@ -40,6 +43,9 @@ def get_log_file(sub_dir: str):
40
43
  """
41
44
  sub_dir should contain a timestamp.
42
45
  """
46
+ pod_name = os.environ.get(XINFERENCE_POD_NAME_ENV_KEY, None)
47
+ if pod_name is not None:
48
+ sub_dir = sub_dir + "_" + pod_name
43
49
  log_dir = os.path.join(XINFERENCE_LOG_DIR, sub_dir)
44
50
  # Here should be creating a new directory each time, so `exist_ok=False`
45
51
  os.makedirs(log_dir, exist_ok=False)
@@ -20,6 +20,7 @@ from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
+ from .cosyvoice import CosyVoiceModel
23
24
  from .whisper import WhisperModel
24
25
 
25
26
  MAX_ATTEMPTS = 3
@@ -149,18 +150,22 @@ def create_audio_model_instance(
149
150
  model_uid: str,
150
151
  model_name: str,
151
152
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
153
+ model_path: Optional[str] = None,
152
154
  **kwargs,
153
- ) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
155
+ ) -> Tuple[Union[WhisperModel, ChatTTSModel, CosyVoiceModel], AudioModelDescription]:
154
156
  model_spec = match_audio(model_name, download_hub)
155
- model_path = cache(model_spec)
156
- model: Union[WhisperModel, ChatTTSModel]
157
+ if model_path is None:
158
+ model_path = cache(model_spec)
159
+ model: Union[WhisperModel, ChatTTSModel, CosyVoiceModel]
157
160
  if model_spec.model_family == "whisper":
158
161
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
159
162
  elif model_spec.model_family == "ChatTTS":
160
163
  model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
164
+ elif model_spec.model_family == "CosyVoice":
165
+ model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
161
166
  else:
162
167
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
163
168
  model_description = AudioModelDescription(
164
- subpool_addr, devices, model_spec, model_path=model_path
169
+ subpool_addr, devices, model_spec, model_path
165
170
  )
166
171
  return model, model_description
@@ -0,0 +1,136 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import logging
16
+ from io import BytesIO
17
+ from typing import TYPE_CHECKING, Optional
18
+
19
+ if TYPE_CHECKING:
20
+ from .core import AudioModelFamilyV1
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class CosyVoiceModel:
26
+ def __init__(
27
+ self,
28
+ model_uid: str,
29
+ model_path: str,
30
+ model_spec: "AudioModelFamilyV1",
31
+ device: Optional[str] = None,
32
+ **kwargs,
33
+ ):
34
+ self._model_uid = model_uid
35
+ self._model_path = model_path
36
+ self._model_spec = model_spec
37
+ self._device = device
38
+ self._model = None
39
+ self._kwargs = kwargs
40
+
41
+ def load(self):
42
+ import os
43
+ import sys
44
+
45
+ # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
46
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
47
+
48
+ from cosyvoice.cli.cosyvoice import CosyVoice
49
+
50
+ self._model = CosyVoice(self._model_path)
51
+
52
+ def speech(
53
+ self,
54
+ input: str,
55
+ voice: str,
56
+ response_format: str = "mp3",
57
+ speed: float = 1.0,
58
+ stream: bool = False,
59
+ **kwargs,
60
+ ):
61
+ if stream:
62
+ raise Exception("CosyVoiceModel does not support stream.")
63
+
64
+ import torchaudio
65
+ from cosyvoice.utils.file_utils import load_wav
66
+
67
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
68
+ prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
69
+ instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
70
+
71
+ if "SFT" in self._model_spec.model_name:
72
+ # inference_sft
73
+ assert (
74
+ prompt_speech is None
75
+ ), "CosyVoice SFT model does not support prompt_speech"
76
+ assert (
77
+ prompt_text is None
78
+ ), "CosyVoice SFT model does not support prompt_text"
79
+ assert (
80
+ instruct_text is None
81
+ ), "CosyVoice SFT model does not support instruct_text"
82
+ elif "Instruct" in self._model_spec.model_name:
83
+ # inference_instruct
84
+ assert (
85
+ prompt_speech is None
86
+ ), "CosyVoice Instruct model does not support prompt_speech"
87
+ assert (
88
+ prompt_text is None
89
+ ), "CosyVoice Instruct model does not support prompt_text"
90
+ assert (
91
+ instruct_text is not None
92
+ ), "CosyVoice Instruct model expect a instruct_text"
93
+ else:
94
+ # inference_zero_shot
95
+ # inference_cross_lingual
96
+ assert prompt_speech is not None, "CosyVoice model expect a prompt_speech"
97
+ assert (
98
+ instruct_text is None
99
+ ), "CosyVoice model does not support instruct_text"
100
+
101
+ assert self._model is not None
102
+ if prompt_speech:
103
+ assert not voice, "voice can't be set with prompt speech."
104
+ with io.BytesIO(prompt_speech) as prompt_speech_io:
105
+ prompt_speech_16k = load_wav(prompt_speech_io, 16000)
106
+ if prompt_text:
107
+ logger.info("CosyVoice inference_zero_shot")
108
+ output = self._model.inference_zero_shot(
109
+ input, prompt_text, prompt_speech_16k
110
+ )
111
+ else:
112
+ logger.info("CosyVoice inference_cross_lingual")
113
+ output = self._model.inference_cross_lingual(
114
+ input, prompt_speech_16k
115
+ )
116
+ else:
117
+ available_speakers = self._model.list_avaliable_spks()
118
+ if not voice:
119
+ voice = available_speakers[0]
120
+ else:
121
+ assert (
122
+ voice in available_speakers
123
+ ), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
124
+ if instruct_text:
125
+ logger.info("CosyVoice inference_instruct")
126
+ output = self._model.inference_instruct(
127
+ input, voice, instruct_text=instruct_text
128
+ )
129
+ else:
130
+ logger.info("CosyVoice inference_sft")
131
+ output = self._model.inference_sft(input, voice)
132
+
133
+ # Save the generated audio
134
+ with BytesIO() as out:
135
+ torchaudio.save(out, output["tts_speech"], 22050, format=response_format)
136
+ return out.getvalue()
@@ -102,5 +102,29 @@
102
102
  "model_revision": "ce5913842aebd78e4a01a02d47244b8d62ac4ee3",
103
103
  "ability": "text-to-audio",
104
104
  "multilingual": true
105
+ },
106
+ {
107
+ "model_name": "CosyVoice-300M",
108
+ "model_family": "CosyVoice",
109
+ "model_id": "model-scope/CosyVoice-300M",
110
+ "model_revision": "ca4e036d2db2aa4731cc1747859a68044b6a4694",
111
+ "ability": "audio-to-audio",
112
+ "multilingual": true
113
+ },
114
+ {
115
+ "model_name": "CosyVoice-300M-SFT",
116
+ "model_family": "CosyVoice",
117
+ "model_id": "model-scope/CosyVoice-300M-SFT",
118
+ "model_revision": "ab918940c6c134b1fc1f069246e67bad6b66abcb",
119
+ "ability": "text-to-audio",
120
+ "multilingual": true
121
+ },
122
+ {
123
+ "model_name": "CosyVoice-300M-Instruct",
124
+ "model_family": "CosyVoice",
125
+ "model_id": "model-scope/CosyVoice-300M-Instruct",
126
+ "model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
127
+ "ability": "text-to-audio",
128
+ "multilingual": true
105
129
  }
106
130
  ]
@@ -16,5 +16,32 @@
16
16
  "model_revision": "master",
17
17
  "ability": "text-to-audio",
18
18
  "multilingual": true
19
+ },
20
+ {
21
+ "model_name": "CosyVoice-300M",
22
+ "model_family": "CosyVoice",
23
+ "model_hub": "modelscope",
24
+ "model_id": "iic/CosyVoice-300M",
25
+ "model_revision": "master",
26
+ "ability": "audio-to-audio",
27
+ "multilingual": true
28
+ },
29
+ {
30
+ "model_name": "CosyVoice-300M-SFT",
31
+ "model_family": "CosyVoice",
32
+ "model_hub": "modelscope",
33
+ "model_id": "iic/CosyVoice-300M-SFT",
34
+ "model_revision": "master",
35
+ "ability": "text-to-audio",
36
+ "multilingual": true
37
+ },
38
+ {
39
+ "model_name": "CosyVoice-300M-Instruct",
40
+ "model_family": "CosyVoice",
41
+ "model_hub": "modelscope",
42
+ "model_id": "iic/CosyVoice-300M-Instruct",
43
+ "model_revision": "master",
44
+ "ability": "text-to-audio",
45
+ "multilingual": true
19
46
  }
20
47
  ]
xinference/model/core.py CHANGED
@@ -56,6 +56,7 @@ def create_model_instance(
56
56
  quantization: Optional[str] = None,
57
57
  peft_model_config: Optional[PeftModelConfig] = None,
58
58
  download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
59
+ model_path: Optional[str] = None,
59
60
  **kwargs,
60
61
  ) -> Tuple[Any, ModelDescription]:
61
62
  from .audio.core import create_audio_model_instance
@@ -77,13 +78,20 @@ def create_model_instance(
77
78
  quantization,
78
79
  peft_model_config,
79
80
  download_hub,
81
+ model_path,
80
82
  **kwargs,
81
83
  )
82
84
  elif model_type == "embedding":
83
85
  # embedding model doesn't accept trust_remote_code
84
86
  kwargs.pop("trust_remote_code", None)
85
87
  return create_embedding_model_instance(
86
- subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
88
+ subpool_addr,
89
+ devices,
90
+ model_uid,
91
+ model_name,
92
+ download_hub,
93
+ model_path,
94
+ **kwargs,
87
95
  )
88
96
  elif model_type == "image":
89
97
  kwargs.pop("trust_remote_code", None)
@@ -94,22 +102,35 @@ def create_model_instance(
94
102
  model_name,
95
103
  peft_model_config,
96
104
  download_hub,
105
+ model_path,
97
106
  **kwargs,
98
107
  )
99
108
  elif model_type == "rerank":
100
109
  kwargs.pop("trust_remote_code", None)
101
110
  return create_rerank_model_instance(
102
- subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
111
+ subpool_addr,
112
+ devices,
113
+ model_uid,
114
+ model_name,
115
+ download_hub,
116
+ model_path,
117
+ **kwargs,
103
118
  )
104
119
  elif model_type == "audio":
105
120
  kwargs.pop("trust_remote_code", None)
106
121
  return create_audio_model_instance(
107
- subpool_addr, devices, model_uid, model_name, download_hub, **kwargs
122
+ subpool_addr,
123
+ devices,
124
+ model_uid,
125
+ model_name,
126
+ download_hub,
127
+ model_path,
128
+ **kwargs,
108
129
  )
109
130
  elif model_type == "flexible":
110
131
  kwargs.pop("trust_remote_code", None)
111
132
  return create_flexible_model_instance(
112
- subpool_addr, devices, model_uid, model_name, **kwargs
133
+ subpool_addr, devices, model_uid, model_name, model_path, **kwargs
113
134
  )
114
135
  else:
115
136
  raise ValueError(f"Unsupported model type: {model_type}.")