xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +23 -1
  4. xinference/core/model.py +1 -6
  5. xinference/core/utils.py +10 -6
  6. xinference/model/audio/core.py +5 -0
  7. xinference/model/audio/cosyvoice.py +25 -3
  8. xinference/model/audio/f5tts.py +15 -10
  9. xinference/model/audio/f5tts_mlx.py +260 -0
  10. xinference/model/audio/fish_speech.py +35 -111
  11. xinference/model/audio/model_spec.json +19 -3
  12. xinference/model/audio/model_spec_modelscope.json +9 -0
  13. xinference/model/audio/utils.py +32 -0
  14. xinference/model/image/core.py +69 -1
  15. xinference/model/image/model_spec.json +127 -4
  16. xinference/model/image/model_spec_modelscope.json +130 -4
  17. xinference/model/image/stable_diffusion/core.py +45 -13
  18. xinference/model/llm/llm_family.json +47 -0
  19. xinference/model/llm/llm_family.py +15 -36
  20. xinference/model/llm/llm_family_modelscope.json +49 -0
  21. xinference/model/llm/mlx/core.py +68 -13
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  24. xinference/model/llm/utils.py +1 -0
  25. xinference/model/llm/vllm/core.py +11 -2
  26. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  27. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  28. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  29. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  30. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  31. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  34. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  35. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  36. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  37. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  38. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  39. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  40. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  41. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  42. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  43. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  44. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  45. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  46. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  47. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  48. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  49. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  50. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  51. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  52. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  53. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  54. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  55. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  56. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  57. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  58. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  59. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  60. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  61. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  62. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  63. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  64. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  65. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  66. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  67. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  68. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  69. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  70. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  71. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  72. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  73. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  74. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  75. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  76. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  77. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  78. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  79. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  80. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  81. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  82. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  83. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  84. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  85. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  86. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  87. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  88. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  89. xinference/thirdparty/matcha/utils/utils.py +2 -2
  90. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
  91. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
  92. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  93. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  94. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  95. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  96. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  99. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  100. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  101. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  102. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  103. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  104. {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
xinference/_compat.py CHANGED
@@ -72,6 +72,7 @@ OpenAIChatCompletionToolParam = create_model_from_typeddict(ChatCompletionToolPa
72
72
  OpenAIChatCompletionNamedToolChoiceParam = create_model_from_typeddict(
73
73
  ChatCompletionNamedToolChoiceParam
74
74
  )
75
+ from openai._types import Body
75
76
 
76
77
 
77
78
  class JSONSchema(BaseModel):
@@ -120,4 +121,5 @@ class CreateChatCompletionOpenAI(BaseModel):
120
121
  tools: Optional[Iterable[OpenAIChatCompletionToolParam]] # type: ignore
121
122
  top_logprobs: Optional[int]
122
123
  top_p: Optional[float]
124
+ extra_body: Optional[Body]
123
125
  user: Optional[str]
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-12-13T18:21:03+0800",
11
+ "date": "2024-12-27T18:14:37+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "b132fca91f3e1b11b111f9b89f68a55e4b7605c6",
15
- "version": "1.1.0"
14
+ "full-revisionid": "d3428697115cc4666b38b32925ba28bdc1a21957",
15
+ "version": "1.1.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -2346,7 +2346,8 @@ class RESTfulAPI(CancelMixin):
2346
2346
  @staticmethod
2347
2347
  def extract_guided_params(raw_body: dict) -> dict:
2348
2348
  kwargs = {}
2349
- if raw_body.get("guided_json") is not None:
2349
+ raw_extra_body: dict = raw_body.get("extra_body") # type: ignore
2350
+ if raw_body.get("guided_json"):
2350
2351
  kwargs["guided_json"] = raw_body.get("guided_json")
2351
2352
  if raw_body.get("guided_regex") is not None:
2352
2353
  kwargs["guided_regex"] = raw_body.get("guided_regex")
@@ -2362,6 +2363,27 @@ class RESTfulAPI(CancelMixin):
2362
2363
  kwargs["guided_whitespace_pattern"] = raw_body.get(
2363
2364
  "guided_whitespace_pattern"
2364
2365
  )
2366
+ # Parse OpenAI extra_body
2367
+ if raw_extra_body is not None:
2368
+ if raw_extra_body.get("guided_json"):
2369
+ kwargs["guided_json"] = raw_extra_body.get("guided_json")
2370
+ if raw_extra_body.get("guided_regex") is not None:
2371
+ kwargs["guided_regex"] = raw_extra_body.get("guided_regex")
2372
+ if raw_extra_body.get("guided_choice") is not None:
2373
+ kwargs["guided_choice"] = raw_extra_body.get("guided_choice")
2374
+ if raw_extra_body.get("guided_grammar") is not None:
2375
+ kwargs["guided_grammar"] = raw_extra_body.get("guided_grammar")
2376
+ if raw_extra_body.get("guided_json_object") is not None:
2377
+ kwargs["guided_json_object"] = raw_extra_body.get("guided_json_object")
2378
+ if raw_extra_body.get("guided_decoding_backend") is not None:
2379
+ kwargs["guided_decoding_backend"] = raw_extra_body.get(
2380
+ "guided_decoding_backend"
2381
+ )
2382
+ if raw_extra_body.get("guided_whitespace_pattern") is not None:
2383
+ kwargs["guided_whitespace_pattern"] = raw_extra_body.get(
2384
+ "guided_whitespace_pattern"
2385
+ )
2386
+
2365
2387
  return kwargs
2366
2388
 
2367
2389
 
xinference/core/model.py CHANGED
@@ -78,7 +78,6 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
78
78
  ]
79
79
 
80
80
  XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
81
- XINFERENCE_BATCHING_BLACK_LIST = ["glm4-chat"]
82
81
 
83
82
 
84
83
  def request_limit(fn):
@@ -373,11 +372,7 @@ class ModelActor(xo.StatelessActor, CancelMixin):
373
372
  f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
374
373
  )
