xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
xinference/_compat.py CHANGED
@@ -72,6 +72,7 @@ OpenAIChatCompletionToolParam = create_model_from_typeddict(ChatCompletionToolPa
72
72
  OpenAIChatCompletionNamedToolChoiceParam = create_model_from_typeddict(
73
73
  ChatCompletionNamedToolChoiceParam
74
74
  )
75
+ from openai._types import Body
75
76
 
76
77
 
77
78
  class JSONSchema(BaseModel):
@@ -120,4 +121,5 @@ class CreateChatCompletionOpenAI(BaseModel):
120
121
  tools: Optional[Iterable[OpenAIChatCompletionToolParam]] # type: ignore
121
122
  top_logprobs: Optional[int]
122
123
  top_p: Optional[float]
124
+ extra_body: Optional[Body]
123
125
  user: Optional[str]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-29T16:57:04+0800",
11
+ "date": "2024-12-27T18:14:37+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "eb8ddd431f5c5fcb2216e25e0d43745f8455d9b9",
15
- "version": "1.0.1"
14
+ "full-revisionid": "d3428697115cc4666b38b32925ba28bdc1a21957",
15
+ "version": "1.1.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -94,9 +94,9 @@ class CreateCompletionRequest(CreateCompletion):
94
94
 
95
95
  class CreateEmbeddingRequest(BaseModel):
96
96
  model: str
97
- input: Union[str, List[str], List[int], List[List[int]]] = Field(
98
- description="The input to embed."
99
- )
97
+ input: Union[
98
+ str, List[str], List[int], List[List[int]], Dict[str, str], List[Dict[str, str]]
99
+ ] = Field(description="The input to embed.")
100
100
  user: Optional[str] = None
101
101
 
102
102
  class Config:
@@ -2044,7 +2044,6 @@ class RESTfulAPI(CancelMixin):
2044
2044
  )
2045
2045
  if body.tools and body.stream:
2046
2046
  is_vllm = await model.is_vllm_backend()
