xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,260 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import datetime
16
+ import io
17
+ import logging
18
+ import os
19
+ from io import BytesIO
20
+ from pathlib import Path
21
+ from typing import TYPE_CHECKING, Literal, Optional, Union
22
+
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ if TYPE_CHECKING:
27
+ from .core import AudioModelFamilyV1
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class F5TTSMLXModel:
33
+ def __init__(
34
+ self,
35
+ model_uid: str,
36
+ model_path: str,
37
+ model_spec: "AudioModelFamilyV1",
38
+ device: Optional[str] = None,
39
+ **kwargs,
40
+ ):
41
+ self._model_uid = model_uid
42
+ self._model_path = model_path
43
+ self._model_spec = model_spec
44
+ self._device = device
45
+ self._model = None
46
+ self._kwargs = kwargs
47
+ self._model = None
48
+
49
+ @property
50
+ def model_ability(self):
51
+ return self._model_spec.model_ability
52
+
53
+ def load(self):
54
+ try:
55
+ import mlx.core as mx
56
+ from f5_tts_mlx.cfm import F5TTS
57
+ from f5_tts_mlx.dit import DiT
58
+ from f5_tts_mlx.duration import DurationPredictor, DurationTransformer
59
+ from vocos_mlx import Vocos
60
+ except ImportError:
61
+ error_message = "Failed to import module 'f5_tts_mlx'"
62
+ installation_guide = [
63
+ "Please make sure 'f5_tts_mlx' is installed.\n",
64
+ ]
65
+
66
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
67
+
68
+ path = Path(self._model_path)
69
+ # vocab
70
+
71
+ vocab_path = path / "vocab.txt"
72
+ vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}
73
+ if len(vocab) == 0:
74
+ raise ValueError(f"Could not load vocab from {vocab_path}")
75
+
76
+ # duration predictor
77
+
78
+ duration_model_path = path / "duration_v2.safetensors"
79
+ duration_predictor = None
80
+
81
+ if duration_model_path.exists():
82
+ duration_predictor = DurationPredictor(
83
+ transformer=DurationTransformer(
84
+ dim=512,
85
+ depth=8,
86
+ heads=8,
87
+ text_dim=512,
88
+ ff_mult=2,
89
+ conv_layers=2,
90
+ text_num_embeds=len(vocab) - 1,
91
+ ),
92
+ vocab_char_map=vocab,
93
+ )
94
+ weights = mx.load(duration_model_path.as_posix(), format="safetensors")
95
+ duration_predictor.load_weights(list(weights.items()))
96
+
97
+ # vocoder
98
+
99
+ vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")
100
+
101
+ # model
102
+
103
+ model_path = path / "model.safetensors"
104
+
105
+ f5tts = F5TTS(
106
+ transformer=DiT(
107
+ dim=1024,
108
+ depth=22,
109
+ heads=16,
110
+ ff_mult=2,
111
+ text_dim=512,
112
+ conv_layers=4,
113
+ text_num_embeds=len(vocab) - 1,
114
+ ),
115
+ vocab_char_map=vocab,
116
+ vocoder=vocos.decode,
117
+ duration_predictor=duration_predictor,
118
+ )
119
+
120
+ weights = mx.load(model_path.as_posix(), format="safetensors")
121
+ f5tts.load_weights(list(weights.items()))
122
+ mx.eval(f5tts.parameters())
123
+
124
+ self._model = f5tts
125
+
126
+ def speech(
127
+ self,
128
+ input: str,
129
+ voice: str,
130
+ response_format: str = "mp3",
131
+ speed: float = 1.0,
132
+ stream: bool = False,
133
+ **kwargs,
134
+ ):
135
+ import mlx.core as mx
136
+ import soundfile as sf
137
+ import tomli
138
+ from f5_tts_mlx.generate import (
139
+ FRAMES_PER_SEC,
140
+ SAMPLE_RATE,
141
+ TARGET_RMS,
142
+ convert_char_to_pinyin,
143
+ split_sentences,
144
+ )
145
+
146
+ from .utils import ensure_sample_rate
147
+
148
+ if stream:
149
+ raise Exception("F5-TTS does not support stream generation.")
150
+
151
+ prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
152
+ prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
153
+ duration: Optional[float] = kwargs.pop("duration", None)
154
+ steps: Optional[int] = kwargs.pop("steps", 8)
155
+ cfg_strength: Optional[float] = kwargs.pop("cfg_strength", 2.0)
156
+ method: Literal["euler", "midpoint"] = kwargs.pop("method", "rk4")
157
+ sway_sampling_coef: float = kwargs.pop("sway_sampling_coef", -1.0)
158
+ seed: Optional[int] = kwargs.pop("seed", None)
159
+
160
+ prompt_speech_path: Union[str, io.BytesIO]
161
+ if prompt_speech is None:
162
+ base = os.path.join(os.path.dirname(__file__), "../../thirdparty/f5_tts")
163
+ config = os.path.join(base, "infer/examples/basic/basic.toml")
164
+ with open(config, "rb") as f:
165
+ config_dict = tomli.load(f)
166
+ prompt_speech_path = os.path.join(base, config_dict["ref_audio"])
167
+ prompt_text = config_dict["ref_text"]
168
+ else:
169
+ prompt_speech_path = io.BytesIO(prompt_speech)
170
+
171
+ if prompt_text is None:
172
+ raise ValueError("`prompt_text` cannot be empty")
173
+
174
+ audio, sr = sf.read(prompt_speech_path)
175
+ audio = ensure_sample_rate(audio, sr, SAMPLE_RATE)
176
+
177
+ audio = mx.array(audio)
178
+ ref_audio_duration = audio.shape[0] / SAMPLE_RATE
179
+ logger.debug(
180
+ f"Got reference audio with duration: {ref_audio_duration:.2f} seconds"
181
+ )
182
+
183
+ rms = mx.sqrt(mx.mean(mx.square(audio)))
184
+ if rms < TARGET_RMS:
185
+ audio = audio * TARGET_RMS / rms
186
+
187
+ sentences = split_sentences(input)
188
+ is_single_generation = len(sentences) <= 1 or duration is not None
189
+
190
+ if is_single_generation:
191
+ generation_text = convert_char_to_pinyin([prompt_text + " " + input]) # type: ignore
192
+
193
+ if duration is not None:
194
+ duration = int(duration * FRAMES_PER_SEC)
195
+
196
+ start_date = datetime.datetime.now()
197
+
198
+ wave, _ = self._model.sample( # type: ignore
199
+ mx.expand_dims(audio, axis=0),
200
+ text=generation_text,
201
+ duration=duration,
202
+ steps=steps,
203
+ method=method,
204
+ speed=speed,
205
+ cfg_strength=cfg_strength,
206
+ sway_sampling_coef=sway_sampling_coef,
207
+ seed=seed,
208
+ )
209
+
210
+ wave = wave[audio.shape[0] :]
211
+ mx.eval(wave)
212
+
213
+ generated_duration = wave.shape[0] / SAMPLE_RATE
214
+ print(
215
+ f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
216
+ )
217
+
218
+ else:
219
+ start_date = datetime.datetime.now()
220
+
221
+ output = []
222
+
223
+ for sentence_text in tqdm(split_sentences(input)):
224
+ text = convert_char_to_pinyin([prompt_text + " " + sentence_text]) # type: ignore
225
+
226
+ if duration is not None:
227
+ duration = int(duration * FRAMES_PER_SEC)
228
+
229
+ wave, _ = self._model.sample( # type: ignore
230
+ mx.expand_dims(audio, axis=0),
231
+ text=text,
232
+ duration=duration,
233
+ steps=steps,
234
+ method=method,
235
+ speed=speed,
236
+ cfg_strength=cfg_strength,
237
+ sway_sampling_coef=sway_sampling_coef,
238
+ seed=seed,
239
+ )
240
+
241
+ # trim the reference audio
242
+ wave = wave[audio.shape[0] :]
243
+ mx.eval(wave)
244
+
245
+ output.append(wave)
246
+
247
+ wave = mx.concatenate(output, axis=0)
248
+
249
+ generated_duration = wave.shape[0] / SAMPLE_RATE
250
+ logger.debug(
251
+ f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
252
+ )
253
+
254
+ # Save the generated audio
255
+ with BytesIO() as out:
256
+ with sf.SoundFile(
257
+ out, "w", SAMPLE_RATE, 1, format=response_format.upper()
258
+ ) as f:
259
+ f.write(np.array(wave))
260
+ return out.getvalue()
@@ -11,10 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import gc
15
14
  import logging