375
374
  return False
376
- return (
377
- condition
378
- and self._model.model_family.model_name
379
- not in XINFERENCE_BATCHING_BLACK_LIST
380
- )
375
+ return condition
381
376
 
382
377
  def allow_batching_for_text_to_image(self) -> bool:
383
378
  from ..model.image.stable_diffusion.core import DiffusionModel
xinference/core/utils.py CHANGED
@@ -62,12 +62,16 @@ def log_async(
62
62
 
63
63
  @wraps(func)
64
64
  async def wrapped(*args, **kwargs):
65
- try:
66
- bound_args = sig.bind_partial(*args, **kwargs)
67
- arguments = bound_args.arguments
68
- except TypeError:
69
- arguments = {}
70
- request_id_str = arguments.get("request_id", "")
65
+ request_id_str = kwargs.get("request_id")
66
+ if not request_id_str:
67
+ # sometimes `request_id` not in kwargs
68
+ # we try to bind the arguments
69
+ try:
70
+ bound_args = sig.bind_partial(*args, **kwargs)
71
+ arguments = bound_args.arguments
72
+ except TypeError:
73
+ arguments = {}
74
+ request_id_str = arguments.get("request_id", "")
71
75
  if not request_id_str:
72
76
  request_id_str = uuid.uuid1()
73
77
  if func_name == "text_to_image":
@@ -22,6 +22,7 @@ from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
23
  from .cosyvoice import CosyVoiceModel
24
24
  from .f5tts import F5TTSModel
25
+ from .f5tts_mlx import F5TTSMLXModel
25
26
  from .fish_speech import FishSpeechModel
26
27
  from .funasr import FunASRModel
27
28
  from .whisper import WhisperModel
@@ -171,6 +172,7 @@ def create_audio_model_instance(
171
172
  CosyVoiceModel,
172
173
  FishSpeechModel,
173
174
  F5TTSModel,
175
+ F5TTSMLXModel,
174
176
  ],
175
177
  AudioModelDescription,
176
178
  ]:
