xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +22 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +91 -6
- xinference/client/restful/restful_client.py +39 -0
- xinference/core/model.py +41 -13
- xinference/deploy/cmdline.py +3 -1
- xinference/deploy/test/test_cmdline.py +56 -0
- xinference/isolation.py +24 -0
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +26 -4
- xinference/model/audio/f5tts.py +195 -0
- xinference/model/audio/fish_speech.py +71 -35
- xinference/model/audio/model_spec.json +88 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +322 -6
- xinference/model/embedding/model_spec.json +8 -1
- xinference/model/embedding/model_spec_modelscope.json +9 -1
- xinference/model/llm/__init__.py +4 -2
- xinference/model/llm/llm_family.json +479 -53
- xinference/model/llm/llm_family_modelscope.json +423 -17
- xinference/model/llm/mlx/core.py +230 -50
- xinference/model/llm/sglang/core.py +2 -0
- xinference/model/llm/transformers/chatglm.py +9 -5
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/transformers/utils.py +16 -8
- xinference/model/llm/utils.py +23 -1
- xinference/model/llm/vllm/core.py +89 -2
- xinference/thirdparty/f5_tts/__init__.py +0 -0
- xinference/thirdparty/f5_tts/api.py +166 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
- xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
- xinference/thirdparty/f5_tts/eval/README.md +49 -0
- xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
- xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
- xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
- xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
- xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
- xinference/thirdparty/f5_tts/infer/README.md +191 -0
- xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
- xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
- xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
- xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
- xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
- xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
- xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
- xinference/thirdparty/f5_tts/model/__init__.py +10 -0
- xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
- xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
- xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
- xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
- xinference/thirdparty/f5_tts/model/cfm.py +285 -0
- xinference/thirdparty/f5_tts/model/dataset.py +319 -0
- xinference/thirdparty/f5_tts/model/modules.py +658 -0
- xinference/thirdparty/f5_tts/model/trainer.py +366 -0
- xinference/thirdparty/f5_tts/model/utils.py +185 -0
- xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
- xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
- xinference/thirdparty/f5_tts/socket_server.py +159 -0
- xinference/thirdparty/f5_tts/train/README.md +77 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
- xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
- xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
- xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
- xinference/thirdparty/f5_tts/train/train.py +75 -0
- xinference/types.py +2 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
- /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import socket
|
|
2
|
+
import struct
|
|
3
|
+
import torch
|
|
4
|
+
import torchaudio
|
|
5
|
+
from threading import Thread
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
import traceback
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
|
|
13
|
+
from model.backbones.dit import DiT
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TTSStreamingProcessor:
|
|
17
|
+
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
|
|
18
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
19
|
+
|
|
20
|
+
# Load the model using the provided checkpoint and vocab files
|
|
21
|
+
self.model = load_model(
|
|
22
|
+
model_cls=DiT,
|
|
23
|
+
model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
|
|
24
|
+
ckpt_path=ckpt_file,
|
|
25
|
+
mel_spec_type="vocos", # or "bigvgan" depending on vocoder
|
|
26
|
+
vocab_file=vocab_file,
|
|
27
|
+
ode_method="euler",
|
|
28
|
+
use_ema=True,
|
|
29
|
+
device=self.device,
|
|
30
|
+
).to(self.device, dtype=dtype)
|
|
31
|
+
|
|
32
|
+
# Load the vocoder
|
|
33
|
+
self.vocoder = load_vocoder(is_local=False)
|
|
34
|
+
|
|
35
|
+
# Set sampling rate for streaming
|
|
36
|
+
self.sampling_rate = 24000 # Consistency with client
|
|
37
|
+
|
|
38
|
+
# Set reference audio and text
|
|
39
|
+
self.ref_audio = ref_audio
|
|
40
|
+
self.ref_text = ref_text
|
|
41
|
+
|
|
42
|
+
# Warm up the model
|
|
43
|
+
self._warm_up()
|
|
44
|
+
|
|
45
|
+
def _warm_up(self):
|
|
46
|
+
"""Warm up the model with a dummy input to ensure it's ready for real-time processing."""
|
|
47
|
+
print("Warming up the model...")
|
|
48
|
+
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
|
|
49
|
+
audio, sr = torchaudio.load(ref_audio)
|
|
50
|
+
gen_text = "Warm-up text for the model."
|
|
51
|
+
|
|
52
|
+
# Pass the vocoder as an argument here
|
|
53
|
+
infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
|
|
54
|
+
print("Warm-up completed.")
|
|
55
|
+
|
|
56
|
+
def generate_stream(self, text, play_steps_in_s=0.5):
|
|
57
|
+
"""Generate audio in chunks and yield them in real-time."""
|
|
58
|
+
# Preprocess the reference audio and text
|
|
59
|
+
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
|
|
60
|
+
|
|
61
|
+
# Load reference audio
|
|
62
|
+
audio, sr = torchaudio.load(ref_audio)
|
|
63
|
+
|
|
64
|
+
# Run inference for the input text
|
|
65
|
+
audio_chunk, final_sample_rate, _ = infer_batch_process(
|
|
66
|
+
(audio, sr),
|
|
67
|
+
ref_text,
|
|
68
|
+
[text],
|
|
69
|
+
self.model,
|
|
70
|
+
self.vocoder,
|
|
71
|
+
device=self.device, # Pass vocoder here
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Break the generated audio into chunks and send them
|
|
75
|
+
chunk_size = int(final_sample_rate * play_steps_in_s)
|
|
76
|
+
|
|
77
|
+
if len(audio_chunk) < chunk_size:
|
|
78
|
+
packed_audio = struct.pack(f"{len(audio_chunk)}f", *audio_chunk)
|
|
79
|
+
yield packed_audio
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
for i in range(0, len(audio_chunk), chunk_size):
|
|
83
|
+
chunk = audio_chunk[i : i + chunk_size]
|
|
84
|
+
|
|
85
|
+
# Check if it's the final chunk
|
|
86
|
+
if i + chunk_size >= len(audio_chunk):
|
|
87
|
+
chunk = audio_chunk[i:]
|
|
88
|
+
|
|
89
|
+
# Send the chunk if it is not empty
|
|
90
|
+
if len(chunk) > 0:
|
|
91
|
+
packed_audio = struct.pack(f"{len(chunk)}f", *chunk)
|
|
92
|
+
yield packed_audio
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def handle_client(client_socket, processor):
|
|
96
|
+
try:
|
|
97
|
+
while True:
|
|
98
|
+
# Receive data from the client
|
|
99
|
+
data = client_socket.recv(1024).decode("utf-8")
|
|
100
|
+
if not data:
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
# The client sends the text input
|
|
105
|
+
text = data.strip()
|
|
106
|
+
|
|
107
|
+
# Generate and stream audio chunks
|
|
108
|
+
for audio_chunk in processor.generate_stream(text):
|
|
109
|
+
client_socket.sendall(audio_chunk)
|
|
110
|
+
|
|
111
|
+
# Send end-of-audio signal
|
|
112
|
+
client_socket.sendall(b"END_OF_AUDIO")
|
|
113
|
+
|
|
114
|
+
except Exception as inner_e:
|
|
115
|
+
print(f"Error during processing: {inner_e}")
|
|
116
|
+
traceback.print_exc() # Print the full traceback to diagnose the issue
|
|
117
|
+
break
|
|
118
|
+
|
|
119
|
+
except Exception as e:
|
|
120
|
+
print(f"Error handling client: {e}")
|
|
121
|
+
traceback.print_exc()
|
|
122
|
+
finally:
|
|
123
|
+
client_socket.close()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def start_server(host, port, processor):
|
|
127
|
+
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
128
|
+
server.bind((host, port))
|
|
129
|
+
server.listen(5)
|
|
130
|
+
print(f"Server listening on {host}:{port}")
|
|
131
|
+
|
|
132
|
+
while True:
|
|
133
|
+
client_socket, addr = server.accept()
|
|
134
|
+
print(f"Accepted connection from {addr}")
|
|
135
|
+
client_handler = Thread(target=handle_client, args=(client_socket, processor))
|
|
136
|
+
client_handler.start()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
if __name__ == "__main__":
|
|
140
|
+
try:
|
|
141
|
+
# Load the model and vocoder using the provided files
|
|
142
|
+
ckpt_file = "" # pointing your checkpoint "ckpts/model/model_1096.pt"
|
|
143
|
+
vocab_file = "" # Add vocab file path if needed
|
|
144
|
+
ref_audio = "" # add ref audio"./tests/ref_audio/reference.wav"
|
|
145
|
+
ref_text = ""
|
|
146
|
+
|
|
147
|
+
# Initialize the processor with the model and vocoder
|
|
148
|
+
processor = TTSStreamingProcessor(
|
|
149
|
+
ckpt_file=ckpt_file,
|
|
150
|
+
vocab_file=vocab_file,
|
|
151
|
+
ref_audio=ref_audio,
|
|
152
|
+
ref_text=ref_text,
|
|
153
|
+
dtype=torch.float32,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Start the server
|
|
157
|
+
start_server("0.0.0.0", 9998, processor)
|
|
158
|
+
except KeyboardInterrupt:
|
|
159
|
+
gc.collect()
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Training
|
|
2
|
+
|
|
3
|
+
## Prepare Dataset
|
|
4
|
+
|
|
5
|
+
Example data processing scripts, and you may tailor your own one along with a Dataset class in `src/f5_tts/model/dataset.py`.
|
|
6
|
+
|
|
7
|
+
### 1. Some specific Datasets preparing scripts
|
|
8
|
+
Download corresponding dataset first, and fill in the path in scripts.
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
# Prepare the Emilia dataset
|
|
12
|
+
python src/f5_tts/train/datasets/prepare_emilia.py
|
|
13
|
+
|
|
14
|
+
# Prepare the Wenetspeech4TTS dataset
|
|
15
|
+
python src/f5_tts/train/datasets/prepare_wenetspeech4tts.py
|
|
16
|
+
|
|
17
|
+
# Prepare the LibriTTS dataset
|
|
18
|
+
python src/f5_tts/train/datasets/prepare_libritts.py
|
|
19
|
+
|
|
20
|
+
# Prepare the LJSpeech dataset
|
|
21
|
+
python src/f5_tts/train/datasets/prepare_ljspeech.py
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
### 2. Create custom dataset with metadata.csv
|
|
25
|
+
Use guidance see [#57 here](https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029).
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
python src/f5_tts/train/datasets/prepare_csv_wavs.py
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## Training & Finetuning
|
|
32
|
+
|
|
33
|
+
Once your datasets are prepared, you can start the training process.
|
|
34
|
+
|
|
35
|
+
### 1. Training script used for pretrained model
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
# setup accelerate config, e.g. use multi-gpu ddp, fp16
|
|
39
|
+
# will be to: ~/.cache/huggingface/accelerate/default_config.yaml
|
|
40
|
+
accelerate config
|
|
41
|
+
|
|
42
|
+
# .yaml files are under src/f5_tts/configs directory
|
|
43
|
+
accelerate launch src/f5_tts/train/train.py --config-name F5TTS_Base_train.yaml
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
### 2. Finetuning practice
|
|
47
|
+
Discussion board for Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
|
|
48
|
+
|
|
49
|
+
Gradio UI training/finetuning with `src/f5_tts/train/finetune_gradio.py` see [#143](https://github.com/SWivid/F5-TTS/discussions/143).
|
|
50
|
+
|
|
51
|
+
### 3. Wandb Logging
|
|
52
|
+
|
|
53
|
+
The `wandb/` dir will be created under path you run training/finetuning scripts.
|
|
54
|
+
|
|
55
|
+
By default, the training script does NOT use logging (assuming you didn't manually log in using `wandb login`).
|
|
56
|
+
|
|
57
|
+
To turn on wandb logging, you can either:
|
|
58
|
+
|
|
59
|
+
1. Manually login with `wandb login`: Learn more [here](https://docs.wandb.ai/ref/cli/wandb-login)
|
|
60
|
+
2. Automatically login programmatically by setting an environment variable: Get an API KEY at https://wandb.ai/site/ and set the environment variable as follows:
|
|
61
|
+
|
|
62
|
+
On Mac & Linux:
|
|
63
|
+
|
|
64
|
+
```
|
|
65
|
+
export WANDB_API_KEY=<YOUR WANDB API KEY>
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
On Windows:
|
|
69
|
+
|
|
70
|
+
```
|
|
71
|
+
set WANDB_API_KEY=<YOUR WANDB API KEY>
|
|
72
|
+
```
|
|
73
|
+
Moreover, if you couldn't access Wandb and want to log metrics offline, you can the environment variable as follows:
|
|
74
|
+
|
|
75
|
+
```
|
|
76
|
+
export WANDB_MODE=offline
|
|
77
|
+
```
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
sys.path.append(os.getcwd())
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import csv
|
|
8
|
+
import json
|
|
9
|
+
import shutil
|
|
10
|
+
from importlib.resources import files
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import torchaudio
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from datasets.arrow_writer import ArrowWriter
|
|
16
|
+
|
|
17
|
+
from f5_tts.model.utils import (
|
|
18
|
+
convert_char_to_pinyin,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_csv_wavs_format(input_dataset_dir):
|
|
26
|
+
fpath = Path(input_dataset_dir)
|
|
27
|
+
metadata = fpath / "metadata.csv"
|
|
28
|
+
wavs = fpath / "wavs"
|
|
29
|
+
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def prepare_csv_wavs_dir(input_dir):
|
|
33
|
+
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
|
34
|
+
input_dir = Path(input_dir)
|
|
35
|
+
metadata_path = input_dir / "metadata.csv"
|
|
36
|
+
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
|
37
|
+
|
|
38
|
+
sub_result, durations = [], []
|
|
39
|
+
vocab_set = set()
|
|
40
|
+
polyphone = True
|
|
41
|
+
for audio_path, text in audio_path_text_pairs:
|
|
42
|
+
if not Path(audio_path).exists():
|
|
43
|
+
print(f"audio {audio_path} not found, skipping")
|
|
44
|
+
continue
|
|
45
|
+
audio_duration = get_audio_duration(audio_path)
|
|
46
|
+
# assume tokenizer = "pinyin" ("pinyin" | "char")
|
|
47
|
+
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
|
48
|
+
sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
|
|
49
|
+
durations.append(audio_duration)
|
|
50
|
+
vocab_set.update(list(text))
|
|
51
|
+
|
|
52
|
+
return sub_result, durations, vocab_set
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_audio_duration(audio_path):
|
|
56
|
+
audio, sample_rate = torchaudio.load(audio_path)
|
|
57
|
+
return audio.shape[1] / sample_rate
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def read_audio_text_pairs(csv_file_path):
|
|
61
|
+
audio_text_pairs = []
|
|
62
|
+
|
|
63
|
+
parent = Path(csv_file_path).parent
|
|
64
|
+
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
|
|
65
|
+
reader = csv.reader(csvfile, delimiter="|")
|
|
66
|
+
next(reader) # Skip the header row
|
|
67
|
+
for row in reader:
|
|
68
|
+
if len(row) >= 2:
|
|
69
|
+
audio_file = row[0].strip() # First column: audio file path
|
|
70
|
+
text = row[1].strip() # Second column: text
|
|
71
|
+
audio_file_path = parent / audio_file
|
|
72
|
+
audio_text_pairs.append((audio_file_path.as_posix(), text))
|
|
73
|
+
|
|
74
|
+
return audio_text_pairs
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
|
78
|
+
out_dir = Path(out_dir)
|
|
79
|
+
# save preprocessed dataset to disk
|
|
80
|
+
out_dir.mkdir(exist_ok=True, parents=True)
|
|
81
|
+
print(f"\nSaving to {out_dir} ...")
|
|
82
|
+
|
|
83
|
+
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
|
84
|
+
# dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
|
|
85
|
+
raw_arrow_path = out_dir / "raw.arrow"
|
|
86
|
+
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
|
|
87
|
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
|
88
|
+
writer.write(line)
|
|
89
|
+
|
|
90
|
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
|
91
|
+
dur_json_path = out_dir / "duration.json"
|
|
92
|
+
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
|
93
|
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
94
|
+
|
|
95
|
+
# vocab map, i.e. tokenizer
|
|
96
|
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
|
97
|
+
# if tokenizer == "pinyin":
|
|
98
|
+
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
|
99
|
+
voca_out_path = out_dir / "vocab.txt"
|
|
100
|
+
with open(voca_out_path.as_posix(), "w") as f:
|
|
101
|
+
for vocab in sorted(text_vocab_set):
|
|
102
|
+
f.write(vocab + "\n")
|
|
103
|
+
|
|
104
|
+
if is_finetune:
|
|
105
|
+
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
|
106
|
+
shutil.copy2(file_vocab_finetune, voca_out_path)
|
|
107
|
+
else:
|
|
108
|
+
with open(voca_out_path, "w") as f:
|
|
109
|
+
for vocab in sorted(text_vocab_set):
|
|
110
|
+
f.write(vocab + "\n")
|
|
111
|
+
|
|
112
|
+
dataset_name = out_dir.stem
|
|
113
|
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
|
114
|
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
|
115
|
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
|
|
119
|
+
if is_finetune:
|
|
120
|
+
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
|
121
|
+
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
|
|
122
|
+
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def cli():
|
|
126
|
+
# finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
|
|
127
|
+
# pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
|
|
128
|
+
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
|
|
129
|
+
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
|
130
|
+
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
|
131
|
+
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
|
132
|
+
|
|
133
|
+
args = parser.parse_args()
|
|
134
|
+
|
|
135
|
+
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
if __name__ == "__main__":
|
|
139
|
+
cli()
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
|
|
2
|
+
# if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
|
|
3
|
+
|
|
4
|
+
# generate audio text map for Emilia ZH & EN
|
|
5
|
+
# evaluate for vocab size
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
|
|
10
|
+
sys.path.append(os.getcwd())
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
14
|
+
from importlib.resources import files
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from datasets.arrow_writer import ArrowWriter
|
|
19
|
+
|
|
20
|
+
from f5_tts.model.utils import (
|
|
21
|
+
repetition_found,
|
|
22
|
+
convert_char_to_pinyin,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
out_zh = {
|
|
27
|
+
"ZH_B00041_S06226",
|
|
28
|
+
"ZH_B00042_S09204",
|
|
29
|
+
"ZH_B00065_S09430",
|
|
30
|
+
"ZH_B00065_S09431",
|
|
31
|
+
"ZH_B00066_S09327",
|
|
32
|
+
"ZH_B00066_S09328",
|
|
33
|
+
}
|
|
34
|
+
zh_filters = ["い", "て"]
|
|
35
|
+
# seems synthesized audios, or heavily code-switched
|
|
36
|
+
out_en = {
|
|
37
|
+
"EN_B00013_S00913",
|
|
38
|
+
"EN_B00042_S00120",
|
|
39
|
+
"EN_B00055_S04111",
|
|
40
|
+
"EN_B00061_S00693",
|
|
41
|
+
"EN_B00061_S01494",
|
|
42
|
+
"EN_B00061_S03375",
|
|
43
|
+
"EN_B00059_S00092",
|
|
44
|
+
"EN_B00111_S04300",
|
|
45
|
+
"EN_B00100_S03759",
|
|
46
|
+
"EN_B00087_S03811",
|
|
47
|
+
"EN_B00059_S00950",
|
|
48
|
+
"EN_B00089_S00946",
|
|
49
|
+
"EN_B00078_S05127",
|
|
50
|
+
"EN_B00070_S04089",
|
|
51
|
+
"EN_B00074_S09659",
|
|
52
|
+
"EN_B00061_S06983",
|
|
53
|
+
"EN_B00061_S07060",
|
|
54
|
+
"EN_B00059_S08397",
|
|
55
|
+
"EN_B00082_S06192",
|
|
56
|
+
"EN_B00091_S01238",
|
|
57
|
+
"EN_B00089_S07349",
|
|
58
|
+
"EN_B00070_S04343",
|
|
59
|
+
"EN_B00061_S02400",
|
|
60
|
+
"EN_B00076_S01262",
|
|
61
|
+
"EN_B00068_S06467",
|
|
62
|
+
"EN_B00076_S02943",
|
|
63
|
+
"EN_B00064_S05954",
|
|
64
|
+
"EN_B00061_S05386",
|
|
65
|
+
"EN_B00066_S06544",
|
|
66
|
+
"EN_B00076_S06944",
|
|
67
|
+
"EN_B00072_S08620",
|
|
68
|
+
"EN_B00076_S07135",
|
|
69
|
+
"EN_B00076_S09127",
|
|
70
|
+
"EN_B00065_S00497",
|
|
71
|
+
"EN_B00059_S06227",
|
|
72
|
+
"EN_B00063_S02859",
|
|
73
|
+
"EN_B00075_S01547",
|
|
74
|
+
"EN_B00061_S08286",
|
|
75
|
+
"EN_B00079_S02901",
|
|
76
|
+
"EN_B00092_S03643",
|
|
77
|
+
"EN_B00096_S08653",
|
|
78
|
+
"EN_B00063_S04297",
|
|
79
|
+
"EN_B00063_S04614",
|
|
80
|
+
"EN_B00079_S04698",
|
|
81
|
+
"EN_B00104_S01666",
|
|
82
|
+
"EN_B00061_S09504",
|
|
83
|
+
"EN_B00061_S09694",
|
|
84
|
+
"EN_B00065_S05444",
|
|
85
|
+
"EN_B00063_S06860",
|
|
86
|
+
"EN_B00065_S05725",
|
|
87
|
+
"EN_B00069_S07628",
|
|
88
|
+
"EN_B00083_S03875",
|
|
89
|
+
"EN_B00071_S07665",
|
|
90
|
+
"EN_B00071_S07665",
|
|
91
|
+
"EN_B00062_S04187",
|
|
92
|
+
"EN_B00065_S09873",
|
|
93
|
+
"EN_B00065_S09922",
|
|
94
|
+
"EN_B00084_S02463",
|
|
95
|
+
"EN_B00067_S05066",
|
|
96
|
+
"EN_B00106_S08060",
|
|
97
|
+
"EN_B00073_S06399",
|
|
98
|
+
"EN_B00073_S09236",
|
|
99
|
+
"EN_B00087_S00432",
|
|
100
|
+
"EN_B00085_S05618",
|
|
101
|
+
"EN_B00064_S01262",
|
|
102
|
+
"EN_B00072_S01739",
|
|
103
|
+
"EN_B00059_S03913",
|
|
104
|
+
"EN_B00069_S04036",
|
|
105
|
+
"EN_B00067_S05623",
|
|
106
|
+
"EN_B00060_S05389",
|
|
107
|
+
"EN_B00060_S07290",
|
|
108
|
+
"EN_B00062_S08995",
|
|
109
|
+
}
|
|
110
|
+
en_filters = ["ا", "い", "て"]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def deal_with_audio_dir(audio_dir):
|
|
114
|
+
audio_jsonl = audio_dir.with_suffix(".jsonl")
|
|
115
|
+
sub_result, durations = [], []
|
|
116
|
+
vocab_set = set()
|
|
117
|
+
bad_case_zh = 0
|
|
118
|
+
bad_case_en = 0
|
|
119
|
+
with open(audio_jsonl, "r") as f:
|
|
120
|
+
lines = f.readlines()
|
|
121
|
+
for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
|
|
122
|
+
obj = json.loads(line)
|
|
123
|
+
text = obj["text"]
|
|
124
|
+
if obj["language"] == "zh":
|
|
125
|
+
if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
|
|
126
|
+
bad_case_zh += 1
|
|
127
|
+
continue
|
|
128
|
+
else:
|
|
129
|
+
text = text.translate(
|
|
130
|
+
str.maketrans({",": ",", "!": "!", "?": "?"})
|
|
131
|
+
) # not "。" cuz much code-switched
|
|
132
|
+
if obj["language"] == "en":
|
|
133
|
+
if (
|
|
134
|
+
obj["wav"].split("/")[1] in out_en
|
|
135
|
+
or any(f in text for f in en_filters)
|
|
136
|
+
or repetition_found(text, length=4)
|
|
137
|
+
):
|
|
138
|
+
bad_case_en += 1
|
|
139
|
+
continue
|
|
140
|
+
if tokenizer == "pinyin":
|
|
141
|
+
text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
|
|
142
|
+
duration = obj["duration"]
|
|
143
|
+
sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
|
|
144
|
+
durations.append(duration)
|
|
145
|
+
vocab_set.update(list(text))
|
|
146
|
+
return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def main():
|
|
150
|
+
assert tokenizer in ["pinyin", "char"]
|
|
151
|
+
result = []
|
|
152
|
+
duration_list = []
|
|
153
|
+
text_vocab_set = set()
|
|
154
|
+
total_bad_case_zh = 0
|
|
155
|
+
total_bad_case_en = 0
|
|
156
|
+
|
|
157
|
+
# process raw data
|
|
158
|
+
executor = ProcessPoolExecutor(max_workers=max_workers)
|
|
159
|
+
futures = []
|
|
160
|
+
for lang in langs:
|
|
161
|
+
dataset_path = Path(os.path.join(dataset_dir, lang))
|
|
162
|
+
[
|
|
163
|
+
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
|
|
164
|
+
for audio_dir in dataset_path.iterdir()
|
|
165
|
+
if audio_dir.is_dir()
|
|
166
|
+
]
|
|
167
|
+
for futures in tqdm(futures, total=len(futures)):
|
|
168
|
+
sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
|
|
169
|
+
result.extend(sub_result)
|
|
170
|
+
duration_list.extend(durations)
|
|
171
|
+
text_vocab_set.update(vocab_set)
|
|
172
|
+
total_bad_case_zh += bad_case_zh
|
|
173
|
+
total_bad_case_en += bad_case_en
|
|
174
|
+
executor.shutdown()
|
|
175
|
+
|
|
176
|
+
# save preprocessed dataset to disk
|
|
177
|
+
if not os.path.exists(f"{save_dir}"):
|
|
178
|
+
os.makedirs(f"{save_dir}")
|
|
179
|
+
print(f"\nSaving to {save_dir} ...")
|
|
180
|
+
|
|
181
|
+
# dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
|
|
182
|
+
# dataset.save_to_disk(f"{save_dir}/raw", max_shard_size="2GB")
|
|
183
|
+
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
|
184
|
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
|
185
|
+
writer.write(line)
|
|
186
|
+
|
|
187
|
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
|
188
|
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
189
|
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
190
|
+
|
|
191
|
+
# vocab map, i.e. tokenizer
|
|
192
|
+
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
|
193
|
+
# if tokenizer == "pinyin":
|
|
194
|
+
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
|
195
|
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
|
196
|
+
for vocab in sorted(text_vocab_set):
|
|
197
|
+
f.write(vocab + "\n")
|
|
198
|
+
|
|
199
|
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
|
200
|
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
|
201
|
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
|
202
|
+
if "ZH" in langs:
|
|
203
|
+
print(f"Bad zh transcription case: {total_bad_case_zh}")
|
|
204
|
+
if "EN" in langs:
|
|
205
|
+
print(f"Bad en transcription case: {total_bad_case_en}\n")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
if __name__ == "__main__":
|
|
209
|
+
max_workers = 32
|
|
210
|
+
|
|
211
|
+
tokenizer = "pinyin" # "pinyin" | "char"
|
|
212
|
+
polyphone = True
|
|
213
|
+
|
|
214
|
+
langs = ["ZH", "EN"]
|
|
215
|
+
dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
|
|
216
|
+
dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
|
|
217
|
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
|
218
|
+
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
|
219
|
+
|
|
220
|
+
main()
|
|
221
|
+
|
|
222
|
+
# Emilia ZH & EN
|
|
223
|
+
# samples count 37837916 (after removal)
|
|
224
|
+
# pinyin vocab size 2543 (polyphone)
|
|
225
|
+
# total duration 95281.87 (hours)
|
|
226
|
+
# bad zh asr cnt 230435 (samples)
|
|
227
|
+
# bad eh asr cnt 37217 (samples)
|
|
228
|
+
|
|
229
|
+
# vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
|
|
230
|
+
# please be careful if using pretrained model, make sure the vocab.txt is same
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
sys.path.append(os.getcwd())
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
8
|
+
from importlib.resources import files
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import soundfile as sf
|
|
12
|
+
from datasets.arrow_writer import ArrowWriter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def deal_with_audio_dir(audio_dir):
|
|
16
|
+
sub_result, durations = [], []
|
|
17
|
+
vocab_set = set()
|
|
18
|
+
audio_lists = list(audio_dir.rglob("*.wav"))
|
|
19
|
+
|
|
20
|
+
for line in audio_lists:
|
|
21
|
+
text_path = line.with_suffix(".normalized.txt")
|
|
22
|
+
text = open(text_path, "r").read().strip()
|
|
23
|
+
duration = sf.info(line).duration
|
|
24
|
+
if duration < 0.4 or duration > 30:
|
|
25
|
+
continue
|
|
26
|
+
sub_result.append({"audio_path": str(line), "text": text, "duration": duration})
|
|
27
|
+
durations.append(duration)
|
|
28
|
+
vocab_set.update(list(text))
|
|
29
|
+
return sub_result, durations, vocab_set
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def main():
|
|
33
|
+
result = []
|
|
34
|
+
duration_list = []
|
|
35
|
+
text_vocab_set = set()
|
|
36
|
+
|
|
37
|
+
# process raw data
|
|
38
|
+
executor = ProcessPoolExecutor(max_workers=max_workers)
|
|
39
|
+
futures = []
|
|
40
|
+
|
|
41
|
+
for subset in tqdm(SUB_SET):
|
|
42
|
+
dataset_path = Path(os.path.join(dataset_dir, subset))
|
|
43
|
+
[
|
|
44
|
+
futures.append(executor.submit(deal_with_audio_dir, audio_dir))
|
|
45
|
+
for audio_dir in dataset_path.iterdir()
|
|
46
|
+
if audio_dir.is_dir()
|
|
47
|
+
]
|
|
48
|
+
for future in tqdm(futures, total=len(futures)):
|
|
49
|
+
sub_result, durations, vocab_set = future.result()
|
|
50
|
+
result.extend(sub_result)
|
|
51
|
+
duration_list.extend(durations)
|
|
52
|
+
text_vocab_set.update(vocab_set)
|
|
53
|
+
executor.shutdown()
|
|
54
|
+
|
|
55
|
+
# save preprocessed dataset to disk
|
|
56
|
+
if not os.path.exists(f"{save_dir}"):
|
|
57
|
+
os.makedirs(f"{save_dir}")
|
|
58
|
+
print(f"\nSaving to {save_dir} ...")
|
|
59
|
+
|
|
60
|
+
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
|
61
|
+
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
|
62
|
+
writer.write(line)
|
|
63
|
+
|
|
64
|
+
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
|
65
|
+
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
66
|
+
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
67
|
+
|
|
68
|
+
# vocab map, i.e. tokenizer
|
|
69
|
+
with open(f"{save_dir}/vocab.txt", "w") as f:
|
|
70
|
+
for vocab in sorted(text_vocab_set):
|
|
71
|
+
f.write(vocab + "\n")
|
|
72
|
+
|
|
73
|
+
print(f"\nFor {dataset_name}, sample count: {len(result)}")
|
|
74
|
+
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
|
|
75
|
+
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
if __name__ == "__main__":
|
|
79
|
+
max_workers = 36
|
|
80
|
+
|
|
81
|
+
tokenizer = "char" # "pinyin" | "char"
|
|
82
|
+
|
|
83
|
+
SUB_SET = ["train-clean-100", "train-clean-360", "train-other-500"]
|
|
84
|
+
dataset_dir = "<SOME_PATH>/LibriTTS"
|
|
85
|
+
dataset_name = f"LibriTTS_{'_'.join(SUB_SET)}_{tokenizer}".replace("train-clean-", "").replace("train-other-", "")
|
|
86
|
+
save_dir = str(files("f5_tts").joinpath("../../")) + f"/data/{dataset_name}"
|
|
87
|
+
print(f"\nPrepare for {dataset_name}, will save to {save_dir}\n")
|
|
88
|
+
main()
|
|
89
|
+
|
|
90
|
+
# For LibriTTS_100_360_500_char, sample count: 354218
|
|
91
|
+
# For LibriTTS_100_360_500_char, vocab size is: 78
|
|
92
|
+
# For LibriTTS_100_360_500_char, total 554.09 hours
|