16
15
  import os.path
17
- import queue
18
16
  import sys
19
17
  from io import BytesIO
20
18
  from typing import TYPE_CHECKING, Optional
@@ -60,6 +58,7 @@ class FishSpeechModel:
60
58
  self._device = device
61
59
  self._llama_queue = None
62
60
  self._model = None
61
+ self._engine = None
63
62
  self._kwargs = kwargs
64
63
 
65
64
  @property
@@ -72,6 +71,7 @@ class FishSpeechModel:
72
71
  0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
73
72
  )
74
73
 
74
+ from tools.inference_engine import TTSInferenceEngine
75
75
  from tools.llama.generate import launch_thread_safe_queue
76
76
  from tools.vqgan.inference import load_model as load_decoder_model
77
77
 
@@ -81,6 +81,11 @@ class FishSpeechModel:
81
81
  if not is_device_available(self._device):
82
82
  raise ValueError(f"Device {self._device} is not available!")
83
83
 
84
+ # https://github.com/pytorch/pytorch/issues/129207
85
+ if self._device == "mps":
86
+ logger.warning("The Conv1d has bugs on MPS backend, fallback to CPU.")
87
+ self._device = "cpu"
88
+
84
89
  enable_compile = self._kwargs.get("compile", False)
85
90
  precision = self._kwargs.get("precision", torch.bfloat16)
