xinference 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +5 -5
- xinference/core/model.py +6 -1
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +2 -1
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +172 -53
- xinference/model/llm/llm_family_modelscope.json +118 -20
- xinference/model/llm/mlx/core.py +230 -49
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +4 -1
- xinference/model/llm/vllm/core.py +5 -0
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/METADATA +33 -14
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/RECORD +85 -34
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-12-13T18:21:03+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.0
|
|
14
|
+
"full-revisionid": "b132fca91f3e1b11b111f9b89f68a55e4b7605c6",
|
|
15
|
+
"version": "1.1.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -94,9 +94,9 @@ class CreateCompletionRequest(CreateCompletion):
|
|
|
94
94
|
|
|
95
95
|
class CreateEmbeddingRequest(BaseModel):
|
|
96
96
|
model: str
|
|
97
|
-
input: Union[
|
|
98
|
-
|
|
99
|
-
)
|
|
97
|
+
input: Union[
|
|
98
|
+
str, List[str], List[int], List[List[int]], Dict[str, str], List[Dict[str, str]]
|
|
99
|
+
] = Field(description="The input to embed.")
|
|
100
100
|
user: Optional[str] = None
|
|
101
101
|
|
|
102
102
|
class Config:
|
|
@@ -2044,7 +2044,6 @@ class RESTfulAPI(CancelMixin):
|
|
|
2044
2044
|
)
|
|
2045
2045
|
if body.tools and body.stream:
|
|
2046
2046
|
is_vllm = await model.is_vllm_backend()
|
|
2047
|
-
|
|
2048
2047
|
if not (
|
|
2049
2048
|
(is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
|
|
2050
2049
|
or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
|
|
@@ -2054,7 +2053,8 @@ class RESTfulAPI(CancelMixin):
|
|
|
2054
2053
|
detail="Streaming support for tool calls is available only when using "
|
|
2055
2054
|
"Qwen models with vLLM backend or GLM4-chat models without vLLM backend.",
|
|
2056
2055
|
)
|
|
2057
|
-
|
|
2056
|
+
if "skip_special_tokens" in raw_kwargs and await model.is_vllm_backend():
|
|
2057
|
+
kwargs["skip_special_tokens"] = raw_kwargs["skip_special_tokens"]
|
|
2058
2058
|
if body.stream:
|
|
2059
2059
|
|
|
2060
2060
|
async def stream_results():
|
xinference/core/model.py
CHANGED
|
@@ -78,6 +78,7 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
|
|
|
78
78
|
]
|
|
79
79
|
|
|
80
80
|
XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
|
|
81
|
+
XINFERENCE_BATCHING_BLACK_LIST = ["glm4-chat"]
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
def request_limit(fn):
|
|
@@ -372,7 +373,11 @@ class ModelActor(xo.StatelessActor, CancelMixin):
|
|
|
372
373
|
f"Your model {self._model.model_family.model_name} with model family {self._model.model_family.model_family} is disqualified."
|
|
373
374
|
)
|
|
374
375
|
return False
|
|
375
|
-
return
|
|
376
|
+
return (
|
|
377
|
+
condition
|
|
378
|
+
and self._model.model_family.model_name
|
|
379
|
+
not in XINFERENCE_BATCHING_BLACK_LIST
|
|
380
|
+
)
|
|
376
381
|
|
|
377
382
|
def allow_batching_for_text_to_image(self) -> bool:
|
|
378
383
|
from ..model.image.stable_diffusion.core import DiffusionModel
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -846,7 +846,9 @@ def model_launch(
|
|
|
846
846
|
kwargs = {}
|
|
847
847
|
for i in range(0, len(ctx.args), 2):
|
|
848
848
|
if not ctx.args[i].startswith("--"):
|
|
849
|
-
raise ValueError(
|
|
849
|
+
raise ValueError(
|
|
850
|
+
f"You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is {ctx.args[i]}."
|
|
851
|
+
)
|
|
850
852
|
kwargs[ctx.args[i][2:]] = handle_click_args_type(ctx.args[i + 1])
|
|
851
853
|
print(f"Launch model name: {model_name} with kwargs: {kwargs}", file=sys.stderr)
|
|
852
854
|
|
|
@@ -23,6 +23,7 @@ from ..cmdline import (
|
|
|
23
23
|
list_model_registrations,
|
|
24
24
|
model_chat,
|
|
25
25
|
model_generate,
|
|
26
|
+
model_launch,
|
|
26
27
|
model_list,
|
|
27
28
|
model_terminate,
|
|
28
29
|
register_model,
|
|
@@ -311,3 +312,58 @@ def test_remove_cache(setup):
|
|
|
311
312
|
|
|
312
313
|
assert result.exit_code == 0
|
|
313
314
|
assert "Cache directory qwen1.5-chat has been deleted."
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def test_launch_error_in_passing_parameters():
|
|
318
|
+
runner = CliRunner()
|
|
319
|
+
|
|
320
|
+
# Known parameter but not provided with value.
|
|
321
|
+
result = runner.invoke(
|
|
322
|
+
model_launch,
|
|
323
|
+
[
|
|
324
|
+
"--model-engine",
|
|
325
|
+
"transformers",
|
|
326
|
+
"--model-name",
|
|
327
|
+
"qwen2.5-instruct",
|
|
328
|
+
"--model-uid",
|
|
329
|
+
"-s",
|
|
330
|
+
"0.5",
|
|
331
|
+
"-f",
|
|
332
|
+
"gptq",
|
|
333
|
+
"-q",
|
|
334
|
+
"INT4",
|
|
335
|
+
"111",
|
|
336
|
+
"-l",
|
|
337
|
+
],
|
|
338
|
+
)
|
|
339
|
+
assert result.exit_code == 1
|
|
340
|
+
assert (
|
|
341
|
+
"You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is 0.5."
|
|
342
|
+
in str(result)
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Unknown parameter
|
|
346
|
+
result = runner.invoke(
|
|
347
|
+
model_launch,
|
|
348
|
+
[
|
|
349
|
+
"--model-engine",
|
|
350
|
+
"transformers",
|
|
351
|
+
"--model-name",
|
|
352
|
+
"qwen2.5-instruct",
|
|
353
|
+
"--model-uid",
|
|
354
|
+
"123",
|
|
355
|
+
"-s",
|
|
356
|
+
"0.5",
|
|
357
|
+
"-f",
|
|
358
|
+
"gptq",
|
|
359
|
+
"-q",
|
|
360
|
+
"INT4",
|
|
361
|
+
"-l",
|
|
362
|
+
"111",
|
|
363
|
+
],
|
|
364
|
+
)
|
|
365
|
+
assert result.exit_code == 1
|
|
366
|
+
assert (
|
|
367
|
+
"You must specify extra kwargs with `--` prefix. There is an error in parameter passing that is -l."
|
|
368
|
+
in str(result)
|
|
369
|
+
)
|
xinference/isolation.py
CHANGED
|
@@ -37,6 +37,30 @@ class Isolation:
|
|
|
37
37
|
asyncio.set_event_loop(self._loop)
|
|
38
38
|
self._stopped = asyncio.Event()
|
|
39
39
|
self._loop.run_until_complete(self._stopped.wait())
|
|
40
|
+
self._cancel_all_tasks(self._loop)
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def _cancel_all_tasks(loop):
|
|
44
|
+
to_cancel = asyncio.all_tasks(loop)
|
|
45
|
+
if not to_cancel:
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
for task in to_cancel:
|
|
49
|
+
task.cancel()
|
|
50
|
+
|
|
51
|
+
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
|
52
|
+
|
|
53
|
+
for task in to_cancel:
|
|
54
|
+
if task.cancelled():
|
|
55
|
+
continue
|
|
56
|
+
if task.exception() is not None:
|
|
57
|
+
loop.call_exception_handler(
|
|
58
|
+
{
|
|
59
|
+
"message": "unhandled exception during asyncio.run() shutdown",
|
|
60
|
+
"exception": task.exception(),
|
|
61
|
+
"task": task,
|
|
62
|
+
}
|
|
63
|
+
)
|
|
40
64
|
|
|
41
65
|
def start(self):
|
|
42
66
|
if self._threaded:
|
xinference/model/audio/core.py
CHANGED
|
@@ -21,6 +21,7 @@ from ..core import CacheableModelSpec, ModelDescription
|
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
22
|
from .chattts import ChatTTSModel
|
|
23
23
|
from .cosyvoice import CosyVoiceModel
|
|
24
|
+
from .f5tts import F5TTSModel
|
|
24
25
|
from .fish_speech import FishSpeechModel
|
|
25
26
|
from .funasr import FunASRModel
|
|
26
27
|
from .whisper import WhisperModel
|
|
@@ -169,6 +170,7 @@ def create_audio_model_instance(
|
|
|
169
170
|
ChatTTSModel,
|
|
170
171
|
CosyVoiceModel,
|
|
171
172
|
FishSpeechModel,
|
|
173
|
+
F5TTSModel,
|
|
172
174
|
],
|
|
173
175
|
AudioModelDescription,
|
|
174
176
|
]:
|
|
@@ -182,6 +184,7 @@ def create_audio_model_instance(
|
|
|
182
184
|
ChatTTSModel,
|
|
183
185
|
CosyVoiceModel,
|
|
184
186
|
FishSpeechModel,
|
|
187
|
+
F5TTSModel,
|
|
185
188
|
]
|
|
186
189
|
if model_spec.model_family == "whisper":
|
|
187
190
|
if not model_spec.engine:
|
|
@@ -196,6 +199,8 @@ def create_audio_model_instance(
|
|
|
196
199
|
model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
|
|
197
200
|
elif model_spec.model_family == "FishAudio":
|
|
198
201
|
model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
|
|
202
|
+
elif model_spec.model_family == "F5-TTS":
|
|
203
|
+
model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
199
204
|
else:
|
|
200
205
|
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
201
206
|
model_description = AudioModelDescription(
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
from io import BytesIO
|
|
19
|
+
from typing import TYPE_CHECKING, Optional
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from .core import AudioModelFamilyV1
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class F5TTSModel:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model_uid: str,
|
|
31
|
+
model_path: str,
|
|
32
|
+
model_spec: "AudioModelFamilyV1",
|
|
33
|
+
device: Optional[str] = None,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
self._model_uid = model_uid
|
|
37
|
+
self._model_path = model_path
|
|
38
|
+
self._model_spec = model_spec
|
|
39
|
+
self._device = device
|
|
40
|
+
self._model = None
|
|
41
|
+
self._vocoder = None
|
|
42
|
+
self._kwargs = kwargs
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def model_ability(self):
|
|
46
|
+
return self._model_spec.model_ability
|
|
47
|
+
|
|
48
|
+
def load(self):
|
|
49
|
+
import os
|
|
50
|
+
import sys
|
|
51
|
+
|
|
52
|
+
# The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
|
|
53
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
|
|
54
|
+
|
|
55
|
+
from f5_tts.infer.utils_infer import load_model, load_vocoder
|
|
56
|
+
from f5_tts.model import DiT
|
|
57
|
+
|
|
58
|
+
vocoder_name = self._kwargs.get("vocoder_name", "vocos")
|
|
59
|
+
vocoder_path = self._kwargs.get("vocoder_path")
|
|
60
|
+
|
|
61
|
+
if vocoder_name not in ["vocos", "bigvgan"]:
|
|
62
|
+
raise Exception(f"Unsupported vocoder name: {vocoder_name}")
|
|
63
|
+
|
|
64
|
+
if vocoder_path is not None:
|
|
65
|
+
self._vocoder = load_vocoder(
|
|
66
|
+
vocoder_name=vocoder_name, is_local=True, local_path=vocoder_path
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
self._vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=False)
|
|
70
|
+
|
|
71
|
+
model_cls = DiT
|
|
72
|
+
model_cfg = dict(
|
|
73
|
+
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
|
74
|
+
)
|
|
75
|
+
if vocoder_name == "vocos":
|
|
76
|
+
exp_name = "F5TTS_Base"
|
|
77
|
+
ckpt_step = 1200000
|
|
78
|
+
elif vocoder_name == "bigvgan":
|
|
79
|
+
exp_name = "F5TTS_Base_bigvgan"
|
|
80
|
+
ckpt_step = 1250000
|
|
81
|
+
else:
|
|
82
|
+
assert False
|
|
83
|
+
ckpt_file = os.path.join(
|
|
84
|
+
self._model_path, exp_name, f"model_{ckpt_step}.safetensors"
|
|
85
|
+
)
|
|
86
|
+
logger.info(f"Loading %s...", ckpt_file)
|
|
87
|
+
self._model = load_model(
|
|
88
|
+
model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _infer(self, ref_audio, ref_text, text_gen, model_obj, mel_spec_type, speed):
|
|
92
|
+
import numpy as np
|
|
93
|
+
from f5_tts.infer.utils_infer import infer_process, preprocess_ref_audio_text
|
|
94
|
+
|
|
95
|
+
config = {}
|
|
96
|
+
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
|
97
|
+
if "voices" not in config:
|
|
98
|
+
voices = {"main": main_voice}
|
|
99
|
+
else:
|
|
100
|
+
voices = config["voices"]
|
|
101
|
+
voices["main"] = main_voice
|
|
102
|
+
for voice in voices:
|
|
103
|
+
(
|
|
104
|
+
voices[voice]["ref_audio"],
|
|
105
|
+
voices[voice]["ref_text"],
|
|
106
|
+
) = preprocess_ref_audio_text(
|
|
107
|
+
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
|
108
|
+
)
|
|
109
|
+
print("Voice:", voice)
|
|
110
|
+
print("Ref_audio:", voices[voice]["ref_audio"])
|
|
111
|
+
print("Ref_text:", voices[voice]["ref_text"])
|
|
112
|
+
|
|
113
|
+
final_sample_rate = None
|
|
114
|
+
generated_audio_segments = []
|
|
115
|
+
reg1 = r"(?=\[\w+\])"
|
|
116
|
+
chunks = re.split(reg1, text_gen)
|
|
117
|
+
reg2 = r"\[(\w+)\]"
|
|
118
|
+
for text in chunks:
|
|
119
|
+
if not text.strip():
|
|
120
|
+
continue
|
|
121
|
+
match = re.match(reg2, text)
|
|
122
|
+
if match:
|
|
123
|
+
voice = match[1]
|
|
124
|
+
else:
|
|
125
|
+
print("No voice tag found, using main.")
|
|
126
|
+
voice = "main"
|
|
127
|
+
if voice not in voices:
|
|
128
|
+
print(f"Voice {voice} not found, using main.")
|
|
129
|
+
voice = "main"
|
|
130
|
+
text = re.sub(reg2, "", text)
|
|
131
|
+
gen_text = text.strip()
|
|
132
|
+
ref_audio = voices[voice]["ref_audio"]
|
|
133
|
+
ref_text = voices[voice]["ref_text"]
|
|
134
|
+
print(f"Voice: {voice}")
|
|
135
|
+
audio, final_sample_rate, spectragram = infer_process(
|
|
136
|
+
ref_audio,
|
|
137
|
+
ref_text,
|
|
138
|
+
gen_text,
|
|
139
|
+
model_obj,
|
|
140
|
+
self._vocoder,
|
|
141
|
+
mel_spec_type=mel_spec_type,
|
|
142
|
+
speed=speed,
|
|
143
|
+
)
|
|
144
|
+
generated_audio_segments.append(audio)
|
|
145
|
+
|
|
146
|
+
if generated_audio_segments:
|
|
147
|
+
final_wave = np.concatenate(generated_audio_segments)
|
|
148
|
+
return final_sample_rate, final_wave
|
|
149
|
+
return None, None
|
|
150
|
+
|
|
151
|
+
def speech(
|
|
152
|
+
self,
|
|
153
|
+
input: str,
|
|
154
|
+
voice: str,
|
|
155
|
+
response_format: str = "mp3",
|
|
156
|
+
speed: float = 1.0,
|
|
157
|
+
stream: bool = False,
|
|
158
|
+
**kwargs,
|
|
159
|
+
):
|
|
160
|
+
import f5_tts
|
|
161
|
+
import soundfile
|
|
162
|
+
import tomli
|
|
163
|
+
|
|
164
|
+
if stream:
|
|
165
|
+
raise Exception("F5-TTS does not support stream generation.")
|
|
166
|
+
|
|
167
|
+
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
168
|
+
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
169
|
+
|
|
170
|
+
if prompt_speech is None:
|
|
171
|
+
base = os.path.dirname(f5_tts.__file__)
|
|
172
|
+
config = os.path.join(base, "infer/examples/basic/basic.toml")
|
|
173
|
+
with open(config, "rb") as f:
|
|
174
|
+
config_dict = tomli.load(f)
|
|
175
|
+
prompt_speech = os.path.join(base, config_dict["ref_audio"])
|
|
176
|
+
prompt_text = config_dict["ref_text"]
|
|
177
|
+
|
|
178
|
+
assert self._model is not None
|
|
179
|
+
vocoder_name = self._kwargs.get("vocoder_name", "vocos")
|
|
180
|
+
sample_rate, wav = self._infer(
|
|
181
|
+
ref_audio=prompt_speech,
|
|
182
|
+
ref_text=prompt_text,
|
|
183
|
+
text_gen=input,
|
|
184
|
+
model_obj=self._model,
|
|
185
|
+
mel_spec_type=vocoder_name,
|
|
186
|
+
speed=speed,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Save the generated audio
|
|
190
|
+
with BytesIO() as out:
|
|
191
|
+
with soundfile.SoundFile(
|
|
192
|
+
out, "w", sample_rate, 1, format=response_format.upper()
|
|
193
|
+
) as f:
|
|
194
|
+
f.write(wav)
|
|
195
|
+
return out.getvalue()
|
|
@@ -213,13 +213,14 @@ class FishSpeechModel:
|
|
|
213
213
|
import torchaudio
|
|
214
214
|
|
|
215
215
|
prompt_speech = kwargs.get("prompt_speech")
|
|
216
|
+
prompt_text = kwargs.get("prompt_text", kwargs.get("reference_text", ""))
|
|
216
217
|
result = self._inference(
|
|
217
218
|
text=input,
|
|
218
219
|
enable_reference_audio=kwargs.get(
|
|
219
220
|
"enable_reference_audio", prompt_speech is not None
|
|
220
221
|
),
|
|
221
222
|
reference_audio=prompt_speech,
|
|
222
|
-
reference_text=
|
|
223
|
+
reference_text=prompt_text,
|
|
223
224
|
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
|
224
225
|
chunk_length=kwargs.get("chunk_length", 200),
|
|
225
226
|
top_p=kwargs.get("top_p", 0.7),
|
|
@@ -242,5 +242,13 @@
|
|
|
242
242
|
"model_revision": "069c573759936b35191d3380deb89183c0656f59",
|
|
243
243
|
"model_ability": "text-to-audio",
|
|
244
244
|
"multilingual": true
|
|
245
|
+
},
|
|
246
|
+
{
|
|
247
|
+
"model_name": "F5-TTS",
|
|
248
|
+
"model_family": "F5-TTS",
|
|
249
|
+
"model_id": "SWivid/F5-TTS",
|
|
250
|
+
"model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
|
|
251
|
+
"model_ability": "text-to-audio",
|
|
252
|
+
"multilingual": true
|
|
245
253
|
}
|
|
246
254
|
]
|
|
@@ -73,5 +73,14 @@
|
|
|
73
73
|
"model_revision": "master",
|
|
74
74
|
"model_ability": "text-to-audio",
|
|
75
75
|
"multilingual": true
|
|
76
|
+
},
|
|
77
|
+
{
|
|
78
|
+
"model_name": "F5-TTS",
|
|
79
|
+
"model_family": "F5-TTS",
|
|
80
|
+
"model_hub": "modelscope",
|
|
81
|
+
"model_id": "SWivid/F5-TTS_Emilia-ZH-EN",
|
|
82
|
+
"model_revision": "master",
|
|
83
|
+
"model_ability": "text-to-audio",
|
|
84
|
+
"multilingual": true
|
|
76
85
|
}
|
|
77
86
|
]
|