xinference 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +5 -5
- xinference/core/model.py +6 -1
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +2 -1
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +172 -53
- xinference/model/llm/llm_family_modelscope.json +118 -20
- xinference/model/llm/mlx/core.py +230 -49
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +4 -1
- xinference/model/llm/vllm/core.py +5 -0
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/METADATA +33 -14
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/RECORD +85 -34
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
# A unified script for inference process
|
|
2
|
+
# Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
|
|
7
|
+
sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import re
|
|
11
|
+
import tempfile
|
|
12
|
+
from importlib.resources import files
|
|
13
|
+
|
|
14
|
+
# import matplotlib
|
|
15
|
+
|
|
16
|
+
# matplotlib.use("Agg")
|
|
17
|
+
#
|
|
18
|
+
# import matplotlib.pylab as plt
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torchaudio
|
|
22
|
+
import tqdm
|
|
23
|
+
from huggingface_hub import snapshot_download, hf_hub_download
|
|
24
|
+
from pydub import AudioSegment, silence
|
|
25
|
+
from transformers import pipeline
|
|
26
|
+
from vocos import Vocos
|
|
27
|
+
|
|
28
|
+
from f5_tts.model import CFM
|
|
29
|
+
from f5_tts.model.utils import (
|
|
30
|
+
get_tokenizer,
|
|
31
|
+
convert_char_to_pinyin,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
_ref_audio_cache = {}
|
|
35
|
+
|
|
36
|
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
37
|
+
|
|
38
|
+
# -----------------------------------------
|
|
39
|
+
|
|
40
|
+
target_sample_rate = 24000
|
|
41
|
+
n_mel_channels = 100
|
|
42
|
+
hop_length = 256
|
|
43
|
+
win_length = 1024
|
|
44
|
+
n_fft = 1024
|
|
45
|
+
mel_spec_type = "vocos"
|
|
46
|
+
target_rms = 0.1
|
|
47
|
+
cross_fade_duration = 0.15
|
|
48
|
+
ode_method = "euler"
|
|
49
|
+
nfe_step = 32 # 16, 32
|
|
50
|
+
cfg_strength = 2.0
|
|
51
|
+
sway_sampling_coef = -1.0
|
|
52
|
+
speed = 1.0
|
|
53
|
+
fix_duration = None
|
|
54
|
+
|
|
55
|
+
# -----------------------------------------
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# chunk text into smaller pieces
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def chunk_text(text, max_chars=135):
|
|
62
|
+
"""
|
|
63
|
+
Splits the input text into chunks, each with a maximum number of characters.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
text (str): The text to be split.
|
|
67
|
+
max_chars (int): The maximum number of characters per chunk.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List[str]: A list of text chunks.
|
|
71
|
+
"""
|
|
72
|
+
chunks = []
|
|
73
|
+
current_chunk = ""
|
|
74
|
+
# Split the text into sentences based on punctuation followed by whitespace
|
|
75
|
+
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
|
|
76
|
+
|
|
77
|
+
for sentence in sentences:
|
|
78
|
+
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
|
|
79
|
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
80
|
+
else:
|
|
81
|
+
if current_chunk:
|
|
82
|
+
chunks.append(current_chunk.strip())
|
|
83
|
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
|
|
84
|
+
|
|
85
|
+
if current_chunk:
|
|
86
|
+
chunks.append(current_chunk.strip())
|
|
87
|
+
|
|
88
|
+
return chunks
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# load vocoder
|
|
92
|
+
def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
|
|
93
|
+
if vocoder_name == "vocos":
|
|
94
|
+
# vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
|
95
|
+
if is_local:
|
|
96
|
+
print(f"Load vocos from local path {local_path}")
|
|
97
|
+
config_path = f"{local_path}/config.yaml"
|
|
98
|
+
model_path = f"{local_path}/pytorch_model.bin"
|
|
99
|
+
else:
|
|
100
|
+
print("Download Vocos from huggingface charactr/vocos-mel-24khz")
|
|
101
|
+
repo_id = "charactr/vocos-mel-24khz"
|
|
102
|
+
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
|
|
103
|
+
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
|
|
104
|
+
vocoder = Vocos.from_hparams(config_path)
|
|
105
|
+
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
|
106
|
+
from vocos.feature_extractors import EncodecFeatures
|
|
107
|
+
|
|
108
|
+
if isinstance(vocoder.feature_extractor, EncodecFeatures):
|
|
109
|
+
encodec_parameters = {
|
|
110
|
+
"feature_extractor.encodec." + key: value
|
|
111
|
+
for key, value in vocoder.feature_extractor.encodec.state_dict().items()
|
|
112
|
+
}
|
|
113
|
+
state_dict.update(encodec_parameters)
|
|
114
|
+
vocoder.load_state_dict(state_dict)
|
|
115
|
+
vocoder = vocoder.eval().to(device)
|
|
116
|
+
elif vocoder_name == "bigvgan":
|
|
117
|
+
try:
|
|
118
|
+
from third_party.BigVGAN import bigvgan
|
|
119
|
+
except ImportError:
|
|
120
|
+
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
|
121
|
+
if is_local:
|
|
122
|
+
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
|
123
|
+
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
|
124
|
+
else:
|
|
125
|
+
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
|
|
126
|
+
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
|
127
|
+
|
|
128
|
+
vocoder.remove_weight_norm()
|
|
129
|
+
vocoder = vocoder.eval().to(device)
|
|
130
|
+
return vocoder
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# load asr pipeline
|
|
134
|
+
|
|
135
|
+
asr_pipe = None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def initialize_asr_pipeline(device: str = device, dtype=None):
|
|
139
|
+
if dtype is None:
|
|
140
|
+
dtype = (
|
|
141
|
+
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
|
142
|
+
)
|
|
143
|
+
global asr_pipe
|
|
144
|
+
asr_pipe = pipeline(
|
|
145
|
+
"automatic-speech-recognition",
|
|
146
|
+
model="openai/whisper-large-v3-turbo",
|
|
147
|
+
torch_dtype=dtype,
|
|
148
|
+
device=device,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# transcribe
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def transcribe(ref_audio, language=None):
|
|
156
|
+
global asr_pipe
|
|
157
|
+
if asr_pipe is None:
|
|
158
|
+
initialize_asr_pipeline(device=device)
|
|
159
|
+
return asr_pipe(
|
|
160
|
+
ref_audio,
|
|
161
|
+
chunk_length_s=30,
|
|
162
|
+
batch_size=128,
|
|
163
|
+
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
|
|
164
|
+
return_timestamps=False,
|
|
165
|
+
)["text"].strip()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# load model checkpoint for inference
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
|
172
|
+
if dtype is None:
|
|
173
|
+
dtype = (
|
|
174
|
+
torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
|
|
175
|
+
)
|
|
176
|
+
model = model.to(dtype)
|
|
177
|
+
|
|
178
|
+
ckpt_type = ckpt_path.split(".")[-1]
|
|
179
|
+
if ckpt_type == "safetensors":
|
|
180
|
+
from safetensors.torch import load_file
|
|
181
|
+
|
|
182
|
+
checkpoint = load_file(ckpt_path, device=device)
|
|
183
|
+
else:
|
|
184
|
+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
|
185
|
+
|
|
186
|
+
if use_ema:
|
|
187
|
+
if ckpt_type == "safetensors":
|
|
188
|
+
checkpoint = {"ema_model_state_dict": checkpoint}
|
|
189
|
+
checkpoint["model_state_dict"] = {
|
|
190
|
+
k.replace("ema_model.", ""): v
|
|
191
|
+
for k, v in checkpoint["ema_model_state_dict"].items()
|
|
192
|
+
if k not in ["initted", "step"]
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
# patch for backward compatibility, 305e3ea
|
|
196
|
+
for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
|
|
197
|
+
if key in checkpoint["model_state_dict"]:
|
|
198
|
+
del checkpoint["model_state_dict"][key]
|
|
199
|
+
|
|
200
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
201
|
+
else:
|
|
202
|
+
if ckpt_type == "safetensors":
|
|
203
|
+
checkpoint = {"model_state_dict": checkpoint}
|
|
204
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
205
|
+
|
|
206
|
+
del checkpoint
|
|
207
|
+
torch.cuda.empty_cache()
|
|
208
|
+
|
|
209
|
+
return model.to(device)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# load model for inference
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def load_model(
|
|
216
|
+
model_cls,
|
|
217
|
+
model_cfg,
|
|
218
|
+
ckpt_path,
|
|
219
|
+
mel_spec_type=mel_spec_type,
|
|
220
|
+
vocab_file="",
|
|
221
|
+
ode_method=ode_method,
|
|
222
|
+
use_ema=True,
|
|
223
|
+
device=device,
|
|
224
|
+
):
|
|
225
|
+
if vocab_file == "":
|
|
226
|
+
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
|
|
227
|
+
tokenizer = "custom"
|
|
228
|
+
|
|
229
|
+
print("\nvocab : ", vocab_file)
|
|
230
|
+
print("token : ", tokenizer)
|
|
231
|
+
print("model : ", ckpt_path, "\n")
|
|
232
|
+
|
|
233
|
+
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
|
|
234
|
+
model = CFM(
|
|
235
|
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
|
236
|
+
mel_spec_kwargs=dict(
|
|
237
|
+
n_fft=n_fft,
|
|
238
|
+
hop_length=hop_length,
|
|
239
|
+
win_length=win_length,
|
|
240
|
+
n_mel_channels=n_mel_channels,
|
|
241
|
+
target_sample_rate=target_sample_rate,
|
|
242
|
+
mel_spec_type=mel_spec_type,
|
|
243
|
+
),
|
|
244
|
+
odeint_kwargs=dict(
|
|
245
|
+
method=ode_method,
|
|
246
|
+
),
|
|
247
|
+
vocab_char_map=vocab_char_map,
|
|
248
|
+
).to(device)
|
|
249
|
+
|
|
250
|
+
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
|
251
|
+
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
|
|
252
|
+
|
|
253
|
+
return model
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def remove_silence_edges(audio, silence_threshold=-42):
|
|
257
|
+
# Remove silence from the start
|
|
258
|
+
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
|
|
259
|
+
audio = audio[non_silent_start_idx:]
|
|
260
|
+
|
|
261
|
+
# Remove silence from the end
|
|
262
|
+
non_silent_end_duration = audio.duration_seconds
|
|
263
|
+
for ms in reversed(audio):
|
|
264
|
+
if ms.dBFS > silence_threshold:
|
|
265
|
+
break
|
|
266
|
+
non_silent_end_duration -= 0.001
|
|
267
|
+
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
|
|
268
|
+
|
|
269
|
+
return trimmed_audio
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# preprocess reference audio and text
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
|
|
276
|
+
show_info("Converting audio...")
|
|
277
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
|
278
|
+
aseg = AudioSegment.from_file(ref_audio_orig)
|
|
279
|
+
|
|
280
|
+
if clip_short:
|
|
281
|
+
# 1. try to find long silence for clipping
|
|
282
|
+
non_silent_segs = silence.split_on_silence(
|
|
283
|
+
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
|
|
284
|
+
)
|
|
285
|
+
non_silent_wave = AudioSegment.silent(duration=0)
|
|
286
|
+
for non_silent_seg in non_silent_segs:
|
|
287
|
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
|
288
|
+
show_info("Audio is over 15s, clipping short. (1)")
|
|
289
|
+
break
|
|
290
|
+
non_silent_wave += non_silent_seg
|
|
291
|
+
|
|
292
|
+
# 2. try to find short silence for clipping if 1. failed
|
|
293
|
+
if len(non_silent_wave) > 15000:
|
|
294
|
+
non_silent_segs = silence.split_on_silence(
|
|
295
|
+
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
|
|
296
|
+
)
|
|
297
|
+
non_silent_wave = AudioSegment.silent(duration=0)
|
|
298
|
+
for non_silent_seg in non_silent_segs:
|
|
299
|
+
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
|
|
300
|
+
show_info("Audio is over 15s, clipping short. (2)")
|
|
301
|
+
break
|
|
302
|
+
non_silent_wave += non_silent_seg
|
|
303
|
+
|
|
304
|
+
aseg = non_silent_wave
|
|
305
|
+
|
|
306
|
+
# 3. if no proper silence found for clipping
|
|
307
|
+
if len(aseg) > 15000:
|
|
308
|
+
aseg = aseg[:15000]
|
|
309
|
+
show_info("Audio is over 15s, clipping short. (3)")
|
|
310
|
+
|
|
311
|
+
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
|
|
312
|
+
aseg.export(f.name, format="wav")
|
|
313
|
+
ref_audio = f.name
|
|
314
|
+
|
|
315
|
+
# Compute a hash of the reference audio file
|
|
316
|
+
with open(ref_audio, "rb") as audio_file:
|
|
317
|
+
audio_data = audio_file.read()
|
|
318
|
+
audio_hash = hashlib.md5(audio_data).hexdigest()
|
|
319
|
+
|
|
320
|
+
if not ref_text.strip():
|
|
321
|
+
global _ref_audio_cache
|
|
322
|
+
if audio_hash in _ref_audio_cache:
|
|
323
|
+
# Use cached asr transcription
|
|
324
|
+
show_info("Using cached reference text...")
|
|
325
|
+
ref_text = _ref_audio_cache[audio_hash]
|
|
326
|
+
else:
|
|
327
|
+
show_info("No reference text provided, transcribing reference audio...")
|
|
328
|
+
ref_text = transcribe(ref_audio)
|
|
329
|
+
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
|
|
330
|
+
_ref_audio_cache[audio_hash] = ref_text
|
|
331
|
+
else:
|
|
332
|
+
show_info("Using custom reference text...")
|
|
333
|
+
|
|
334
|
+
# Ensure ref_text ends with a proper sentence-ending punctuation
|
|
335
|
+
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
|
336
|
+
if ref_text.endswith("."):
|
|
337
|
+
ref_text += " "
|
|
338
|
+
else:
|
|
339
|
+
ref_text += ". "
|
|
340
|
+
|
|
341
|
+
print("ref_text ", ref_text)
|
|
342
|
+
|
|
343
|
+
return ref_audio, ref_text
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# infer process: chunk text -> infer batches [i.e. infer_batch_process()]
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def infer_process(
|
|
350
|
+
ref_audio,
|
|
351
|
+
ref_text,
|
|
352
|
+
gen_text,
|
|
353
|
+
model_obj,
|
|
354
|
+
vocoder,
|
|
355
|
+
mel_spec_type=mel_spec_type,
|
|
356
|
+
show_info=print,
|
|
357
|
+
progress=tqdm,
|
|
358
|
+
target_rms=target_rms,
|
|
359
|
+
cross_fade_duration=cross_fade_duration,
|
|
360
|
+
nfe_step=nfe_step,
|
|
361
|
+
cfg_strength=cfg_strength,
|
|
362
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
363
|
+
speed=speed,
|
|
364
|
+
fix_duration=fix_duration,
|
|
365
|
+
device=device,
|
|
366
|
+
):
|
|
367
|
+
# Split the input text into batches
|
|
368
|
+
audio, sr = torchaudio.load(ref_audio)
|
|
369
|
+
max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
|
370
|
+
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
|
371
|
+
for i, gen_text in enumerate(gen_text_batches):
|
|
372
|
+
print(f"gen_text {i}", gen_text)
|
|
373
|
+
|
|
374
|
+
show_info(f"Generating audio in {len(gen_text_batches)} batches...")
|
|
375
|
+
return infer_batch_process(
|
|
376
|
+
(audio, sr),
|
|
377
|
+
ref_text,
|
|
378
|
+
gen_text_batches,
|
|
379
|
+
model_obj,
|
|
380
|
+
vocoder,
|
|
381
|
+
mel_spec_type=mel_spec_type,
|
|
382
|
+
progress=progress,
|
|
383
|
+
target_rms=target_rms,
|
|
384
|
+
cross_fade_duration=cross_fade_duration,
|
|
385
|
+
nfe_step=nfe_step,
|
|
386
|
+
cfg_strength=cfg_strength,
|
|
387
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
388
|
+
speed=speed,
|
|
389
|
+
fix_duration=fix_duration,
|
|
390
|
+
device=device,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
# infer batches
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def infer_batch_process(
|
|
398
|
+
ref_audio,
|
|
399
|
+
ref_text,
|
|
400
|
+
gen_text_batches,
|
|
401
|
+
model_obj,
|
|
402
|
+
vocoder,
|
|
403
|
+
mel_spec_type="vocos",
|
|
404
|
+
progress=tqdm,
|
|
405
|
+
target_rms=0.1,
|
|
406
|
+
cross_fade_duration=0.15,
|
|
407
|
+
nfe_step=32,
|
|
408
|
+
cfg_strength=2.0,
|
|
409
|
+
sway_sampling_coef=-1,
|
|
410
|
+
speed=1,
|
|
411
|
+
fix_duration=None,
|
|
412
|
+
device=None,
|
|
413
|
+
):
|
|
414
|
+
audio, sr = ref_audio
|
|
415
|
+
if audio.shape[0] > 1:
|
|
416
|
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
417
|
+
|
|
418
|
+
rms = torch.sqrt(torch.mean(torch.square(audio)))
|
|
419
|
+
if rms < target_rms:
|
|
420
|
+
audio = audio * target_rms / rms
|
|
421
|
+
if sr != target_sample_rate:
|
|
422
|
+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
|
423
|
+
audio = resampler(audio)
|
|
424
|
+
audio = audio.to(device)
|
|
425
|
+
|
|
426
|
+
generated_waves = []
|
|
427
|
+
spectrograms = []
|
|
428
|
+
|
|
429
|
+
if len(ref_text[-1].encode("utf-8")) == 1:
|
|
430
|
+
ref_text = ref_text + " "
|
|
431
|
+
for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
|
|
432
|
+
# Prepare the text
|
|
433
|
+
text_list = [ref_text + gen_text]
|
|
434
|
+
final_text_list = convert_char_to_pinyin(text_list)
|
|
435
|
+
|
|
436
|
+
ref_audio_len = audio.shape[-1] // hop_length
|
|
437
|
+
if fix_duration is not None:
|
|
438
|
+
duration = int(fix_duration * target_sample_rate / hop_length)
|
|
439
|
+
else:
|
|
440
|
+
# Calculate duration
|
|
441
|
+
ref_text_len = len(ref_text.encode("utf-8"))
|
|
442
|
+
gen_text_len = len(gen_text.encode("utf-8"))
|
|
443
|
+
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
|
|
444
|
+
|
|
445
|
+
# inference
|
|
446
|
+
with torch.inference_mode():
|
|
447
|
+
generated, _ = model_obj.sample(
|
|
448
|
+
cond=audio,
|
|
449
|
+
text=final_text_list,
|
|
450
|
+
duration=duration,
|
|
451
|
+
steps=nfe_step,
|
|
452
|
+
cfg_strength=cfg_strength,
|
|
453
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
generated = generated.to(torch.float32)
|
|
457
|
+
generated = generated[:, ref_audio_len:, :]
|
|
458
|
+
generated_mel_spec = generated.permute(0, 2, 1)
|
|
459
|
+
if mel_spec_type == "vocos":
|
|
460
|
+
generated_wave = vocoder.decode(generated_mel_spec)
|
|
461
|
+
elif mel_spec_type == "bigvgan":
|
|
462
|
+
generated_wave = vocoder(generated_mel_spec)
|
|
463
|
+
if rms < target_rms:
|
|
464
|
+
generated_wave = generated_wave * rms / target_rms
|
|
465
|
+
|
|
466
|
+
# wav -> numpy
|
|
467
|
+
generated_wave = generated_wave.squeeze().cpu().numpy()
|
|
468
|
+
|
|
469
|
+
generated_waves.append(generated_wave)
|
|
470
|
+
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
|
471
|
+
|
|
472
|
+
# Combine all generated waves with cross-fading
|
|
473
|
+
if cross_fade_duration <= 0:
|
|
474
|
+
# Simply concatenate
|
|
475
|
+
final_wave = np.concatenate(generated_waves)
|
|
476
|
+
else:
|
|
477
|
+
final_wave = generated_waves[0]
|
|
478
|
+
for i in range(1, len(generated_waves)):
|
|
479
|
+
prev_wave = final_wave
|
|
480
|
+
next_wave = generated_waves[i]
|
|
481
|
+
|
|
482
|
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
|
483
|
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
|
484
|
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
|
485
|
+
|
|
486
|
+
if cross_fade_samples <= 0:
|
|
487
|
+
# No overlap possible, concatenate
|
|
488
|
+
final_wave = np.concatenate([prev_wave, next_wave])
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
# Overlapping parts
|
|
492
|
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
|
493
|
+
next_overlap = next_wave[:cross_fade_samples]
|
|
494
|
+
|
|
495
|
+
# Fade out and fade in
|
|
496
|
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
|
497
|
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
|
498
|
+
|
|
499
|
+
# Cross-faded overlap
|
|
500
|
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
|
501
|
+
|
|
502
|
+
# Combine
|
|
503
|
+
new_wave = np.concatenate(
|
|
504
|
+
[prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
final_wave = new_wave
|
|
508
|
+
|
|
509
|
+
# Create a combined spectrogram
|
|
510
|
+
combined_spectrogram = np.concatenate(spectrograms, axis=1)
|
|
511
|
+
|
|
512
|
+
return final_wave, target_sample_rate, combined_spectrogram
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
# remove silence from generated wav
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def remove_silence_for_generated_wav(filename):
|
|
519
|
+
aseg = AudioSegment.from_file(filename)
|
|
520
|
+
non_silent_segs = silence.split_on_silence(
|
|
521
|
+
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500, seek_step=10
|
|
522
|
+
)
|
|
523
|
+
non_silent_wave = AudioSegment.silent(duration=0)
|
|
524
|
+
for non_silent_seg in non_silent_segs:
|
|
525
|
+
non_silent_wave += non_silent_seg
|
|
526
|
+
aseg = non_silent_wave
|
|
527
|
+
aseg.export(filename, format="wav")
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
# save spectrogram
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def save_spectrogram(spectrogram, path):
|
|
534
|
+
plt.figure(figsize=(12, 4))
|
|
535
|
+
plt.imshow(spectrogram, origin="lower", aspect="auto")
|
|
536
|
+
plt.colorbar()
|
|
537
|
+
plt.savefig(path)
|
|
538
|
+
plt.close()
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from f5_tts.model.cfm import CFM
|
|
2
|
+
|
|
3
|
+
from f5_tts.model.backbones.unett import UNetT
|
|
4
|
+
from f5_tts.model.backbones.dit import DiT
|
|
5
|
+
from f5_tts.model.backbones.mmdit import MMDiT
|
|
6
|
+
|
|
7
|
+
# from f5_tts.model.trainer import Trainer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = ["CFM", "UNetT", "DiT", "MMDiT"] # , "Trainer"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
## Backbones quick introduction
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
### unett.py
|
|
5
|
+
- flat unet transformer
|
|
6
|
+
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
|
7
|
+
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
|
|
8
|
+
|
|
9
|
+
### dit.py
|
|
10
|
+
- adaln-zero dit
|
|
11
|
+
- embedded timestep as condition
|
|
12
|
+
- concatted noised_input + masked_cond + embedded_text, linear proj in
|
|
13
|
+
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
|
14
|
+
- possible long skip connection (first layer to last layer)
|
|
15
|
+
|
|
16
|
+
### mmdit.py
|
|
17
|
+
- sd3 structure
|
|
18
|
+
- timestep as condition
|
|
19
|
+
- left stream: text embedded and applied a abs pos emb
|
|
20
|
+
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|