xinference 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +23 -1
- xinference/core/model.py +1 -6
- xinference/core/utils.py +10 -6
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/llm_family.json +47 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +49 -0
- xinference/model/llm/mlx/core.py +68 -13
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/utils.py +1 -0
- xinference/model/llm/vllm/core.py +11 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/METADATA +11 -6
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/RECORD +95 -74
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
xinference/_compat.py
CHANGED
|
@@ -72,6 +72,7 @@ OpenAIChatCompletionToolParam = create_model_from_typeddict(ChatCompletionToolPa
|
|
|
72
72
|
OpenAIChatCompletionNamedToolChoiceParam = create_model_from_typeddict(
|
|
73
73
|
ChatCompletionNamedToolChoiceParam
|
|
74
74
|
)
|
|
75
|
+
from openai._types import Body
|
|
75
76
|
|
|
76
77
|
|
|
77
78
|
class JSONSchema(BaseModel):
|
|
@@ -120,4 +121,5 @@ class CreateChatCompletionOpenAI(BaseModel):
|
|
|
120
121
|
tools: Optional[Iterable[OpenAIChatCompletionToolParam]] # type: ignore
|
|
121
122
|
top_logprobs: Optional[int]
|
|
122
123
|
top_p: Optional[float]
|
|
124
|
+
extra_body: Optional[Body]
|
|
123
125
|
user: Optional[str]
|
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-12-
|
|
11
|
+
"date": "2024-12-27T18:14:37+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.1.
|
|
14
|
+
"full-revisionid": "d3428697115cc4666b38b32925ba28bdc1a21957",
|
|
15
|
+
"version": "1.1.1"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -2346,7 +2346,8 @@ class RESTfulAPI(CancelMixin):
|
|
|
2346
2346
|
@staticmethod
|
|
2347
2347
|
def extract_guided_params(raw_body: dict) -> dict:
|
|
2348
2348
|
kwargs = {}
|
|
2349
|
-
|
|
2349
|
+
raw_extra_body: dict = raw_body.get("extra_body") # type: ignore
|
|
2350
|
+
if raw_body.get("guided_json"):
|
|
2350
2351
|
kwargs["guided_json"] = raw_body.get("guided_json")
|
|
2351
2352
|
if raw_body.get("guided_regex") is not None:
|
|
2352
2353
|
kwargs["guided_regex"] = raw_body.get("guided_regex")
|
|
@@ -2362,6 +2363,27 @@ class RESTfulAPI(CancelMixin):
|
|
|
2362
2363
|
kwargs["guided_whitespace_pattern"] = raw_body.get(
|
|
2363
2364
|
"guided_whitespace_pattern"
|
|
2364
2365
|
)
|
|
2366
|
+
# Parse OpenAI extra_body
|
|
2367
|
+
if raw_extra_body is not None:
|
|
2368
|
+
if raw_extra_body.get("guided_json"):
|
|
2369
|
+
kwargs["guided_json"] = raw_extra_body.get("guided_json")
|
|
2370
|
+
if raw_extra_body.get("guided_regex") is not None:
|
|
2371
|
+
kwargs["guided_regex"] = raw_extra_body.get("guided_regex")
|
|
2372
|
+
if raw_extra_body.get("guided_choice") is not None:
|
|
2373
|
+
kwargs["guided_choice"] = raw_extra_body.get("guided_choice")
|
|
2374
|
+
if raw_extra_body.get("guided_grammar") is not None:
|
|
2375
|
+
kwargs["guided_grammar"] = raw_extra_body.get("guided_grammar")
|
|
2376
|
+
if raw_extra_body.get("guided_json_object") is not None:
|
|
2377
|
+
kwargs["guided_json_object"] = raw_extra_body.get("guided_json_object")
|
|
2378
|
+
if raw_extra_body.get("guided_decoding_backend") is not None:
|
|
2379
|
+
kwargs["guided_decoding_backend"] = raw_extra_body.get(
|
|
2380
|
+
"guided_decoding_backend"
|
|
2381
|
+
)
|
|
2382
|
+
if raw_extra_body.get("guided_whitespace_pattern") is not None:
|
|
2383
|
+
kwargs["guided_whitespace_pattern"] = raw_extra_body.get(
|
|
2384
|
+
"guided_whitespace_pattern"
|
|
2385
|
+
)
|
|
2386
|
+
|
|
2365
2387
|
return kwargs
|
|
2366
2388
|
|
|
2367
2389
|
|
xinference/core/model.py
CHANGED
|
@@ -78,7 +78,6 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
|
|
|
78
78
|
]
|
|
79
79
|
|
|
80
80
|
XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
|
|
81
|
-
XINFERENCE_BATCHING_BLACK_LIST = ["glm4-chat"]
|
|
82
81
|
|
|
83
82
|
|
|
84
83
|
def request_limit(fn):
|
|
@@ -373,11 +372,7 @@ class ModelActor(xo.StatelessActor, CancelMixin):
|
|
|
373
372
|
f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
|
|
374
373
|
)
|
|
375
374
|
return False
|
|
376
|
-
return
|
|
377
|
-
condition
|
|
378
|
-
and self._model.model_family.model_name
|
|
379
|
-
not in XINFERENCE_BATCHING_BLACK_LIST
|
|
380
|
-
)
|
|
375
|
+
return condition
|
|
381
376
|
|
|
382
377
|
def allow_batching_for_text_to_image(self) -> bool:
|
|
383
378
|
from ..model.image.stable_diffusion.core import DiffusionModel
|
xinference/core/utils.py
CHANGED
|
@@ -62,12 +62,16 @@ def log_async(
|
|
|
62
62
|
|
|
63
63
|
@wraps(func)
|
|
64
64
|
async def wrapped(*args, **kwargs):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
65
|
+
request_id_str = kwargs.get("request_id")
|
|
66
|
+
if not request_id_str:
|
|
67
|
+
# sometimes `request_id` not in kwargs
|
|
68
|
+
# we try to bind the arguments
|
|
69
|
+
try:
|
|
70
|
+
bound_args = sig.bind_partial(*args, **kwargs)
|
|
71
|
+
arguments = bound_args.arguments
|
|
72
|
+
except TypeError:
|
|
73
|
+
arguments = {}
|
|
74
|
+
request_id_str = arguments.get("request_id", "")
|
|
71
75
|
if not request_id_str:
|
|
72
76
|
request_id_str = uuid.uuid1()
|
|
73
77
|
if func_name == "text_to_image":
|
xinference/model/audio/core.py
CHANGED
|
@@ -22,6 +22,7 @@ from ..utils import valid_model_revision
|
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
24
24
|
from .f5tts import F5TTSModel
|
|
25
|
+
from .f5tts_mlx import F5TTSMLXModel
|
|
25
26
|
from .fish_speech import FishSpeechModel
|
|
26
27
|
from .funasr import FunASRModel
|
|
27
28
|
from .whisper import WhisperModel
|
|
@@ -171,6 +172,7 @@ def create_audio_model_instance(
|
|
|
171
172
|
CosyVoiceModel,
|
|
172
173
|
FishSpeechModel,
|
|
173
174
|
F5TTSModel,
|
|
175
|
+
F5TTSMLXModel,
|
|
174
176
|
],
|
|
175
177
|
AudioModelDescription,
|
|
176
178
|
]:
|
|
@@ -185,6 +187,7 @@ def create_audio_model_instance(
|
|
|
185
187
|
CosyVoiceModel,
|
|
186
188
|
FishSpeechModel,
|
|
187
189
|
F5TTSModel,
|
|
190
|
+
F5TTSMLXModel,
|
|
188
191
|
]
|
|
189
192
|
if model_spec.model_family == "whisper":
|
|
190
193
|
if not model_spec.engine:
|
|
@@ -201,6 +204,8 @@ def create_audio_model_instance(
|
|
|
201
204
|
model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
|
|
202
205
|
elif model_spec.model_family == "F5-TTS":
|
|
203
206
|
model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
207
|
+
elif model_spec.model_family == "F5-TTS-MLX":
|
|
208
|
+
model = F5TTSMLXModel(model_uid, model_path, model_spec, **kwargs)
|
|
204
209
|
else:
|
|
205
210
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
206
211
|
model_description = AudioModelDescription(
|
|
@@ -39,6 +39,7 @@ class CosyVoiceModel:
|
|
|
39
39
|
self._device = device
|
|
40
40
|
self._model = None
|
|
41
41
|
self._kwargs = kwargs
|
|
42
|
+
self._is_cosyvoice2 = False
|
|
42
43
|
|
|
43
44
|
@property
|
|
44
45
|
def model_ability(self):
|
|
@@ -51,7 +52,14 @@ class CosyVoiceModel:
|
|
|
51
52
|
# The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
|
|
52
53
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
|
|
53
54
|
|
|
54
|
-
|
|
55
|
+
if "CosyVoice2" in self._model_spec.model_name:
|
|
56
|
+
from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
|
|
57
|
+
|
|
58
|
+
self._is_cosyvoice2 = True
|
|
59
|
+
else:
|
|
60
|
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
61
|
+
|
|
62
|
+
self._is_cosyvoice2 = False
|
|
55
63
|
|
|
56
64
|
self._model = CosyVoice(
|
|
57
65
|
self._model_path, load_jit=self._kwargs.get("load_jit", False)
|
|
@@ -78,12 +86,22 @@ class CosyVoiceModel:
|
|
|
78
86
|
output = self._model.inference_zero_shot(
|
|
79
87
|
input, prompt_text, prompt_speech_16k, stream=stream
|
|
80
88
|
)
|
|
89
|
+
elif instruct_text:
|
|
90
|
+
assert self._is_cosyvoice2
|
|
91
|
+
logger.info("CosyVoice inference_instruct")
|
|
92
|
+
output = self._model.inference_instruct2(
|
|
93
|
+
input,
|
|
94
|
+
instruct_text=instruct_text,
|
|
95
|
+
prompt_speech_16k=prompt_speech_16k,
|
|
96
|
+
stream=stream,
|
|
97
|
+
)
|
|
81
98
|
else:
|
|
82
99
|
logger.info("CosyVoice inference_cross_lingual")
|
|
83
100
|
output = self._model.inference_cross_lingual(
|
|
84
101
|
input, prompt_speech_16k, stream=stream
|
|
85
102
|
)
|
|
86
103
|
else:
|
|
104
|
+
assert not self._is_cosyvoice2
|
|
87
105
|
available_speakers = self._model.list_avaliable_spks()
|
|
88
106
|
if not voice:
|
|
89
107
|
voice = available_speakers[0]
|
|
@@ -106,7 +124,9 @@ class CosyVoiceModel:
|
|
|
106
124
|
def _generator_stream():
|
|
107
125
|
with BytesIO() as out:
|
|
108
126
|
writer = torchaudio.io.StreamWriter(out, format=response_format)
|
|
109
|
-
writer.add_audio_stream(
|
|
127
|
+
writer.add_audio_stream(
|
|
128
|
+
sample_rate=self._model.sample_rate, num_channels=1
|
|
129
|
+
)
|
|
110
130
|
i = 0
|
|
111
131
|
last_pos = 0
|
|
112
132
|
with writer.open():
|
|
@@ -125,7 +145,7 @@ class CosyVoiceModel:
|
|
|
125
145
|
chunks = [o["tts_speech"] for o in output]
|
|
126
146
|
t = torch.cat(chunks, dim=1)
|
|
127
147
|
with BytesIO() as out:
|
|
128
|
-
torchaudio.save(out, t,
|
|
148
|
+
torchaudio.save(out, t, self._model.sample_rate, format=response_format)
|
|
129
149
|
return out.getvalue()
|
|
130
150
|
|
|
131
151
|
return _generator_stream() if stream else _generator_block()
|
|
@@ -163,6 +183,8 @@ class CosyVoiceModel:
|
|
|
163
183
|
assert (
|
|
164
184
|
prompt_text is None
|
|
165
185
|
), "CosyVoice Instruct model does not support prompt_text"
|
|
186
|
+
elif self._is_cosyvoice2:
|
|
187
|
+
assert prompt_speech is not None, "CosyVoice2 requires prompt_speech"
|
|
166
188
|
else:
|
|
167
189
|
# inference_zero_shot
|
|
168
190
|
# inference_cross_lingual
|
xinference/model/audio/f5tts.py
CHANGED
|
@@ -11,12 +11,12 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
14
|
+
import io
|
|
15
15
|
import logging
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
18
|
from io import BytesIO
|
|
19
|
-
from typing import TYPE_CHECKING, Optional
|
|
19
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
22
|
from .core import AudioModelFamilyV1
|
|
@@ -106,9 +106,9 @@ class F5TTSModel:
|
|
|
106
106
|
) = preprocess_ref_audio_text(
|
|
107
107
|
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
|
108
108
|
)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
109
|
+
logger.info("Voice:", voice)
|
|
110
|
+
logger.info("Ref_audio:", voices[voice]["ref_audio"])
|
|
111
|
+
logger.info("Ref_text:", voices[voice]["ref_text"])
|
|
112
112
|
|
|
113
113
|
final_sample_rate = None
|
|
114
114
|
generated_audio_segments = []
|
|
@@ -122,16 +122,16 @@ class F5TTSModel:
|
|
|
122
122
|
if match:
|
|
123
123
|
voice = match[1]
|
|
124
124
|
else:
|
|
125
|
-
|
|
125
|
+
logger.info("No voice tag found, using main.")
|
|
126
126
|
voice = "main"
|
|
127
127
|
if voice not in voices:
|
|
128
|
-
|
|
128
|
+
logger.info(f"Voice {voice} not found, using main.")
|
|
129
129
|
voice = "main"
|
|
130
130
|
text = re.sub(reg2, "", text)
|
|
131
131
|
gen_text = text.strip()
|
|
132
132
|
ref_audio = voices[voice]["ref_audio"]
|
|
133
133
|
ref_text = voices[voice]["ref_text"]
|
|
134
|
-
|
|
134
|
+
logger.info(f"Voice: {voice}")
|
|
135
135
|
audio, final_sample_rate, spectragram = infer_process(
|
|
136
136
|
ref_audio,
|
|
137
137
|
ref_text,
|
|
@@ -167,18 +167,23 @@ class F5TTSModel:
|
|
|
167
167
|
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
168
168
|
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
169
169
|
|
|
170
|
+
ref_audio: Union[str, io.BytesIO]
|
|
170
171
|
if prompt_speech is None:
|
|
171
172
|
base = os.path.dirname(f5_tts.__file__)
|
|
172
173
|
config = os.path.join(base, "infer/examples/basic/basic.toml")
|
|
173
174
|
with open(config, "rb") as f:
|
|
174
175
|
config_dict = tomli.load(f)
|
|
175
|
-
|
|
176
|
+
ref_audio = os.path.join(base, config_dict["ref_audio"])
|
|
176
177
|
prompt_text = config_dict["ref_text"]
|
|
178
|
+
else:
|
|
179
|
+
ref_audio = io.BytesIO(prompt_speech)
|
|
180
|
+
if prompt_text is None:
|
|
181
|
+
raise ValueError("`prompt_text` cannot be empty")
|
|
177
182
|
|
|
178
183
|
assert self._model is not None
|
|
179
184
|
vocoder_name = self._kwargs.get("vocoder_name", "vocos")
|
|
180
185
|
sample_rate, wav = self._infer(
|
|
181
|
-
ref_audio=
|
|
186
|
+
ref_audio=ref_audio,
|
|
182
187
|
ref_text=prompt_text,
|
|
183
188
|
text_gen=input,
|
|
184
189
|
model_obj=self._model,
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import datetime
|
|
16
|
+
import io
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from .core import AudioModelFamilyV1
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class F5TTSMLXModel:
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
model_uid: str,
|
|
36
|
+
model_path: str,
|
|
37
|
+
model_spec: "AudioModelFamilyV1",
|
|
38
|
+
device: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
):
|
|
41
|
+
self._model_uid = model_uid
|
|
42
|
+
self._model_path = model_path
|
|
43
|
+
self._model_spec = model_spec
|
|
44
|
+
self._device = device
|
|
45
|
+
self._model = None
|
|
46
|
+
self._kwargs = kwargs
|
|
47
|
+
self._model = None
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def model_ability(self):
|
|
51
|
+
return self._model_spec.model_ability
|
|
52
|
+
|
|
53
|
+
def load(self):
|
|
54
|
+
try:
|
|
55
|
+
import mlx.core as mx
|
|
56
|
+
from f5_tts_mlx.cfm import F5TTS
|
|
57
|
+
from f5_tts_mlx.dit import DiT
|
|
58
|
+
from f5_tts_mlx.duration import DurationPredictor, DurationTransformer
|
|
59
|
+
from vocos_mlx import Vocos
|
|
60
|
+
except ImportError:
|
|
61
|
+
error_message = "Failed to import module 'f5_tts_mlx'"
|
|
62
|
+
installation_guide = [
|
|
63
|
+
"Please make sure 'f5_tts_mlx' is installed.\n",
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
67
|
+
|
|
68
|
+
path = Path(self._model_path)
|
|
69
|
+
# vocab
|
|
70
|
+
|
|
71
|
+
vocab_path = path / "vocab.txt"
|
|
72
|
+
vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}
|
|
73
|
+
if len(vocab) == 0:
|
|
74
|
+
raise ValueError(f"Could not load vocab from {vocab_path}")
|
|
75
|
+
|
|
76
|
+
# duration predictor
|
|
77
|
+
|
|
78
|
+
duration_model_path = path / "duration_v2.safetensors"
|
|
79
|
+
duration_predictor = None
|
|
80
|
+
|
|
81
|
+
if duration_model_path.exists():
|
|
82
|
+
duration_predictor = DurationPredictor(
|
|
83
|
+
transformer=DurationTransformer(
|
|
84
|
+
dim=512,
|
|
85
|
+
depth=8,
|
|
86
|
+
heads=8,
|
|
87
|
+
text_dim=512,
|
|
88
|
+
ff_mult=2,
|
|
89
|
+
conv_layers=2,
|
|
90
|
+
text_num_embeds=len(vocab) - 1,
|
|
91
|
+
),
|
|
92
|
+
vocab_char_map=vocab,
|
|
93
|
+
)
|
|
94
|
+
weights = mx.load(duration_model_path.as_posix(), format="safetensors")
|
|
95
|
+
duration_predictor.load_weights(list(weights.items()))
|
|
96
|
+
|
|
97
|
+
# vocoder
|
|
98
|
+
|
|
99
|
+
vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")
|
|
100
|
+
|
|
101
|
+
# model
|
|
102
|
+
|
|
103
|
+
model_path = path / "model.safetensors"
|
|
104
|
+
|
|
105
|
+
f5tts = F5TTS(
|
|
106
|
+
transformer=DiT(
|
|
107
|
+
dim=1024,
|
|
108
|
+
depth=22,
|
|
109
|
+
heads=16,
|
|
110
|
+
ff_mult=2,
|
|
111
|
+
text_dim=512,
|
|
112
|
+
conv_layers=4,
|
|
113
|
+
text_num_embeds=len(vocab) - 1,
|
|
114
|
+
),
|
|
115
|
+
vocab_char_map=vocab,
|
|
116
|
+
vocoder=vocos.decode,
|
|
117
|
+
duration_predictor=duration_predictor,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
weights = mx.load(model_path.as_posix(), format="safetensors")
|
|
121
|
+
f5tts.load_weights(list(weights.items()))
|
|
122
|
+
mx.eval(f5tts.parameters())
|
|
123
|
+
|
|
124
|
+
self._model = f5tts
|
|
125
|
+
|
|
126
|
+
def speech(
|
|
127
|
+
self,
|
|
128
|
+
input: str,
|
|
129
|
+
voice: str,
|
|
130
|
+
response_format: str = "mp3",
|
|
131
|
+
speed: float = 1.0,
|
|
132
|
+
stream: bool = False,
|
|
133
|
+
**kwargs,
|
|
134
|
+
):
|
|
135
|
+
import mlx.core as mx
|
|
136
|
+
import soundfile as sf
|
|
137
|
+
import tomli
|
|
138
|
+
from f5_tts_mlx.generate import (
|
|
139
|
+
FRAMES_PER_SEC,
|
|
140
|
+
SAMPLE_RATE,
|
|
141
|
+
TARGET_RMS,
|
|
142
|
+
convert_char_to_pinyin,
|
|
143
|
+
split_sentences,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
from .utils import ensure_sample_rate
|
|
147
|
+
|
|
148
|
+
if stream:
|
|
149
|
+
raise Exception("F5-TTS does not support stream generation.")
|
|
150
|
+
|
|
151
|
+
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
152
|
+
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
153
|
+
duration: Optional[float] = kwargs.pop("duration", None)
|
|
154
|
+
steps: Optional[int] = kwargs.pop("steps", 8)
|
|
155
|
+
cfg_strength: Optional[float] = kwargs.pop("cfg_strength", 2.0)
|
|
156
|
+
method: Literal["euler", "midpoint"] = kwargs.pop("method", "rk4")
|
|
157
|
+
sway_sampling_coef: float = kwargs.pop("sway_sampling_coef", -1.0)
|
|
158
|
+
seed: Optional[int] = kwargs.pop("seed", None)
|
|
159
|
+
|
|
160
|
+
prompt_speech_path: Union[str, io.BytesIO]
|
|
161
|
+
if prompt_speech is None:
|
|
162
|
+
base = os.path.join(os.path.dirname(__file__), "../../thirdparty/f5_tts")
|
|
163
|
+
config = os.path.join(base, "infer/examples/basic/basic.toml")
|
|
164
|
+
with open(config, "rb") as f:
|
|
165
|
+
config_dict = tomli.load(f)
|
|
166
|
+
prompt_speech_path = os.path.join(base, config_dict["ref_audio"])
|
|
167
|
+
prompt_text = config_dict["ref_text"]
|
|
168
|
+
else:
|
|
169
|
+
prompt_speech_path = io.BytesIO(prompt_speech)
|
|
170
|
+
|
|
171
|
+
if prompt_text is None:
|
|
172
|
+
raise ValueError("`prompt_text` cannot be empty")
|
|
173
|
+
|
|
174
|
+
audio, sr = sf.read(prompt_speech_path)
|
|
175
|
+
audio = ensure_sample_rate(audio, sr, SAMPLE_RATE)
|
|
176
|
+
|
|
177
|
+
audio = mx.array(audio)
|
|
178
|
+
ref_audio_duration = audio.shape[0] / SAMPLE_RATE
|
|
179
|
+
logger.debug(
|
|
180
|
+
f"Got reference audio with duration: {ref_audio_duration:.2f} seconds"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
rms = mx.sqrt(mx.mean(mx.square(audio)))
|
|
184
|
+
if rms < TARGET_RMS:
|
|
185
|
+
audio = audio * TARGET_RMS / rms
|
|
186
|
+
|
|
187
|
+
sentences = split_sentences(input)
|
|
188
|
+
is_single_generation = len(sentences) <= 1 or duration is not None
|
|
189
|
+
|
|
190
|
+
if is_single_generation:
|
|
191
|
+
generation_text = convert_char_to_pinyin([prompt_text + " " + input]) # type: ignore
|
|
192
|
+
|
|
193
|
+
if duration is not None:
|
|
194
|
+
duration = int(duration * FRAMES_PER_SEC)
|
|
195
|
+
|
|
196
|
+
start_date = datetime.datetime.now()
|
|
197
|
+
|
|
198
|
+
wave, _ = self._model.sample( # type: ignore
|
|
199
|
+
mx.expand_dims(audio, axis=0),
|
|
200
|
+
text=generation_text,
|
|
201
|
+
duration=duration,
|
|
202
|
+
steps=steps,
|
|
203
|
+
method=method,
|
|
204
|
+
speed=speed,
|
|
205
|
+
cfg_strength=cfg_strength,
|
|
206
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
207
|
+
seed=seed,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
wave = wave[audio.shape[0] :]
|
|
211
|
+
mx.eval(wave)
|
|
212
|
+
|
|
213
|
+
generated_duration = wave.shape[0] / SAMPLE_RATE
|
|
214
|
+
print(
|
|
215
|
+
f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
else:
|
|
219
|
+
start_date = datetime.datetime.now()
|
|
220
|
+
|
|
221
|
+
output = []
|
|
222
|
+
|
|
223
|
+
for sentence_text in tqdm(split_sentences(input)):
|
|
224
|
+
text = convert_char_to_pinyin([prompt_text + " " + sentence_text]) # type: ignore
|
|
225
|
+
|
|
226
|
+
if duration is not None:
|
|
227
|
+
duration = int(duration * FRAMES_PER_SEC)
|
|
228
|
+
|
|
229
|
+
wave, _ = self._model.sample( # type: ignore
|
|
230
|
+
mx.expand_dims(audio, axis=0),
|
|
231
|
+
text=text,
|
|
232
|
+
duration=duration,
|
|
233
|
+
steps=steps,
|
|
234
|
+
method=method,
|
|
235
|
+
speed=speed,
|
|
236
|
+
cfg_strength=cfg_strength,
|
|
237
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
238
|
+
seed=seed,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# trim the reference audio
|
|
242
|
+
wave = wave[audio.shape[0] :]
|
|
243
|
+
mx.eval(wave)
|
|
244
|
+
|
|
245
|
+
output.append(wave)
|
|
246
|
+
|
|
247
|
+
wave = mx.concatenate(output, axis=0)
|
|
248
|
+
|
|
249
|
+
generated_duration = wave.shape[0] / SAMPLE_RATE
|
|
250
|
+
logger.debug(
|
|
251
|
+
f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Save the generated audio
|
|
255
|
+
with BytesIO() as out:
|
|
256
|
+
with sf.SoundFile(
|
|
257
|
+
out, "w", SAMPLE_RATE, 1, format=response_format.upper()
|
|
258
|
+
) as f:
|
|
259
|
+
f.write(np.array(wave))
|
|
260
|
+
return out.getvalue()
|