86
91
  logger.info("Loading Llama model, compile=%s...", enable_compile)
@@ -102,102 +107,10 @@ class FishSpeechModel:
102
107
  device=self._device,
103
108
  )
104
109
 
105
- @torch.inference_mode()
106
- def _inference(
107
- self,
108
- text,
109
- enable_reference_audio,
110
- reference_audio,
111
- reference_text,
112
- max_new_tokens,
113
- chunk_length,
114
- top_p,
115
- repetition_penalty,
116
- temperature,
117
- seed="0",
118
- streaming=False,
119
- ):
120
- from fish_speech.utils import autocast_exclude_mps, set_seed
121
- from tools.api import decode_vq_tokens, encode_reference
122
- from tools.llama.generate import (
123
- GenerateRequest,
124
- GenerateResponse,
125
- WrappedGenerateResponse,
126
- )
127
-
128
- seed = int(seed)
129
- if seed != 0:
130
- set_seed(seed)
131
- logger.warning(f"set seed: {seed}")
132
-
133
- # Parse reference audio aka prompt
134
- prompt_tokens = encode_reference(
135
- decoder_model=self._model,
136
- reference_audio=reference_audio,
137
- enable_reference_audio=enable_reference_audio,
138
- )
139
-
140
- # LLAMA Inference
141
- request = dict(
142
- device=self._model.device,
143
- max_new_tokens=max_new_tokens,
144
- text=text,
145
- top_p=top_p,
146
- repetition_penalty=repetition_penalty,
147
- temperature=temperature,
148
- compile=self._kwargs.get("compile", False),
149
- iterative_prompt=chunk_length > 0,
150
- chunk_length=chunk_length,
151
- max_length=2048,
152
- prompt_tokens=prompt_tokens if enable_reference_audio else None,
153
- prompt_text=reference_text if enable_reference_audio else None,
154
- )
155
-
156
- response_queue = queue.Queue()
157
- self._llama_queue.put(
158
- GenerateRequest(
159
- request=request,
160
- response_queue=response_queue,
161
- )
110
+ self._engine = TTSInferenceEngine(
111
+ self._llama_queue, self._model, precision, enable_compile
162
112
  )
163
113
 
164
- segments = []
165
-
166
- while True:
167
- result: WrappedGenerateResponse = response_queue.get()
168
- if result.status == "error":
169
- raise result.response
170
-
171
- result: GenerateResponse = result.response
172
- if result.action == "next":
173
- break
174
-
175
- with autocast_exclude_mps(
176
- device_type=self._model.device.type,
177
- dtype=self._kwargs.get("precision", torch.bfloat16),
178
- ):
179
- fake_audios = decode_vq_tokens(
180
- decoder_model=self._model,
181
- codes=result.codes,
182
- )
183
-
184
- fake_audios = fake_audios.float().cpu().numpy()
185
- segments.append(fake_audios)
186
-
187
- if streaming:
188
- yield fake_audios, None, None
189
-
190
- if len(segments) == 0:
191
- raise Exception("No audio generated, please check the input text.")
192
-
193
- # No matter streaming or not, we need to return the final audio
194
- audio = np.concatenate(segments, axis=0)
195
- yield None, (self._model.spec_transform.sample_rate, audio), None
196
-
197
- if torch.cuda.is_available():
198
- torch.cuda.empty_cache()
199
- gc.collect()
200
-
201
114
  def speech(
202
115
  self,
203
116
  input: str,
@@ -211,21 +124,31 @@ class FishSpeechModel:
211
124
  if speed != 1.0:
212
125
  logger.warning("Fish speech does not support setting speed: %s.", speed)
213
126
  import torchaudio
127
+ from tools.schema import ServeReferenceAudio, ServeTTSRequest
214
128
 
215
129
  prompt_speech = kwargs.get("prompt_speech")
216
- result = self._inference(
217
- text=input,
218
- enable_reference_audio=kwargs.get(
219
- "enable_reference_audio", prompt_speech is not None
220
- ),
221
- reference_audio=prompt_speech,
222
- reference_text=kwargs.get("reference_text", ""),
223
- max_new_tokens=kwargs.get("max_new_tokens", 1024),
224
- chunk_length=kwargs.get("chunk_length", 200),
225
- top_p=kwargs.get("top_p", 0.7),
226
- repetition_penalty=kwargs.get("repetition_penalty", 1.2),
227
- temperature=kwargs.get("temperature", 0.7),
228
- streaming=stream,
130
+ prompt_text = kwargs.get("prompt_text", kwargs.get("reference_text", ""))
131
+ if prompt_speech is not None:
132
+ r = ServeReferenceAudio(audio=prompt_speech, text=prompt_text)
133
+ references = [r]
134
+ else:
135
+ references = []
136
+
137
+ assert self._engine is not None
138
+ result = self._engine.inference(
139
+ ServeTTSRequest(
140
+ text=input,
141
+ references=references,
142
+ reference_id=kwargs.get("reference_id"),
143
+ seed=kwargs.get("seed"),
144
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
145
+ chunk_length=kwargs.get("chunk_length", 200),
146
+ top_p=kwargs.get("top_p", 0.7),
147
+ repetition_penalty=kwargs.get("repetition_penalty", 1.2),
148
+ temperature=kwargs.get("temperature", 0.7),
149
+ streaming=stream,
150
+ format=response_format,
151
+ )
229
152
  )
230
153
 
231
154
  if stream:
@@ -241,7 +164,9 @@ class FishSpeechModel:
241
164
  last_pos = 0
242
165
  with writer.open():
243
166
  for chunk in result:
244
- chunk = chunk[0]
167
+ if chunk.code == "final":
168
+ continue
169
+ chunk = chunk.audio[1]
245
170
  if chunk is not None:
246
171
  chunk = chunk.reshape((chunk.shape[0], 1))
247
172
  trans_chunk = torch.from_numpy(chunk)
@@ -256,7 +181,7 @@ class FishSpeechModel:
256
181
  return _stream_generator()
257
182
  else:
258
183
  result = list(result)
259
- sample_rate, audio = result[0][1]
184
+ sample_rate, audio = result[0].audio
260
185
  audio = np.array([audio])
261
186
 
262
187
  # Save the generated audio
@@ -236,10 +236,34 @@
236
236
  "multilingual": true
237
237
  },
238
238
  {
239
- "model_name": "FishSpeech-1.4",
239
+ "model_name": "CosyVoice2-0.5B",
240
+ "model_family": "CosyVoice",
241
+ "model_id": "mrfakename/CosyVoice2-0.5B",
242
+ "model_revision": "5676baabc8a76dc93ef60a88bbd2420deaa2f644",
243
+ "model_ability": "text-to-audio",
244
+ "multilingual": true
245
+ },
246
+ {
247
+ "model_name": "FishSpeech-1.5",
240
248
  "model_family": "FishAudio",
241
- "model_id": "fishaudio/fish-speech-1.4",
242
- "model_revision": "069c573759936b35191d3380deb89183c0656f59",
249
+ "model_id": "fishaudio/fish-speech-1.5",
250
+ "model_revision": "268b6ec86243dd683bc78dab7e9a6cedf9191f2a",
251
+ "model_ability": "text-to-audio",
252
+ "multilingual": true
253
+ },
254
+ {
255
+ "model_name": "F5-TTS",
256
+ "model_family": "F5-TTS",
257
+ "model_id": "SWivid/F5-TTS",
258
+ "model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
259
+ "model_ability": "text-to-audio",
260
+ "multilingual": true
261
+ },
262
+ {
263
+ "model_name": "F5-TTS-MLX",
264
+ "model_family": "F5-TTS-MLX",
265
+ "model_id": "lucasnewman/f5-tts-mlx",
266
+ "model_revision": "7642bb232e3fcacf92c51c786edebb8624da6b93",
243
267
  "model_ability": "text-to-audio",
244
268
  "multilingual": true
245
269
  }
@@ -73,5 +73,23 @@
73
73
  "model_revision": "master",
74
74
  "model_ability": "text-to-audio",
75
75
  "multilingual": true
76
+ },
77
+ {
78
+ "model_name": "CosyVoice2-0.5B",
79
+ "model_family": "CosyVoice",
80
+ "model_hub": "modelscope",
81
+ "model_id": "iic/CosyVoice2-0.5B",
82
+ "model_revision": "master",
83
+ "model_ability": "text-to-audio",
84
+ "multilingual": true
85
+ },
86
+ {
87
+ "model_name": "F5-TTS",
88
+ "model_family": "F5-TTS",
89
+ "model_hub": "modelscope",
90
+ "model_id": "SWivid/F5-TTS_Emilia-ZH-EN",
91
+ "model_revision": "master",
92
+ "model_ability": "text-to-audio",
93
+ "multilingual": true
76
94
  }
77
95
  ]