@@ -185,6 +187,7 @@ def create_audio_model_instance(
185
187
  CosyVoiceModel,
186
188
  FishSpeechModel,
187
189
  F5TTSModel,
190
+ F5TTSMLXModel,
188
191
  ]
189
192
  if model_spec.model_family == "whisper":
190
193
  if not model_spec.engine:
@@ -201,6 +204,8 @@ def create_audio_model_instance(
201
204
  model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
202
205
  elif model_spec.model_family == "F5-TTS":
203
206
  model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
207
+ elif model_spec.model_family == "F5-TTS-MLX":
208
+ model = F5TTSMLXModel(model_uid, model_path, model_spec, **kwargs)
204
209
  else:
205
210
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
206
211
  model_description = AudioModelDescription(
@@ -39,6 +39,7 @@ class CosyVoiceModel:
39
39
  self._device = device
40
40
  self._model = None
41
41
  self._kwargs = kwargs
42
+ self._is_cosyvoice2 = False
42
43
 
43
44
  @property
44
45
  def model_ability(self):
@@ -51,7 +52,14 @@ class CosyVoiceModel:
51
52
  # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
52
53
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
53
54
 
54
- from cosyvoice.cli.cosyvoice import CosyVoice
55
+ if "CosyVoice2" in self._model_spec.model_name:
56
+ from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
57
+
58
+ self._is_cosyvoice2 = True
59
+ else:
60
+ from cosyvoice.cli.cosyvoice import CosyVoice
61
+
62
+ self._is_cosyvoice2 = False
55
63
 
56
64
  self._model = CosyVoice(
57
65
  self._model_path, load_jit=self._kwargs.get("load_jit", False)
@@ -78,12 +86,22 @@ class CosyVoiceModel:
78
86
  output = self._model.inference_zero_shot(
79
87
  input, prompt_text, prompt_speech_16k, stream=stream
80
88
  )
89
+ elif instruct_text:
90
+ assert self._is_cosyvoice2
91
+ logger.info("CosyVoice inference_instruct")
92
+ output = self._model.inference_instruct2(
93
+ input,
94
+ instruct_text=instruct_text,
95
+ prompt_speech_16k=prompt_speech_16k,
96
+ stream=stream,
97
+ )
81
98
  else:
82
99
  logger.info("CosyVoice inference_cross_lingual")
83
100
  output = self._model.inference_cross_lingual(
84
101
  input, prompt_speech_16k, stream=stream
85
102
  )
86
103
  else:
104
+ assert not self._is_cosyvoice2
87
105
  available_speakers = self._model.list_avaliable_spks()
88
106
  if not voice:
89
107
  voice = available_speakers[0]
@@ -106,7 +124,9 @@ class CosyVoiceModel:
106
124
  def _generator_stream():
107
125
  with BytesIO() as out:
108
126
  writer = torchaudio.io.StreamWriter(out, format=response_format)
109
- writer.add_audio_stream(sample_rate=22050, num_channels=1)
127
+ writer.add_audio_stream(
128
+ sample_rate=self._model.sample_rate, num_channels=1
129
+ )
110
130
  i = 0
111
131
  last_pos = 0
112
132
  with writer.open():
@@ -125,7 +145,7 @@ class CosyVoiceModel:
125
145
  chunks = [o["tts_speech"] for o in output]
126
146
  t = torch.cat(chunks, dim=1)
127
147
  with BytesIO() as out:
128
- torchaudio.save(out, t, 22050, format=response_format)
148
+ torchaudio.save(out, t, self._model.sample_rate, format=response_format)
129
149
  return out.getvalue()
130
150
 
131
151
  return _generator_stream() if stream else _generator_block()
@@ -163,6 +183,8 @@ class CosyVoiceModel:
163
183
  assert (
164
184
  prompt_text is None
165
185
  ), "CosyVoice Instruct model does not support prompt_text"
186
+ elif self._is_cosyvoice2:
187
+ assert prompt_speech is not None, "CosyVoice2 requires prompt_speech"
166
188
  else:
167
189
  # inference_zero_shot
168
190
  # inference_cross_lingual
@@ -11,12 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ import io
15
15
  import logging
16
16
  import os
17
17
  import re
18
18
  from io import BytesIO
19
- from typing import TYPE_CHECKING, Optional
19
+ from typing import TYPE_CHECKING, Optional, Union
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from .core import AudioModelFamilyV1
@@ -106,9 +106,9 @@ class F5TTSModel:
106
106
  ) = preprocess_ref_audio_text(
107
107
  voices[voice]["ref_audio"], voices[voice]["ref_text"]
108
108
  )
109
- print("Voice:", voice)
110
- print("Ref_audio:", voices[voice]["ref_audio"])
111
- print("Ref_text:", voices[voice]["ref_text"])
109
+ logger.info("Voice:", voice)
110
+ logger.info("Ref_audio:", voices[voice]["ref_audio"])
111
+ logger.info("Ref_text:", voices[voice]["ref_text"])
112
112
 
113
113
  final_sample_rate = None
114
114
  generated_audio_segments = []
@@ -122,16 +122,16 @@ class F5TTSModel:
122
122
  if match:
123
123
  voice = match[1]
124
124
  else:
125
- print("No voice tag found, using main.")
125
+ logger.info("No voice tag found, using main.")
126
126
  voice = "main"
127
127
  if voice not in voices:
128
- print(f"Voice {voice} not found, using main.")
128
+ logger.info(f"Voice {voice} not found, using main.")
129
129
  voice = "main"
130
130
  text = re.sub(reg2, "", text)
131
131
  gen_text = text.strip()
132
132
  ref_audio = voices[voice]["ref_audio"]
133
133
  ref_text = voices[voice]["ref_text"]
134
- print(f"Voice: {voice}")
134
+ logger.info(f"Voice: {voice}")
135
135
  audio, final_sample_rate, spectragram = infer_process(
136
136
  ref_audio,
137
137
  ref_text,
@@ -167,18 +167,23 @@ class F5TTSModel:
167
167
  prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
168
168
  prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
169
169
 
170
+ ref_audio: Union[str, io.BytesIO]
170
171
  if prompt_speech is None:
171
172
  base = os.path.dirname(f5_tts.__file__)
172
173
  config = os.path.join(base, "infer/examples/basic/basic.toml")
173
174
  with open(config, "rb") as f:
174
175
  config_dict = tomli.load(f)
175
- prompt_speech = os.path.join(base, config_dict["ref_audio"])
176
+ ref_audio = os.path.join(base, config_dict["ref_audio"])
176
177
  prompt_text = config_dict["ref_text"]
178
+ else:
179
+ ref_audio = io.BytesIO(prompt_speech)
180
+ if prompt_text is None:
181
+ raise ValueError("`prompt_text` cannot be empty")
177
182
 
178
183
  assert self._model is not None
179
184
  vocoder_name = self._kwargs.get("vocoder_name", "vocos")
180
185
  sample_rate, wav = self._infer(
181
- ref_audio=prompt_speech,
186
+ ref_audio=ref_audio,
182
187
  ref_text=prompt_text,
183
188
  text_gen=input,
184
189
  model_obj=self._model,
@@ -0,0 +1,260 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import datetime
16
+ import io
17
+ import logging
18
+ import os
19
+ from io import BytesIO
20
+ from pathlib import Path
21
+ from typing import TYPE_CHECKING, Literal, Optional, Union
22
+
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ if TYPE_CHECKING:
27
+ from .core import AudioModelFamilyV1
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class F5TTSMLXModel:
33
+ def __init__(
34
+ self,
35
+ model_uid: str,
36
+ model_path: str,
37
+ model_spec: "AudioModelFamilyV1",
38
+ device: Optional[str] = None,
39
+ **kwargs,
40
+ ):
41
+ self._model_uid = model_uid
42
+ self._model_path = model_path
43
+ self._model_spec = model_spec
44
+ self._device = device
45
+ self._model = None
46
+ self._kwargs = kwargs
47
+ self._model = None
48
+
49
+ @property
50
+ def model_ability(self):
51
+ return self._model_spec.model_ability
52
+
53
+ def load(self):
54
+ try:
55
+ import mlx.core as mx
56
+ from f5_tts_mlx.cfm import F5TTS
57
+ from f5_tts_mlx.dit import DiT
58
+ from f5_tts_mlx.duration import DurationPredictor, DurationTransformer
59
+ from vocos_mlx import Vocos
60
+ except ImportError:
61
+ error_message = "Failed to import module 'f5_tts_mlx'"
62
+ installation_guide = [
63
+ "Please make sure 'f5_tts_mlx' is installed.\n",
64
+ ]
65
+
66
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
67
+
68
+ path = Path(self._model_path)
69
+ # vocab
70
+
71
+ vocab_path = path / "vocab.txt"
72
+ vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}
73
+ if len(vocab) == 0:
74
+ raise ValueError(f"Could not load vocab from {vocab_path}")
75
+
76
+ # duration predictor
77
+
78
+ duration_model_path = path / "duration_v2.safetensors"
79
+ duration_predictor = None
80
+
81
+ if duration_model_path.exists():
82
+ duration_predictor = DurationPredictor(
83
+ transformer=DurationTransformer(
84
+ dim=512,
85
+ depth=8,
86
+ heads=8,
87
+ text_dim=512,
88
+ ff_mult=2,
89
+ conv_layers=2,
90
+ text_num_embeds=len(vocab) - 1,
91
+ ),
92
+ vocab_char_map=vocab,
93
+ )
94
+ weights = mx.load(duration_model_path.as_posix(), format="safetensors")
95
+ duration_predictor.load_weights(list(weights.items()))
96
+
97
+ # vocoder
98
+
99
+ vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")
100
+
101
+ # model
102
+
103
+ model_path = path / "model.safetensors"
104
+
105
+ f5tts = F5TTS(
106
+ transformer=DiT(
107
+ dim=1024,
108
+ depth=22,
109
+ heads=16,
110
+ ff_mult=2,
111
+ text_dim=512,
112
+ conv_layers=4,
113
+ text_num_embeds=len(vocab) - 1,
114
+ ),
115
+ vocab_char_map=vocab,
116
+ vocoder=vocos.decode,
117
+ duration_predictor=duration_predictor,
118
+ )
119
+
120
+ weights = mx.load(model_path.as_posix(), format="safetensors")
121
+ f5tts.load_weights(list(weights.items()))
122
+ mx.eval(f5tts.parameters())
123
+
124
+ self._model = f5tts
125
+
126
+ def speech(
127
+ self,
128
+ input: str,
129
+ voice: str,
130
+ response_format: str = "mp3",
131
+ speed: float = 1.0,
132
+ stream: bool = False,
133
+ **kwargs,
134
+ ):
135
+ import mlx.core as mx
136
+ import soundfile as sf
137
+ import tomli
138
+ from f5_tts_mlx.generate import (
139
+ FRAMES_PER_SEC,
140
+ SAMPLE_RATE,
141
+ TARGET_RMS,
142
+ convert_char_to_pinyin,
143
+ split_sentences,
144
+ )
145
+
146
+ from .utils import ensure_sample_rate
147
+
148
+ if stream:
149
+ raise Exception("F5-TTS does not support stream generation.")
150
+
151
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
152
+ prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
153
+ duration: Optional[float] = kwargs.pop("duration", None)
154
+ steps: Optional[int] = kwargs.pop("steps", 8)
155
+ cfg_strength: Optional[float] = kwargs.pop("cfg_strength", 2.0)
156
+ method: Literal["euler", "midpoint"] = kwargs.pop("method", "rk4")
157
+ sway_sampling_coef: float = kwargs.pop("sway_sampling_coef", -1.0)
158
+ seed: Optional[int] = kwargs.pop("seed", None)
159
+
160
+ prompt_speech_path: Union[str, io.BytesIO]
161
+ if prompt_speech is None:
162
+ base = os.path.join(os.path.dirname(__file__), "../../thirdparty/f5_tts")
163
+ config = os.path.join(base, "infer/examples/basic/basic.toml")
164
+ with open(config, "rb") as f:
165
+ config_dict = tomli.load(f)
166
+ prompt_speech_path = os.path.join(base, config_dict["ref_audio"])
167
+ prompt_text = config_dict["ref_text"]
168
+ else:
169
+ prompt_speech_path = io.BytesIO(prompt_speech)
170
+
171
+ if prompt_text is None:
172
+ raise ValueError("`prompt_text` cannot be empty")
173
+
174
+ audio, sr = sf.read(prompt_speech_path)
175
+ audio = ensure_sample_rate(audio, sr, SAMPLE_RATE)
176
+
177
+ audio = mx.array(audio)
178
+ ref_audio_duration = audio.shape[0] / SAMPLE_RATE
179
+ logger.debug(
180
+ f"Got reference audio with duration: {ref_audio_duration:.2f} seconds"
181
+ )
182
+
183
+ rms = mx.sqrt(mx.mean(mx.square(audio)))
184
+ if rms < TARGET_RMS:
185
+ audio = audio * TARGET_RMS / rms
186
+
187
+ sentences = split_sentences(input)
188
+ is_single_generation = len(sentences) <= 1 or duration is not None
189
+
190
+ if is_single_generation:
191
+ generation_text = convert_char_to_pinyin([prompt_text + " " + input]) # type: ignore
192
+
193
+ if duration is not None:
194
+ duration = int(duration * FRAMES_PER_SEC)
195
+
196
+ start_date = datetime.datetime.now()
197
+
198
+ wave, _ = self._model.sample( # type: ignore
199
+ mx.expand_dims(audio, axis=0),
200
+ text=generation_text,
201
+ duration=duration,
202
+ steps=steps,
203
+ method=method,
204
+ speed=speed,
205
+ cfg_strength=cfg_strength,
206
+ sway_sampling_coef=sway_sampling_coef,
207
+ seed=seed,
208
+ )
209
+
210
+ wave = wave[audio.shape[0] :]
211
+ mx.eval(wave)
212
+
213
+ generated_duration = wave.shape[0] / SAMPLE_RATE
214
+ print(
215
+ f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
216
+ )
217
+
218
+ else:
219
+ start_date = datetime.datetime.now()
220
+
221
+ output = []
222
+
223
+ for sentence_text in tqdm(split_sentences(input)):
224
+ text = convert_char_to_pinyin([prompt_text + " " + sentence_text]) # type: ignore
225
+
226
+ if duration is not None:
227
+ duration = int(duration * FRAMES_PER_SEC)
228
+
229
+ wave, _ = self._model.sample( # type: ignore
230
+ mx.expand_dims(audio, axis=0),
231
+ text=text,
232
+ duration=duration,
233
+ steps=steps,
234
+ method=method,
235
+ speed=speed,
236
+ cfg_strength=cfg_strength,
237
+ sway_sampling_coef=sway_sampling_coef,
238
+ seed=seed,
239
+ )
240
+
241
+ # trim the reference audio
242
+ wave = wave[audio.shape[0] :]
243
+ mx.eval(wave)
244
+
245
+ output.append(wave)
246
+
247
+ wave = mx.concatenate(output, axis=0)
248
+
249
+ generated_duration = wave.shape[0] / SAMPLE_RATE
250
+ logger.debug(
251
+ f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
252
+ )
253
+
254
+ # Save the generated audio
255
+ with BytesIO() as out:
256
+ with sf.SoundFile(
257
+ out, "w", SAMPLE_RATE, 1, format=response_format.upper()
258
+ ) as f:
259
+ f.write(np.array(wave))
260
+ return out.getvalue()