xinference 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +5 -5
- xinference/core/model.py +6 -1
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +2 -1
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +172 -53
- xinference/model/llm/llm_family_modelscope.json +118 -20
- xinference/model/llm/mlx/core.py +230 -49
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +4 -1
- xinference/model/llm/vllm/core.py +5 -0
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/METADATA +33 -14
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/RECORD +85 -34
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -61,7 +61,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
61
61
|
|
|
62
62
|
def _load_model(self, **kwargs):
|
|
63
63
|
try:
|
|
64
|
-
from transformers import
|
|
64
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
65
65
|
except ImportError:
|
|
66
66
|
error_message = "Failed to import module 'transformers'"
|
|
67
67
|
installation_guide = [
|
|
@@ -77,7 +77,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
77
77
|
encode_special_tokens=True,
|
|
78
78
|
revision=kwargs["revision"],
|
|
79
79
|
)
|
|
80
|
-
model =
|
|
80
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
81
81
|
self.model_path,
|
|
82
82
|
**kwargs,
|
|
83
83
|
)
|
|
@@ -232,9 +232,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
232
232
|
content = {
|
|
233
233
|
"name": function_name,
|
|
234
234
|
"arguments": json.dumps(
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
235
|
+
(
|
|
236
|
+
arguments_json
|
|
237
|
+
if isinstance(arguments_json, dict)
|
|
238
|
+
else arguments
|
|
239
|
+
),
|
|
238
240
|
ensure_ascii=False,
|
|
239
241
|
),
|
|
240
242
|
}
|
|
@@ -331,6 +333,8 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
331
333
|
max_new_tokens = generate_config.get("max_tokens")
|
|
332
334
|
if max_new_tokens is not None:
|
|
333
335
|
kwargs["max_new_tokens"] = int(max_new_tokens)
|
|
336
|
+
else:
|
|
337
|
+
kwargs["max_new_tokens"] = 1024
|
|
334
338
|
do_sample = generate_config.get("do_sample")
|
|
335
339
|
if do_sample is not None:
|
|
336
340
|
kwargs["do_sample"] = bool(do_sample)
|
|
@@ -156,6 +156,7 @@ def _get_completion(
|
|
|
156
156
|
finish_reason: Optional[str],
|
|
157
157
|
model_uid: str,
|
|
158
158
|
r: InferenceRequest,
|
|
159
|
+
completion_tokens: int,
|
|
159
160
|
):
|
|
160
161
|
completion_choice = CompletionChoice(
|
|
161
162
|
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
@@ -170,8 +171,8 @@ def _get_completion(
|
|
|
170
171
|
)
|
|
171
172
|
completion_usage = CompletionUsage(
|
|
172
173
|
prompt_tokens=len(r.prompt_tokens),
|
|
173
|
-
completion_tokens=
|
|
174
|
-
total_tokens=len(r.prompt_tokens) +
|
|
174
|
+
completion_tokens=completion_tokens,
|
|
175
|
+
total_tokens=len(r.prompt_tokens) + completion_tokens,
|
|
175
176
|
)
|
|
176
177
|
completion = Completion(
|
|
177
178
|
id=completion_chunk["id"],
|
|
@@ -371,7 +372,7 @@ def _batch_inference_one_step_internal(
|
|
|
371
372
|
r.stopped = stopped
|
|
372
373
|
r.finish_reason = finish_reason
|
|
373
374
|
|
|
374
|
-
if r.stopped and r not in stop_token_mapping
|
|
375
|
+
if r.stopped and r not in stop_token_mapping:
|
|
375
376
|
stop_token_mapping[r] = _i + 1
|
|
376
377
|
|
|
377
378
|
if r.stream:
|
|
@@ -446,12 +447,14 @@ def _batch_inference_one_step_internal(
|
|
|
446
447
|
else:
|
|
447
448
|
# last round, handle non-stream result
|
|
448
449
|
if r.stopped and _i == decode_round - 1:
|
|
449
|
-
invalid_token_num =
|
|
450
|
+
invalid_token_num = (
|
|
451
|
+
(decode_round - stop_token_mapping[r] + 1)
|
|
452
|
+
if r.finish_reason == "stop"
|
|
453
|
+
else (decode_round - stop_token_mapping[r])
|
|
454
|
+
)
|
|
450
455
|
outputs = (
|
|
451
456
|
tokenizer.decode(
|
|
452
|
-
r.new_tokens[
|
|
453
|
-
if r.finish_reason == "stop"
|
|
454
|
-
else r.new_tokens[:-invalid_token_num],
|
|
457
|
+
r.new_tokens[:-invalid_token_num],
|
|
455
458
|
skip_special_tokens=True,
|
|
456
459
|
spaces_between_special_tokens=False,
|
|
457
460
|
clean_up_tokenization_spaces=True,
|
|
@@ -460,7 +463,12 @@ def _batch_inference_one_step_internal(
|
|
|
460
463
|
else output_mapping[r]
|
|
461
464
|
)
|
|
462
465
|
completion = _get_completion(
|
|
463
|
-
outputs,
|
|
466
|
+
outputs,
|
|
467
|
+
r.chunk_id,
|
|
468
|
+
r.finish_reason,
|
|
469
|
+
model_uid,
|
|
470
|
+
r,
|
|
471
|
+
len(r.new_tokens) - invalid_token_num,
|
|
464
472
|
)
|
|
465
473
|
r.completion = [completion]
|
|
466
474
|
|
xinference/model/llm/utils.py
CHANGED
|
@@ -324,7 +324,10 @@ class ChatModelMixin:
|
|
|
324
324
|
"""
|
|
325
325
|
try:
|
|
326
326
|
if isinstance(c, dict):
|
|
327
|
-
|
|
327
|
+
try:
|
|
328
|
+
return [(None, c["name"], json.loads(c["arguments"]))]
|
|
329
|
+
except Exception:
|
|
330
|
+
return [(None, c["name"], c["arguments"])]
|
|
328
331
|
except KeyError:
|
|
329
332
|
logger.error("Can't parse glm output: %s", c)
|
|
330
333
|
return [(str(c), None, None)]
|
|
@@ -86,6 +86,7 @@ class VLLMGenerateConfig(TypedDict, total=False):
|
|
|
86
86
|
stop: Optional[Union[str, List[str]]]
|
|
87
87
|
stream: bool # non-sampling param, should not be passed to the engine.
|
|
88
88
|
stream_options: Optional[Union[dict, None]]
|
|
89
|
+
skip_special_tokens: Optional[bool]
|
|
89
90
|
response_format: Optional[dict]
|
|
90
91
|
guided_json: Optional[Union[str, dict]]
|
|
91
92
|
guided_regex: Optional[str]
|
|
@@ -181,6 +182,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
|
|
|
181
182
|
if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
|
|
182
183
|
VLLM_SUPPORTED_MODELS.append("llama-3.1")
|
|
183
184
|
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
|
|
185
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.3-instruct")
|
|
184
186
|
|
|
185
187
|
if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
|
|
186
188
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
|
|
@@ -373,6 +375,9 @@ class VLLMModel(LLM):
|
|
|
373
375
|
sanitized.setdefault(
|
|
374
376
|
"stream_options", generate_config.get("stream_options", None)
|
|
375
377
|
)
|
|
378
|
+
sanitized.setdefault(
|
|
379
|
+
"skip_special_tokens", generate_config.get("skip_special_tokens", True)
|
|
380
|
+
)
|
|
376
381
|
sanitized.setdefault(
|
|
377
382
|
"guided_json", generate_config.get("guided_json", guided_json)
|
|
378
383
|
)
|
|
File without changes
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import sys
|
|
3
|
+
from importlib.resources import files
|
|
4
|
+
|
|
5
|
+
import soundfile as sf
|
|
6
|
+
import tqdm
|
|
7
|
+
from cached_path import cached_path
|
|
8
|
+
|
|
9
|
+
from f5_tts.infer.utils_infer import (
|
|
10
|
+
hop_length,
|
|
11
|
+
infer_process,
|
|
12
|
+
load_model,
|
|
13
|
+
load_vocoder,
|
|
14
|
+
preprocess_ref_audio_text,
|
|
15
|
+
remove_silence_for_generated_wav,
|
|
16
|
+
save_spectrogram,
|
|
17
|
+
transcribe,
|
|
18
|
+
target_sample_rate,
|
|
19
|
+
)
|
|
20
|
+
from f5_tts.model import DiT, UNetT
|
|
21
|
+
from f5_tts.model.utils import seed_everything
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class F5TTS:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
model_type="F5-TTS",
|
|
28
|
+
ckpt_file="",
|
|
29
|
+
vocab_file="",
|
|
30
|
+
ode_method="euler",
|
|
31
|
+
use_ema=True,
|
|
32
|
+
vocoder_name="vocos",
|
|
33
|
+
local_path=None,
|
|
34
|
+
device=None,
|
|
35
|
+
hf_cache_dir=None,
|
|
36
|
+
):
|
|
37
|
+
# Initialize parameters
|
|
38
|
+
self.final_wave = None
|
|
39
|
+
self.target_sample_rate = target_sample_rate
|
|
40
|
+
self.hop_length = hop_length
|
|
41
|
+
self.seed = -1
|
|
42
|
+
self.mel_spec_type = vocoder_name
|
|
43
|
+
|
|
44
|
+
# Set device
|
|
45
|
+
if device is not None:
|
|
46
|
+
self.device = device
|
|
47
|
+
else:
|
|
48
|
+
import torch
|
|
49
|
+
|
|
50
|
+
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
|
51
|
+
|
|
52
|
+
# Load models
|
|
53
|
+
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
|
|
54
|
+
self.load_ema_model(
|
|
55
|
+
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
|
|
59
|
+
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
|
|
60
|
+
|
|
61
|
+
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
|
|
62
|
+
if model_type == "F5-TTS":
|
|
63
|
+
if not ckpt_file:
|
|
64
|
+
if mel_spec_type == "vocos":
|
|
65
|
+
ckpt_file = str(
|
|
66
|
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
|
67
|
+
)
|
|
68
|
+
elif mel_spec_type == "bigvgan":
|
|
69
|
+
ckpt_file = str(
|
|
70
|
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
|
|
71
|
+
)
|
|
72
|
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
|
73
|
+
model_cls = DiT
|
|
74
|
+
elif model_type == "E2-TTS":
|
|
75
|
+
if not ckpt_file:
|
|
76
|
+
ckpt_file = str(
|
|
77
|
+
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
|
|
78
|
+
)
|
|
79
|
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
|
80
|
+
model_cls = UNetT
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Unknown model type: {model_type}")
|
|
83
|
+
|
|
84
|
+
self.ema_model = load_model(
|
|
85
|
+
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def transcribe(self, ref_audio, language=None):
|
|
89
|
+
return transcribe(ref_audio, language)
|
|
90
|
+
|
|
91
|
+
def export_wav(self, wav, file_wave, remove_silence=False):
|
|
92
|
+
sf.write(file_wave, wav, self.target_sample_rate)
|
|
93
|
+
|
|
94
|
+
if remove_silence:
|
|
95
|
+
remove_silence_for_generated_wav(file_wave)
|
|
96
|
+
|
|
97
|
+
def export_spectrogram(self, spect, file_spect):
|
|
98
|
+
save_spectrogram(spect, file_spect)
|
|
99
|
+
|
|
100
|
+
def infer(
|
|
101
|
+
self,
|
|
102
|
+
ref_file,
|
|
103
|
+
ref_text,
|
|
104
|
+
gen_text,
|
|
105
|
+
show_info=print,
|
|
106
|
+
progress=tqdm,
|
|
107
|
+
target_rms=0.1,
|
|
108
|
+
cross_fade_duration=0.15,
|
|
109
|
+
sway_sampling_coef=-1,
|
|
110
|
+
cfg_strength=2,
|
|
111
|
+
nfe_step=32,
|
|
112
|
+
speed=1.0,
|
|
113
|
+
fix_duration=None,
|
|
114
|
+
remove_silence=False,
|
|
115
|
+
file_wave=None,
|
|
116
|
+
file_spect=None,
|
|
117
|
+
seed=-1,
|
|
118
|
+
):
|
|
119
|
+
if seed == -1:
|
|
120
|
+
seed = random.randint(0, sys.maxsize)
|
|
121
|
+
seed_everything(seed)
|
|
122
|
+
self.seed = seed
|
|
123
|
+
|
|
124
|
+
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
|
|
125
|
+
|
|
126
|
+
wav, sr, spect = infer_process(
|
|
127
|
+
ref_file,
|
|
128
|
+
ref_text,
|
|
129
|
+
gen_text,
|
|
130
|
+
self.ema_model,
|
|
131
|
+
self.vocoder,
|
|
132
|
+
self.mel_spec_type,
|
|
133
|
+
show_info=show_info,
|
|
134
|
+
progress=progress,
|
|
135
|
+
target_rms=target_rms,
|
|
136
|
+
cross_fade_duration=cross_fade_duration,
|
|
137
|
+
nfe_step=nfe_step,
|
|
138
|
+
cfg_strength=cfg_strength,
|
|
139
|
+
sway_sampling_coef=sway_sampling_coef,
|
|
140
|
+
speed=speed,
|
|
141
|
+
fix_duration=fix_duration,
|
|
142
|
+
device=self.device,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if file_wave is not None:
|
|
146
|
+
self.export_wav(wav, file_wave, remove_silence)
|
|
147
|
+
|
|
148
|
+
if file_spect is not None:
|
|
149
|
+
self.export_spectrogram(spect, file_spect)
|
|
150
|
+
|
|
151
|
+
return wav, sr, spect
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
f5tts = F5TTS()
|
|
156
|
+
|
|
157
|
+
wav, sr, spect = f5tts.infer(
|
|
158
|
+
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
|
|
159
|
+
ref_text="some call me nature, others call me mother nature.",
|
|
160
|
+
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
|
|
161
|
+
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
|
|
162
|
+
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
|
|
163
|
+
seed=-1, # random seed = -1
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
print("seed :", f5tts.seed)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN # dataset name
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0 # gradient clipping
|
|
18
|
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: E2TTS_Base
|
|
22
|
+
tokenizer: pinyin
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 1024
|
|
26
|
+
depth: 24
|
|
27
|
+
heads: 16
|
|
28
|
+
ff_mult: 4
|
|
29
|
+
mel_spec:
|
|
30
|
+
target_sample_rate: 24000
|
|
31
|
+
n_mel_channels: 100
|
|
32
|
+
hop_length: 256
|
|
33
|
+
win_length: 1024
|
|
34
|
+
n_fft: 1024
|
|
35
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
36
|
+
vocoder:
|
|
37
|
+
is_local: False # use local offline ckpt or not
|
|
38
|
+
local_path: None # local vocoder path
|
|
39
|
+
|
|
40
|
+
ckpts:
|
|
41
|
+
logger: wandb # wandb | tensorboard | None
|
|
42
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
43
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
44
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0
|
|
18
|
+
bnb_optimizer: False
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: E2TTS_Small
|
|
22
|
+
tokenizer: pinyin
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 768
|
|
26
|
+
depth: 20
|
|
27
|
+
heads: 12
|
|
28
|
+
ff_mult: 4
|
|
29
|
+
mel_spec:
|
|
30
|
+
target_sample_rate: 24000
|
|
31
|
+
n_mel_channels: 100
|
|
32
|
+
hop_length: 256
|
|
33
|
+
win_length: 1024
|
|
34
|
+
n_fft: 1024
|
|
35
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
36
|
+
vocoder:
|
|
37
|
+
is_local: False # use local offline ckpt or not
|
|
38
|
+
local_path: None # local vocoder path
|
|
39
|
+
|
|
40
|
+
ckpts:
|
|
41
|
+
logger: wandb # wandb | tensorboard | None
|
|
42
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
43
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
44
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN # dataset name
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0 # gradient clipping
|
|
18
|
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: F5TTS_Base # model name
|
|
22
|
+
tokenizer: pinyin # tokenizer type
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 1024
|
|
26
|
+
depth: 22
|
|
27
|
+
heads: 16
|
|
28
|
+
ff_mult: 2
|
|
29
|
+
text_dim: 512
|
|
30
|
+
conv_layers: 4
|
|
31
|
+
mel_spec:
|
|
32
|
+
target_sample_rate: 24000
|
|
33
|
+
n_mel_channels: 100
|
|
34
|
+
hop_length: 256
|
|
35
|
+
win_length: 1024
|
|
36
|
+
n_fft: 1024
|
|
37
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
38
|
+
vocoder:
|
|
39
|
+
is_local: False # use local offline ckpt or not
|
|
40
|
+
local_path: None # local vocoder path
|
|
41
|
+
|
|
42
|
+
ckpts:
|
|
43
|
+
logger: wandb # wandb | tensorboard | None
|
|
44
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
45
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
46
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
hydra:
|
|
2
|
+
run:
|
|
3
|
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
4
|
+
|
|
5
|
+
datasets:
|
|
6
|
+
name: Emilia_ZH_EN
|
|
7
|
+
batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
|
|
8
|
+
batch_size_type: frame # "frame" or "sample"
|
|
9
|
+
max_samples: 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
|
|
10
|
+
num_workers: 16
|
|
11
|
+
|
|
12
|
+
optim:
|
|
13
|
+
epochs: 15
|
|
14
|
+
learning_rate: 7.5e-5
|
|
15
|
+
num_warmup_updates: 20000 # warmup steps
|
|
16
|
+
grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
|
|
17
|
+
max_grad_norm: 1.0 # gradient clipping
|
|
18
|
+
bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
|
|
19
|
+
|
|
20
|
+
model:
|
|
21
|
+
name: F5TTS_Small
|
|
22
|
+
tokenizer: pinyin
|
|
23
|
+
tokenizer_path: None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt)
|
|
24
|
+
arch:
|
|
25
|
+
dim: 768
|
|
26
|
+
depth: 18
|
|
27
|
+
heads: 12
|
|
28
|
+
ff_mult: 2
|
|
29
|
+
text_dim: 512
|
|
30
|
+
conv_layers: 4
|
|
31
|
+
mel_spec:
|
|
32
|
+
target_sample_rate: 24000
|
|
33
|
+
n_mel_channels: 100
|
|
34
|
+
hop_length: 256
|
|
35
|
+
win_length: 1024
|
|
36
|
+
n_fft: 1024
|
|
37
|
+
mel_spec_type: vocos # 'vocos' or 'bigvgan'
|
|
38
|
+
vocoder:
|
|
39
|
+
is_local: False # use local offline ckpt or not
|
|
40
|
+
local_path: None # local vocoder path
|
|
41
|
+
|
|
42
|
+
ckpts:
|
|
43
|
+
logger: wandb # wandb | tensorboard | None
|
|
44
|
+
save_per_updates: 50000 # save checkpoint per steps
|
|
45
|
+
last_per_steps: 5000 # save last checkpoint per steps
|
|
46
|
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
|
|
2
|
+
# Evaluation
|
|
3
|
+
|
|
4
|
+
Install packages for evaluation:
|
|
5
|
+
|
|
6
|
+
```bash
|
|
7
|
+
pip install -e .[eval]
|
|
8
|
+
```
|
|
9
|
+
|
|
10
|
+
## Generating Samples for Evaluation
|
|
11
|
+
|
|
12
|
+
### Prepare Test Datasets
|
|
13
|
+
|
|
14
|
+
1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
|
|
15
|
+
2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
|
|
16
|
+
3. Unzip the downloaded datasets and place them in the `data/` directory.
|
|
17
|
+
4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
|
|
18
|
+
5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
|
|
19
|
+
|
|
20
|
+
### Batch Inference for Test Set
|
|
21
|
+
|
|
22
|
+
To run batch inference for evaluations, execute the following commands:
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
# batch inference for evaluations
|
|
26
|
+
accelerate config # if not set before
|
|
27
|
+
bash src/f5_tts/eval/eval_infer_batch.sh
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Objective Evaluation on Generated Results
|
|
31
|
+
|
|
32
|
+
### Download Evaluation Model Checkpoints
|
|
33
|
+
|
|
34
|
+
1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
|
|
35
|
+
2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
|
|
36
|
+
3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
|
|
37
|
+
|
|
38
|
+
Then update in the following scripts with the paths you put evaluation model ckpts to.
|
|
39
|
+
|
|
40
|
+
### Objective Evaluation
|
|
41
|
+
|
|
42
|
+
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
|
43
|
+
```bash
|
|
44
|
+
# Evaluation for Seed-TTS test set
|
|
45
|
+
python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
|
|
46
|
+
|
|
47
|
+
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
|
|
48
|
+
python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
|
|
49
|
+
```
|