xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +28 -6
- xinference/core/utils.py +10 -6
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +10 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/model_spec.json +27 -3
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +219 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +167 -20
- xinference/model/llm/mlx/core.py +287 -51
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +5 -1
- xinference/model/llm/vllm/core.py +16 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import datetime
|
|
16
|
+
import io
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from io import BytesIO
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from .core import AudioModelFamilyV1
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class F5TTSMLXModel:
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
model_uid: str,
|
|
36
|
+
model_path: str,
|
|
37
|
+
model_spec: "AudioModelFamilyV1",
|
|
38
|
+
device: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
):
|
|
41
|
+
self._model_uid = model_uid
|
|
42
|
+
self._model_path = model_path
|
|
43
|
+
self._model_spec = model_spec
|
|
44
|
+
self._device = device
|
|
45
|
+
self._model = None
|
|
46
|
+
self._kwargs = kwargs
|
|
47
|
+
self._model = None
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def model_ability(self):
|
|
51
|
+
return self._model_spec.model_ability
|
|
52
|
+
|
|
53
|
+
def load(self):
|
|
54
|
+
try:
|
|
55
|
+
import mlx.core as mx
|
|
56
|
+
from f5_tts_mlx.cfm import F5TTS
|
|
57
|
+
from f5_tts_mlx.dit import DiT
|
|
58
|
+
from f5_tts_mlx.duration import DurationPredictor, DurationTransformer
|
|
59
|
+
from vocos_mlx import Vocos
|
|
60
|
+
except ImportError:
|
|
61
|
+
error_message = "Failed to import module 'f5_tts_mlx'"
|
|
62
|
+
installation_guide = [
|
|
63
|
+
"Please make sure 'f5_tts_mlx' is installed.\n",
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
67
|
+
|
|
68
|
+
path = Path(self._model_path)
|
|
69
|
+
# vocab
|
|
70
|
+
|
|
71
|
+
vocab_path = path / "vocab.txt"
|
|
72
|
+
vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}
|
|
73
|
+
if len(vocab) == 0:
|
|
74
|
+
raise ValueError(f"Could not load vocab from {vocab_path}")
|
|
75
|
+
|
|
76
|
+
# duration predictor
|
|
77
|
+
|
|
78
|
+
duration_model_path = path / "duration_v2.safetensors"
|
|
79
|
+
duration_predictor = None
|
|
80
|
+
|
|
81
|
+
if duration_model_path.exists():
|
|
82
|
+
duration_predictor = DurationPredictor(
|
|
83
|
+
transformer=DurationTransformer(
|
|
84
|
+
dim=512,
|
|
85
|
+
depth=8,
|
|
86
|
+
heads=8,
|
|
87
|
+
text_dim=512,
|
|
88
|
+
ff_mult=2,
|
|
89
|
+
conv_layers=2,
|
|
90
|
+
text_num_embeds=len(vocab) - 1,
|
|
91
|
+
),
|
|
92
|
+
vocab_char_map=vocab,
|
|
93
|
+
)
|
|
94
|
+
weights = mx.load(duration_model_path.as_posix(), format="safetensors")
|
|
95
|
+
duration_predictor.load_weights(list(weights.items()))
|
|
96
|
+
|
|
97
|
+
# vocoder
|
|
98
|
+
|
|
99
|
+
vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")
|
|
100
|
+
|
|
101
|
+
# model
|
|
102
|
+
|
|
103
|
+
model_path = path / "model.safetensors"
|
|
104
|
+
|
|
105
|
+
f5tts = F5TTS(
|
|
106
|
+
transformer=DiT(
|
|
107
|
+
dim=1024,
|
|
108
|
+
depth=22,
|
|
109
|
+
heads=16,
|
|
110
|
+
ff_mult=2,
|
|
111
|
+
text_dim=512,
|
|
112
|
+
conv_layers=4,
|
|
113
|
+
text_num_embeds=len(vocab) - 1,
|
|
114
|
+
),
|
|
115
|
+
vocab_char_map=vocab,
|
|
116
|
+
vocoder=vocos.decode,
|
|
117
|
+
duration_predictor=duration_predictor,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
weights = mx.load(model_path.as_posix(), format="safetensors")
|
|
121
|
+
f5tts.load_weights(list(weights.items()))
|
|
122
|
+
mx.eval(f5tts.parameters())
|
|
123
|
+
|
|
124
|
+
self._model = f5tts
|
|
125
|
+
|
|
126
|
+
def speech(
|
|
127
|
+
self,
|
|
128
|
+
input: str,
|
|
129
|
+
voice: str,
|
|
130
|
+
response_format: str = "mp3",
|
|
131
|
+
speed: float = 1.0,
|
|
132
|
+
stream: bool = False,
|
|
133
|
+
**kwargs,
|
|
134
|
+
):
|
|
135
|
+
import mlx.core as mx
|
|
136
|
+
import soundfile as sf
|
|
137
|
+
import tomli
|
|
138
|
+
from f5_tts_mlx.generate import (
|
|
139
|
+
FRAMES_PER_SEC,
|
|
140
|
+
SAMPLE_RATE,
|
|
141
|
+
TARGET_RMS,
|
|
142
|
+
convert_char_to_pinyin,
|
|
143
|
+
split_sentences,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
from .utils import ensure_sample_rate
|
|
147
|
+
|
|
148
|
+
if stream:
|
|
149
|
+
raise Exception("F5-TTS does not support stream generation.")
|
|
150
|
+
|
|
151
|
+
prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
|
|
152
|
+
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
|
|
153
|
+
duration: Optional[float] = kwargs.pop("duration", None)
|
|
154
|
+
steps: Optional[int] = kwargs.pop("steps", 8)
|
|
155
|
+
cfg_strength: Optional[float] = kwargs.pop("cfg_strength", 2.0)
|
|
156
|
+
method: Literal["euler", "midpoint"] = kwargs.pop("method", "rk4")
|
|
157
|
+
sway_sampling_coef: float = kwargs.pop("sway_sampling_coef", -1.0)
|
|
158
|
+
seed: Optional[int] = kwargs.pop("seed", None)
|
|
159
|
+
|
|
160
|
+
prompt_speech_path: Union[str, io.BytesIO]
|
|
161
|
+
if prompt_speech is None:
|
|
162
|
+
base = os.path.join(os.path.dirname(__file__), "../../thirdparty/f5_tts")
|
|
163
|
+
config = os.path.join(base, "infer/examples/basic/basic.toml")
|
|
164
|
+
with open(config, "rb") as f:
|
|
165
|
+
config_dict = tomli.load(f)
|
|
166
|
+
prompt_speech_path = os.path.join(base, config_dict["ref_audio"])
|
|
167
|
+
prompt_text = config_dict["ref_text"]
|
|
168
|
+
else:
|
|
169
|
+
prompt_speech_path = io.BytesIO(prompt_speech)
|
|
170
|
+
|
|
171
|
+
if prompt_text is None:
|
|
172
|
+
raise ValueError("`prompt_text` cannot be empty")
|
|
173
|
+
|
|
174
|
+
audio, sr = sf.read(prompt_speech_path)
|
|
175
|
+
audio = ensure_sample_rate(audio, sr, SAMPLE_RATE)
|
|
176
|
+
|
|
177
|
+
audio = mx.array(audio)
|
|
178
|
+
ref_audio_duration = audio.shape[0] / SAMPLE_RATE
|
|
179
|
+
logger.debug(
|
|
180
|
+
f"Got reference audio with duration: {ref_audio_duration:.2f} seconds"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
rms = mx.sqrt(mx.mean(mx.square(audio)))
|
|
184
|
+
if rms < TARGET_RMS:
|
|
185
|
+
audio = audio * TARGET_RMS / rms
|
|
186
|
+
|
|
187
|
+
sentences = split_sentences(input)
|
|
188
|
+
is_single_generation = len(sentences) <= 1 or duration is not None
|
|
189
|
+
|
|
190
|
+
if is_single_generation:
|
|
191
|
+
generation_text = convert_char_to_pinyin([prompt_text + " " + input]) # type: ignore
|
|
192
|
+
|
|
193
|
+
if duration is not None:
|
|
194
|
+
duration = int(duration * FRAMES_PER_SEC)
|
|
195
|
+
|
|
196
|
+
start_date = datetime.datetime.now()
|
|
197
|
+
|
|
198
|
+
wave, _ = self._model.sample( # type: ignore
|
|
199
|
+
mx.expand_dims(audio, axis=0),
|
|
200
|
+
text=generation_text,
|
|
201
|
+
duration=duration,
|
|
202
|
+
steps=steps,
|
|
203
|
+
method=method,
|
|
204
|
+
speed=speed,
|
|
205
|
+
cfg_strength=cfg_strength,
|
|
206
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
207
|
+
seed=seed,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
wave = wave[audio.shape[0] :]
|
|
211
|
+
mx.eval(wave)
|
|
212
|
+
|
|
213
|
+
generated_duration = wave.shape[0] / SAMPLE_RATE
|
|
214
|
+
print(
|
|
215
|
+
f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
else:
|
|
219
|
+
start_date = datetime.datetime.now()
|
|
220
|
+
|
|
221
|
+
output = []
|
|
222
|
+
|
|
223
|
+
for sentence_text in tqdm(split_sentences(input)):
|
|
224
|
+
text = convert_char_to_pinyin([prompt_text + " " + sentence_text]) # type: ignore
|
|
225
|
+
|
|
226
|
+
if duration is not None:
|
|
227
|
+
duration = int(duration * FRAMES_PER_SEC)
|
|
228
|
+
|
|
229
|
+
wave, _ = self._model.sample( # type: ignore
|
|
230
|
+
mx.expand_dims(audio, axis=0),
|
|
231
|
+
text=text,
|
|
232
|
+
duration=duration,
|
|
233
|
+
steps=steps,
|
|
234
|
+
method=method,
|
|
235
|
+
speed=speed,
|
|
236
|
+
cfg_strength=cfg_strength,
|
|
237
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
238
|
+
seed=seed,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# trim the reference audio
|
|
242
|
+
wave = wave[audio.shape[0] :]
|
|
243
|
+
mx.eval(wave)
|
|
244
|
+
|
|
245
|
+
output.append(wave)
|
|
246
|
+
|
|
247
|
+
wave = mx.concatenate(output, axis=0)
|
|
248
|
+
|
|
249
|
+
generated_duration = wave.shape[0] / SAMPLE_RATE
|
|
250
|
+
logger.debug(
|
|
251
|
+
f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Save the generated audio
|
|
255
|
+
with BytesIO() as out:
|
|
256
|
+
with sf.SoundFile(
|
|
257
|
+
out, "w", SAMPLE_RATE, 1, format=response_format.upper()
|
|
258
|
+
) as f:
|
|
259
|
+
f.write(np.array(wave))
|
|
260
|
+
return out.getvalue()
|
|
@@ -11,10 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
import gc
|
|
15
14
|
import logging
|
|
16
15
|
import os.path
|
|
17
|
-
import queue
|
|
18
16
|
import sys
|
|
19
17
|
from io import BytesIO
|
|
20
18
|
from typing import TYPE_CHECKING, Optional
|
|
@@ -60,6 +58,7 @@ class FishSpeechModel:
|
|
|
60
58
|
self._device = device
|
|
61
59
|
self._llama_queue = None
|
|
62
60
|
self._model = None
|
|
61
|
+
self._engine = None
|
|
63
62
|
self._kwargs = kwargs
|
|
64
63
|
|
|
65
64
|
@property
|
|
@@ -72,6 +71,7 @@ class FishSpeechModel:
|
|
|
72
71
|
0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
|
|
73
72
|
)
|
|
74
73
|
|
|
74
|
+
from tools.inference_engine import TTSInferenceEngine
|
|
75
75
|
from tools.llama.generate import launch_thread_safe_queue
|
|
76
76
|
from tools.vqgan.inference import load_model as load_decoder_model
|
|
77
77
|
|
|
@@ -81,6 +81,11 @@ class FishSpeechModel:
|
|
|
81
81
|
if not is_device_available(self._device):
|
|
82
82
|
raise ValueError(f"Device {self._device} is not available!")
|
|
83
83
|
|
|
84
|
+
# https://github.com/pytorch/pytorch/issues/129207
|
|
85
|
+
if self._device == "mps":
|
|
86
|
+
logger.warning("The Conv1d has bugs on MPS backend, fallback to CPU.")
|
|
87
|
+
self._device = "cpu"
|
|
88
|
+
|
|
84
89
|
enable_compile = self._kwargs.get("compile", False)
|
|
85
90
|
precision = self._kwargs.get("precision", torch.bfloat16)
|
|
86
91
|
logger.info("Loading Llama model, compile=%s...", enable_compile)
|
|
@@ -102,102 +107,10 @@ class FishSpeechModel:
|
|
|
102
107
|
device=self._device,
|
|
103
108
|
)
|
|
104
109
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
self,
|
|
108
|
-
text,
|
|
109
|
-
enable_reference_audio,
|
|
110
|
-
reference_audio,
|
|
111
|
-
reference_text,
|
|
112
|
-
max_new_tokens,
|
|
113
|
-
chunk_length,
|
|
114
|
-
top_p,
|
|
115
|
-
repetition_penalty,
|
|
116
|
-
temperature,
|
|
117
|
-
seed="0",
|
|
118
|
-
streaming=False,
|
|
119
|
-
):
|
|
120
|
-
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
121
|
-
from tools.api import decode_vq_tokens, encode_reference
|
|
122
|
-
from tools.llama.generate import (
|
|
123
|
-
GenerateRequest,
|
|
124
|
-
GenerateResponse,
|
|
125
|
-
WrappedGenerateResponse,
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
seed = int(seed)
|
|
129
|
-
if seed != 0:
|
|
130
|
-
set_seed(seed)
|
|
131
|
-
logger.warning(f"set seed: {seed}")
|
|
132
|
-
|
|
133
|
-
# Parse reference audio aka prompt
|
|
134
|
-
prompt_tokens = encode_reference(
|
|
135
|
-
decoder_model=self._model,
|
|
136
|
-
reference_audio=reference_audio,
|
|
137
|
-
enable_reference_audio=enable_reference_audio,
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
# LLAMA Inference
|
|
141
|
-
request = dict(
|
|
142
|
-
device=self._model.device,
|
|
143
|
-
max_new_tokens=max_new_tokens,
|
|
144
|
-
text=text,
|
|
145
|
-
top_p=top_p,
|
|
146
|
-
repetition_penalty=repetition_penalty,
|
|
147
|
-
temperature=temperature,
|
|
148
|
-
compile=self._kwargs.get("compile", False),
|
|
149
|
-
iterative_prompt=chunk_length > 0,
|
|
150
|
-
chunk_length=chunk_length,
|
|
151
|
-
max_length=2048,
|
|
152
|
-
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
|
153
|
-
prompt_text=reference_text if enable_reference_audio else None,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
response_queue = queue.Queue()
|
|
157
|
-
self._llama_queue.put(
|
|
158
|
-
GenerateRequest(
|
|
159
|
-
request=request,
|
|
160
|
-
response_queue=response_queue,
|
|
161
|
-
)
|
|
110
|
+
self._engine = TTSInferenceEngine(
|
|
111
|
+
self._llama_queue, self._model, precision, enable_compile
|
|
162
112
|
)
|
|
163
113
|
|
|
164
|
-
segments = []
|
|
165
|
-
|
|
166
|
-
while True:
|
|
167
|
-
result: WrappedGenerateResponse = response_queue.get()
|
|
168
|
-
if result.status == "error":
|
|
169
|
-
raise result.response
|
|
170
|
-
|
|
171
|
-
result: GenerateResponse = result.response
|
|
172
|
-
if result.action == "next":
|
|
173
|
-
break
|
|
174
|
-
|
|
175
|
-
with autocast_exclude_mps(
|
|
176
|
-
device_type=self._model.device.type,
|
|
177
|
-
dtype=self._kwargs.get("precision", torch.bfloat16),
|
|
178
|
-
):
|
|
179
|
-
fake_audios = decode_vq_tokens(
|
|
180
|
-
decoder_model=self._model,
|
|
181
|
-
codes=result.codes,
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
fake_audios = fake_audios.float().cpu().numpy()
|
|
185
|
-
segments.append(fake_audios)
|
|
186
|
-
|
|
187
|
-
if streaming:
|
|
188
|
-
yield fake_audios, None, None
|
|
189
|
-
|
|
190
|
-
if len(segments) == 0:
|
|
191
|
-
raise Exception("No audio generated, please check the input text.")
|
|
192
|
-
|
|
193
|
-
# No matter streaming or not, we need to return the final audio
|
|
194
|
-
audio = np.concatenate(segments, axis=0)
|
|
195
|
-
yield None, (self._model.spec_transform.sample_rate, audio), None
|
|
196
|
-
|
|
197
|
-
if torch.cuda.is_available():
|
|
198
|
-
torch.cuda.empty_cache()
|
|
199
|
-
gc.collect()
|
|
200
|
-
|
|
201
114
|
def speech(
|
|
202
115
|
self,
|
|
203
116
|
input: str,
|
|
@@ -211,21 +124,31 @@ class FishSpeechModel:
|
|
|
211
124
|
if speed != 1.0:
|
|
212
125
|
logger.warning("Fish speech does not support setting speed: %s.", speed)
|
|
213
126
|
import torchaudio
|
|
127
|
+
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
|
214
128
|
|
|
215
129
|
prompt_speech = kwargs.get("prompt_speech")
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
130
|
+
prompt_text = kwargs.get("prompt_text", kwargs.get("reference_text", ""))
|
|
131
|
+
if prompt_speech is not None:
|
|
132
|
+
r = ServeReferenceAudio(audio=prompt_speech, text=prompt_text)
|
|
133
|
+
references = [r]
|
|
134
|
+
else:
|
|
135
|
+
references = []
|
|
136
|
+
|
|
137
|
+
assert self._engine is not None
|
|
138
|
+
result = self._engine.inference(
|
|
139
|
+
ServeTTSRequest(
|
|
140
|
+
text=input,
|
|
141
|
+
references=references,
|
|
142
|
+
reference_id=kwargs.get("reference_id"),
|
|
143
|
+
seed=kwargs.get("seed"),
|
|
144
|
+
max_new_tokens=kwargs.get("max_new_tokens", 1024),
|
|
145
|
+
chunk_length=kwargs.get("chunk_length", 200),
|
|
146
|
+
top_p=kwargs.get("top_p", 0.7),
|
|
147
|
+
repetition_penalty=kwargs.get("repetition_penalty", 1.2),
|
|
148
|
+
temperature=kwargs.get("temperature", 0.7),
|
|
149
|
+
streaming=stream,
|
|
150
|
+
format=response_format,
|
|
151
|
+
)
|
|
229
152
|
)
|
|
230
153
|
|
|
231
154
|
if stream:
|
|
@@ -241,7 +164,9 @@ class FishSpeechModel:
|
|
|
241
164
|
last_pos = 0
|
|
242
165
|
with writer.open():
|
|
243
166
|
for chunk in result:
|
|
244
|
-
chunk
|
|
167
|
+
if chunk.code == "final":
|
|
168
|
+
continue
|
|
169
|
+
chunk = chunk.audio[1]
|
|
245
170
|
if chunk is not None:
|
|
246
171
|
chunk = chunk.reshape((chunk.shape[0], 1))
|
|
247
172
|
trans_chunk = torch.from_numpy(chunk)
|
|
@@ -256,7 +181,7 @@ class FishSpeechModel:
|
|
|
256
181
|
return _stream_generator()
|
|
257
182
|
else:
|
|
258
183
|
result = list(result)
|
|
259
|
-
sample_rate, audio = result[0]
|
|
184
|
+
sample_rate, audio = result[0].audio
|
|
260
185
|
audio = np.array([audio])
|
|
261
186
|
|
|
262
187
|
# Save the generated audio
|
|
@@ -236,10 +236,34 @@
|
|
|
236
236
|
"multilingual": true
|
|
237
237
|
},
|
|
238
238
|
{
|
|
239
|
-
"model_name": "
|
|
239
|
+
"model_name": "CosyVoice2-0.5B",
|
|
240
|
+
"model_family": "CosyVoice",
|
|
241
|
+
"model_id": "mrfakename/CosyVoice2-0.5B",
|
|
242
|
+
"model_revision": "5676baabc8a76dc93ef60a88bbd2420deaa2f644",
|
|
243
|
+
"model_ability": "text-to-audio",
|
|
244
|
+
"multilingual": true
|
|
245
|
+
},
|
|
246
|
+
{
|
|
247
|
+
"model_name": "FishSpeech-1.5",
|
|
240
248
|
"model_family": "FishAudio",
|
|
241
|
-
"model_id": "fishaudio/fish-speech-1.
|
|
242
|
-
"model_revision": "
|
|
249
|
+
"model_id": "fishaudio/fish-speech-1.5",
|
|
250
|
+
"model_revision": "268b6ec86243dd683bc78dab7e9a6cedf9191f2a",
|
|
251
|
+
"model_ability": "text-to-audio",
|
|
252
|
+
"multilingual": true
|
|
253
|
+
},
|
|
254
|
+
{
|
|
255
|
+
"model_name": "F5-TTS",
|
|
256
|
+
"model_family": "F5-TTS",
|
|
257
|
+
"model_id": "SWivid/F5-TTS",
|
|
258
|
+
"model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
|
|
259
|
+
"model_ability": "text-to-audio",
|
|
260
|
+
"multilingual": true
|
|
261
|
+
},
|
|
262
|
+
{
|
|
263
|
+
"model_name": "F5-TTS-MLX",
|
|
264
|
+
"model_family": "F5-TTS-MLX",
|
|
265
|
+
"model_id": "lucasnewman/f5-tts-mlx",
|
|
266
|
+
"model_revision": "7642bb232e3fcacf92c51c786edebb8624da6b93",
|
|
243
267
|
"model_ability": "text-to-audio",
|
|
244
268
|
"multilingual": true
|
|
245
269
|
}
|
|
@@ -73,5 +73,23 @@
|
|
|
73
73
|
"model_revision": "master",
|
|
74
74
|
"model_ability": "text-to-audio",
|
|
75
75
|
"multilingual": true
|
|
76
|
+
},
|
|
77
|
+
{
|
|
78
|
+
"model_name": "CosyVoice2-0.5B",
|
|
79
|
+
"model_family": "CosyVoice",
|
|
80
|
+
"model_hub": "modelscope",
|
|
81
|
+
"model_id": "iic/CosyVoice2-0.5B",
|
|
82
|
+
"model_revision": "master",
|
|
83
|
+
"model_ability": "text-to-audio",
|
|
84
|
+
"multilingual": true
|
|
85
|
+
},
|
|
86
|
+
{
|
|
87
|
+
"model_name": "F5-TTS",
|
|
88
|
+
"model_family": "F5-TTS",
|
|
89
|
+
"model_hub": "modelscope",
|
|
90
|
+
"model_id": "SWivid/F5-TTS_Emilia-ZH-EN",
|
|
91
|
+
"model_revision": "master",
|
|
92
|
+
"model_ability": "text-to-audio",
|
|
93
|
+
"multilingual": true
|
|
76
94
|
}
|
|
77
95
|
]
|
xinference/model/audio/utils.py
CHANGED
|
@@ -11,8 +11,40 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import io
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
14
19
|
from .core import AudioModelFamilyV1
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
def get_model_version(audio_model: AudioModelFamilyV1) -> str:
|
|
18
23
|
return audio_model.model_name
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def ensure_sample_rate(
|
|
27
|
+
audio: np.ndarray, old_sample_rate: int, sample_rate: int
|
|
28
|
+
) -> np.ndarray:
|
|
29
|
+
import soundfile as sf
|
|
30
|
+
from scipy.signal import resample
|
|
31
|
+
|
|
32
|
+
if old_sample_rate != sample_rate:
|
|
33
|
+
# Calculate the new data length
|
|
34
|
+
new_length = int(len(audio) * sample_rate / old_sample_rate)
|
|
35
|
+
|
|
36
|
+
# Resample the data
|
|
37
|
+
resampled_data = resample(audio, new_length)
|
|
38
|
+
|
|
39
|
+
# Use BytesIO to save the resampled data to memory
|
|
40
|
+
with io.BytesIO() as buffer:
|
|
41
|
+
# Write the resampled data to the memory buffer
|
|
42
|
+
sf.write(buffer, resampled_data, sample_rate, format="WAV")
|
|
43
|
+
|
|
44
|
+
# Reset the buffer position to the beginning
|
|
45
|
+
buffer.seek(0)
|
|
46
|
+
|
|
47
|
+
# Read the data from the memory buffer
|
|
48
|
+
audio, sr = sf.read(buffer, dtype="float32")
|
|
49
|
+
|
|
50
|
+
return audio
|