xinference 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +5 -5
- xinference/core/model.py +6 -1
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +2 -1
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +172 -53
- xinference/model/llm/llm_family_modelscope.json +118 -20
- xinference/model/llm/mlx/core.py +230 -49
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +4 -1
- xinference/model/llm/vllm/core.py +5 -0
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/METADATA +33 -14
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/RECORD +85 -34
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Evaluate with Seed-TTS testset
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
import argparse
|
|
6
|
+
|
|
7
|
+
sys.path.append(os.getcwd())
|
|
8
|
+
|
|
9
|
+
import multiprocessing as mp
|
|
10
|
+
from importlib.resources import files
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from f5_tts.eval.utils_eval import (
|
|
15
|
+
get_seed_tts_test,
|
|
16
|
+
run_asr_wer,
|
|
17
|
+
run_sim,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
rel_path = str(files("f5_tts").joinpath("../../"))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_args():
|
|
24
|
+
parser = argparse.ArgumentParser()
|
|
25
|
+
parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
|
|
26
|
+
parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
|
|
27
|
+
parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
|
|
28
|
+
parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
|
|
29
|
+
parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
|
|
30
|
+
return parser.parse_args()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def main():
|
|
34
|
+
args = get_args()
|
|
35
|
+
eval_task = args.eval_task
|
|
36
|
+
lang = args.lang
|
|
37
|
+
gen_wav_dir = args.gen_wav_dir
|
|
38
|
+
metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
|
|
39
|
+
|
|
40
|
+
# NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
|
|
41
|
+
# zh 1.254 seems a result of 4 workers wer_seed_tts
|
|
42
|
+
gpus = list(range(args.gpu_nums))
|
|
43
|
+
test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
|
|
44
|
+
|
|
45
|
+
local = args.local
|
|
46
|
+
if local: # use local custom checkpoint dir
|
|
47
|
+
if lang == "zh":
|
|
48
|
+
asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
|
|
49
|
+
elif lang == "en":
|
|
50
|
+
asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
|
|
51
|
+
else:
|
|
52
|
+
asr_ckpt_dir = "" # auto download to cache dir
|
|
53
|
+
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
|
54
|
+
|
|
55
|
+
# --------------------------- WER ---------------------------
|
|
56
|
+
|
|
57
|
+
if eval_task == "wer":
|
|
58
|
+
wers = []
|
|
59
|
+
with mp.Pool(processes=len(gpus)) as pool:
|
|
60
|
+
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
|
61
|
+
results = pool.map(run_asr_wer, args)
|
|
62
|
+
for wers_ in results:
|
|
63
|
+
wers.extend(wers_)
|
|
64
|
+
|
|
65
|
+
wer = round(np.mean(wers) * 100, 3)
|
|
66
|
+
print(f"\nTotal {len(wers)} samples")
|
|
67
|
+
print(f"WER : {wer}%")
|
|
68
|
+
|
|
69
|
+
# --------------------------- SIM ---------------------------
|
|
70
|
+
if eval_task == "sim":
|
|
71
|
+
sim_list = []
|
|
72
|
+
with mp.Pool(processes=len(gpus)) as pool:
|
|
73
|
+
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
|
74
|
+
results = pool.map(run_sim, args)
|
|
75
|
+
for sim_ in results:
|
|
76
|
+
sim_list.extend(sim_)
|
|
77
|
+
|
|
78
|
+
sim = round(sum(sim_list) / len(sim_list), 3)
|
|
79
|
+
print(f"\nTotal {len(sim_list)} samples")
|
|
80
|
+
print(f"SIM : {sim}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if __name__ == "__main__":
|
|
84
|
+
main()
|
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
import string
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import torchaudio
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
|
12
|
+
from f5_tts.model.modules import MelSpec
|
|
13
|
+
from f5_tts.model.utils import convert_char_to_pinyin
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
|
17
|
+
def get_seedtts_testset_metainfo(metalst):
|
|
18
|
+
f = open(metalst)
|
|
19
|
+
lines = f.readlines()
|
|
20
|
+
f.close()
|
|
21
|
+
metainfo = []
|
|
22
|
+
for line in lines:
|
|
23
|
+
if len(line.strip().split("|")) == 5:
|
|
24
|
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
|
25
|
+
elif len(line.strip().split("|")) == 4:
|
|
26
|
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
|
27
|
+
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
|
28
|
+
if not os.path.isabs(prompt_wav):
|
|
29
|
+
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
|
30
|
+
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
|
31
|
+
return metainfo
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
|
35
|
+
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
|
36
|
+
f = open(metalst)
|
|
37
|
+
lines = f.readlines()
|
|
38
|
+
f.close()
|
|
39
|
+
metainfo = []
|
|
40
|
+
for line in lines:
|
|
41
|
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
|
42
|
+
|
|
43
|
+
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
|
44
|
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
|
45
|
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
|
46
|
+
|
|
47
|
+
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
|
48
|
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
|
49
|
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
|
50
|
+
|
|
51
|
+
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
|
52
|
+
|
|
53
|
+
return metainfo
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# padded to max length mel batch
|
|
57
|
+
def padded_mel_batch(ref_mels):
|
|
58
|
+
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
|
59
|
+
padded_ref_mels = []
|
|
60
|
+
for mel in ref_mels:
|
|
61
|
+
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
|
62
|
+
padded_ref_mels.append(padded_ref_mel)
|
|
63
|
+
padded_ref_mels = torch.stack(padded_ref_mels)
|
|
64
|
+
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
|
65
|
+
return padded_ref_mels
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_inference_prompt(
|
|
72
|
+
metainfo,
|
|
73
|
+
speed=1.0,
|
|
74
|
+
tokenizer="pinyin",
|
|
75
|
+
polyphone=True,
|
|
76
|
+
target_sample_rate=24000,
|
|
77
|
+
n_fft=1024,
|
|
78
|
+
win_length=1024,
|
|
79
|
+
n_mel_channels=100,
|
|
80
|
+
hop_length=256,
|
|
81
|
+
mel_spec_type="vocos",
|
|
82
|
+
target_rms=0.1,
|
|
83
|
+
use_truth_duration=False,
|
|
84
|
+
infer_batch_size=1,
|
|
85
|
+
num_buckets=200,
|
|
86
|
+
min_secs=3,
|
|
87
|
+
max_secs=40,
|
|
88
|
+
):
|
|
89
|
+
prompts_all = []
|
|
90
|
+
|
|
91
|
+
min_tokens = min_secs * target_sample_rate // hop_length
|
|
92
|
+
max_tokens = max_secs * target_sample_rate // hop_length
|
|
93
|
+
|
|
94
|
+
batch_accum = [0] * num_buckets
|
|
95
|
+
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
|
96
|
+
[[] for _ in range(num_buckets)] for _ in range(6)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
mel_spectrogram = MelSpec(
|
|
100
|
+
n_fft=n_fft,
|
|
101
|
+
hop_length=hop_length,
|
|
102
|
+
win_length=win_length,
|
|
103
|
+
n_mel_channels=n_mel_channels,
|
|
104
|
+
target_sample_rate=target_sample_rate,
|
|
105
|
+
mel_spec_type=mel_spec_type,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
|
109
|
+
# Audio
|
|
110
|
+
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
|
111
|
+
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
|
112
|
+
if ref_rms < target_rms:
|
|
113
|
+
ref_audio = ref_audio * target_rms / ref_rms
|
|
114
|
+
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
|
115
|
+
if ref_sr != target_sample_rate:
|
|
116
|
+
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
|
117
|
+
ref_audio = resampler(ref_audio)
|
|
118
|
+
|
|
119
|
+
# Text
|
|
120
|
+
if len(prompt_text[-1].encode("utf-8")) == 1:
|
|
121
|
+
prompt_text = prompt_text + " "
|
|
122
|
+
text = [prompt_text + gt_text]
|
|
123
|
+
if tokenizer == "pinyin":
|
|
124
|
+
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
|
125
|
+
else:
|
|
126
|
+
text_list = text
|
|
127
|
+
|
|
128
|
+
# Duration, mel frame length
|
|
129
|
+
ref_mel_len = ref_audio.shape[-1] // hop_length
|
|
130
|
+
if use_truth_duration:
|
|
131
|
+
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
|
132
|
+
if gt_sr != target_sample_rate:
|
|
133
|
+
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
|
134
|
+
gt_audio = resampler(gt_audio)
|
|
135
|
+
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
|
136
|
+
|
|
137
|
+
# # test vocoder resynthesis
|
|
138
|
+
# ref_audio = gt_audio
|
|
139
|
+
else:
|
|
140
|
+
ref_text_len = len(prompt_text.encode("utf-8"))
|
|
141
|
+
gen_text_len = len(gt_text.encode("utf-8"))
|
|
142
|
+
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
|
143
|
+
|
|
144
|
+
# to mel spectrogram
|
|
145
|
+
ref_mel = mel_spectrogram(ref_audio)
|
|
146
|
+
ref_mel = ref_mel.squeeze(0)
|
|
147
|
+
|
|
148
|
+
# deal with batch
|
|
149
|
+
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
|
150
|
+
assert (
|
|
151
|
+
min_tokens <= total_mel_len <= max_tokens
|
|
152
|
+
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
|
153
|
+
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
|
154
|
+
|
|
155
|
+
utts[bucket_i].append(utt)
|
|
156
|
+
ref_rms_list[bucket_i].append(ref_rms)
|
|
157
|
+
ref_mels[bucket_i].append(ref_mel)
|
|
158
|
+
ref_mel_lens[bucket_i].append(ref_mel_len)
|
|
159
|
+
total_mel_lens[bucket_i].append(total_mel_len)
|
|
160
|
+
final_text_list[bucket_i].extend(text_list)
|
|
161
|
+
|
|
162
|
+
batch_accum[bucket_i] += total_mel_len
|
|
163
|
+
|
|
164
|
+
if batch_accum[bucket_i] >= infer_batch_size:
|
|
165
|
+
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
|
166
|
+
prompts_all.append(
|
|
167
|
+
(
|
|
168
|
+
utts[bucket_i],
|
|
169
|
+
ref_rms_list[bucket_i],
|
|
170
|
+
padded_mel_batch(ref_mels[bucket_i]),
|
|
171
|
+
ref_mel_lens[bucket_i],
|
|
172
|
+
total_mel_lens[bucket_i],
|
|
173
|
+
final_text_list[bucket_i],
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
batch_accum[bucket_i] = 0
|
|
177
|
+
(
|
|
178
|
+
utts[bucket_i],
|
|
179
|
+
ref_rms_list[bucket_i],
|
|
180
|
+
ref_mels[bucket_i],
|
|
181
|
+
ref_mel_lens[bucket_i],
|
|
182
|
+
total_mel_lens[bucket_i],
|
|
183
|
+
final_text_list[bucket_i],
|
|
184
|
+
) = [], [], [], [], [], []
|
|
185
|
+
|
|
186
|
+
# add residual
|
|
187
|
+
for bucket_i, bucket_frames in enumerate(batch_accum):
|
|
188
|
+
if bucket_frames > 0:
|
|
189
|
+
prompts_all.append(
|
|
190
|
+
(
|
|
191
|
+
utts[bucket_i],
|
|
192
|
+
ref_rms_list[bucket_i],
|
|
193
|
+
padded_mel_batch(ref_mels[bucket_i]),
|
|
194
|
+
ref_mel_lens[bucket_i],
|
|
195
|
+
total_mel_lens[bucket_i],
|
|
196
|
+
final_text_list[bucket_i],
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
# not only leave easy work for last workers
|
|
200
|
+
random.seed(666)
|
|
201
|
+
random.shuffle(prompts_all)
|
|
202
|
+
|
|
203
|
+
return prompts_all
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# get wav_res_ref_text of seed-tts test metalst
|
|
207
|
+
# https://github.com/BytedanceSpeech/seed-tts-eval
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
|
211
|
+
f = open(metalst)
|
|
212
|
+
lines = f.readlines()
|
|
213
|
+
f.close()
|
|
214
|
+
|
|
215
|
+
test_set_ = []
|
|
216
|
+
for line in tqdm(lines):
|
|
217
|
+
if len(line.strip().split("|")) == 5:
|
|
218
|
+
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
|
219
|
+
elif len(line.strip().split("|")) == 4:
|
|
220
|
+
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
|
221
|
+
|
|
222
|
+
if not os.path.exists(os.path.join(gen_wav_dir, utt + ".wav")):
|
|
223
|
+
continue
|
|
224
|
+
gen_wav = os.path.join(gen_wav_dir, utt + ".wav")
|
|
225
|
+
if not os.path.isabs(prompt_wav):
|
|
226
|
+
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
|
227
|
+
|
|
228
|
+
test_set_.append((gen_wav, prompt_wav, gt_text))
|
|
229
|
+
|
|
230
|
+
num_jobs = len(gpus)
|
|
231
|
+
if num_jobs == 1:
|
|
232
|
+
return [(gpus[0], test_set_)]
|
|
233
|
+
|
|
234
|
+
wav_per_job = len(test_set_) // num_jobs + 1
|
|
235
|
+
test_set = []
|
|
236
|
+
for i in range(num_jobs):
|
|
237
|
+
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
|
238
|
+
|
|
239
|
+
return test_set
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# get librispeech test-clean cross sentence test
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth=False):
|
|
246
|
+
f = open(metalst)
|
|
247
|
+
lines = f.readlines()
|
|
248
|
+
f.close()
|
|
249
|
+
|
|
250
|
+
test_set_ = []
|
|
251
|
+
for line in tqdm(lines):
|
|
252
|
+
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split("\t")
|
|
253
|
+
|
|
254
|
+
if eval_ground_truth:
|
|
255
|
+
gen_spk_id, gen_chaptr_id, _ = gen_utt.split("-")
|
|
256
|
+
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + ".flac")
|
|
257
|
+
else:
|
|
258
|
+
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + ".wav")):
|
|
259
|
+
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
|
260
|
+
gen_wav = os.path.join(gen_wav_dir, gen_utt + ".wav")
|
|
261
|
+
|
|
262
|
+
ref_spk_id, ref_chaptr_id, _ = ref_utt.split("-")
|
|
263
|
+
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + ".flac")
|
|
264
|
+
|
|
265
|
+
test_set_.append((gen_wav, ref_wav, gen_txt))
|
|
266
|
+
|
|
267
|
+
num_jobs = len(gpus)
|
|
268
|
+
if num_jobs == 1:
|
|
269
|
+
return [(gpus[0], test_set_)]
|
|
270
|
+
|
|
271
|
+
wav_per_job = len(test_set_) // num_jobs + 1
|
|
272
|
+
test_set = []
|
|
273
|
+
for i in range(num_jobs):
|
|
274
|
+
test_set.append((gpus[i], test_set_[i * wav_per_job : (i + 1) * wav_per_job]))
|
|
275
|
+
|
|
276
|
+
return test_set
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# load asr model
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def load_asr_model(lang, ckpt_dir=""):
|
|
283
|
+
if lang == "zh":
|
|
284
|
+
from funasr import AutoModel
|
|
285
|
+
|
|
286
|
+
model = AutoModel(
|
|
287
|
+
model=os.path.join(ckpt_dir, "paraformer-zh"),
|
|
288
|
+
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
|
289
|
+
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
|
290
|
+
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
|
291
|
+
disable_update=True,
|
|
292
|
+
) # following seed-tts setting
|
|
293
|
+
elif lang == "en":
|
|
294
|
+
from faster_whisper import WhisperModel
|
|
295
|
+
|
|
296
|
+
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
|
297
|
+
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
|
298
|
+
return model
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# WER Evaluation, the way Seed-TTS does
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def run_asr_wer(args):
|
|
305
|
+
rank, lang, test_set, ckpt_dir = args
|
|
306
|
+
|
|
307
|
+
if lang == "zh":
|
|
308
|
+
import zhconv
|
|
309
|
+
|
|
310
|
+
torch.cuda.set_device(rank)
|
|
311
|
+
elif lang == "en":
|
|
312
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
|
313
|
+
else:
|
|
314
|
+
raise NotImplementedError(
|
|
315
|
+
"lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
asr_model = load_asr_model(lang, ckpt_dir=ckpt_dir)
|
|
319
|
+
|
|
320
|
+
from zhon.hanzi import punctuation
|
|
321
|
+
|
|
322
|
+
punctuation_all = punctuation + string.punctuation
|
|
323
|
+
wers = []
|
|
324
|
+
|
|
325
|
+
from jiwer import compute_measures
|
|
326
|
+
|
|
327
|
+
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
|
328
|
+
if lang == "zh":
|
|
329
|
+
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
|
330
|
+
hypo = res[0]["text"]
|
|
331
|
+
hypo = zhconv.convert(hypo, "zh-cn")
|
|
332
|
+
elif lang == "en":
|
|
333
|
+
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
|
334
|
+
hypo = ""
|
|
335
|
+
for segment in segments:
|
|
336
|
+
hypo = hypo + " " + segment.text
|
|
337
|
+
|
|
338
|
+
# raw_truth = truth
|
|
339
|
+
# raw_hypo = hypo
|
|
340
|
+
|
|
341
|
+
for x in punctuation_all:
|
|
342
|
+
truth = truth.replace(x, "")
|
|
343
|
+
hypo = hypo.replace(x, "")
|
|
344
|
+
|
|
345
|
+
truth = truth.replace(" ", " ")
|
|
346
|
+
hypo = hypo.replace(" ", " ")
|
|
347
|
+
|
|
348
|
+
if lang == "zh":
|
|
349
|
+
truth = " ".join([x for x in truth])
|
|
350
|
+
hypo = " ".join([x for x in hypo])
|
|
351
|
+
elif lang == "en":
|
|
352
|
+
truth = truth.lower()
|
|
353
|
+
hypo = hypo.lower()
|
|
354
|
+
|
|
355
|
+
measures = compute_measures(truth, hypo)
|
|
356
|
+
wer = measures["wer"]
|
|
357
|
+
|
|
358
|
+
# ref_list = truth.split(" ")
|
|
359
|
+
# subs = measures["substitutions"] / len(ref_list)
|
|
360
|
+
# dele = measures["deletions"] / len(ref_list)
|
|
361
|
+
# inse = measures["insertions"] / len(ref_list)
|
|
362
|
+
|
|
363
|
+
wers.append(wer)
|
|
364
|
+
|
|
365
|
+
return wers
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
# SIM Evaluation
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def run_sim(args):
|
|
372
|
+
rank, test_set, ckpt_dir = args
|
|
373
|
+
device = f"cuda:{rank}"
|
|
374
|
+
|
|
375
|
+
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type="wavlm_large", config_path=None)
|
|
376
|
+
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
|
377
|
+
model.load_state_dict(state_dict["model"], strict=False)
|
|
378
|
+
|
|
379
|
+
use_gpu = True if torch.cuda.is_available() else False
|
|
380
|
+
if use_gpu:
|
|
381
|
+
model = model.cuda(device)
|
|
382
|
+
model.eval()
|
|
383
|
+
|
|
384
|
+
sim_list = []
|
|
385
|
+
for wav1, wav2, truth in tqdm(test_set):
|
|
386
|
+
wav1, sr1 = torchaudio.load(wav1)
|
|
387
|
+
wav2, sr2 = torchaudio.load(wav2)
|
|
388
|
+
|
|
389
|
+
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
|
390
|
+
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
|
391
|
+
wav1 = resample1(wav1)
|
|
392
|
+
wav2 = resample2(wav2)
|
|
393
|
+
|
|
394
|
+
if use_gpu:
|
|
395
|
+
wav1 = wav1.cuda(device)
|
|
396
|
+
wav2 = wav2.cuda(device)
|
|
397
|
+
with torch.no_grad():
|
|
398
|
+
emb1 = model(wav1)
|
|
399
|
+
emb2 = model(wav2)
|
|
400
|
+
|
|
401
|
+
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
|
402
|
+
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
|
403
|
+
sim_list.append(sim)
|
|
404
|
+
|
|
405
|
+
return sim_list
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# Inference
|
|
2
|
+
|
|
3
|
+
The pretrained model checkpoints can be reached at [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS) and [🤖 Model Scope](https://www.modelscope.cn/models/SWivid/F5-TTS_Emilia-ZH-EN), or will be automatically downloaded when running inference scripts.
|
|
4
|
+
|
|
5
|
+
**More checkpoints with whole community efforts can be found in [SHARED.md](SHARED.md), supporting more languages.**
|
|
6
|
+
|
|
7
|
+
Currently support **30s for a single** generation, which is the **total length** including both prompt and output audio. However, you can provide `infer_cli` and `infer_gradio` with longer text, will automatically do chunk generation. Long reference audio will be **clip short to ~15s**.
|
|
8
|
+
|
|
9
|
+
To avoid possible inference failures, make sure you have seen through the following instructions.
|
|
10
|
+
|
|
11
|
+
- Use reference audio <15s and leave some silence (e.g. 1s) at the end. Otherwise there is a risk of truncating in the middle of word, leading to suboptimal generation.
|
|
12
|
+
- Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
|
|
13
|
+
- Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses.
|
|
14
|
+
- Preprocess numbers to Chinese letters if you want to have them read in Chinese, otherwise in English.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
## Gradio App
|
|
18
|
+
|
|
19
|
+
Currently supported features:
|
|
20
|
+
|
|
21
|
+
- Basic TTS with Chunk Inference
|
|
22
|
+
- Multi-Style / Multi-Speaker Generation
|
|
23
|
+
- Voice Chat powered by Qwen2.5-3B-Instruct
|
|
24
|
+
|
|
25
|
+
The cli command `f5-tts_infer-gradio` equals to `python src/f5_tts/infer/infer_gradio.py`, which launches a Gradio APP (web interface) for inference.
|
|
26
|
+
|
|
27
|
+
The script will load model checkpoints from Huggingface. You can also manually download files and update the path to `load_model()` in `infer_gradio.py`. Currently only load TTS models first, will load ASR model to do transcription if `ref_text` not provided, will load LLM model if use Voice Chat.
|
|
28
|
+
|
|
29
|
+
Could also be used as a component for larger application.
|
|
30
|
+
```python
|
|
31
|
+
import gradio as gr
|
|
32
|
+
from f5_tts.infer.infer_gradio import app
|
|
33
|
+
|
|
34
|
+
with gr.Blocks() as main_app:
|
|
35
|
+
gr.Markdown("# This is an example of using F5-TTS within a bigger Gradio app")
|
|
36
|
+
|
|
37
|
+
# ... other Gradio components
|
|
38
|
+
|
|
39
|
+
app.render()
|
|
40
|
+
|
|
41
|
+
main_app.launch()
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
## CLI Inference
|
|
46
|
+
|
|
47
|
+
The cli command `f5-tts_infer-cli` equals to `python src/f5_tts/infer/infer_cli.py`, which is a command line tool for inference.
|
|
48
|
+
|
|
49
|
+
The script will load model checkpoints from Huggingface. You can also manually download files and use `--ckpt_file` to specify the model you want to load, or directly update in `infer_cli.py`.
|
|
50
|
+
|
|
51
|
+
For change vocab.txt use `--vocab_file` to provide your `vocab.txt` file.
|
|
52
|
+
|
|
53
|
+
Basically you can inference with flags:
|
|
54
|
+
```bash
|
|
55
|
+
# Leave --ref_text "" will have ASR model transcribe (extra GPU memory usage)
|
|
56
|
+
f5-tts_infer-cli \
|
|
57
|
+
--model "F5-TTS" \
|
|
58
|
+
--ref_audio "ref_audio.wav" \
|
|
59
|
+
--ref_text "The content, subtitle or transcription of reference audio." \
|
|
60
|
+
--gen_text "Some text you want TTS model generate for you."
|
|
61
|
+
|
|
62
|
+
# Choose Vocoder
|
|
63
|
+
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
|
64
|
+
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
And a `.toml` file would help with more flexible usage.
|
|
68
|
+
|
|
69
|
+
```bash
|
|
70
|
+
f5-tts_infer-cli -c custom.toml
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
For example, you can use `.toml` to pass in variables, refer to `src/f5_tts/infer/examples/basic/basic.toml`:
|
|
74
|
+
|
|
75
|
+
```toml
|
|
76
|
+
# F5-TTS | E2-TTS
|
|
77
|
+
model = "F5-TTS"
|
|
78
|
+
ref_audio = "infer/examples/basic/basic_ref_en.wav"
|
|
79
|
+
# If an empty "", transcribes the reference audio automatically.
|
|
80
|
+
ref_text = "Some call me nature, others call me mother nature."
|
|
81
|
+
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring."
|
|
82
|
+
# File with text to generate. Ignores the text above.
|
|
83
|
+
gen_file = ""
|
|
84
|
+
remove_silence = false
|
|
85
|
+
output_dir = "tests"
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
You can also leverage `.toml` file to do multi-style generation, refer to `src/f5_tts/infer/examples/multi/story.toml`.
|
|
89
|
+
|
|
90
|
+
```toml
|
|
91
|
+
# F5-TTS | E2-TTS
|
|
92
|
+
model = "F5-TTS"
|
|
93
|
+
ref_audio = "infer/examples/multi/main.flac"
|
|
94
|
+
# If an empty "", transcribes the reference audio automatically.
|
|
95
|
+
ref_text = ""
|
|
96
|
+
gen_text = ""
|
|
97
|
+
# File with text to generate. Ignores the text above.
|
|
98
|
+
gen_file = "infer/examples/multi/story.txt"
|
|
99
|
+
remove_silence = true
|
|
100
|
+
output_dir = "tests"
|
|
101
|
+
|
|
102
|
+
[voices.town]
|
|
103
|
+
ref_audio = "infer/examples/multi/town.flac"
|
|
104
|
+
ref_text = ""
|
|
105
|
+
|
|
106
|
+
[voices.country]
|
|
107
|
+
ref_audio = "infer/examples/multi/country.flac"
|
|
108
|
+
ref_text = ""
|
|
109
|
+
```
|
|
110
|
+
You should mark the voice with `[main]` `[town]` `[country]` whenever you want to change voice, refer to `src/f5_tts/infer/examples/multi/story.txt`.
|
|
111
|
+
|
|
112
|
+
## Speech Editing
|
|
113
|
+
|
|
114
|
+
To test speech editing capabilities, use the following command:
|
|
115
|
+
|
|
116
|
+
```bash
|
|
117
|
+
python src/f5_tts/infer/speech_edit.py
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
## Socket Realtime Client
|
|
121
|
+
|
|
122
|
+
To communicate with socket server you need to run
|
|
123
|
+
```bash
|
|
124
|
+
python src/f5_tts/socket_server.py
|
|
125
|
+
```
|
|
126
|
+
|
|
127
|
+
<details>
|
|
128
|
+
<summary>Then create client to communicate</summary>
|
|
129
|
+
|
|
130
|
+
``` python
|
|
131
|
+
import socket
|
|
132
|
+
import numpy as np
|
|
133
|
+
import asyncio
|
|
134
|
+
import pyaudio
|
|
135
|
+
|
|
136
|
+
async def listen_to_voice(text, server_ip='localhost', server_port=9999):
|
|
137
|
+
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
138
|
+
client_socket.connect((server_ip, server_port))
|
|
139
|
+
|
|
140
|
+
async def play_audio_stream():
|
|
141
|
+
buffer = b''
|
|
142
|
+
p = pyaudio.PyAudio()
|
|
143
|
+
stream = p.open(format=pyaudio.paFloat32,
|
|
144
|
+
channels=1,
|
|
145
|
+
rate=24000, # Ensure this matches the server's sampling rate
|
|
146
|
+
output=True,
|
|
147
|
+
frames_per_buffer=2048)
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
while True:
|
|
151
|
+
chunk = await asyncio.get_event_loop().run_in_executor(None, client_socket.recv, 1024)
|
|
152
|
+
if not chunk: # End of stream
|
|
153
|
+
break
|
|
154
|
+
if b"END_OF_AUDIO" in chunk:
|
|
155
|
+
buffer += chunk.replace(b"END_OF_AUDIO", b"")
|
|
156
|
+
if buffer:
|
|
157
|
+
audio_array = np.frombuffer(buffer, dtype=np.float32).copy() # Make a writable copy
|
|
158
|
+
stream.write(audio_array.tobytes())
|
|
159
|
+
break
|
|
160
|
+
buffer += chunk
|
|
161
|
+
if len(buffer) >= 4096:
|
|
162
|
+
audio_array = np.frombuffer(buffer[:4096], dtype=np.float32).copy() # Make a writable copy
|
|
163
|
+
stream.write(audio_array.tobytes())
|
|
164
|
+
buffer = buffer[4096:]
|
|
165
|
+
finally:
|
|
166
|
+
stream.stop_stream()
|
|
167
|
+
stream.close()
|
|
168
|
+
p.terminate()
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
# Send only the text to the server
|
|
172
|
+
await asyncio.get_event_loop().run_in_executor(None, client_socket.sendall, text.encode('utf-8'))
|
|
173
|
+
await play_audio_stream()
|
|
174
|
+
print("Audio playback finished.")
|
|
175
|
+
|
|
176
|
+
except Exception as e:
|
|
177
|
+
print(f"Error in listen_to_voice: {e}")
|
|
178
|
+
|
|
179
|
+
finally:
|
|
180
|
+
client_socket.close()
|
|
181
|
+
|
|
182
|
+
# Example usage: Replace this with your actual server IP and port
|
|
183
|
+
async def main():
|
|
184
|
+
await listen_to_voice("my name is jenny..", server_ip='localhost', server_port=9998)
|
|
185
|
+
|
|
186
|
+
# Run the main async function
|
|
187
|
+
asyncio.run(main())
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
</details>
|
|
191
|
+
|