xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +28 -6
- xinference/core/utils.py +10 -6
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/core.py +10 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +200 -0
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +36 -111
- xinference/model/audio/model_spec.json +27 -3
- xinference/model/audio/model_spec_modelscope.json +18 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/embedding/core.py +203 -142
- xinference/model/embedding/model_spec.json +7 -0
- xinference/model/embedding/model_spec_modelscope.json +8 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +127 -4
- xinference/model/image/model_spec_modelscope.json +130 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -2
- xinference/model/llm/llm_family.json +219 -53
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +167 -20
- xinference/model/llm/mlx/core.py +287 -51
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/qwen2_vl.py +2 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +5 -1
- xinference/model/llm/vllm/core.py +16 -2
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
sys.path.append(os.getcwd())
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from importlib.resources import files
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
import soundfile as sf
|
|
11
|
+
from datasets.arrow_writer import ArrowWriter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def main():
|
|
15
|
+
result = []
|
|
16
|
+
duration_list = []
|
|
17
|
+
text_vocab_set = set()
|
|
18
|
+
|
|
19
|
+
with open(meta_info, "r") as f:
|
|
20
|
+
lines = f.readlines()
|
|
21
|
+
for line in tqdm(lines):
|
|
22
|
+
uttr, text, norm_text = line.split("|")
|
|
23
|
+
norm_text = norm_text.strip()
|
|
24
|
+
wav_path = Path(dataset_dir) / "wavs" / f"{uttr}.wav"
|
|
25
|
+
duration = sf.info(wav_path).duration
|
|
26
|
+
if duration < 0.4 or duration > 30:
|
|
27
|
+
continue
|
|
28
|
+
result.append({"audio_path": str(wav_path), "text": norm_text, "duration": duration})
|
|
29
|
+
duration_list.append(duration)
|
|
30
|
+
text_vocab_set.update(list(norm_text))
|
|
31
|
+
|
|
32
|
+
# save preprocessed dataset to disk
|
|
33
|
+
if not os.path.exists(f"{save_dir}"):
|
|
34
|
+
os.makedirs(f"{save_dir}")
|
|
35
|
+
print(f"\nSaving to {save_dir} ...")
|
|
36
|
+
|
|
37
|
+
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
|
38
|
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
|
39
|
+
writer.write(line)
|
|
40
|
+
|
|
41
|
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
|
42
|
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
43
|
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
44
|
+
|
|
45
|
+
# vocab map, i.e. tokenizer
|
|
46
|
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
|
47
|
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
|
48
|
+
for vocab in sorted(text_vocab_set):
|
|
49
|
+
f.write(vocab + "\n")
|
|
50
|
+
|
|
51
|
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
|
52
|
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
|
53
|
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
if __name__ == "__main__":
|
|
57
|
+
tokenizer = "char" # "pinyin" | "char"
|
|
58
|
+
|
|
59
|
+
dataset_dir = "<SOME_PATH>/LJSpeech-1.1"
|
|
60
|
+
dataset_name = f"LJSpeech_{tokenizer}"
|
|
61
|
+
meta_info = os.path.join(dataset_dir, "metadata.csv")
|
|
62
|
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
|
63
|
+
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
|
64
|
+
|
|
65
|
+
main()
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# generate audio text map for WenetSpeech4TTS
|
|
2
|
+
# evaluate for vocab size
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
sys.path.append(os.getcwd())
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
11
|
+
from importlib.resources import files
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
import torchaudio
|
|
15
|
+
from datasets import Dataset
|
|
16
|
+
|
|
17
|
+
from f5_tts.model.utils import convert_char_to_pinyin
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def deal_with_sub_path_files(dataset_path, sub_path):
|
|
21
|
+
print(f"Dealing with: {sub_path}")
|
|
22
|
+
|
|
23
|
+
text_dir = os.path.join(dataset_path, sub_path, "txts")
|
|
24
|
+
audio_dir = os.path.join(dataset_path, sub_path, "wavs")
|
|
25
|
+
text_files = os.listdir(text_dir)
|
|
26
|
+
|
|
27
|
+
audio_paths, texts, durations = [], [], []
|
|
28
|
+
for text_file in tqdm(text_files):
|
|
29
|
+
with open(os.path.join(text_dir, text_file), "r", encoding="utf-8") as file:
|
|
30
|
+
first_line = file.readline().split("\t")
|
|
31
|
+
audio_nm = first_line[0]
|
|
32
|
+
audio_path = os.path.join(audio_dir, audio_nm + ".wav")
|
|
33
|
+
text = first_line[1].strip()
|
|
34
|
+
|
|
35
|
+
audio_paths.append(audio_path)
|
|
36
|
+
|
|
37
|
+
if tokenizer == "pinyin":
|
|
38
|
+
texts.extend(convert_char_to_pinyin([text], polyphone=polyphone))
|
|
39
|
+
elif tokenizer == "char":
|
|
40
|
+
texts.append(text)
|
|
41
|
+
|
|
42
|
+
audio, sample_rate = torchaudio.load(audio_path)
|
|
43
|
+
durations.append(audio.shape[-1] / sample_rate)
|
|
44
|
+
|
|
45
|
+
return audio_paths, texts, durations
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def main():
|
|
49
|
+
assert tokenizer in ["pinyin", "char"]
|
|
50
|
+
|
|
51
|
+
audio_path_list, text_list, duration_list = [], [], []
|
|
52
|
+
|
|
53
|
+
executor = ProcessPoolExecutor(max_workers=max_workers)
|
|
54
|
+
futures = []
|
|
55
|
+
for dataset_path in dataset_paths:
|
|
56
|
+
sub_items = os.listdir(dataset_path)
|
|
57
|
+
sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
|
|
58
|
+
for sub_path in sub_paths:
|
|
59
|
+
futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
|
|
60
|
+
for future in tqdm(futures, total=len(futures)):
|
|
61
|
+
audio_paths, texts, durations = future.result()
|
|
62
|
+
audio_path_list.extend(audio_paths)
|
|
63
|
+
text_list.extend(texts)
|
|
64
|
+
duration_list.extend(durations)
|
|
65
|
+
executor.shutdown()
|
|
66
|
+
|
|
67
|
+
if not os.path.exists("data"):
|
|
68
|
+
os.makedirs("data")
|
|
69
|
+
|
|
70
|
+
print(f"\nSaving to {save_dir} ...")
|
|
71
|
+
dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
|
|
72
|
+
dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB") # arrow format
|
|
73
|
+
|
|
74
|
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
75
|
+
json.dump(
|
|
76
|
+
{"duration": duration_list}, f, ensure_ascii=False
|
|
77
|
+
) # dup a json separately saving duration in case for DynamicBatchSampler ease
|
|
78
|
+
|
|
79
|
+
print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
|
|
80
|
+
text_vocab_set = set()
|
|
81
|
+
for text in tqdm(text_list):
|
|
82
|
+
text_vocab_set.update(list(text))
|
|
83
|
+
|
|
84
|
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
|
85
|
+
if tokenizer == "pinyin":
|
|
86
|
+
text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
|
87
|
+
|
|
88
|
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
|
89
|
+
for vocab in sorted(text_vocab_set):
|
|
90
|
+
f.write(vocab + "\n")
|
|
91
|
+
print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
|
|
92
|
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
if __name__ == "__main__":
|
|
96
|
+
max_workers = 32
|
|
97
|
+
|
|
98
|
+
tokenizer = "pinyin" # "pinyin" | "char"
|
|
99
|
+
polyphone = True
|
|
100
|
+
dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
|
|
101
|
+
|
|
102
|
+
dataset_name = (
|
|
103
|
+
["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice - 1]
|
|
104
|
+
+ "_"
|
|
105
|
+
+ tokenizer
|
|
106
|
+
)
|
|
107
|
+
dataset_paths = [
|
|
108
|
+
"<SOME_PATH>/WenetSpeech4TTS/Basic",
|
|
109
|
+
"<SOME_PATH>/WenetSpeech4TTS/Standard",
|
|
110
|
+
"<SOME_PATH>/WenetSpeech4TTS/Premium",
|
|
111
|
+
][-dataset_choice:]
|
|
112
|
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
|
113
|
+
print(f"\nChoose Dataset: {dataset_name}, will save to {save_dir}\n")
|
|
114
|
+
|
|
115
|
+
main()
|
|
116
|
+
|
|
117
|
+
# Results (if adding alphabets with accents and symbols):
|
|
118
|
+
# WenetSpeech4TTS Basic Standard Premium
|
|
119
|
+
# samples count 3932473 1941220 407494
|
|
120
|
+
# pinyin vocab size 1349 1348 1344 (no polyphone)
|
|
121
|
+
# - - 1459 (polyphone)
|
|
122
|
+
# char vocab size 5264 5219 5042
|
|
123
|
+
|
|
124
|
+
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
|
125
|
+
# please be careful if using pretrained model, make sure the vocab.txt is same
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
|
|
5
|
+
from cached_path import cached_path
|
|
6
|
+
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
|
7
|
+
from f5_tts.model.utils import get_tokenizer
|
|
8
|
+
from f5_tts.model.dataset import load_dataset
|
|
9
|
+
from importlib.resources import files
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# -------------------------- Dataset Settings --------------------------- #
|
|
13
|
+
target_sample_rate = 24000
|
|
14
|
+
n_mel_channels = 100
|
|
15
|
+
hop_length = 256
|
|
16
|
+
win_length = 1024
|
|
17
|
+
n_fft = 1024
|
|
18
|
+
mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# -------------------------- Argument Parsing --------------------------- #
|
|
22
|
+
def parse_args():
|
|
23
|
+
# batch_size_per_gpu = 1000 settting for gpu 8GB
|
|
24
|
+
# batch_size_per_gpu = 1600 settting for gpu 12GB
|
|
25
|
+
# batch_size_per_gpu = 2000 settting for gpu 16GB
|
|
26
|
+
# batch_size_per_gpu = 3200 settting for gpu 24GB
|
|
27
|
+
|
|
28
|
+
# num_warmup_updates = 300 for 5000 sample about 10 hours
|
|
29
|
+
|
|
30
|
+
# change save_per_updates , last_per_steps change this value what you need ,
|
|
31
|
+
|
|
32
|
+
parser = argparse.ArgumentParser(description="Train CFM Model")
|
|
33
|
+
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
|
|
38
|
+
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
|
|
39
|
+
parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
|
|
40
|
+
parser.add_argument(
|
|
41
|
+
"--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
|
|
42
|
+
)
|
|
43
|
+
parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
|
|
44
|
+
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
|
|
45
|
+
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
|
|
46
|
+
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
|
47
|
+
parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
|
|
48
|
+
parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
|
|
49
|
+
parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
|
|
50
|
+
parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
|
|
51
|
+
parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
|
|
54
|
+
)
|
|
55
|
+
parser.add_argument(
|
|
56
|
+
"--tokenizer_path",
|
|
57
|
+
type=str,
|
|
58
|
+
default=None,
|
|
59
|
+
help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
|
|
60
|
+
)
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"--log_samples",
|
|
63
|
+
type=bool,
|
|
64
|
+
default=False,
|
|
65
|
+
help="Log inferenced samples per ckpt save steps",
|
|
66
|
+
)
|
|
67
|
+
parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--bnb_optimizer",
|
|
70
|
+
type=bool,
|
|
71
|
+
default=False,
|
|
72
|
+
help="Use 8-bit Adam optimizer from bitsandbytes",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return parser.parse_args()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# -------------------------- Training Settings -------------------------- #
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def main():
|
|
82
|
+
args = parse_args()
|
|
83
|
+
|
|
84
|
+
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
|
85
|
+
|
|
86
|
+
# Model parameters based on experiment name
|
|
87
|
+
if args.exp_name == "F5TTS_Base":
|
|
88
|
+
wandb_resume_id = None
|
|
89
|
+
model_cls = DiT
|
|
90
|
+
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
|
91
|
+
if args.finetune:
|
|
92
|
+
if args.pretrain is None:
|
|
93
|
+
ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
|
|
94
|
+
else:
|
|
95
|
+
ckpt_path = args.pretrain
|
|
96
|
+
elif args.exp_name == "E2TTS_Base":
|
|
97
|
+
wandb_resume_id = None
|
|
98
|
+
model_cls = UNetT
|
|
99
|
+
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
|
100
|
+
if args.finetune:
|
|
101
|
+
if args.pretrain is None:
|
|
102
|
+
ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
|
|
103
|
+
else:
|
|
104
|
+
ckpt_path = args.pretrain
|
|
105
|
+
|
|
106
|
+
if args.finetune:
|
|
107
|
+
if not os.path.isdir(checkpoint_path):
|
|
108
|
+
os.makedirs(checkpoint_path, exist_ok=True)
|
|
109
|
+
|
|
110
|
+
file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
|
|
111
|
+
if not os.path.isfile(file_checkpoint):
|
|
112
|
+
shutil.copy2(ckpt_path, file_checkpoint)
|
|
113
|
+
print("copy checkpoint for finetune")
|
|
114
|
+
|
|
115
|
+
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
|
116
|
+
tokenizer = args.tokenizer
|
|
117
|
+
if tokenizer == "custom":
|
|
118
|
+
if not args.tokenizer_path:
|
|
119
|
+
raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
|
|
120
|
+
tokenizer_path = args.tokenizer_path
|
|
121
|
+
else:
|
|
122
|
+
tokenizer_path = args.dataset_name
|
|
123
|
+
|
|
124
|
+
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
|
|
125
|
+
|
|
126
|
+
print("\nvocab : ", vocab_size)
|
|
127
|
+
print("\nvocoder : ", mel_spec_type)
|
|
128
|
+
|
|
129
|
+
mel_spec_kwargs = dict(
|
|
130
|
+
n_fft=n_fft,
|
|
131
|
+
hop_length=hop_length,
|
|
132
|
+
win_length=win_length,
|
|
133
|
+
n_mel_channels=n_mel_channels,
|
|
134
|
+
target_sample_rate=target_sample_rate,
|
|
135
|
+
mel_spec_type=mel_spec_type,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
model = CFM(
|
|
139
|
+
transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
|
|
140
|
+
mel_spec_kwargs=mel_spec_kwargs,
|
|
141
|
+
vocab_char_map=vocab_char_map,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
trainer = Trainer(
|
|
145
|
+
model,
|
|
146
|
+
args.epochs,
|
|
147
|
+
args.learning_rate,
|
|
148
|
+
num_warmup_updates=args.num_warmup_updates,
|
|
149
|
+
save_per_updates=args.save_per_updates,
|
|
150
|
+
checkpoint_path=checkpoint_path,
|
|
151
|
+
batch_size=args.batch_size_per_gpu,
|
|
152
|
+
batch_size_type=args.batch_size_type,
|
|
153
|
+
max_samples=args.max_samples,
|
|
154
|
+
grad_accumulation_steps=args.grad_accumulation_steps,
|
|
155
|
+
max_grad_norm=args.max_grad_norm,
|
|
156
|
+
logger=args.logger,
|
|
157
|
+
wandb_project=args.dataset_name,
|
|
158
|
+
wandb_run_name=args.exp_name,
|
|
159
|
+
wandb_resume_id=wandb_resume_id,
|
|
160
|
+
log_samples=args.log_samples,
|
|
161
|
+
last_per_steps=args.last_per_steps,
|
|
162
|
+
bnb_optimizer=args.bnb_optimizer,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
|
|
166
|
+
|
|
167
|
+
trainer.train(
|
|
168
|
+
train_dataset,
|
|
169
|
+
resumable_with_seed=666, # seed for shuffling dataset
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
if __name__ == "__main__":
|
|
174
|
+
main()
|