@@ -11,8 +11,40 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import io
16
+
17
+ import numpy as np
18
+
14
19
  from .core import AudioModelFamilyV1
15
20
 
16
21
 
17
22
  def get_model_version(audio_model: AudioModelFamilyV1) -> str:
18
23
  return audio_model.model_name
24
+
25
+
26
+ def ensure_sample_rate(
27
+ audio: np.ndarray, old_sample_rate: int, sample_rate: int
28
+ ) -> np.ndarray:
29
+ import soundfile as sf
30
+ from scipy.signal import resample
31
+
32
+ if old_sample_rate != sample_rate:
33
+ # Calculate the new data length
34
+ new_length = int(len(audio) * sample_rate / old_sample_rate)
35
+
36
+ # Resample the data
37
+ resampled_data = resample(audio, new_length)
38
+
39
+ # Use BytesIO to save the resampled data to memory
40
+ with io.BytesIO() as buffer:
41
+ # Write the resampled data to the memory buffer
42
+ sf.write(buffer, resampled_data, sample_rate, format="WAV")
43
+
44
+ # Reset the buffer position to the beginning
45
+ buffer.seek(0)
46
+
47
+ # Read the data from the memory buffer
48
+ audio, sr = sf.read(buffer, dtype="float32")
49
+
50
+ return audio