2047
-
2048
2047
  if not (
2049
2048
  (is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
2050
2049
  or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
@@ -2054,7 +2053,8 @@ class RESTfulAPI(CancelMixin):
2054
2053
  detail="Streaming support for tool calls is available only when using "
2055
2054
  "Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
2056
2055
  )
2057
-
2056
+ if "skip_special_tokens" in raw_kwargs and await model.is_vllm_backend():
2057
+ kwargs["skip_special_tokens"] = raw_kwargs["skip_special_tokens"]
2058
2058
  if body.stream:
2059
2059
 
2060
2060
  async def stream_results():
@@ -2346,7 +2346,8 @@ class RESTfulAPI(CancelMixin):
2346
2346
  @staticmethod
2347
2347
  def extract_guided_params(raw_body: dict) -> dict:
2348
2348
  kwargs = {}
2349
- if raw_body.get("guided_json") is not None:
2349
+ raw_extra_body: dict = raw_body.get("extra_body") # type: ignore
2350
+ if raw_body.get("guided_json"):
2350
2351
  kwargs["guided_json"] = raw_body.get("guided_json")
2351
2352
  if raw_body.get("guided_regex") is not None:
2352
2353
  kwargs["guided_regex"] = raw_body.get("guided_regex")
@@ -2362,6 +2363,27 @@ class RESTfulAPI(CancelMixin):
2362
2363
  kwargs["guided_whitespace_pattern"] = raw_body.get(
2363
2364
  "guided_whitespace_pattern"
2364
2365
  )
2366
+ # Parse OpenAI extra_body
2367
+ if raw_extra_body is not None:
2368
+ if raw_extra_body.get("guided_json"):
2369
+ kwargs["guided_json"] = raw_extra_body.get("guided_json")
2370
+ if raw_extra_body.get("guided_regex") is not None:
2371
+ kwargs["guided_regex"] = raw_extra_body.get("guided_regex")
2372
+ if raw_extra_body.get("guided_choice") is not None:
2373
+ kwargs["guided_choice"] = raw_extra_body.get("guided_choice")
2374
+ if raw_extra_body.get("guided_grammar") is not None:
2375
+ kwargs["guided_grammar"] = raw_extra_body.get("guided_grammar")
2376
+ if raw_extra_body.get("guided_json_object") is not None:
2377
+ kwargs["guided_json_object"] = raw_extra_body.get("guided_json_object")
2378
+ if raw_extra_body.get("guided_decoding_backend") is not None:
2379
+ kwargs["guided_decoding_backend"] = raw_extra_body.get(
2380
+ "guided_decoding_backend"
2381
+ )
2382
+ if raw_extra_body.get("guided_whitespace_pattern") is not None:
2383
+ kwargs["guided_whitespace_pattern"] = raw_extra_body.get(
2384
+ "guided_whitespace_pattern"
2385
+ )
2386
+
2365
2387
  return kwargs
2366
2388
 
2367
2389
 
xinference/core/utils.py CHANGED
@@ -62,12 +62,16 @@ def log_async(
62
62
 
63
63
  @wraps(func)
64
64
  async def wrapped(*args, **kwargs):
65
- try:
66
- bound_args = sig.bind_partial(*args, **kwargs)
67
- arguments = bound_args.arguments
68
- except TypeError:
69
- arguments = {}
70
- request_id_str = arguments.get("request_id", "")
65
+ request_id_str = kwargs.get("request_id")
66
+ if not request_id_str:
67
+ # sometimes `request_id` not in kwargs
68
+ # we try to bind the arguments
69
+ try:
70
+ bound_args = sig.bind_partial(*args, **kwargs)
71
+ arguments = bound_args.arguments
72
+ except TypeError:
73
+ arguments = {}
74
+ request_id_str = arguments.get("request_id", "")
71
75
  if not request_id_str:
72
76
  request_id_str = uuid.uuid1()
73
77
  if func_name == "text_to_image":
@@ -846,7 +846,9 @@ def model_launch(
846
846
  kwargs = {}
847
847
  for i in range(0, len(ctx.args), 2):
848
848
  if not ctx.args[i].startswith("--"):
849
- raise ValueError("You must specify extra kwargs with `--` prefix.")
849
+ raise ValueError(
850
+ f"You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is {ctx.args[i]}."
851
+ )
850
852
  kwargs[ctx.args[i][2:]] = handle_click_args_type(ctx.args[i + 1])
851
853
  print(f"Launch model name: {model_name} with kwargs: {kwargs}", file=sys.stderr)
852
854
 
@@ -23,6 +23,7 @@ from ..cmdline import (
23
23
  list_model_registrations,
24
24
  model_chat,
25
25
  model_generate,
26
+ model_launch,
26
27
  model_list,
27
28
  model_terminate,
28
29
  register_model,
@@ -311,3 +312,58 @@ def test_remove_cache(setup):
311
312
 
312
313
  assert result.exit_code == 0
313
314
  assert "Cache directory qwen1.5-chat has been deleted."
315
+
316
+
317
+ def test_launch_error_in_passing_parameters():
318
+ runner = CliRunner()
319
+
320
+ # Known parameter but not provided with value.
321
+ result = runner.invoke(
322
+ model_launch,
323
+ [
324
+ "--model-engine",
325
+ "transformers",
326
+ "--model-name",
327
+ "qwen2.5-instruct",
328
+ "--model-uid",
329
+ "-s",
330
+ "0.5",
331
+ "-f",
332
+ "gptq",
333
+ "-q",
334
+ "INT4",
335
+ "111",
336
+ "-l",
337
+ ],
338
+ )
339
+ assert result.exit_code == 1
340
+ assert (
341
+ "You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is 0.5."
342
+ in str(result)
343
+ )
344
+
345
+ # Unknown parameter
346
+ result = runner.invoke(
347
+ model_launch,
348
+ [
349
+ "--model-engine",
350
+ "transformers",
351
+ "--model-name",
352
+ "qwen2.5-instruct",
353
+ "--model-uid",
354
+ "123",
355
+ "-s",
356
+ "0.5",
357
+ "-f",
358
+ "gptq",
359
+ "-q",
360
+ "INT4",
361
+ "-l",
362
+ "111",
363
+ ],
364
+ )
365
+ assert result.exit_code == 1
366
+ assert (
367
+ "You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is -l."
368
+ in str(result)
369
+ )
xinference/isolation.py CHANGED
@@ -37,6 +37,30 @@ class Isolation:
37
37
  asyncio.set_event_loop(self._loop)
38
38
  self._stopped = asyncio.Event()
39
39
  self._loop.run_until_complete(self._stopped.wait())
40
+ self._cancel_all_tasks(self._loop)
41
+
42
+ @staticmethod
43
+ def _cancel_all_tasks(loop):
44
+ to_cancel = asyncio.all_tasks(loop)
45
+ if not to_cancel:
46
+ return
47
+
48
+ for task in to_cancel:
49
+ task.cancel()
50
+
51
+ loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
52
+
53
+ for task in to_cancel:
54
+ if task.cancelled():
55
+ continue
56
+ if task.exception() is not None:
57
+ loop.call_exception_handler(
58
+ {
59
+ "message": "unhandled exception during asyncio.run() shutdown",
60
+ "exception": task.exception(),
61
+ "task": task,
62
+ }
63
+ )
40
64
 
41
65
  def start(self):
42
66
  if self._threaded:
@@ -21,6 +21,8 @@ from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
23
  from .cosyvoice import CosyVoiceModel
24
+ from .f5tts import F5TTSModel
25
+ from .f5tts_mlx import F5TTSMLXModel
24
26
  from .fish_speech import FishSpeechModel
25
27
  from .funasr import FunASRModel
26
28
  from .whisper import WhisperModel
@@ -169,6 +171,8 @@ def create_audio_model_instance(
169
171
  ChatTTSModel,
170
172
  CosyVoiceModel,
171
173
  FishSpeechModel,
174
+ F5TTSModel,
175
+ F5TTSMLXModel,
172
176
  ],
173
177
  AudioModelDescription,
174
178
  ]:
@@ -182,6 +186,8 @@ def create_audio_model_instance(
182
186
  ChatTTSModel,
183
187
  CosyVoiceModel,
184
188
  FishSpeechModel,
189
+ F5TTSModel,
190
+ F5TTSMLXModel,
185
191
  ]
186
192
  if model_spec.model_family == "whisper":
187
193
  if not model_spec.engine:
@@ -196,6 +202,10 @@ def create_audio_model_instance(
196
202
  model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
197
203
  elif model_spec.model_family == "FishAudio":
198
204
  model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
205
+ elif model_spec.model_family == "F5-TTS":
206
+ model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
207
+ elif model_spec.model_family == "F5-TTS-MLX":
208
+ model = F5TTSMLXModel(model_uid, model_path, model_spec, **kwargs)
199
209
  else:
200
210
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
201
211
  model_description = AudioModelDescription(
@@ -39,6 +39,7 @@ class CosyVoiceModel:
39
39
  self._device = device
40
40
  self._model = None
41
41
  self._kwargs = kwargs
42
+ self._is_cosyvoice2 = False
42
43
 
43
44
  @property
44
45
  def model_ability(self):
@@ -51,7 +52,14 @@ class CosyVoiceModel:
51
52
  # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
52
53
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
53
54
 
54
- from cosyvoice.cli.cosyvoice import CosyVoice
55
+ if "CosyVoice2" in self._model_spec.model_name:
56
+ from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
57
+
58
+ self._is_cosyvoice2 = True
59
+ else:
60
+ from cosyvoice.cli.cosyvoice import CosyVoice
61
+
62
+ self._is_cosyvoice2 = False
55
63
 
56
64
  self._model = CosyVoice(
57
65
  self._model_path, load_jit=self._kwargs.get("load_jit", False)
@@ -78,12 +86,22 @@ class CosyVoiceModel:
78
86
  output = self._model.inference_zero_shot(
79
87
  input, prompt_text, prompt_speech_16k, stream=stream
80
88
  )
89
+ elif instruct_text:
90
+ assert self._is_cosyvoice2
91
+ logger.info("CosyVoice inference_instruct")
92
+ output = self._model.inference_instruct2(
93
+ input,
94
+ instruct_text=instruct_text,
95
+ prompt_speech_16k=prompt_speech_16k,
96
+ stream=stream,
97
+ )
81
98
  else:
82
99
  logger.info("CosyVoice inference_cross_lingual")
83
100
  output = self._model.inference_cross_lingual(
84
101
  input, prompt_speech_16k, stream=stream
85
102
  )
86
103
  else:
104
+ assert not self._is_cosyvoice2
87
105
  available_speakers = self._model.list_avaliable_spks()
88
106
  if not voice:
89
107
  voice = available_speakers[0]
@@ -106,7 +124,9 @@ class CosyVoiceModel:
106
124
  def _generator_stream():
107
125
  with BytesIO() as out:
108
126
  writer = torchaudio.io.StreamWriter(out, format=response_format)
109
- writer.add_audio_stream(sample_rate=22050, num_channels=1)
127
+ writer.add_audio_stream(
128
+ sample_rate=self._model.sample_rate, num_channels=1
129
+ )
110
130
  i = 0
111
131
  last_pos = 0
112
132
  with writer.open():
@@ -125,7 +145,7 @@ class CosyVoiceModel:
125
145
  chunks = [o["tts_speech"] for o in output]
126
146
  t = torch.cat(chunks, dim=1)
127
147
  with BytesIO() as out:
128
- torchaudio.save(out, t, 22050, format=response_format)
148
+ torchaudio.save(out, t, self._model.sample_rate, format=response_format)
129
149
  return out.getvalue()
130
150
 
131
151
  return _generator_stream() if stream else _generator_block()
@@ -163,6 +183,8 @@ class CosyVoiceModel:
163
183
  assert (
164
184
  prompt_text is None
165
185
  ), "CosyVoice Instruct model does not support prompt_text"
186
+ elif self._is_cosyvoice2:
187
+ assert prompt_speech is not None, "CosyVoice2 requires prompt_speech"
166
188
  else:
167
189
  # inference_zero_shot
168
190
  # inference_cross_lingual
@@ -0,0 +1,200 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import logging
16
+ import os
17
+ import re
18
+ from io import BytesIO
19
+ from typing import TYPE_CHECKING, Optional, Union
20
+
21
+ if TYPE_CHECKING:
22
+ from .core import AudioModelFamilyV1
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class F5TTSModel:
28
+ def __init__(
29
+ self,
30
+ model_uid: str,
31
+ model_path: str,
32
+ model_spec: "AudioModelFamilyV1",
33
+ device: Optional[str] = None,
34
+ **kwargs,
35
+ ):
36
+ self._model_uid = model_uid
37
+ self._model_path = model_path
38
+ self._model_spec = model_spec
39
+ self._device = device
40
+ self._model = None
41
+ self._vocoder = None
42
+ self._kwargs = kwargs
43
+
44
+ @property
45
+ def model_ability(self):
46
+ return self._model_spec.model_ability
47
+
48
+ def load(self):
49
+ import os
50
+ import sys
51
+
52
+ # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
53
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
54
+
55
+ from f5_tts.infer.utils_infer import load_model, load_vocoder
56
+ from f5_tts.model import DiT
57
+
58
+ vocoder_name = self._kwargs.get("vocoder_name", "vocos")
59
+ vocoder_path = self._kwargs.get("vocoder_path")
60
+
61
+ if vocoder_name not in ["vocos", "bigvgan"]:
62
+ raise Exception(f"Unsupported vocoder name: {vocoder_name}")
63
+
64
+ if vocoder_path is not None:
65
+ self._vocoder = load_vocoder(
66
+ vocoder_name=vocoder_name, is_local=True, local_path=vocoder_path
67
+ )
68
+ else:
69
+ self._vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=False)
70
+
71
+ model_cls = DiT
72
+ model_cfg = dict(
73
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
74
+ )
75
+ if vocoder_name == "vocos":
76
+ exp_name = "F5TTS_Base"
77
+ ckpt_step = 1200000
78
+ elif vocoder_name == "bigvgan":
79
+ exp_name = "F5TTS_Base_bigvgan"
80
+ ckpt_step = 1250000
81
+ else:
82
+ assert False
83
+ ckpt_file = os.path.join(
84
+ self._model_path, exp_name, f"model_{ckpt_step}.safetensors"
85
+ )
86
+ logger.info(f"Loading %s...", ckpt_file)
87
+ self._model = load_model(
88
+ model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name
89
+ )
90
+
91
+ def _infer(self, ref_audio, ref_text, text_gen, model_obj, mel_spec_type, speed):
92
+ import numpy as np
93
+ from f5_tts.infer.utils_infer import infer_process, preprocess_ref_audio_text
94
+
95
+ config = {}
96
+ main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
97
+ if "voices" not in config:
98
+ voices = {"main": main_voice}
99
+ else:
100
+ voices = config["voices"]
101
+ voices["main"] = main_voice
102
+ for voice in voices:
103
+ (
104
+ voices[voice]["ref_audio"],
105
+ voices[voice]["ref_text"],
106
+ ) = preprocess_ref_audio_text(
107
+ voices[voice]["ref_audio"], voices[voice]["ref_text"]
108
+ )
109
+ logger.info("Voice:", voice)
110
+ logger.info("Ref_audio:", voices[voice]["ref_audio"])
111
+ logger.info("Ref_text:", voices[voice]["ref_text"])
112
+
113
+ final_sample_rate = None
114
+ generated_audio_segments = []
115
+ reg1 = r"(?=\[\w+\])"
116
+ chunks = re.split(reg1, text_gen)
117
+ reg2 = r"\[(\w+)\]"
118
+ for text in chunks:
119
+ if not text.strip():
120
+ continue
121
+ match = re.match(reg2, text)
122
+ if match:
123
+ voice = match[1]
124
+ else:
125
+ logger.info("No voice tag found, using main.")
126
+ voice = "main"
127
+ if voice not in voices:
128
+ logger.info(f"Voice {voice} not found, using main.")
129
+ voice = "main"
130
+ text = re.sub(reg2, "", text)
131
+ gen_text = text.strip()
132
+ ref_audio = voices[voice]["ref_audio"]
133
+ ref_text = voices[voice]["ref_text"]
134
+ logger.info(f"Voice: {voice}")
135
+ audio, final_sample_rate, spectragram = infer_process(
136
+ ref_audio,
137
+ ref_text,
138
+ gen_text,
139
+ model_obj,
140
+ self._vocoder,
141
+ mel_spec_type=mel_spec_type,
142
+ speed=speed,
143
+ )
144
+ generated_audio_segments.append(audio)
145
+
146
+ if generated_audio_segments:
147
+ final_wave = np.concatenate(generated_audio_segments)
148
+ return final_sample_rate, final_wave
149
+ return None, None
150
+
151
+ def speech(
152
+ self,
153
+ input: str,
154
+ voice: str,
155
+ response_format: str = "mp3",
156
+ speed: float = 1.0,
157
+ stream: bool = False,
158
+ **kwargs,
159
+ ):
160
+ import f5_tts
161
+ import soundfile
162
+ import tomli
163
+
164
+ if stream:
165
+ raise Exception("F5-TTS does not support stream generation.")
166
+
167
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
168
+ prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
169
+
170
+ ref_audio: Union[str, io.BytesIO]
171
+ if prompt_speech is None:
172
+ base = os.path.dirname(f5_tts.__file__)
173
+ config = os.path.join(base, "infer/examples/basic/basic.toml")
174
+ with open(config, "rb") as f:
175
+ config_dict = tomli.load(f)
176
+ ref_audio = os.path.join(base, config_dict["ref_audio"])
177
+ prompt_text = config_dict["ref_text"]
178
+ else:
179
+ ref_audio = io.BytesIO(prompt_speech)
180
+ if prompt_text is None:
181
+ raise ValueError("`prompt_text` cannot be empty")
182
+
183
+ assert self._model is not None
184
+ vocoder_name = self._kwargs.get("vocoder_name", "vocos")
185
+ sample_rate, wav = self._infer(
186
+ ref_audio=ref_audio,
187
+ ref_text=prompt_text,
188
+ text_gen=input,
189
+ model_obj=self._model,
190
+ mel_spec_type=vocoder_name,
191
+ speed=speed,
192
+ )
193
+
194
+ # Save the generated audio
195
+ with BytesIO() as out:
196
+ with soundfile.SoundFile(
197
+ out, "w", sample_rate, 1, format=response_format.upper()
198
+ ) as f:
199
+ f.write(wav)
200
+ return out.getvalue()